PageRenderTime 115ms CodeModel.GetById 60ms app.highlight 34ms RepoModel.GetById 16ms app.codeStats 1ms

/Lib/ssl.py

http://unladen-swallow.googlecode.com/
Python | 451 lines | 375 code | 12 blank | 64 comment | 6 complexity | a29337a19f06c1c1c3eccae9835ad91f MD5 | raw file
  1# Wrapper module for _ssl, providing some additional facilities
  2# implemented in Python.  Written by Bill Janssen.
  3
  4"""\
  5This module provides some more Pythonic support for SSL.
  6
  7Object types:
  8
  9  SSLSocket -- subtype of socket.socket which does SSL over the socket
 10
 11Exceptions:
 12
 13  SSLError -- exception raised for I/O errors
 14
 15Functions:
 16
 17  cert_time_to_seconds -- convert time string used for certificate
 18                          notBefore and notAfter functions to integer
 19                          seconds past the Epoch (the time values
 20                          returned from time.time())
 21
 22  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
 23                          by the server running on HOST at port PORT.  No
 24                          validation of the certificate is performed.
 25
 26Integer constants:
 27
 28SSL_ERROR_ZERO_RETURN
 29SSL_ERROR_WANT_READ
 30SSL_ERROR_WANT_WRITE
 31SSL_ERROR_WANT_X509_LOOKUP
 32SSL_ERROR_SYSCALL
 33SSL_ERROR_SSL
 34SSL_ERROR_WANT_CONNECT
 35
 36SSL_ERROR_EOF
 37SSL_ERROR_INVALID_ERROR_CODE
 38
 39The following group define certificate requirements that one side is
 40allowing/requiring from the other side:
 41
 42CERT_NONE - no certificates from the other side are required (or will
 43            be looked at if provided)
 44CERT_OPTIONAL - certificates are not required, but if provided will be
 45                validated, and if validation fails, the connection will
 46                also fail
 47CERT_REQUIRED - certificates are required, and will be validated, and
 48                if validation fails, the connection will also fail
 49
 50The following constants identify various SSL protocol variants:
 51
 52PROTOCOL_SSLv2
 53PROTOCOL_SSLv3
 54PROTOCOL_SSLv23
 55PROTOCOL_TLSv1
 56"""
 57
 58import textwrap
 59
 60import _ssl             # if we can't import it, let the error propagate
 61
 62from _ssl import SSLError
 63from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
 64from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
 65from _ssl import RAND_status, RAND_egd, RAND_add
 66from _ssl import \
 67     SSL_ERROR_ZERO_RETURN, \
 68     SSL_ERROR_WANT_READ, \
 69     SSL_ERROR_WANT_WRITE, \
 70     SSL_ERROR_WANT_X509_LOOKUP, \
 71     SSL_ERROR_SYSCALL, \
 72     SSL_ERROR_SSL, \
 73     SSL_ERROR_WANT_CONNECT, \
 74     SSL_ERROR_EOF, \
 75     SSL_ERROR_INVALID_ERROR_CODE
 76
 77from socket import socket, _fileobject
 78from socket import getnameinfo as _getnameinfo
 79import base64        # for DER-to-PEM translation
 80
 81class SSLSocket (socket):
 82
 83    """This class implements a subtype of socket.socket that wraps
 84    the underlying OS socket in an SSL context when necessary, and
 85    provides read and write methods over that channel."""
 86
 87    def __init__(self, sock, keyfile=None, certfile=None,
 88                 server_side=False, cert_reqs=CERT_NONE,
 89                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
 90                 do_handshake_on_connect=True,
 91                 suppress_ragged_eofs=True):
 92        socket.__init__(self, _sock=sock._sock)
 93        # the initializer for socket trashes the methods (tsk, tsk), so...
 94        self.send = lambda data, flags=0: SSLSocket.send(self, data, flags)
 95        self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags)
 96        self.recv = lambda buflen=1024, flags=0: SSLSocket.recv(self, buflen, flags)
 97        self.recvfrom = lambda addr, buflen=1024, flags=0: SSLSocket.recvfrom(self, addr, buflen, flags)
 98        self.recv_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recv_into(self, buffer, nbytes, flags)
 99        self.recvfrom_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recvfrom_into(self, buffer, nbytes, flags)
