/gdata/tlslite/messages.py

http://radioappz.googlecode.com/ · Python · 561 lines · 480 code · 75 blank · 6 comment · 67 complexity · 5fb035859840cd589da445ae21c9f2e7 MD5 · raw file

  1. """Classes representing TLS messages."""
  2. from utils.compat import *
  3. from utils.cryptomath import *
  4. from errors import *
  5. from utils.codec import *
  6. from constants import *
  7. from X509 import X509
  8. from X509CertChain import X509CertChain
  9. import sha
  10. import md5
  11. class RecordHeader3:
  12. def __init__(self):
  13. self.type = 0
  14. self.version = (0,0)
  15. self.length = 0
  16. self.ssl2 = False
  17. def create(self, version, type, length):
  18. self.type = type
  19. self.version = version
  20. self.length = length
  21. return self
  22. def write(self):
  23. w = Writer(5)
  24. w.add(self.type, 1)
  25. w.add(self.version[0], 1)
  26. w.add(self.version[1], 1)
  27. w.add(self.length, 2)
  28. return w.bytes
  29. def parse(self, p):
  30. self.type = p.get(1)
  31. self.version = (p.get(1), p.get(1))
  32. self.length = p.get(2)
  33. self.ssl2 = False
  34. return self
  35. class RecordHeader2:
  36. def __init__(self):
  37. self.type = 0
  38. self.version = (0,0)
  39. self.length = 0
  40. self.ssl2 = True
  41. def parse(self, p):
  42. if p.get(1)!=128:
  43. raise SyntaxError()
  44. self.type = ContentType.handshake
  45. self.version = (2,0)
  46. #We don't support 2-byte-length-headers; could be a problem
  47. self.length = p.get(1)
  48. return self
  49. class Msg:
  50. def preWrite(self, trial):
  51. if trial:
  52. w = Writer()
  53. else:
  54. length = self.write(True)
  55. w = Writer(length)
  56. return w
  57. def postWrite(self, w, trial):
  58. if trial:
  59. return w.index
  60. else:
  61. return w.bytes
  62. class Alert(Msg):
  63. def __init__(self):
  64. self.contentType = ContentType.alert
  65. self.level = 0
  66. self.description = 0
  67. def create(self, description, level=AlertLevel.fatal):
  68. self.level = level
  69. self.description = description
  70. return self
  71. def parse(self, p):
  72. p.setLengthCheck(2)
  73. self.level = p.get(1)
  74. self.description = p.get(1)
  75. p.stopLengthCheck()
  76. return self
  77. def write(self):
  78. w = Writer(2)
  79. w.add(self.level, 1)
  80. w.add(self.description, 1)
  81. return w.bytes
  82. class HandshakeMsg(Msg):
  83. def preWrite(self, handshakeType, trial):
  84. if trial:
  85. w = Writer()
  86. w.add(handshakeType, 1)
  87. w.add(0, 3)
  88. else:
  89. length = self.write(True)
  90. w = Writer(length)
  91. w.add(handshakeType, 1)
  92. w.add(length-4, 3)
  93. return w
  94. class ClientHello(HandshakeMsg):
  95. def __init__(self, ssl2=False):
  96. self.contentType = ContentType.handshake
  97. self.ssl2 = ssl2
  98. self.client_version = (0,0)
  99. self.random = createByteArrayZeros(32)
  100. self.session_id = createByteArraySequence([])
  101. self.cipher_suites = [] # a list of 16-bit values
  102. self.certificate_types = [CertificateType.x509]
  103. self.compression_methods = [] # a list of 8-bit values
  104. self.srp_username = None # a string
  105. def create(self, version, random, session_id, cipher_suites,
  106. certificate_types=None, srp_username=None):
  107. self.client_version = version
  108. self.random = random
  109. self.session_id = session_id
  110. self.cipher_suites = cipher_suites
  111. self.certificate_types = certificate_types
  112. self.compression_methods = [0]
  113. self.srp_username = srp_username
  114. return self
  115. def parse(self, p):
  116. if self.ssl2:
  117. self.client_version = (p.get(1), p.get(1))
  118. cipherSpecsLength = p.get(2)
  119. sessionIDLength = p.get(2)
  120. randomLength = p.get(2)
  121. self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
  122. self.session_id = p.getFixBytes(sessionIDLength)
  123. self.random = p.getFixBytes(randomLength)
  124. if len(self.random) < 32:
  125. zeroBytes = 32-len(self.random)
  126. self.random = createByteArrayZeros(zeroBytes) + self.random
  127. self.compression_methods = [0]#Fake this value
  128. #We're not doing a stopLengthCheck() for SSLv2, oh well..
  129. else:
  130. p.startLengthCheck(3)
  131. self.client_version = (p.get(1), p.get(1))
  132. self.random = p.getFixBytes(32)
  133. self.session_id = p.getVarBytes(1)
  134. self.cipher_suites = p.getVarList(2, 2)
  135. self.compression_methods = p.getVarList(1, 1)
  136. if not p.atLengthCheck():
  137. totalExtLength = p.get(2)
  138. soFar = 0
  139. while soFar != totalExtLength:
  140. extType = p.get(2)
  141. extLength = p.get(2)
  142. if extType == 6:
  143. self.srp_username = bytesToString(p.getVarBytes(1))
  144. elif extType == 7:
  145. self.certificate_types = p.getVarList(1, 1)
  146. else:
  147. p.getFixBytes(extLength)
  148. soFar += 4 + extLength
  149. p.stopLengthCheck()
  150. return self
  151. def write(self, trial=False):
  152. w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
  153. w.add(self.client_version[0], 1)
  154. w.add(self.client_version[1], 1)
  155. w.addFixSeq(self.random, 1)
  156. w.addVarSeq(self.session_id, 1, 1)
  157. w.addVarSeq(self.cipher_suites, 2, 2)
  158. w.addVarSeq(self.compression_methods, 1, 1)
  159. extLength = 0
  160. if self.certificate_types and self.certificate_types != \
  161. [CertificateType.x509]:
  162. extLength += 5 + len(self.certificate_types)
  163. if self.srp_username:
  164. extLength += 5 + len(self.srp_username)
  165. if extLength > 0:
  166. w.add(extLength, 2)
  167. if self.certificate_types and self.certificate_types != \
  168. [CertificateType.x509]:
  169. w.add(7, 2)
  170. w.add(len(self.certificate_types)+1, 2)
  171. w.addVarSeq(self.certificate_types, 1, 1)
  172. if self.srp_username:
  173. w.add(6, 2)
  174. w.add(len(self.srp_username)+1, 2)
  175. w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
  176. return HandshakeMsg.postWrite(self, w, trial)
  177. class ServerHello(HandshakeMsg):
  178. def __init__(self):
  179. self.contentType = ContentType.handshake
  180. self.server_version = (0,0)
  181. self.random = createByteArrayZeros(32)
  182. self.session_id = createByteArraySequence([])
  183. self.cipher_suite = 0
  184. self.certificate_type = CertificateType.x509
  185. self.compression_method = 0
  186. def create(self, version, random, session_id, cipher_suite,
  187. certificate_type):
  188. self.server_version = version
  189. self.random = random
  190. self.session_id = session_id
  191. self.cipher_suite = cipher_suite
  192. self.certificate_type = certificate_type
  193. self.compression_method = 0
  194. return self
  195. def parse(self, p):
  196. p.startLengthCheck(3)
  197. self.server_version = (p.get(1), p.get(1))
  198. self.random = p.getFixBytes(32)
  199. self.session_id = p.getVarBytes(1)
  200. self.cipher_suite = p.get(2)
  201. self.compression_method = p.get(1)
  202. if not p.atLengthCheck():
  203. totalExtLength = p.get(2)
  204. soFar = 0
  205. while soFar != totalExtLength:
  206. extType = p.get(2)
  207. extLength = p.get(2)
  208. if extType == 7:
  209. self.certificate_type = p.get(1)
  210. else:
  211. p.getFixBytes(extLength)
  212. soFar += 4 + extLength
  213. p.stopLengthCheck()
  214. return self
  215. def write(self, trial=False):
  216. w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
  217. w.add(self.server_version[0], 1)
  218. w.add(self.server_version[1], 1)
  219. w.addFixSeq(self.random, 1)
  220. w.addVarSeq(self.session_id, 1, 1)
  221. w.add(self.cipher_suite, 2)
  222. w.add(self.compression_method, 1)
  223. extLength = 0
  224. if self.certificate_type and self.certificate_type != \
  225. CertificateType.x509:
  226. extLength += 5
  227. if extLength != 0:
  228. w.add(extLength, 2)
  229. if self.certificate_type and self.certificate_type != \
  230. CertificateType.x509:
  231. w.add(7, 2)
  232. w.add(1, 2)
  233. w.add(self.certificate_type, 1)
  234. return HandshakeMsg.postWrite(self, w, trial)
  235. class Certificate(HandshakeMsg):
  236. def __init__(self, certificateType):
  237. self.certificateType = certificateType
  238. self.contentType = ContentType.handshake
  239. self.certChain = None
  240. def create(self, certChain):
  241. self.certChain = certChain
  242. return self
  243. def parse(self, p):
  244. p.startLengthCheck(3)
  245. if self.certificateType == CertificateType.x509:
  246. chainLength = p.get(3)
  247. index = 0
  248. certificate_list = []
  249. while index != chainLength:
  250. certBytes = p.getVarBytes(3)
  251. x509 = X509()
  252. x509.parseBinary(certBytes)
  253. certificate_list.append(x509)
  254. index += len(certBytes)+3
  255. if certificate_list:
  256. self.certChain = X509CertChain(certificate_list)
  257. elif self.certificateType == CertificateType.cryptoID:
  258. s = bytesToString(p.getVarBytes(2))
  259. if s:
  260. try:
  261. import cryptoIDlib.CertChain
  262. except ImportError:
  263. raise SyntaxError(\
  264. "cryptoID cert chain received, cryptoIDlib not present")
  265. self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
  266. else:
  267. raise AssertionError()
  268. p.stopLengthCheck()
  269. return self
  270. def write(self, trial=False):
  271. w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
  272. if self.certificateType == CertificateType.x509:
  273. chainLength = 0
  274. if self.certChain:
  275. certificate_list = self.certChain.x509List
  276. else:
  277. certificate_list = []
  278. #determine length
  279. for cert in certificate_list:
  280. bytes = cert.writeBytes()
  281. chainLength += len(bytes)+3
  282. #add bytes
  283. w.add(chainLength, 3)
  284. for cert in certificate_list:
  285. bytes = cert.writeBytes()
  286. w.addVarSeq(bytes, 1, 3)
  287. elif self.certificateType == CertificateType.cryptoID:
  288. if self.certChain:
  289. bytes = stringToBytes(self.certChain.write())
  290. else:
  291. bytes = createByteArraySequence([])
  292. w.addVarSeq(bytes, 1, 2)
  293. else:
  294. raise AssertionError()
  295. return HandshakeMsg.postWrite(self, w, trial)
  296. class CertificateRequest(HandshakeMsg):
  297. def __init__(self):
  298. self.contentType = ContentType.handshake
  299. self.certificate_types = []
  300. #treat as opaque bytes for now
  301. self.certificate_authorities = createByteArraySequence([])
  302. def create(self, certificate_types, certificate_authorities):
  303. self.certificate_types = certificate_types
  304. self.certificate_authorities = certificate_authorities
  305. return self
  306. def parse(self, p):
  307. p.startLengthCheck(3)
  308. self.certificate_types = p.getVarList(1, 1)
  309. self.certificate_authorities = p.getVarBytes(2)
  310. p.stopLengthCheck()
  311. return self
  312. def write(self, trial=False):
  313. w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
  314. trial)
  315. w.addVarSeq(self.certificate_types, 1, 1)
  316. w.addVarSeq(self.certificate_authorities, 1, 2)
  317. return HandshakeMsg.postWrite(self, w, trial)
  318. class ServerKeyExchange(HandshakeMsg):
  319. def __init__(self, cipherSuite):
  320. self.cipherSuite = cipherSuite
  321. self.contentType = ContentType.handshake
  322. self.srp_N = 0L
  323. self.srp_g = 0L
  324. self.srp_s = createByteArraySequence([])
  325. self.srp_B = 0L
  326. self.signature = createByteArraySequence([])
  327. def createSRP(self, srp_N, srp_g, srp_s, srp_B):
  328. self.srp_N = srp_N
  329. self.srp_g = srp_g
  330. self.srp_s = srp_s
  331. self.srp_B = srp_B
  332. return self
  333. def parse(self, p):
  334. p.startLengthCheck(3)
  335. self.srp_N = bytesToNumber(p.getVarBytes(2))
  336. self.srp_g = bytesToNumber(p.getVarBytes(2))
  337. self.srp_s = p.getVarBytes(1)
  338. self.srp_B = bytesToNumber(p.getVarBytes(2))
  339. if self.cipherSuite in CipherSuite.srpRsaSuites:
  340. self.signature = p.getVarBytes(2)
  341. p.stopLengthCheck()
  342. return self
  343. def write(self, trial=False):
  344. w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
  345. trial)
  346. w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
  347. w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
  348. w.addVarSeq(self.srp_s, 1, 1)
  349. w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
  350. if self.cipherSuite in CipherSuite.srpRsaSuites:
  351. w.addVarSeq(self.signature, 1, 2)
  352. return HandshakeMsg.postWrite(self, w, trial)
  353. def hash(self, clientRandom, serverRandom):
  354. oldCipherSuite = self.cipherSuite
  355. self.cipherSuite = None
  356. try:
  357. bytes = clientRandom + serverRandom + self.write()[4:]
  358. s = bytesToString(bytes)
  359. return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
  360. finally:
  361. self.cipherSuite = oldCipherSuite
  362. class ServerHelloDone(HandshakeMsg):
  363. def __init__(self):
  364. self.contentType = ContentType.handshake
  365. def create(self):
  366. return self
  367. def parse(self, p):
  368. p.startLengthCheck(3)
  369. p.stopLengthCheck()
  370. return self
  371. def write(self, trial=False):
  372. w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
  373. return HandshakeMsg.postWrite(self, w, trial)
  374. class ClientKeyExchange(HandshakeMsg):
  375. def __init__(self, cipherSuite, version=None):
  376. self.cipherSuite = cipherSuite
  377. self.version = version
  378. self.contentType = ContentType.handshake
  379. self.srp_A = 0
  380. self.encryptedPreMasterSecret = createByteArraySequence([])
  381. def createSRP(self, srp_A):
  382. self.srp_A = srp_A
  383. return self
  384. def createRSA(self, encryptedPreMasterSecret):
  385. self.encryptedPreMasterSecret = encryptedPreMasterSecret
  386. return self
  387. def parse(self, p):
  388. p.startLengthCheck(3)
  389. if self.cipherSuite in CipherSuite.srpSuites + \
  390. CipherSuite.srpRsaSuites:
  391. self.srp_A = bytesToNumber(p.getVarBytes(2))
  392. elif self.cipherSuite in CipherSuite.rsaSuites:
  393. if self.version in ((3,1), (3,2)):
  394. self.encryptedPreMasterSecret = p.getVarBytes(2)
  395. elif self.version == (3,0):
  396. self.encryptedPreMasterSecret = \
  397. p.getFixBytes(len(p.bytes)-p.index)
  398. else:
  399. raise AssertionError()
  400. else:
  401. raise AssertionError()
  402. p.stopLengthCheck()
  403. return self
  404. def write(self, trial=False):
  405. w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
  406. trial)
  407. if self.cipherSuite in CipherSuite.srpSuites + \
  408. CipherSuite.srpRsaSuites:
  409. w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
  410. elif self.cipherSuite in CipherSuite.rsaSuites:
  411. if self.version in ((3,1), (3,2)):
  412. w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
  413. elif self.version == (3,0):
  414. w.addFixSeq(self.encryptedPreMasterSecret, 1)
  415. else:
  416. raise AssertionError()
  417. else:
  418. raise AssertionError()
  419. return HandshakeMsg.postWrite(self, w, trial)
  420. class CertificateVerify(HandshakeMsg):
  421. def __init__(self):
  422. self.contentType = ContentType.handshake
  423. self.signature = createByteArraySequence([])
  424. def create(self, signature):
  425. self.signature = signature
  426. return self
  427. def parse(self, p):
  428. p.startLengthCheck(3)
  429. self.signature = p.getVarBytes(2)
  430. p.stopLengthCheck()
  431. return self
  432. def write(self, trial=False):
  433. w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
  434. trial)
  435. w.addVarSeq(self.signature, 1, 2)
  436. return HandshakeMsg.postWrite(self, w, trial)
  437. class ChangeCipherSpec(Msg):
  438. def __init__(self):
  439. self.contentType = ContentType.change_cipher_spec
  440. self.type = 1
  441. def create(self):
  442. self.type = 1
  443. return self
  444. def parse(self, p):
  445. p.setLengthCheck(1)
  446. self.type = p.get(1)
  447. p.stopLengthCheck()
  448. return self
  449. def write(self, trial=False):
  450. w = Msg.preWrite(self, trial)
  451. w.add(self.type,1)
  452. return Msg.postWrite(self, w, trial)
  453. class Finished(HandshakeMsg):
  454. def __init__(self, version):
  455. self.contentType = ContentType.handshake
  456. self.version = version
  457. self.verify_data = createByteArraySequence([])
  458. def create(self, verify_data):
  459. self.verify_data = verify_data
  460. return self
  461. def parse(self, p):
  462. p.startLengthCheck(3)
  463. if self.version == (3,0):
  464. self.verify_data = p.getFixBytes(36)
  465. elif self.version in ((3,1), (3,2)):
  466. self.verify_data = p.getFixBytes(12)
  467. else:
  468. raise AssertionError()
  469. p.stopLengthCheck()
  470. return self
  471. def write(self, trial=False):
  472. w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
  473. w.addFixSeq(self.verify_data, 1)
  474. return HandshakeMsg.postWrite(self, w, trial)
  475. class ApplicationData(Msg):
  476. def __init__(self):
  477. self.contentType = ContentType.application_data
  478. self.bytes = createByteArraySequence([])
  479. def create(self, bytes):
  480. self.bytes = bytes
  481. return self
  482. def parse(self, p):
  483. self.bytes = p.bytes
  484. return self
  485. def write(self):
  486. return self.bytes