PageRenderTime 51ms CodeModel.GetById 10ms app.highlight 35ms RepoModel.GetById 1ms app.codeStats 0ms

/gdata/tlslite/messages.py

http://radioappz.googlecode.com/
Python | 561 lines | 480 code | 75 blank | 6 comment | 53 complexity | 5fb035859840cd589da445ae21c9f2e7 MD5 | raw file
  1"""Classes representing TLS messages."""
  2
  3from utils.compat import *
  4from utils.cryptomath import *
  5from errors import *
  6from utils.codec import *
  7from constants import *
  8from X509 import X509
  9from X509CertChain import X509CertChain
 10
 11import sha
 12import md5
 13
 14class RecordHeader3:
 15    def __init__(self):
 16        self.type = 0
 17        self.version = (0,0)
 18        self.length = 0
 19        self.ssl2 = False
 20
 21    def create(self, version, type, length):
 22        self.type = type
 23        self.version = version
 24        self.length = length
 25        return self
 26
 27    def write(self):
 28        w = Writer(5)
 29        w.add(self.type, 1)
 30        w.add(self.version[0], 1)
 31        w.add(self.version[1], 1)
 32        w.add(self.length, 2)
 33        return w.bytes
 34
 35    def parse(self, p):
 36        self.type = p.get(1)
 37        self.version = (p.get(1), p.get(1))
 38        self.length = p.get(2)
 39        self.ssl2 = False
 40        return self
 41
 42class RecordHeader2:
 43    def __init__(self):
 44        self.type = 0
 45        self.version = (0,0)
 46        self.length = 0
 47        self.ssl2 = True
 48
 49    def parse(self, p):
 50        if p.get(1)!=128:
 51            raise SyntaxError()
 52        self.type = ContentType.handshake
 53        self.version = (2,0)
 54        #We don't support 2-byte-length-headers; could be a problem
 55        self.length = p.get(1)
 56        return self
 57
 58
 59class Msg:
 60    def preWrite(self, trial):
 61        if trial:
 62            w = Writer()
 63        else:
 64            length = self.write(True)
 65            w = Writer(length)
 66        return w
 67
 68    def postWrite(self, w, trial):
 69        if trial:
 70            return w.index
 71        else:
 72            return w.bytes
 73
 74class Alert(Msg):
 75    def __init__(self):
 76        self.contentType = ContentType.alert
 77        self.level = 0
 78        self.description = 0
 79
 80    def create(self, description, level=AlertLevel.fatal):
 81        self.level = level
 82        self.description = description
 83        return self
 84
 85    def parse(self, p):
 86        p.setLengthCheck(2)
 87        self.level = p.get(1)
 88        self.description = p.get(1)
 89        p.stopLengthCheck()
 90        return self
 91
 92    def write(self):
 93        w = Writer(2)
 94        w.add(self.level, 1)
 95        w.add(self.description, 1)
 96        return w.bytes
 97
 98
 99class HandshakeMsg(Msg):