100
101        if certfile and not keyfile:
102            keyfile = certfile
103        # see if it's connected
104        try:
105            socket.getpeername(self)
106        except:
107            # no, no connection yet
108            self._sslobj = None
109        else:
110            # yes, create the SSL object
111            self._sslobj = _ssl.sslwrap(self._sock, server_side,
112                                        keyfile, certfile,
113                                        cert_reqs, ssl_version, ca_certs)
114            if do_handshake_on_connect:
115                timeout = self.gettimeout()
116                try:
117                    self.settimeout(None)
118                    self.do_handshake()
119                finally:
120                    self.settimeout(timeout)
121        self.keyfile = keyfile
122        self.certfile = certfile
123        self.cert_reqs = cert_reqs
124        self.ssl_version = ssl_version
125        self.ca_certs = ca_certs
126        self.do_handshake_on_connect = do_handshake_on_connect
127        self.suppress_ragged_eofs = suppress_ragged_eofs
128        self._makefile_refs = 0
129
130    def read(self, len=1024):
131
132        """Read up to LEN bytes and return them.
133        Return zero-length string on EOF."""
134
135        try:
136            return self._sslobj.read(len)
137        except SSLError, x:
138            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
139                return ''
140            else:
141                raise
142
143    def write(self, data):
144
145        """Write DATA to the underlying SSL channel.  Returns
146        number of bytes of DATA actually transmitted."""
147
148        return self._sslobj.write(data)
149
150    def getpeercert(self, binary_form=False):
151
152        """Returns a formatted version of the data in the
153        certificate provided by the other end of the SSL channel.
154        Return None if no certificate was provided, {} if a
155        certificate was provided, but not validated."""
156
157        return self._sslobj.peer_certificate(binary_form)
158
159    def cipher (self):
160
161        if not self._sslobj:
162            return None
163        else:
164            return self._sslobj.cipher()
165
166    def send (self, data, flags=0):
167        if self._sslobj:
168            if flags != 0:
169                raise ValueError(
170                    "non-zero flags not allowed in calls to send() on %s" %
171                    self.__class__)
172            while True:
173                try:
174                    v = self._sslobj.write(data)
175                except SSLError, x:
176                    if x.args[0] == SSL_ERROR_WANT_READ:
177                        return 0
178                    elif x.args[0] == SSL_ERROR_WANT_WRITE:
179                        return 0
180                    else:
181                        raise
182                else:
183                    return v
184        else:
185            return socket.send(self, data, flags)
186
187    def sendto (self, data, addr, flags=0):
188        if self._sslobj:
189            raise ValueError("sendto not allowed on instances of %s" %
190                             self.__class__)
191        else:
192            return socket.sendto(self, data, addr, flags)
193
194    def sendall (self, data, flags=0):
195        if self._sslobj:
196            if flags != 0:
197                raise ValueError(
198                    "non-zero flags not allowed in calls to sendall() on %s" %
199                    self.__class__)
200            amount = len(data)
201            count = 0
202            while (count < amount):
203                v = self.send(data[count:])
204                count += v
205            return amount
206        else:
207            return socket.sendall(self, data, flags)
208
209    def recv (self, buflen=1024, flags=0):
210        if self._sslobj:
211            if flags != 0:
212                raise ValueError(
213                    "non-zero flags not allowed in calls to sendall() on %s" %
214                    self.__class__)
215            while True:
216                try:
217                    return self.read(buflen)
218                except SSLError, x:
219                    if x.args[0] == SSL_ERROR_WANT_READ:
220                        continue
221                    else:
222                        raise x
223        else:
224            return socket.recv(self, buflen, flags)
225
226    def recv_into (self, buffer, nbytes=None, flags=0):
227        if buffer and (nbytes is None):
228            nbytes = len(buffer)
229        elif nbytes is None:
230            nbytes = 1024
231        if self._sslobj:
232            if flags != 0:
233                raise ValueError(
234                  "non-zero flags not allowed in calls to recv_into() on %s" %
235                  self.__class__)
236            while True:
237                try:
238                    tmp_buffer = self.read(nbytes)
239                    v = len(tmp_buffer)
240                    buffer[:v] = tmp_buffer
241                    return v
242                except SSLError as x:
243                    if x.args[0] == SSL_ERROR_WANT_READ:
244                        continue
245                    else:
246                        raise x
247        else:
248            return socket.recv_into(self, buffer, nbytes, flags)
249
250    def recvfrom (self, addr, buflen=1024, flags=0):
251        if self._sslobj:
252            raise ValueError("recvfrom not allowed on instances of %s" %
253                             self.__class__)
254        else:
255            return socket.recvfrom(self, addr, buflen, flags)
256
257    def recvfrom_into (self, buffer, nbytes=None, flags=0):
258        if self._sslobj:
259            raise ValueError("recvfrom_into not allowed on instances of %s" %
260                             self.__class__)
261        else:
262            return socket.recvfrom_into(self, buffer, nbytes, flags)
263
264    def pending (self):
265        if self._sslobj:
266            return self._sslobj.pending()
267        else:
268            return 0
269
270    def unwrap (self):
271        if self._sslobj:
272            s = self._sslobj.shutdown()
273            self._sslobj = None
274            return s
275        else:
276            raise ValueError("No SSL wrapper around " + str(self))
277
278    def shutdown (self, how):
279        self._sslobj = None
280        socket.shutdown(self, how)
281
282    def close (self):
283        if self._makefile_refs < 1:
284            self._sslobj = None
285            socket.close(self)
286        else:
287            self._makefile_refs -= 1
288
289    def do_handshake (self):
290
291        """Perform a TLS/SSL handshake."""
292
293        self._sslobj.do_handshake()
294
295    def connect(self, addr):
296
297        """Connects to remote ADDR, and then wraps the connection in
298        an SSL channel."""
299
300        # Here we assume that the socket is client-side, and not
301        # connected at the time of the call.  We connect it, then wrap it.
302        if self._sslobj:
303            raise ValueError("attempt to connect already-connected SSLSocket!")
304        socket.connect(self, addr)
305        self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
306                                    self.cert_reqs, self.ssl_version,
307                                    self.ca_certs)
308        if self.do_handshake_on_connect:
309            self.do_handshake()
310
311    def accept(self):
312
313        """Accepts a new connection from a remote client, and returns
314        a tuple containing that new connection wrapped with a server-side
315        SSL channel, and the address of the remote client."""
316
317        newsock, addr = socket.accept(self)
318        return (SSLSocket(newsock,
319                          keyfile=self.keyfile,
320                          certfile=self.certfile,
321                          server_side=True,
322                          cert_reqs=self.cert_reqs,
323                          ssl_version=self.ssl_version,
324                          ca_certs=self.ca_certs,
325                          do_handshake_on_connect=self.do_handshake_on_connect,
326                          suppress_ragged_eofs=self.suppress_ragged_eofs),
327                addr)
328
329    def makefile(self, mode='r', bufsize=-1):
330
331        """Make and return a file-like object that
332        works with the SSL connection.  Just use the code
333        from the socket module."""
334
335        self._makefile_refs += 1
336        return _fileobject(self, mode, bufsize)
337
338
339
340def wrap_socket(sock, keyfile=None, certfile=None,
341                server_side=False, cert_reqs=CERT_NONE,
342                ssl_version=PROTOCOL_SSLv23, ca_certs=None,
343                do_handshake_on_connect=True,
344                suppress_ragged_eofs=True):
345
346    return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
347                     server_side=server_side, cert_reqs=cert_reqs,
348                     ssl_version=ssl_version, ca_certs=ca_certs,
349                     do_handshake_on_connect=do_handshake_on_connect,
350                     suppress_ragged_eofs=suppress_ragged_eofs)
351
352
353# some utility functions
354
355def cert_time_to_seconds(cert_time):
356
357    """Takes a date-time string in standard ASN1_print form
358    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
359    a Python time value in seconds past the epoch."""
360
361    import time
362    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
363
364PEM_HEADER = "-----BEGIN CERTIFICATE-----"
365PEM_FOOTER = "-----END CERTIFICATE-----"
366
367def DER_cert_to_PEM_cert(der_cert_bytes):
368
369    """Takes a certificate in binary DER format and returns the
370    PEM version of it as a string."""
371
372    if hasattr(base64, 'standard_b64encode'):
373        # preferred because older API gets line-length wrong
374        f = base64.standard_b64encode(der_cert_bytes)
375        return (PEM_HEADER + '\n' +
376                textwrap.fill(f, 64) +
377                PEM_FOOTER + '\n')
378    else:
379        return (PEM_HEADER + '\n' +
380                base64.encodestring(der_cert_bytes) +
381                PEM_FOOTER + '\n')
382
383def PEM_cert_to_DER_cert(pem_cert_string):
384
385    """Takes a certificate in ASCII PEM format and returns the
386    DER-encoded version of it as a byte sequence"""
387
388    if not pem_cert_string.startswith(PEM_HEADER):
389        raise ValueError("Invalid PEM encoding; must start with %s"
390                         % PEM_HEADER)
391    if not pem_cert_string.strip().endswith(PEM_FOOTER):
392        raise ValueError("Invalid PEM encoding; must end with %s"
393                         % PEM_FOOTER)
394    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
395    return base64.decodestring(d)
396
397def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
398
399    """Retrieve the certificate from the server at the specified address,
400    and return it as a PEM-encoded string.
401    If 'ca_certs' is specified, validate the server cert against it.
402    If 'ssl_version' is specified, use it in the connection attempt."""
403
404    host, port = addr
405    if (ca_certs is not None):
406        cert_reqs = CERT_REQUIRED
407    else:
408        cert_reqs = CERT_NONE
409    s = wrap_socket(socket(), ssl_version=ssl_version,
410                    cert_reqs=cert_reqs, ca_certs=ca_certs)
411    s.connect(addr)
412    dercert = s.getpeercert(True)
413    s.close()
414    return DER_cert_to_PEM_cert(dercert)
415
416def get_protocol_name (protocol_code):
417    if protocol_code == PROTOCOL_TLSv1:
418        return "TLSv1"
419    elif protocol_code == PROTOCOL_SSLv23:
420        return "SSLv23"
421    elif protocol_code == PROTOCOL_SSLv2:
422        return "SSLv2"
423    elif protocol_code == PROTOCOL_SSLv3:
424        return "SSLv3"
425    else:
426        return "<unknown>"
427
428
429# a replacement for the old socket.ssl function
430
431def sslwrap_simple (sock, keyfile=None, certfile=None):
432
433    """A replacement for the old socket.ssl function.  Designed
434    for compability with Python 2.5 and earlier.  Will disappear in
435    Python 3.0."""
436
437    if hasattr(sock, "_sock"):
438        sock = sock._sock
439
440    ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE,
441                            PROTOCOL_SSLv23, None)
442    try:
443        sock.getpeername()
444    except:
445        # no, no connection yet
446        pass
447    else:
448        # yes, do the handshake
449        ssl_sock.do_handshake()
450
451    return ssl_sock