PageRenderTime 26ms CodeModel.GetById 1ms app.highlight 21ms RepoModel.GetById 1ms app.codeStats 0ms

/gdata/tlslite/utils/cryptomath.py

http://radioappz.googlecode.com/
Python | 404 lines | 353 code | 27 blank | 24 comment | 8 complexity | b6b832d9aaddee7c4093248b96db86b1 MD5 | raw file
  1"""cryptomath module
  2
  3This module has basic math/crypto code."""
  4
  5import os
  6import sys
  7import math
  8import base64
  9import binascii
 10if sys.version_info[:2] <= (2, 4):
 11  from sha import sha as sha1
 12else:
 13  from hashlib import sha1
 14
 15from compat import *
 16
 17
 18# **************************************************************************
 19# Load Optional Modules
 20# **************************************************************************
 21
 22# Try to load M2Crypto/OpenSSL
 23try:
 24    from M2Crypto import m2
 25    m2cryptoLoaded = True
 26
 27except ImportError:
 28    m2cryptoLoaded = False
 29
 30
 31# Try to load cryptlib
 32try:
 33    import cryptlib_py
 34    try:
 35        cryptlib_py.cryptInit()
 36    except cryptlib_py.CryptException, e:
 37        #If tlslite and cryptoIDlib are both present,
 38        #they might each try to re-initialize this,
 39        #so we're tolerant of that.
 40        if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
 41            raise
 42    cryptlibpyLoaded = True
 43
 44except ImportError:
 45    cryptlibpyLoaded = False
 46
 47#Try to load GMPY
 48try:
 49    import gmpy
 50    gmpyLoaded = True
 51except ImportError:
 52    gmpyLoaded = False
 53
 54#Try to load pycrypto
 55try:
 56    import Crypto.Cipher.AES
 57    pycryptoLoaded = True
 58except ImportError:
 59    pycryptoLoaded = False
 60
 61
 62# **************************************************************************
 63# PRNG Functions
 64# **************************************************************************
 65
 66# Get os.urandom PRNG
 67try:
 68    os.urandom(1)
 69    def getRandomBytes(howMany):
 70        return stringToBytes(os.urandom(howMany))
 71    prngName = "os.urandom"
 72
 73except:
 74    # Else get cryptlib PRNG
 75    if cryptlibpyLoaded:
 76        def getRandomBytes(howMany):
 77            randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
 78                                                       cryptlib_py.CRYPT_ALGO_AES)
 79            cryptlib_py.cryptSetAttribute(randomKey,
 80                                          cryptlib_py.CRYPT_CTXINFO_MODE,
 81                                          cryptlib_py.CRYPT_MODE_OFB)
 82            cryptlib_py.cryptGenerateKey(randomKey)
 83            bytes = createByteArrayZeros(howMany)
 84            cryptlib_py.cryptEncrypt(randomKey, bytes)
 85            return bytes
 86        prngName = "cryptlib"
 87
 88    else:
 89        #Else get UNIX /dev/urandom PRNG
 90        try:
 91            devRandomFile = open("/dev/urandom", "rb")
 92            def getRandomBytes(howMany):
 93                return stringToBytes(devRandomFile.read(howMany))
 94            prngName = "/dev/urandom"
 95        except IOError:
 96            #Else get Win32 CryptoAPI PRNG
 97            try:
 98                import win32prng
 99                def getRandomBytes(howMany):
