/gdata/tlslite/utils/cryptomath.py
http://radioappz.googlecode.com/ · Python · 404 lines · 275 code · 58 blank · 71 comment · 75 complexity · b6b832d9aaddee7c4093248b96db86b1 MD5 · raw file
- """cryptomath module
- This module has basic math/crypto code."""
- import os
- import sys
- import math
- import base64
- import binascii
- if sys.version_info[:2] <= (2, 4):
- from sha import sha as sha1
- else:
- from hashlib import sha1
- from compat import *
- # **************************************************************************
- # Load Optional Modules
- # **************************************************************************
- # Try to load M2Crypto/OpenSSL
- try:
- from M2Crypto import m2
- m2cryptoLoaded = True
- except ImportError:
- m2cryptoLoaded = False
- # Try to load cryptlib
- try:
- import cryptlib_py
- try:
- cryptlib_py.cryptInit()
- except cryptlib_py.CryptException, e:
- #If tlslite and cryptoIDlib are both present,
- #they might each try to re-initialize this,
- #so we're tolerant of that.
- if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
- raise
- cryptlibpyLoaded = True
- except ImportError:
- cryptlibpyLoaded = False
- #Try to load GMPY
- try:
- import gmpy
- gmpyLoaded = True
- except ImportError:
- gmpyLoaded = False
- #Try to load pycrypto
- try:
- import Crypto.Cipher.AES
- pycryptoLoaded = True
- except ImportError:
- pycryptoLoaded = False
- # **************************************************************************
- # PRNG Functions
- # **************************************************************************
- # Get os.urandom PRNG
- try:
- os.urandom(1)
- def getRandomBytes(howMany):
- return stringToBytes(os.urandom(howMany))
- prngName = "os.urandom"
- except:
- # Else get cryptlib PRNG
- if cryptlibpyLoaded:
- def getRandomBytes(howMany):
- randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
- cryptlib_py.CRYPT_ALGO_AES)
- cryptlib_py.cryptSetAttribute(randomKey,
- cryptlib_py.CRYPT_CTXINFO_MODE,
- cryptlib_py.CRYPT_MODE_OFB)
- cryptlib_py.cryptGenerateKey(randomKey)
- bytes = createByteArrayZeros(howMany)
- cryptlib_py.cryptEncrypt(randomKey, bytes)
- return bytes
- prngName = "cryptlib"
- else:
- #Else get UNIX /dev/urandom PRNG
- try:
- devRandomFile = open("/dev/urandom", "rb")
- def getRandomBytes(howMany):
- return stringToBytes(devRandomFile.read(howMany))
- prngName = "/dev/urandom"
- except IOError:
- #Else get Win32 CryptoAPI PRNG
- try:
- import win32prng
- def getRandomBytes(howMany):
- s = win32prng.getRandomBytes(howMany)
- if len(s) != howMany:
- raise AssertionError()
- return stringToBytes(s)
- prngName ="CryptoAPI"
- except ImportError:
- #Else no PRNG :-(
- def getRandomBytes(howMany):
- raise NotImplementedError("No Random Number Generator "\
- "available.")
- prngName = "None"
- # **************************************************************************
- # Converter Functions
- # **************************************************************************
- def bytesToNumber(bytes):
- total = 0L
- multiplier = 1L
- for count in range(len(bytes)-1, -1, -1):
- byte = bytes[count]
- total += multiplier * byte
- multiplier *= 256
- return total
- def numberToBytes(n):
- howManyBytes = numBytes(n)
- bytes = createByteArrayZeros(howManyBytes)
- for count in range(howManyBytes-1, -1, -1):
- bytes[count] = int(n % 256)
- n >>= 8
- return bytes
- def bytesToBase64(bytes):
- s = bytesToString(bytes)
- return stringToBase64(s)
- def base64ToBytes(s):
- s = base64ToString(s)
- return stringToBytes(s)
- def numberToBase64(n):
- bytes = numberToBytes(n)
- return bytesToBase64(bytes)
- def base64ToNumber(s):
- bytes = base64ToBytes(s)
- return bytesToNumber(bytes)
- def stringToNumber(s):
- bytes = stringToBytes(s)
- return bytesToNumber(bytes)
- def numberToString(s):
- bytes = numberToBytes(s)
- return bytesToString(bytes)
- def base64ToString(s):
- try:
- return base64.decodestring(s)
- except binascii.Error, e:
- raise SyntaxError(e)
- except binascii.Incomplete, e:
- raise SyntaxError(e)
- def stringToBase64(s):
- return base64.encodestring(s).replace("\n", "")
- def mpiToNumber(mpi): #mpi is an openssl-format bignum string
- if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
- raise AssertionError()
- bytes = stringToBytes(mpi[4:])
- return bytesToNumber(bytes)
- def numberToMPI(n):
- bytes = numberToBytes(n)
- ext = 0
- #If the high-order bit is going to be set,
- #add an extra byte of zeros
- if (numBits(n) & 0x7)==0:
- ext = 1
- length = numBytes(n) + ext
- bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
- bytes[0] = (length >> 24) & 0xFF
- bytes[1] = (length >> 16) & 0xFF
- bytes[2] = (length >> 8) & 0xFF
- bytes[3] = length & 0xFF
- return bytesToString(bytes)
- # **************************************************************************
- # Misc. Utility Functions
- # **************************************************************************
- def numBytes(n):
- if n==0:
- return 0
- bits = numBits(n)
- return int(math.ceil(bits / 8.0))
- def hashAndBase64(s):
- return stringToBase64(sha1(s).digest())
- def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
- bytes = getRandomBytes(numChars)
- bytesStr = "".join([chr(b) for b in bytes])
- return stringToBase64(bytesStr)[:numChars]
- # **************************************************************************
- # Big Number Math
- # **************************************************************************
- def getRandomNumber(low, high):
- if low >= high:
- raise AssertionError()
- howManyBits = numBits(high)
- howManyBytes = numBytes(high)
- lastBits = howManyBits % 8
- while 1:
- bytes = getRandomBytes(howManyBytes)
- if lastBits:
- bytes[0] = bytes[0] % (1 << lastBits)
- n = bytesToNumber(bytes)
- if n >= low and n < high:
- return n
- def gcd(a,b):
- a, b = max(a,b), min(a,b)
- while b:
- a, b = b, a % b
- return a
- def lcm(a, b):
- #This will break when python division changes, but we can't use // cause
- #of Jython
- return (a * b) / gcd(a, b)
- #Returns inverse of a mod b, zero if none
- #Uses Extended Euclidean Algorithm
- def invMod(a, b):
- c, d = a, b
- uc, ud = 1, 0
- while c != 0:
- #This will break when python division changes, but we can't use //
- #cause of Jython
- q = d / c
- c, d = d-(q*c), c
- uc, ud = ud - (q * uc), uc
- if d == 1:
- return ud % b
- return 0
- if gmpyLoaded:
- def powMod(base, power, modulus):
- base = gmpy.mpz(base)
- power = gmpy.mpz(power)
- modulus = gmpy.mpz(modulus)
- result = pow(base, power, modulus)
- return long(result)
- else:
- #Copied from Bryan G. Olson's post to comp.lang.python
- #Does left-to-right instead of pow()'s right-to-left,
- #thus about 30% faster than the python built-in with small bases
- def powMod(base, power, modulus):
- nBitScan = 5
- """ Return base**power mod modulus, using multi bit scanning
- with nBitScan bits at a time."""
- #TREV - Added support for negative exponents
- negativeResult = False
- if (power < 0):
- power *= -1
- negativeResult = True
- exp2 = 2**nBitScan
- mask = exp2 - 1
- # Break power into a list of digits of nBitScan bits.
- # The list is recursive so easy to read in reverse direction.
- nibbles = None
- while power:
- nibbles = int(power & mask), nibbles
- power = power >> nBitScan
- # Make a table of powers of base up to 2**nBitScan - 1
- lowPowers = [1]
- for i in xrange(1, exp2):
- lowPowers.append((lowPowers[i-1] * base) % modulus)
- # To exponentiate by the first nibble, look it up in the table
- nib, nibbles = nibbles
- prod = lowPowers[nib]
- # For the rest, square nBitScan times, then multiply by
- # base^nibble
- while nibbles:
- nib, nibbles = nibbles
- for i in xrange(nBitScan):
- prod = (prod * prod) % modulus
- if nib: prod = (prod * lowPowers[nib]) % modulus
- #TREV - Added support for negative exponents
- if negativeResult:
- prodInv = invMod(prod, modulus)
- #Check to make sure the inverse is correct
- if (prod * prodInv) % modulus != 1:
- raise AssertionError()
- return prodInv
- return prod
- #Pre-calculate a sieve of the ~100 primes < 1000:
- def makeSieve(n):
- sieve = range(n)
- for count in range(2, int(math.sqrt(n))):
- if sieve[count] == 0:
- continue
- x = sieve[count] * 2
- while x < len(sieve):
- sieve[x] = 0
- x += sieve[count]
- sieve = [x for x in sieve[2:] if x]
- return sieve
- sieve = makeSieve(1000)
- def isPrime(n, iterations=5, display=False):
- #Trial division with sieve
- for x in sieve:
- if x >= n: return True
- if n % x == 0: return False
- #Passed trial division, proceed to Rabin-Miller
- #Rabin-Miller implemented per Ferguson & Schneier
- #Compute s, t for Rabin-Miller
- if display: print "*",
- s, t = n-1, 0
- while s % 2 == 0:
- s, t = s/2, t+1
- #Repeat Rabin-Miller x times
- a = 2 #Use 2 as a base for first iteration speedup, per HAC
- for count in range(iterations):
- v = powMod(a, s, n)
- if v==1:
- continue
- i = 0
- while v != n-1:
- if i == t-1:
- return False
- else:
- v, i = powMod(v, 2, n), i+1
- a = getRandomNumber(2, n)
- return True
- def getRandomPrime(bits, display=False):
- if bits < 10:
- raise AssertionError()
- #The 1.5 ensures the 2 MSBs are set
- #Thus, when used for p,q in RSA, n will have its MSB set
- #
- #Since 30 is lcm(2,3,5), we'll set our test numbers to
- #29 % 30 and keep them there
- low = (2L ** (bits-1)) * 3/2
- high = 2L ** bits - 30
- p = getRandomNumber(low, high)
- p += 29 - (p % 30)
- while 1:
- if display: print ".",
- p += 30
- if p >= high:
- p = getRandomNumber(low, high)
- p += 29 - (p % 30)
- if isPrime(p, display=display):
- return p
- #Unused at the moment...
- def getRandomSafePrime(bits, display=False):
- if bits < 10:
- raise AssertionError()
- #The 1.5 ensures the 2 MSBs are set
- #Thus, when used for p,q in RSA, n will have its MSB set
- #
- #Since 30 is lcm(2,3,5), we'll set our test numbers to
- #29 % 30 and keep them there
- low = (2 ** (bits-2)) * 3/2
- high = (2 ** (bits-1)) - 30
- q = getRandomNumber(low, high)
- q += 29 - (q % 30)
- while 1:
- if display: print ".",
- q += 30
- if (q >= high):
- q = getRandomNumber(low, high)
- q += 29 - (q % 30)
- #Ideas from Tom Wu's SRP code
- #Do trial division on p and q before Rabin-Miller
- if isPrime(q, 0, display=display):
- p = (2 * q) + 1
- if isPrime(p, display=display):
- if isPrime(q, display=display):
- return p