/gdata/tlslite/utils/cryptomath.py

http://radioappz.googlecode.com/ · Python · 404 lines · 275 code · 58 blank · 71 comment · 75 complexity · b6b832d9aaddee7c4093248b96db86b1 MD5 · raw file

  1. """cryptomath module
  2. This module has basic math/crypto code."""
  3. import os
  4. import sys
  5. import math
  6. import base64
  7. import binascii
  8. if sys.version_info[:2] <= (2, 4):
  9. from sha import sha as sha1
  10. else:
  11. from hashlib import sha1
  12. from compat import *
  13. # **************************************************************************
  14. # Load Optional Modules
  15. # **************************************************************************
  16. # Try to load M2Crypto/OpenSSL
  17. try:
  18. from M2Crypto import m2
  19. m2cryptoLoaded = True
  20. except ImportError:
  21. m2cryptoLoaded = False
  22. # Try to load cryptlib
  23. try:
  24. import cryptlib_py
  25. try:
  26. cryptlib_py.cryptInit()
  27. except cryptlib_py.CryptException, e:
  28. #If tlslite and cryptoIDlib are both present,
  29. #they might each try to re-initialize this,
  30. #so we're tolerant of that.
  31. if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
  32. raise
  33. cryptlibpyLoaded = True
  34. except ImportError:
  35. cryptlibpyLoaded = False
  36. #Try to load GMPY
  37. try:
  38. import gmpy
  39. gmpyLoaded = True
  40. except ImportError:
  41. gmpyLoaded = False
  42. #Try to load pycrypto
  43. try:
  44. import Crypto.Cipher.AES
  45. pycryptoLoaded = True
  46. except ImportError:
  47. pycryptoLoaded = False
  48. # **************************************************************************
  49. # PRNG Functions
  50. # **************************************************************************
  51. # Get os.urandom PRNG
  52. try:
  53. os.urandom(1)
  54. def getRandomBytes(howMany):
  55. return stringToBytes(os.urandom(howMany))
  56. prngName = "os.urandom"
  57. except:
  58. # Else get cryptlib PRNG
  59. if cryptlibpyLoaded:
  60. def getRandomBytes(howMany):
  61. randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
  62. cryptlib_py.CRYPT_ALGO_AES)
  63. cryptlib_py.cryptSetAttribute(randomKey,
  64. cryptlib_py.CRYPT_CTXINFO_MODE,
  65. cryptlib_py.CRYPT_MODE_OFB)
  66. cryptlib_py.cryptGenerateKey(randomKey)
  67. bytes = createByteArrayZeros(howMany)
  68. cryptlib_py.cryptEncrypt(randomKey, bytes)
  69. return bytes
  70. prngName = "cryptlib"
  71. else:
  72. #Else get UNIX /dev/urandom PRNG
  73. try:
  74. devRandomFile = open("/dev/urandom", "rb")
  75. def getRandomBytes(howMany):
  76. return stringToBytes(devRandomFile.read(howMany))
  77. prngName = "/dev/urandom"
  78. except IOError:
  79. #Else get Win32 CryptoAPI PRNG
  80. try:
  81. import win32prng
  82. def getRandomBytes(howMany):
  83. s = win32prng.getRandomBytes(howMany)
  84. if len(s) != howMany:
  85. raise AssertionError()
  86. return stringToBytes(s)
  87. prngName ="CryptoAPI"
  88. except ImportError:
  89. #Else no PRNG :-(
  90. def getRandomBytes(howMany):
  91. raise NotImplementedError("No Random Number Generator "\
  92. "available.")
  93. prngName = "None"
  94. # **************************************************************************
  95. # Converter Functions
  96. # **************************************************************************
  97. def bytesToNumber(bytes):
  98. total = 0L
  99. multiplier = 1L
  100. for count in range(len(bytes)-1, -1, -1):
  101. byte = bytes[count]
  102. total += multiplier * byte
  103. multiplier *= 256
  104. return total
  105. def numberToBytes(n):
  106. howManyBytes = numBytes(n)
  107. bytes = createByteArrayZeros(howManyBytes)
  108. for count in range(howManyBytes-1, -1, -1):
  109. bytes[count] = int(n % 256)
  110. n >>= 8
  111. return bytes
  112. def bytesToBase64(bytes):
  113. s = bytesToString(bytes)
  114. return stringToBase64(s)
  115. def base64ToBytes(s):
  116. s = base64ToString(s)
  117. return stringToBytes(s)
  118. def numberToBase64(n):
  119. bytes = numberToBytes(n)
  120. return bytesToBase64(bytes)
  121. def base64ToNumber(s):
  122. bytes = base64ToBytes(s)
  123. return bytesToNumber(bytes)
  124. def stringToNumber(s):
  125. bytes = stringToBytes(s)
  126. return bytesToNumber(bytes)
  127. def numberToString(s):
  128. bytes = numberToBytes(s)
  129. return bytesToString(bytes)
  130. def base64ToString(s):
  131. try:
  132. return base64.decodestring(s)
  133. except binascii.Error, e:
  134. raise SyntaxError(e)
  135. except binascii.Incomplete, e:
  136. raise SyntaxError(e)
  137. def stringToBase64(s):
  138. return base64.encodestring(s).replace("\n", "")
  139. def mpiToNumber(mpi): #mpi is an openssl-format bignum string
  140. if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
  141. raise AssertionError()
  142. bytes = stringToBytes(mpi[4:])
  143. return bytesToNumber(bytes)
  144. def numberToMPI(n):
  145. bytes = numberToBytes(n)
  146. ext = 0
  147. #If the high-order bit is going to be set,
  148. #add an extra byte of zeros
  149. if (numBits(n) & 0x7)==0:
  150. ext = 1
  151. length = numBytes(n) + ext
  152. bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
  153. bytes[0] = (length >> 24) & 0xFF
  154. bytes[1] = (length >> 16) & 0xFF
  155. bytes[2] = (length >> 8) & 0xFF
  156. bytes[3] = length & 0xFF
  157. return bytesToString(bytes)
  158. # **************************************************************************
  159. # Misc. Utility Functions
  160. # **************************************************************************
  161. def numBytes(n):
  162. if n==0:
  163. return 0
  164. bits = numBits(n)
  165. return int(math.ceil(bits / 8.0))
  166. def hashAndBase64(s):
  167. return stringToBase64(sha1(s).digest())
  168. def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
  169. bytes = getRandomBytes(numChars)
  170. bytesStr = "".join([chr(b) for b in bytes])
  171. return stringToBase64(bytesStr)[:numChars]
  172. # **************************************************************************
  173. # Big Number Math
  174. # **************************************************************************
  175. def getRandomNumber(low, high):
  176. if low >= high:
  177. raise AssertionError()
  178. howManyBits = numBits(high)
  179. howManyBytes = numBytes(high)
  180. lastBits = howManyBits % 8
  181. while 1:
  182. bytes = getRandomBytes(howManyBytes)
  183. if lastBits:
  184. bytes[0] = bytes[0] % (1 << lastBits)
  185. n = bytesToNumber(bytes)
  186. if n >= low and n < high:
  187. return n
  188. def gcd(a,b):
  189. a, b = max(a,b), min(a,b)
  190. while b:
  191. a, b = b, a % b
  192. return a
  193. def lcm(a, b):
  194. #This will break when python division changes, but we can't use // cause
  195. #of Jython
  196. return (a * b) / gcd(a, b)
  197. #Returns inverse of a mod b, zero if none
  198. #Uses Extended Euclidean Algorithm
  199. def invMod(a, b):
  200. c, d = a, b
  201. uc, ud = 1, 0
  202. while c != 0:
  203. #This will break when python division changes, but we can't use //
  204. #cause of Jython
  205. q = d / c
  206. c, d = d-(q*c), c
  207. uc, ud = ud - (q * uc), uc
  208. if d == 1:
  209. return ud % b
  210. return 0
  211. if gmpyLoaded:
  212. def powMod(base, power, modulus):
  213. base = gmpy.mpz(base)
  214. power = gmpy.mpz(power)
  215. modulus = gmpy.mpz(modulus)
  216. result = pow(base, power, modulus)
  217. return long(result)
  218. else:
  219. #Copied from Bryan G. Olson's post to comp.lang.python
  220. #Does left-to-right instead of pow()'s right-to-left,
  221. #thus about 30% faster than the python built-in with small bases
  222. def powMod(base, power, modulus):
  223. nBitScan = 5
  224. """ Return base**power mod modulus, using multi bit scanning
  225. with nBitScan bits at a time."""
  226. #TREV - Added support for negative exponents
  227. negativeResult = False
  228. if (power < 0):
  229. power *= -1
  230. negativeResult = True
  231. exp2 = 2**nBitScan
  232. mask = exp2 - 1
  233. # Break power into a list of digits of nBitScan bits.
  234. # The list is recursive so easy to read in reverse direction.
  235. nibbles = None
  236. while power:
  237. nibbles = int(power & mask), nibbles
  238. power = power >> nBitScan
  239. # Make a table of powers of base up to 2**nBitScan - 1
  240. lowPowers = [1]
  241. for i in xrange(1, exp2):
  242. lowPowers.append((lowPowers[i-1] * base) % modulus)
  243. # To exponentiate by the first nibble, look it up in the table
  244. nib, nibbles = nibbles
  245. prod = lowPowers[nib]
  246. # For the rest, square nBitScan times, then multiply by
  247. # base^nibble
  248. while nibbles:
  249. nib, nibbles = nibbles
  250. for i in xrange(nBitScan):
  251. prod = (prod * prod) % modulus
  252. if nib: prod = (prod * lowPowers[nib]) % modulus
  253. #TREV - Added support for negative exponents
  254. if negativeResult:
  255. prodInv = invMod(prod, modulus)
  256. #Check to make sure the inverse is correct
  257. if (prod * prodInv) % modulus != 1:
  258. raise AssertionError()
  259. return prodInv
  260. return prod
  261. #Pre-calculate a sieve of the ~100 primes < 1000:
  262. def makeSieve(n):
  263. sieve = range(n)
  264. for count in range(2, int(math.sqrt(n))):
  265. if sieve[count] == 0:
  266. continue
  267. x = sieve[count] * 2
  268. while x < len(sieve):
  269. sieve[x] = 0
  270. x += sieve[count]
  271. sieve = [x for x in sieve[2:] if x]
  272. return sieve
  273. sieve = makeSieve(1000)
  274. def isPrime(n, iterations=5, display=False):
  275. #Trial division with sieve
  276. for x in sieve:
  277. if x >= n: return True
  278. if n % x == 0: return False
  279. #Passed trial division, proceed to Rabin-Miller
  280. #Rabin-Miller implemented per Ferguson & Schneier
  281. #Compute s, t for Rabin-Miller
  282. if display: print "*",
  283. s, t = n-1, 0
  284. while s % 2 == 0:
  285. s, t = s/2, t+1
  286. #Repeat Rabin-Miller x times
  287. a = 2 #Use 2 as a base for first iteration speedup, per HAC
  288. for count in range(iterations):
  289. v = powMod(a, s, n)
  290. if v==1:
  291. continue
  292. i = 0
  293. while v != n-1:
  294. if i == t-1:
  295. return False
  296. else:
  297. v, i = powMod(v, 2, n), i+1
  298. a = getRandomNumber(2, n)
  299. return True
  300. def getRandomPrime(bits, display=False):
  301. if bits < 10:
  302. raise AssertionError()
  303. #The 1.5 ensures the 2 MSBs are set
  304. #Thus, when used for p,q in RSA, n will have its MSB set
  305. #
  306. #Since 30 is lcm(2,3,5), we'll set our test numbers to
  307. #29 % 30 and keep them there
  308. low = (2L ** (bits-1)) * 3/2
  309. high = 2L ** bits - 30
  310. p = getRandomNumber(low, high)
  311. p += 29 - (p % 30)
  312. while 1:
  313. if display: print ".",
  314. p += 30
  315. if p >= high:
  316. p = getRandomNumber(low, high)
  317. p += 29 - (p % 30)
  318. if isPrime(p, display=display):
  319. return p
  320. #Unused at the moment...
  321. def getRandomSafePrime(bits, display=False):
  322. if bits < 10:
  323. raise AssertionError()
  324. #The 1.5 ensures the 2 MSBs are set
  325. #Thus, when used for p,q in RSA, n will have its MSB set
  326. #
  327. #Since 30 is lcm(2,3,5), we'll set our test numbers to
  328. #29 % 30 and keep them there
  329. low = (2 ** (bits-2)) * 3/2
  330. high = (2 ** (bits-1)) - 30
  331. q = getRandomNumber(low, high)
  332. q += 29 - (q % 30)
  333. while 1:
  334. if display: print ".",
  335. q += 30
  336. if (q >= high):
  337. q = getRandomNumber(low, high)
  338. q += 29 - (q % 30)
  339. #Ideas from Tom Wu's SRP code
  340. #Do trial division on p and q before Rabin-Miller
  341. if isPrime(q, 0, display=display):
  342. p = (2 * q) + 1
  343. if isPrime(p, display=display):
  344. if isPrime(q, display=display):
  345. return p