100                    s = win32prng.getRandomBytes(howMany)
101                    if len(s) != howMany:
102                        raise AssertionError()
103                    return stringToBytes(s)
104                prngName ="CryptoAPI"
105            except ImportError:
106                #Else no PRNG :-(
107                def getRandomBytes(howMany):
108                    raise NotImplementedError("No Random Number Generator "\
109                                              "available.")
110            prngName = "None"
111
112# **************************************************************************
113# Converter Functions
114# **************************************************************************
115
116def bytesToNumber(bytes):
117    total = 0L
118    multiplier = 1L
119    for count in range(len(bytes)-1, -1, -1):
120        byte = bytes[count]
121        total += multiplier * byte
122        multiplier *= 256
123    return total
124
125def numberToBytes(n):
126    howManyBytes = numBytes(n)
127    bytes = createByteArrayZeros(howManyBytes)
128    for count in range(howManyBytes-1, -1, -1):
129        bytes[count] = int(n % 256)
130        n >>= 8
131    return bytes
132
133def bytesToBase64(bytes):
134    s = bytesToString(bytes)
135    return stringToBase64(s)
136
137def base64ToBytes(s):
138    s = base64ToString(s)
139    return stringToBytes(s)
140
141def numberToBase64(n):
142    bytes = numberToBytes(n)
143    return bytesToBase64(bytes)
144
145def base64ToNumber(s):
146    bytes = base64ToBytes(s)
147    return bytesToNumber(bytes)
148
149def stringToNumber(s):
150    bytes = stringToBytes(s)
151    return bytesToNumber(bytes)
152
153def numberToString(s):
154    bytes = numberToBytes(s)
155    return bytesToString(bytes)
156
157def base64ToString(s):
158    try:
159        return base64.decodestring(s)
160    except binascii.Error, e:
161        raise SyntaxError(e)
162    except binascii.Incomplete, e:
163        raise SyntaxError(e)
164
165def stringToBase64(s):
166    return base64.encodestring(s).replace("\n", "")
167
168def mpiToNumber(mpi): #mpi is an openssl-format bignum string
169    if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
170        raise AssertionError()
171    bytes = stringToBytes(mpi[4:])
172    return bytesToNumber(bytes)
173
174def numberToMPI(n):
175    bytes = numberToBytes(n)
176    ext = 0
177    #If the high-order bit is going to be set,
178    #add an extra byte of zeros
179    if (numBits(n) & 0x7)==0:
180        ext = 1
181    length = numBytes(n) + ext
182    bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
183    bytes[0] = (length >> 24) & 0xFF
184    bytes[1] = (length >> 16) & 0xFF
185    bytes[2] = (length >> 8) & 0xFF
186    bytes[3] = length & 0xFF
187    return bytesToString(bytes)
188
189
190
191# **************************************************************************
192# Misc. Utility Functions
193# **************************************************************************
194
195def numBytes(n):
196    if n==0:
197        return 0
198    bits = numBits(n)
199    return int(math.ceil(bits / 8.0))
200
201def hashAndBase64(s):
202    return stringToBase64(sha1(s).digest())
203
204def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
205    bytes = getRandomBytes(numChars)
206    bytesStr = "".join([chr(b) for b in bytes])
207    return stringToBase64(bytesStr)[:numChars]
208
209
210# **************************************************************************
211# Big Number Math
212# **************************************************************************
213
214def getRandomNumber(low, high):
215    if low >= high:
216        raise AssertionError()
217    howManyBits = numBits(high)
218    howManyBytes = numBytes(high)
219    lastBits = howManyBits % 8
220    while 1:
221        bytes = getRandomBytes(howManyBytes)
222        if lastBits:
223            bytes[0] = bytes[0] % (1 << lastBits)
224        n = bytesToNumber(bytes)
225        if n >= low and n < high:
226            return n
227
228def gcd(a,b):
229    a, b = max(a,b), min(a,b)
230    while b:
231        a, b = b, a % b
232    return a
233
234def lcm(a, b):
235    #This will break when python division changes, but we can't use // cause
236    #of Jython
237    return (a * b) / gcd(a, b)
238
239#Returns inverse of a mod b, zero if none
240#Uses Extended Euclidean Algorithm
241def invMod(a, b):
242    c, d = a, b
243    uc, ud = 1, 0
244    while c != 0:
245        #This will break when python division changes, but we can't use //
246        #cause of Jython
247        q = d / c
248        c, d = d-(q*c), c
249        uc, ud = ud - (q * uc), uc
250    if d == 1:
251        return ud % b
252    return 0
253
254
255if gmpyLoaded:
256    def powMod(base, power, modulus):
257        base = gmpy.mpz(base)
258        power = gmpy.mpz(power)
259        modulus = gmpy.mpz(modulus)
260        result = pow(base, power, modulus)
261        return long(result)
262
263else:
264    #Copied from Bryan G. Olson's post to comp.lang.python
265    #Does left-to-right instead of pow()'s right-to-left,
266    #thus about 30% faster than the python built-in with small bases
267    def powMod(base, power, modulus):
268        nBitScan = 5
269
270        """ Return base**power mod modulus, using multi bit scanning
271        with nBitScan bits at a time."""
272
273        #TREV - Added support for negative exponents
274        negativeResult = False
275        if (power < 0):
276            power *= -1
277            negativeResult = True
278
279        exp2 = 2**nBitScan
280        mask = exp2 - 1
281
282        # Break power into a list of digits of nBitScan bits.
283        # The list is recursive so easy to read in reverse direction.
284        nibbles = None
285        while power:
286            nibbles = int(power & mask), nibbles
287            power = power >> nBitScan
288
289        # Make a table of powers of base up to 2**nBitScan - 1
290        lowPowers = [1]
291        for i in xrange(1, exp2):
292            lowPowers.append((lowPowers[i-1] * base) % modulus)
293
294        # To exponentiate by the first nibble, look it up in the table
295        nib, nibbles = nibbles
296        prod = lowPowers[nib]
297
298        # For the rest, square nBitScan times, then multiply by
299        # base^nibble
300        while nibbles:
301            nib, nibbles = nibbles
302            for i in xrange(nBitScan):
303                prod = (prod * prod) % modulus
304            if nib: prod = (prod * lowPowers[nib]) % modulus
305
306        #TREV - Added support for negative exponents
307        if negativeResult:
308            prodInv = invMod(prod, modulus)
309            #Check to make sure the inverse is correct
310            if (prod * prodInv) % modulus != 1:
311                raise AssertionError()
312            return prodInv
313        return prod
314
315
316#Pre-calculate a sieve of the ~100 primes < 1000:
317def makeSieve(n):
318    sieve = range(n)
319    for count in range(2, int(math.sqrt(n))):
320        if sieve[count] == 0:
321            continue
322        x = sieve[count] * 2
323        while x < len(sieve):
324            sieve[x] = 0
325            x += sieve[count]
326    sieve = [x for x in sieve[2:] if x]
327    return sieve
328
329sieve = makeSieve(1000)
330
331def isPrime(n, iterations=5, display=False):
332    #Trial division with sieve
333    for x in sieve:
334        if x >= n: return True
335        if n % x == 0: return False
336    #Passed trial division, proceed to Rabin-Miller
337    #Rabin-Miller implemented per Ferguson & Schneier
338    #Compute s, t for Rabin-Miller
339    if display: print "*",
340    s, t = n-1, 0
341    while s % 2 == 0:
342        s, t = s/2, t+1
343    #Repeat Rabin-Miller x times
344    a = 2 #Use 2 as a base for first iteration speedup, per HAC
345    for count in range(iterations):
346        v = powMod(a, s, n)
347        if v==1:
348            continue
349        i = 0
350        while v != n-1:
351            if i == t-1:
352                return False
353            else:
354                v, i = powMod(v, 2, n), i+1
355        a = getRandomNumber(2, n)
356    return True
357
358def getRandomPrime(bits, display=False):
359    if bits < 10:
360        raise AssertionError()
361    #The 1.5 ensures the 2 MSBs are set
362    #Thus, when used for p,q in RSA, n will have its MSB set
363    #
364    #Since 30 is lcm(2,3,5), we'll set our test numbers to
365    #29 % 30 and keep them there
366    low = (2L ** (bits-1)) * 3/2
367    high = 2L ** bits - 30
368    p = getRandomNumber(low, high)
369    p += 29 - (p % 30)
370    while 1:
371        if display: print ".",
372        p += 30
373        if p >= high:
374            p = getRandomNumber(low, high)
375            p += 29 - (p % 30)
376        if isPrime(p, display=display):
377            return p
378
379#Unused at the moment...
380def getRandomSafePrime(bits, display=False):
381    if bits < 10:
382        raise AssertionError()
383    #The 1.5 ensures the 2 MSBs are set
384    #Thus, when used for p,q in RSA, n will have its MSB set
385    #
386    #Since 30 is lcm(2,3,5), we'll set our test numbers to
387    #29 % 30 and keep them there
388    low = (2 ** (bits-2)) * 3/2
389    high = (2 ** (bits-1)) - 30
390    q = getRandomNumber(low, high)
391    q += 29 - (q % 30)
392    while 1:
393        if display: print ".",
394        q += 30
395        if (q >= high):
396            q = getRandomNumber(low, high)
397            q += 29 - (q % 30)
398        #Ideas from Tom Wu's SRP code
399        #Do trial division on p and q before Rabin-Miller
400        if isPrime(q, 0, display=display):
401            p = (2 * q) + 1
402            if isPrime(p, display=display):
403                if isPrime(q, display=display):
404                    return p