#include #include #ifndef _DBGLEVEL #define _DBGLEVEL 0 #endif // modexp2 0857bde36f680cda9535784bc4aec5f4344131071b419f732ac9c74d0e61db49dd958c7344236e0279df009c6e66aec6ba574c2820d4aeb0c4d814c8e184c6ea7e6d8aa3e15d1c251c78c5364ea2b3edb3c19e90739afa765506242e78fcdc71a87efdfe2df6ce6039fc62cb3b360cb77cd5574292282df352886cbc3fcfbff2 10001 F765A3A0C9C291D81A56FE73794A746B8DA23DBE155D0D495B49D581B5C6545F449A10FDF1C26A92FBD1F43A0687044927A6A21B69A73999E6083D03ACDAFFA6409F1BC71D810628F6E18F76231ED6E22D54ED2502E66F8A33D0D5F07B3EB605F7418110E2EF9A5EE77B070F4EADFCF3D70C53E870F29C9D4F229F2CB6C25383 // = 0001FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF003021300906052B0E03021A050004147020B30D42ED7CDF0E849F30E8B2E137941E9691 // see also // c:\local\cvsprj\secphone\trunk\machine\echocancel\fir_fxp\nr_class.cpp // c:\local\cvsprj\secphone\trunk\machine\echocancel\fir_fxp\nr_class.h typedef long long carry_t; typedef unsigned long base_t; typedef std::vector Number; #define BITSINBASE (8*sizeof(base_t)) Number bitmask; void initmask() { for (size_t i=0 ; i(1<(carry>>1); carry &= 1; carry <<= BITSINBASE; } } void leftshift(Number& val) { carry_t carry= 0; for (Number::iterator i=val.begin() ; i!=val.end() ; i++) { carry |= (carry_t)(*i) <<1; (*i) = static_cast(carry); carry >>= BITSINBASE; carry &= 1; } if (carry) val.push_back(static_cast(carry)); } void big_add(Number& sum, const Number& val) { carry_t carry= 0; #if _DBGLEVEL > 3 writenum(sum); printf(" + "); writenum(val); #endif if (sum.size() < val.size()) sum.resize(val.size()); Number::const_iterator i_val= val.begin(); Number::iterator i_sum= sum.begin(); while (i_sum!=sum.end()) { carry += (carry_t)(*i_sum) + (i_val!=val.end() ? (*i_val):0); (*i_sum) = static_cast(carry); carry >>= BITSINBASE; ++i_sum; if (i_val!=val.end()) ++i_val; } if (carry) sum.push_back(static_cast(carry)); #if _DBGLEVEL > 3 printf(" = "); writenum(sum); printf("\n"); #endif } void big_sub(Number& sum, const Number& val) { carry_t carry= 0; if (sum.size() < val.size()) sum.resize(val.size()); #if _DBGLEVEL > 3 writenum(sum); printf(" - "); writenum(val); #endif Number::const_iterator i_val= val.begin(); Number::iterator i_sum= sum.begin(); while (i_sum!=sum.end()) { //printf("%x+%x-%x", carry, *i_sum, (i_val!=val.end() ? (*i_val):0)); carry += (carry_t)(*i_sum) - (i_val!=val.end() ? (*i_val):0); (*i_sum) = static_cast(carry); carry >>= BITSINBASE; //printf("=%x:%x\n", carry, *i_sum); ++i_sum; if (i_val!=val.end()) ++i_val; } if (carry) sum.push_back(static_cast(carry)); #if _DBGLEVEL > 3 printf(" = "); writenum(sum); printf("\n"); #endif } void shrink(Number& a) { while (a.size() && a.back()==0) a.pop_back(); } int compare(const Number& a, const Number& b) { if (a.size() < b.size()) return -1; if (a.size() > b.size()) return 1; Number::const_iterator i_b= b.begin(); Number::const_iterator i_a= a.begin(); while (i_a!=a.end() && i_b!=b.end()) { if ((*i_a)<(*i_b)) return -1; if ((*i_a)>(*i_b)) return 1; ++i_a; ++i_b; } if (i_a!=a.end()) return 1; if (i_b!=b.end()) return -1; return 0; } void modtrunc(Number& val, const Number& mod) { #if _DBGLEVEL > 2 printf("modtrunc "); writenum(val); printf(" : "); writenum(mod); printf("\n"); #endif Number::reverse_iterator i_val= val.rbegin(); while (i_val!=val.rend() && (*i_val)==0) ++i_val; Number::const_reverse_iterator i_mod= mod.rbegin(); while (i_mod!=mod.rend() && (*i_mod)==0) ++i_mod; int diff= distance(i_val, val.rend()) - distance(i_mod, mod.rend()); #if _DBGLEVEL > 2 printf(" diff=%d ", diff); #endif if (diff<0) return; if (diff>1) { printf("!!!diff = %d\n", diff); exit(1); } if (diff==1) { #if _DBGLEVEL > 2 printf("mod trunc %d: %0*X >= 00\n", i_val-val.rbegin(), 2*sizeof(base_t), (*i_val)); #endif big_sub(val, mod); return; } while (i_val!=val.rend() && (*i_val)==(*i_mod)) { ++i_val; ++i_mod; } if (i_val!=val.rend() && (*i_val)>=(*i_mod)) { #if _DBGLEVEL > 2 printf("mod trunc %d: %0*X >= %0*X\n", i_val-val.rbegin(), 2*sizeof(base_t), (*i_val), 2*sizeof(base_t), (*i_mod)); #endif big_sub(val, mod); } } void big_mulmod(Number& val, const Number& mult, const Number& mod) { #if _DBGLEVEL > 1 printf("calc mulmod "); writenum(val); printf(" x "); writenum(mult); printf(" mod "); writenum(mod); printf("\n"); #endif Number bits= val; Number lshifter= mult; val.clear(); for (Number::iterator i_bits= bits.begin() ; i_bits!=bits.end() ; ++i_bits) { base_t bitval= (*i_bits); for (size_t i=0 ; i 1 //printf(" bits="); writenum(bits); printf(" "); printf(" lsh="); writenum(lshifter); printf(" "); printf(" val="); writenum(val); printf("\n"); #endif //XX rightshift(bits, len); leftshift(lshifter); modtrunc(lshifter, mod); } } #if _DBGLEVEL > 1 printf("mulmodresult="); writenum(val); printf("\n"); #endif } bool big_bittest(const Number& num, size_t bit) { return (num[bit/BITSINBASE] & bitmask[bit&(BITSINBASE-1)])!=0; } void modexp(const Number& num, const Number& exp, const Number& modulus, Number& result) { #if _DBGLEVEL > 0 printf("calc modexp "); writenum(num); printf(" ^ "); writenum(exp); printf(" mod "); writenum(modulus); printf("\n"); #endif Number num2; SetNumber(num2, 1); big_mulmod(num2, num, modulus); // num2 = (num2*num) % modulus SetNumber(result, 1); for (size_t expbit=0 ; expbit 0 printf("%d: num2=", expbit); writenum(num2); printf("\t"); printf("result="); writenum(result); printf("\n"); #endif if (big_bittest(exp, expbit)) { big_mulmod(result, num2, modulus); #if _DBGLEVEL > 0 printf(" ***"); #endif } #if _DBGLEVEL > 0 printf("\n"); #endif big_mulmod(num2, num2, modulus); // num2 = (num2*num2) % modulus } #if _DBGLEVEL > 0 printf("modexpresult="); writenum(result); printf("\n"); #endif } int digit2val(char c) { return c<='9' ? c-'0' : c<='F' ? c-'A'+10 : c<='f' ? c-'a'+10 :0; } void hexstr2num(Number& num, const char* hexstr) { num.clear(); int j=0; for (int i=strlen(hexstr)-1 ; i>=0 ; i--, j++) { int shift= j & (sizeof(base_t)*2-1); if (shift) { num.back() |= digit2val(hexstr[i])<<(4*shift); } else { num.push_back( digit2val(hexstr[i]) ); } } } void writenum(const Number& num) { Number::const_reverse_iterator i= num.rbegin(); while (i!=num.rend() && (*i)==0) ++i; if (i!=num.rend()) printf("%lX", (*i++)); while (i!=num.rend()) { printf("%0*lX", 2*sizeof(base_t), (*i)); ++i; } } bool is_even(base_t a) { return (a&1)==0; } bool is_even(const Number& a) { return (a.size()==0) || is_even(a[0]); } bool is_zero(const Number& a) { if (a.size()==0) return true; for (Number::const_iterator i= a.begin() ; i!=a.end() ; ++i) if (*i) return false; return true; } // for x,y - calculates a,b,v such that a*x+b*y=v = gcd(x,y) void gcd_calc(const Number& x, Number& a, const Number& y, Number& b, Number& v) { // algorithm 14.61 int nshifts= 0; Number xx= x; Number yy= y; while (is_even(xx) && is_even(yy)) { rightshift(xx); rightshift(yy); nshifts++; } Number u = xx; v = yy; Number A; SetNumber(A, 1); Number B; Number C; Number D; SetNumber(D, 1); while (!is_zero(u)) { while (is_even(u)) { rightshift(u); if (!(is_even(A) && is_even(B))) { big_add(A,yy); big_sub(B,xx); } rightshift(A); rightshift(B); } while (is_even(v)) { rightshift(v); if (!(is_even(C) && is_even(D))) { big_add(C,yy); big_sub(D,xx); } rightshift(C); rightshift(D); } if (compare(u,v)>=0) { big_sub(u, v); big_sub(A, C); big_sub(B, D); } else { big_sub(v, u); big_sub(C, A); big_sub(D, B); } } a= C; b= D; for (int i=0 ; i(carry); carry >>= BITSINBASE; } if (carry) a.push_back(static_cast(carry)); } if (compare(a, _modulus)>0) big_sub(a, _modulus); } // returns value * base^-1 ( mod modulus ) void reduce(const Number& val, Number& a) { a.clear(); for (size_t i=0 ; i(carry); carry >>= BITSINBASE; } if (carry) a.push_back(static_cast(carry)); } } void modexp(const Number& num, const Number& exp, Number& result) { Number a= _base; Number b= num; big_mulmod(b, _base, _modulus); SetNumber(result, 1); int byte; for (byte=exp.size()-1 ; exp[byte]==0 && byte>=0 ; byte--) ; if (byte<0) return; int bit; for (bit=BITSINBASE-1 ; (exp[byte]&bitmask[bit])==0 && bit>=0 ; bit--) ; while(byte>=0) { Number a2; multiply(a, a, a2); a= a2; if (exp[byte]&bitmask[bit]) { Number ab; multiply(a, b, ab); a= ab; } bit--; if (bit<0) { bit=BITSINBASE-1; byte--; } } Number one; SetNumber(one, 1); multiply(a, one, result); } private: Number _modulus; Number _base; Number _inversebase; Number _inversemodulus; }; #include #include #include void run_tests() { // test left+right shift Number x; SetNumber(x, 1); for (size_t i=0 ; i> %d =", BITSINBASE*16); writenum(x); printf("\n"); exit(1); } rightshift(x); shrink(x); if (x.size()!=0) { printf("shift test failed\n"); printf(" >> 1 = "); writenum(x); printf("\n"); exit(1); } // test add/sub Number y; x.clear(); srand(0); for (int i=0 ; i<8 ; i++) x.push_back(static_cast(rand())); for (int i=0 ; i<8 ; i++) y.push_back(static_cast(rand())); //writenum(x); printf(" + "); writenum(y); big_add(x, y); //printf(" = "); writenum(x); printf("\n"); big_sub(x, y); //printf("- = "); writenum(x); printf("\n"); Number z; for (int i=0 ; i<8 ; i++) z.push_back(static_cast(rand())); // writenum(x); // printf("* "); writenum(y); // printf("mod "); writenum(z); // big_mulmod(x, y, z); // printf(" = "); writenum(x); printf("\n"); // writenum(x); // printf("^ "); writenum(y); // printf("mod "); writenum(z); // Number q; // modexp(x, y, z, q); // printf(" = "); writenum(q); printf("\n"); hexstr2num(x, "0857BDE36F680CDA9535784BC4AEC5F4344131071B419F732AC9C74D0E61DB49DD958C7344236E0279DF009C6E66AEC6BA574C2820D4AEB0C4D814C8E184C6EA7E6D8AA3E15D1C251C78C5364EA2B3EDB3C19E90739AFA765506242E78FCDC71A87EFDFE2DF6CE6039FC62CB3B360CB77CD5574292282DF352886CBC3FCFBFF2"); hexstr2num(y, "10001"); hexstr2num(z, "F765A3A0C9C291D81A56FE73794A746B8DA23DBE155D0D495B49D581B5C6545F449A10FDF1C26A92FBD1F43A0687044927A6A21B69A73999E6083D03ACDAFFA6409F1BC71D810628F6E18F76231ED6E22D54ED2502E66F8A33D0D5F07B3EB605F7418110E2EF9A5EE77B070F4EADFCF3D70C53E870F29C9D4F229F2CB6C25383"); Number q; DWORD t0= GetTickCount(); modexp(x, y, z, q); printf("%lu ticks\n", GetTickCount()-t0); Number r; hexstr2num(r, "1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF003021300906052B0E03021A050004147020B30D42ED7CDF0E849F30E8B2E137941E9691"); shrink(r); shrink(q); if (compare(r,q)!=0) { printf("got: "); writenum(q); printf("\n"); printf("expected: "); writenum(r); printf("\n"); } Number a; Number b; Number g; SetNumber(x, 693); SetNumber(y, 609); gcd_calc(x, a, y, b, g); printf("a="); writenum(a); printf("\n"); printf("b="); writenum(b); printf("\n"); printf("g="); writenum(g); printf("\n"); } int main(int argc, char **argv) { initmask(); run_tests(); if (argc!=4) { printf("Usage: modexp num exp mod\n"); return 1; } char *numstr= argv[1]; char *expstr= argv[2]; char *modstr= argv[3]; // strip leading zeros while (*numstr=='0') numstr++; while (*expstr=='0') expstr++; while (*modstr=='0') modstr++; Number num; hexstr2num(num, numstr); Number exp; hexstr2num(exp, expstr); Number mod; hexstr2num(mod, modstr); if (num.size() > mod.size()) { printf("number cannot be larger than mod\n"); return 1; } //printf("num="); writenum(num); printf("\n"); //printf("mod="); writenum(mod); printf("\n"); //printf("exp="); writenum(exp); printf("\n"); Number result; try { modexp(num, exp, mod, result); } catch(char *msg) { printf("EXCEPTION: %s\n", msg); } writenum(result); printf("\n"); printf(" now trying montgomery\n"); Number base; SetNumber(base, 0x10000); Montgomery m(mod, base); m.modexp(num, exp, result); writenum(result); printf(".\n"); }