100    def preWrite(self, handshakeType, trial):
101        if trial:
102            w = Writer()
103            w.add(handshakeType, 1)
104            w.add(0, 3)
105        else:
106            length = self.write(True)
107            w = Writer(length)
108            w.add(handshakeType, 1)
109            w.add(length-4, 3)
110        return w
111
112
113class ClientHello(HandshakeMsg):
114    def __init__(self, ssl2=False):
115        self.contentType = ContentType.handshake
116        self.ssl2 = ssl2
117        self.client_version = (0,0)
118        self.random = createByteArrayZeros(32)
119        self.session_id = createByteArraySequence([])
120        self.cipher_suites = []         # a list of 16-bit values
121        self.certificate_types = [CertificateType.x509]
122        self.compression_methods = []   # a list of 8-bit values
123        self.srp_username = None        # a string
124
125    def create(self, version, random, session_id, cipher_suites,
126               certificate_types=None, srp_username=None):
127        self.client_version = version
128        self.random = random
129        self.session_id = session_id
130        self.cipher_suites = cipher_suites
131        self.certificate_types = certificate_types
132        self.compression_methods = [0]
133        self.srp_username = srp_username
134        return self
135
136    def parse(self, p):
137        if self.ssl2:
138            self.client_version = (p.get(1), p.get(1))
139            cipherSpecsLength = p.get(2)
140            sessionIDLength = p.get(2)
141            randomLength = p.get(2)
142            self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
143            self.session_id = p.getFixBytes(sessionIDLength)
144            self.random = p.getFixBytes(randomLength)
145            if len(self.random) < 32:
146                zeroBytes = 32-len(self.random)
147                self.random = createByteArrayZeros(zeroBytes) + self.random
148            self.compression_methods = [0]#Fake this value
149
150            #We're not doing a stopLengthCheck() for SSLv2, oh well..
151        else:
152            p.startLengthCheck(3)
153            self.client_version = (p.get(1), p.get(1))
154            self.random = p.getFixBytes(32)
155            self.session_id = p.getVarBytes(1)
156            self.cipher_suites = p.getVarList(2, 2)
157            self.compression_methods = p.getVarList(1, 1)
158            if not p.atLengthCheck():
159                totalExtLength = p.get(2)
160                soFar = 0
161                while soFar != totalExtLength:
162                    extType = p.get(2)
163                    extLength = p.get(2)
164                    if extType == 6:
165                        self.srp_username = bytesToString(p.getVarBytes(1))
166                    elif extType == 7:
167                        self.certificate_types = p.getVarList(1, 1)
168                    else:
169                        p.getFixBytes(extLength)
170                    soFar += 4 + extLength
171            p.stopLengthCheck()
172        return self
173
174    def write(self, trial=False):
175        w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
176        w.add(self.client_version[0], 1)
177        w.add(self.client_version[1], 1)
178        w.addFixSeq(self.random, 1)
179        w.addVarSeq(self.session_id, 1, 1)
180        w.addVarSeq(self.cipher_suites, 2, 2)
181        w.addVarSeq(self.compression_methods, 1, 1)
182
183        extLength = 0
184        if self.certificate_types and self.certificate_types != \
185                [CertificateType.x509]:
186            extLength += 5 + len(self.certificate_types)
187        if self.srp_username:
188            extLength += 5 + len(self.srp_username)
189        if extLength > 0:
190            w.add(extLength, 2)
191
192        if self.certificate_types and self.certificate_types != \
193                [CertificateType.x509]:
194            w.add(7, 2)
195            w.add(len(self.certificate_types)+1, 2)
196            w.addVarSeq(self.certificate_types, 1, 1)
197        if self.srp_username:
198            w.add(6, 2)
199            w.add(len(self.srp_username)+1, 2)
200            w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
201
202        return HandshakeMsg.postWrite(self, w, trial)
203
204
205class ServerHello(HandshakeMsg):
206    def __init__(self):
207        self.contentType = ContentType.handshake
208        self.server_version = (0,0)
209        self.random = createByteArrayZeros(32)
210        self.session_id = createByteArraySequence([])
211        self.cipher_suite = 0
212        self.certificate_type = CertificateType.x509
213        self.compression_method = 0
214
215    def create(self, version, random, session_id, cipher_suite,
216               certificate_type):
217        self.server_version = version
218        self.random = random
219        self.session_id = session_id
220        self.cipher_suite = cipher_suite
221        self.certificate_type = certificate_type
222        self.compression_method = 0
223        return self
224
225    def parse(self, p):
226        p.startLengthCheck(3)
227        self.server_version = (p.get(1), p.get(1))
228        self.random = p.getFixBytes(32)
229        self.session_id = p.getVarBytes(1)
230        self.cipher_suite = p.get(2)
231        self.compression_method = p.get(1)
232        if not p.atLengthCheck():
233            totalExtLength = p.get(2)
234            soFar = 0
235            while soFar != totalExtLength:
236                extType = p.get(2)
237                extLength = p.get(2)
238                if extType == 7:
239                    self.certificate_type = p.get(1)
240                else:
241                    p.getFixBytes(extLength)
242                soFar += 4 + extLength
243        p.stopLengthCheck()
244        return self
245
246    def write(self, trial=False):
247        w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
248        w.add(self.server_version[0], 1)
249        w.add(self.server_version[1], 1)
250        w.addFixSeq(self.random, 1)
251        w.addVarSeq(self.session_id, 1, 1)
252        w.add(self.cipher_suite, 2)
253        w.add(self.compression_method, 1)
254
255        extLength = 0
256        if self.certificate_type and self.certificate_type != \
257                CertificateType.x509:
258            extLength += 5
259
260        if extLength != 0:
261            w.add(extLength, 2)
262
263        if self.certificate_type and self.certificate_type != \
264                CertificateType.x509:
265            w.add(7, 2)
266            w.add(1, 2)
267            w.add(self.certificate_type, 1)
268
269        return HandshakeMsg.postWrite(self, w, trial)
270
271class Certificate(HandshakeMsg):
272    def __init__(self, certificateType):
273        self.certificateType = certificateType
274        self.contentType = ContentType.handshake
275        self.certChain = None
276
277    def create(self, certChain):
278        self.certChain = certChain
279        return self
280
281    def parse(self, p):
282        p.startLengthCheck(3)
283        if self.certificateType == CertificateType.x509:
284            chainLength = p.get(3)
285            index = 0
286            certificate_list = []
287            while index != chainLength:
288                certBytes = p.getVarBytes(3)
289                x509 = X509()
290                x509.parseBinary(certBytes)
291                certificate_list.append(x509)
292                index += len(certBytes)+3
293            if certificate_list:
294                self.certChain = X509CertChain(certificate_list)
295        elif self.certificateType == CertificateType.cryptoID:
296            s = bytesToString(p.getVarBytes(2))
297            if s:
298                try:
299                    import cryptoIDlib.CertChain
300                except ImportError:
301                    raise SyntaxError(\
302                    "cryptoID cert chain received, cryptoIDlib not present")
303                self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
304        else:
305            raise AssertionError()
306
307        p.stopLengthCheck()
308        return self
309
310    def write(self, trial=False):
311        w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
312        if self.certificateType == CertificateType.x509:
313            chainLength = 0
314            if self.certChain:
315                certificate_list = self.certChain.x509List
316            else:
317                certificate_list = []
318            #determine length
319            for cert in certificate_list:
320                bytes = cert.writeBytes()
321                chainLength += len(bytes)+3
322            #add bytes
323            w.add(chainLength, 3)
324            for cert in certificate_list:
325                bytes = cert.writeBytes()
326                w.addVarSeq(bytes, 1, 3)
327        elif self.certificateType == CertificateType.cryptoID:
328            if self.certChain:
329                bytes = stringToBytes(self.certChain.write())
330            else:
331                bytes = createByteArraySequence([])
332            w.addVarSeq(bytes, 1, 2)
333        else:
334            raise AssertionError()
335        return HandshakeMsg.postWrite(self, w, trial)
336
337class CertificateRequest(HandshakeMsg):
338    def __init__(self):
339        self.contentType = ContentType.handshake
340        self.certificate_types = []
341        #treat as opaque bytes for now
342        self.certificate_authorities = createByteArraySequence([])
343
344    def create(self, certificate_types, certificate_authorities):
345        self.certificate_types = certificate_types
346        self.certificate_authorities = certificate_authorities
347        return self
348
349    def parse(self, p):
350        p.startLengthCheck(3)
351        self.certificate_types = p.getVarList(1, 1)
352        self.certificate_authorities = p.getVarBytes(2)
353        p.stopLengthCheck()
354        return self
355
356    def write(self, trial=False):
357        w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
358                                  trial)
359        w.addVarSeq(self.certificate_types, 1, 1)
360        w.addVarSeq(self.certificate_authorities, 1, 2)
361        return HandshakeMsg.postWrite(self, w, trial)
362
363class ServerKeyExchange(HandshakeMsg):
364    def __init__(self, cipherSuite):
365        self.cipherSuite = cipherSuite
366        self.contentType = ContentType.handshake
367        self.srp_N = 0L
368        self.srp_g = 0L
369        self.srp_s = createByteArraySequence([])
370        self.srp_B = 0L
371        self.signature = createByteArraySequence([])
372
373    def createSRP(self, srp_N, srp_g, srp_s, srp_B):
374        self.srp_N = srp_N
375        self.srp_g = srp_g
376        self.srp_s = srp_s
377        self.srp_B = srp_B
378        return self
379
380    def parse(self, p):
381        p.startLengthCheck(3)
382        self.srp_N = bytesToNumber(p.getVarBytes(2))
383        self.srp_g = bytesToNumber(p.getVarBytes(2))
384        self.srp_s = p.getVarBytes(1)
385        self.srp_B = bytesToNumber(p.getVarBytes(2))
386        if self.cipherSuite in CipherSuite.srpRsaSuites:
387            self.signature = p.getVarBytes(2)
388        p.stopLengthCheck()
389        return self
390
391    def write(self, trial=False):
392        w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
393                                  trial)
394        w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
395        w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
396        w.addVarSeq(self.srp_s, 1, 1)
397        w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
398        if self.cipherSuite in CipherSuite.srpRsaSuites:
399            w.addVarSeq(self.signature, 1, 2)
400        return HandshakeMsg.postWrite(self, w, trial)
401
402    def hash(self, clientRandom, serverRandom):
403        oldCipherSuite = self.cipherSuite
404        self.cipherSuite = None
405        try:
406            bytes = clientRandom + serverRandom + self.write()[4:]
407            s = bytesToString(bytes)
408            return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
409        finally:
410            self.cipherSuite = oldCipherSuite
411
412class ServerHelloDone(HandshakeMsg):
413    def __init__(self):
414        self.contentType = ContentType.handshake
415
416    def create(self):
417        return self
418
419    def parse(self, p):
420        p.startLengthCheck(3)
421        p.stopLengthCheck()
422        return self
423
424    def write(self, trial=False):
425        w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
426        return HandshakeMsg.postWrite(self, w, trial)
427
428class ClientKeyExchange(HandshakeMsg):
429    def __init__(self, cipherSuite, version=None):
430        self.cipherSuite = cipherSuite
431        self.version = version
432        self.contentType = ContentType.handshake
433        self.srp_A = 0
434        self.encryptedPreMasterSecret = createByteArraySequence([])
435
436    def createSRP(self, srp_A):
437        self.srp_A = srp_A
438        return self
439
440    def createRSA(self, encryptedPreMasterSecret):
441        self.encryptedPreMasterSecret = encryptedPreMasterSecret
442        return self
443
444    def parse(self, p):
445        p.startLengthCheck(3)
446        if self.cipherSuite in CipherSuite.srpSuites + \
447                               CipherSuite.srpRsaSuites:
448            self.srp_A = bytesToNumber(p.getVarBytes(2))
449        elif self.cipherSuite in CipherSuite.rsaSuites:
450            if self.version in ((3,1), (3,2)):
451                self.encryptedPreMasterSecret = p.getVarBytes(2)
452            elif self.version == (3,0):
453                self.encryptedPreMasterSecret = \
454                    p.getFixBytes(len(p.bytes)-p.index)
455            else:
456                raise AssertionError()
457        else:
458            raise AssertionError()
459        p.stopLengthCheck()
460        return self
461
462    def write(self, trial=False):
463        w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
464                                  trial)
465        if self.cipherSuite in CipherSuite.srpSuites + \
466                               CipherSuite.srpRsaSuites:
467            w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
468        elif self.cipherSuite in CipherSuite.rsaSuites:
469            if self.version in ((3,1), (3,2)):
470                w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
471            elif self.version == (3,0):
472                w.addFixSeq(self.encryptedPreMasterSecret, 1)
473            else:
474                raise AssertionError()
475        else:
476            raise AssertionError()
477        return HandshakeMsg.postWrite(self, w, trial)
478
479class CertificateVerify(HandshakeMsg):
480    def __init__(self):
481        self.contentType = ContentType.handshake
482        self.signature = createByteArraySequence([])
483
484    def create(self, signature):
485        self.signature = signature
486        return self
487
488    def parse(self, p):
489        p.startLengthCheck(3)
490        self.signature = p.getVarBytes(2)
491        p.stopLengthCheck()
492        return self
493
494    def write(self, trial=False):
495        w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
496                                  trial)
497        w.addVarSeq(self.signature, 1, 2)
498        return HandshakeMsg.postWrite(self, w, trial)
499
500class ChangeCipherSpec(Msg):
501    def __init__(self):
502        self.contentType = ContentType.change_cipher_spec
503        self.type = 1
504
505    def create(self):
506        self.type = 1
507        return self
508
509    def parse(self, p):
510        p.setLengthCheck(1)
511        self.type = p.get(1)
512        p.stopLengthCheck()
513        return self
514
515    def write(self, trial=False):
516        w = Msg.preWrite(self, trial)
517        w.add(self.type,1)
518        return Msg.postWrite(self, w, trial)
519
520
521class Finished(HandshakeMsg):
522    def __init__(self, version):
523        self.contentType = ContentType.handshake
524        self.version = version
525        self.verify_data = createByteArraySequence([])
526
527    def create(self, verify_data):
528        self.verify_data = verify_data
529        return self
530
531    def parse(self, p):
532        p.startLengthCheck(3)
533        if self.version == (3,0):
534            self.verify_data = p.getFixBytes(36)
535        elif self.version in ((3,1), (3,2)):
536            self.verify_data = p.getFixBytes(12)
537        else:
538            raise AssertionError()
539        p.stopLengthCheck()
540        return self
541
542    def write(self, trial=False):
543        w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
544        w.addFixSeq(self.verify_data, 1)
545        return HandshakeMsg.postWrite(self, w, trial)
546
547class ApplicationData(Msg):
548    def __init__(self):
549        self.contentType = ContentType.application_data
550        self.bytes = createByteArraySequence([])
551
552    def create(self, bytes):
553        self.bytes = bytes
554        return self
555
556    def parse(self, p):
557        self.bytes = p.bytes
558        return self
559
560    def write(self):
561        return self.bytes