/Lib/ssl.py

http://unladen-swallow.googlecode.com/ · Python · 451 lines · 300 code · 61 blank · 90 comment · 65 complexity · a29337a19f06c1c1c3eccae9835ad91f MD5 · raw file

  1. # Wrapper module for _ssl, providing some additional facilities
  2. # implemented in Python. Written by Bill Janssen.
  3. """\
  4. This module provides some more Pythonic support for SSL.
  5. Object types:
  6. SSLSocket -- subtype of socket.socket which does SSL over the socket
  7. Exceptions:
  8. SSLError -- exception raised for I/O errors
  9. Functions:
  10. cert_time_to_seconds -- convert time string used for certificate
  11. notBefore and notAfter functions to integer
  12. seconds past the Epoch (the time values
  13. returned from time.time())
  14. fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
  15. by the server running on HOST at port PORT. No
  16. validation of the certificate is performed.
  17. Integer constants:
  18. SSL_ERROR_ZERO_RETURN
  19. SSL_ERROR_WANT_READ
  20. SSL_ERROR_WANT_WRITE
  21. SSL_ERROR_WANT_X509_LOOKUP
  22. SSL_ERROR_SYSCALL
  23. SSL_ERROR_SSL
  24. SSL_ERROR_WANT_CONNECT
  25. SSL_ERROR_EOF
  26. SSL_ERROR_INVALID_ERROR_CODE
  27. The following group define certificate requirements that one side is
  28. allowing/requiring from the other side:
  29. CERT_NONE - no certificates from the other side are required (or will
  30. be looked at if provided)
  31. CERT_OPTIONAL - certificates are not required, but if provided will be
  32. validated, and if validation fails, the connection will
  33. also fail
  34. CERT_REQUIRED - certificates are required, and will be validated, and
  35. if validation fails, the connection will also fail
  36. The following constants identify various SSL protocol variants:
  37. PROTOCOL_SSLv2
  38. PROTOCOL_SSLv3
  39. PROTOCOL_SSLv23
  40. PROTOCOL_TLSv1
  41. """
  42. import textwrap
  43. import _ssl # if we can't import it, let the error propagate
  44. from _ssl import SSLError
  45. from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
  46. from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
  47. from _ssl import RAND_status, RAND_egd, RAND_add
  48. from _ssl import \
  49. SSL_ERROR_ZERO_RETURN, \
  50. SSL_ERROR_WANT_READ, \
  51. SSL_ERROR_WANT_WRITE, \
  52. SSL_ERROR_WANT_X509_LOOKUP, \
  53. SSL_ERROR_SYSCALL, \
  54. SSL_ERROR_SSL, \
  55. SSL_ERROR_WANT_CONNECT, \
  56. SSL_ERROR_EOF, \
  57. SSL_ERROR_INVALID_ERROR_CODE
  58. from socket import socket, _fileobject
  59. from socket import getnameinfo as _getnameinfo
  60. import base64 # for DER-to-PEM translation
  61. class SSLSocket (socket):
  62. """This class implements a subtype of socket.socket that wraps
  63. the underlying OS socket in an SSL context when necessary, and
  64. provides read and write methods over that channel."""
  65. def __init__(self, sock, keyfile=None, certfile=None,
  66. server_side=False, cert_reqs=CERT_NONE,
  67. ssl_version=PROTOCOL_SSLv23, ca_certs=None,
  68. do_handshake_on_connect=True,
  69. suppress_ragged_eofs=True):
  70. socket.__init__(self, _sock=sock._sock)
  71. # the initializer for socket trashes the methods (tsk, tsk), so...
  72. self.send = lambda data, flags=0: SSLSocket.send(self, data, flags)
  73. self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags)
  74. self.recv = lambda buflen=1024, flags=0: SSLSocket.recv(self, buflen, flags)
  75. self.recvfrom = lambda addr, buflen=1024, flags=0: SSLSocket.recvfrom(self, addr, buflen, flags)
  76. self.recv_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recv_into(self, buffer, nbytes, flags)
  77. self.recvfrom_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recvfrom_into(self, buffer, nbytes, flags)
  78. if certfile and not keyfile:
  79. keyfile = certfile
  80. # see if it's connected
  81. try:
  82. socket.getpeername(self)
  83. except:
  84. # no, no connection yet
  85. self._sslobj = None
  86. else:
  87. # yes, create the SSL object
  88. self._sslobj = _ssl.sslwrap(self._sock, server_side,
  89. keyfile, certfile,
  90. cert_reqs, ssl_version, ca_certs)
  91. if do_handshake_on_connect:
  92. timeout = self.gettimeout()
  93. try:
  94. self.settimeout(None)
  95. self.do_handshake()
  96. finally:
  97. self.settimeout(timeout)
  98. self.keyfile = keyfile
  99. self.certfile = certfile
  100. self.cert_reqs = cert_reqs
  101. self.ssl_version = ssl_version
  102. self.ca_certs = ca_certs
  103. self.do_handshake_on_connect = do_handshake_on_connect
  104. self.suppress_ragged_eofs = suppress_ragged_eofs
  105. self._makefile_refs = 0
  106. def read(self, len=1024):
  107. """Read up to LEN bytes and return them.
  108. Return zero-length string on EOF."""
  109. try:
  110. return self._sslobj.read(len)
  111. except SSLError, x:
  112. if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
  113. return ''
  114. else:
  115. raise
  116. def write(self, data):
  117. """Write DATA to the underlying SSL channel. Returns
  118. number of bytes of DATA actually transmitted."""
  119. return self._sslobj.write(data)
  120. def getpeercert(self, binary_form=False):
  121. """Returns a formatted version of the data in the
  122. certificate provided by the other end of the SSL channel.
  123. Return None if no certificate was provided, {} if a
  124. certificate was provided, but not validated."""
  125. return self._sslobj.peer_certificate(binary_form)
  126. def cipher (self):
  127. if not self._sslobj:
  128. return None
  129. else:
  130. return self._sslobj.cipher()
  131. def send (self, data, flags=0):
  132. if self._sslobj:
  133. if flags != 0:
  134. raise ValueError(
  135. "non-zero flags not allowed in calls to send() on %s" %
  136. self.__class__)
  137. while True:
  138. try:
  139. v = self._sslobj.write(data)
  140. except SSLError, x:
  141. if x.args[0] == SSL_ERROR_WANT_READ:
  142. return 0
  143. elif x.args[0] == SSL_ERROR_WANT_WRITE:
  144. return 0
  145. else:
  146. raise
  147. else:
  148. return v
  149. else:
  150. return socket.send(self, data, flags)
  151. def sendto (self, data, addr, flags=0):
  152. if self._sslobj:
  153. raise ValueError("sendto not allowed on instances of %s" %
  154. self.__class__)
  155. else:
  156. return socket.sendto(self, data, addr, flags)
  157. def sendall (self, data, flags=0):
  158. if self._sslobj:
  159. if flags != 0:
  160. raise ValueError(
  161. "non-zero flags not allowed in calls to sendall() on %s" %
  162. self.__class__)
  163. amount = len(data)
  164. count = 0
  165. while (count < amount):
  166. v = self.send(data[count:])
  167. count += v
  168. return amount
  169. else:
  170. return socket.sendall(self, data, flags)
  171. def recv (self, buflen=1024, flags=0):
  172. if self._sslobj:
  173. if flags != 0:
  174. raise ValueError(
  175. "non-zero flags not allowed in calls to sendall() on %s" %
  176. self.__class__)
  177. while True:
  178. try:
  179. return self.read(buflen)
  180. except SSLError, x:
  181. if x.args[0] == SSL_ERROR_WANT_READ:
  182. continue
  183. else:
  184. raise x
  185. else:
  186. return socket.recv(self, buflen, flags)
  187. def recv_into (self, buffer, nbytes=None, flags=0):
  188. if buffer and (nbytes is None):
  189. nbytes = len(buffer)
  190. elif nbytes is None:
  191. nbytes = 1024
  192. if self._sslobj:
  193. if flags != 0:
  194. raise ValueError(
  195. "non-zero flags not allowed in calls to recv_into() on %s" %
  196. self.__class__)
  197. while True:
  198. try:
  199. tmp_buffer = self.read(nbytes)
  200. v = len(tmp_buffer)
  201. buffer[:v] = tmp_buffer
  202. return v
  203. except SSLError as x:
  204. if x.args[0] == SSL_ERROR_WANT_READ:
  205. continue
  206. else:
  207. raise x
  208. else:
  209. return socket.recv_into(self, buffer, nbytes, flags)
  210. def recvfrom (self, addr, buflen=1024, flags=0):
  211. if self._sslobj:
  212. raise ValueError("recvfrom not allowed on instances of %s" %
  213. self.__class__)
  214. else:
  215. return socket.recvfrom(self, addr, buflen, flags)
  216. def recvfrom_into (self, buffer, nbytes=None, flags=0):
  217. if self._sslobj:
  218. raise ValueError("recvfrom_into not allowed on instances of %s" %
  219. self.__class__)
  220. else:
  221. return socket.recvfrom_into(self, buffer, nbytes, flags)
  222. def pending (self):
  223. if self._sslobj:
  224. return self._sslobj.pending()
  225. else:
  226. return 0
  227. def unwrap (self):
  228. if self._sslobj:
  229. s = self._sslobj.shutdown()
  230. self._sslobj = None
  231. return s
  232. else:
  233. raise ValueError("No SSL wrapper around " + str(self))
  234. def shutdown (self, how):
  235. self._sslobj = None
  236. socket.shutdown(self, how)
  237. def close (self):
  238. if self._makefile_refs < 1:
  239. self._sslobj = None
  240. socket.close(self)
  241. else:
  242. self._makefile_refs -= 1
  243. def do_handshake (self):
  244. """Perform a TLS/SSL handshake."""
  245. self._sslobj.do_handshake()
  246. def connect(self, addr):
  247. """Connects to remote ADDR, and then wraps the connection in
  248. an SSL channel."""
  249. # Here we assume that the socket is client-side, and not
  250. # connected at the time of the call. We connect it, then wrap it.
  251. if self._sslobj:
  252. raise ValueError("attempt to connect already-connected SSLSocket!")
  253. socket.connect(self, addr)
  254. self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
  255. self.cert_reqs, self.ssl_version,
  256. self.ca_certs)
  257. if self.do_handshake_on_connect:
  258. self.do_handshake()
  259. def accept(self):
  260. """Accepts a new connection from a remote client, and returns
  261. a tuple containing that new connection wrapped with a server-side
  262. SSL channel, and the address of the remote client."""
  263. newsock, addr = socket.accept(self)
  264. return (SSLSocket(newsock,
  265. keyfile=self.keyfile,
  266. certfile=self.certfile,
  267. server_side=True,
  268. cert_reqs=self.cert_reqs,
  269. ssl_version=self.ssl_version,
  270. ca_certs=self.ca_certs,
  271. do_handshake_on_connect=self.do_handshake_on_connect,
  272. suppress_ragged_eofs=self.suppress_ragged_eofs),
  273. addr)
  274. def makefile(self, mode='r', bufsize=-1):
  275. """Make and return a file-like object that
  276. works with the SSL connection. Just use the code
  277. from the socket module."""
  278. self._makefile_refs += 1
  279. return _fileobject(self, mode, bufsize)
  280. def wrap_socket(sock, keyfile=None, certfile=None,
  281. server_side=False, cert_reqs=CERT_NONE,
  282. ssl_version=PROTOCOL_SSLv23, ca_certs=None,
  283. do_handshake_on_connect=True,
  284. suppress_ragged_eofs=True):
  285. return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
  286. server_side=server_side, cert_reqs=cert_reqs,
  287. ssl_version=ssl_version, ca_certs=ca_certs,
  288. do_handshake_on_connect=do_handshake_on_connect,
  289. suppress_ragged_eofs=suppress_ragged_eofs)
  290. # some utility functions
  291. def cert_time_to_seconds(cert_time):
  292. """Takes a date-time string in standard ASN1_print form
  293. ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
  294. a Python time value in seconds past the epoch."""
  295. import time
  296. return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
  297. PEM_HEADER = "-----BEGIN CERTIFICATE-----"
  298. PEM_FOOTER = "-----END CERTIFICATE-----"
  299. def DER_cert_to_PEM_cert(der_cert_bytes):
  300. """Takes a certificate in binary DER format and returns the
  301. PEM version of it as a string."""
  302. if hasattr(base64, 'standard_b64encode'):
  303. # preferred because older API gets line-length wrong
  304. f = base64.standard_b64encode(der_cert_bytes)
  305. return (PEM_HEADER + '\n' +
  306. textwrap.fill(f, 64) +
  307. PEM_FOOTER + '\n')
  308. else:
  309. return (PEM_HEADER + '\n' +
  310. base64.encodestring(der_cert_bytes) +
  311. PEM_FOOTER + '\n')
  312. def PEM_cert_to_DER_cert(pem_cert_string):
  313. """Takes a certificate in ASCII PEM format and returns the
  314. DER-encoded version of it as a byte sequence"""
  315. if not pem_cert_string.startswith(PEM_HEADER):
  316. raise ValueError("Invalid PEM encoding; must start with %s"
  317. % PEM_HEADER)
  318. if not pem_cert_string.strip().endswith(PEM_FOOTER):
  319. raise ValueError("Invalid PEM encoding; must end with %s"
  320. % PEM_FOOTER)
  321. d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
  322. return base64.decodestring(d)
  323. def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
  324. """Retrieve the certificate from the server at the specified address,
  325. and return it as a PEM-encoded string.
  326. If 'ca_certs' is specified, validate the server cert against it.
  327. If 'ssl_version' is specified, use it in the connection attempt."""
  328. host, port = addr
  329. if (ca_certs is not None):
  330. cert_reqs = CERT_REQUIRED
  331. else:
  332. cert_reqs = CERT_NONE
  333. s = wrap_socket(socket(), ssl_version=ssl_version,
  334. cert_reqs=cert_reqs, ca_certs=ca_certs)
  335. s.connect(addr)
  336. dercert = s.getpeercert(True)
  337. s.close()
  338. return DER_cert_to_PEM_cert(dercert)
  339. def get_protocol_name (protocol_code):
  340. if protocol_code == PROTOCOL_TLSv1:
  341. return "TLSv1"
  342. elif protocol_code == PROTOCOL_SSLv23:
  343. return "SSLv23"
  344. elif protocol_code == PROTOCOL_SSLv2:
  345. return "SSLv2"
  346. elif protocol_code == PROTOCOL_SSLv3:
  347. return "SSLv3"
  348. else:
  349. return "<unknown>"
  350. # a replacement for the old socket.ssl function
  351. def sslwrap_simple (sock, keyfile=None, certfile=None):
  352. """A replacement for the old socket.ssl function. Designed
  353. for compability with Python 2.5 and earlier. Will disappear in
  354. Python 3.0."""
  355. if hasattr(sock, "_sock"):
  356. sock = sock._sock
  357. ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE,
  358. PROTOCOL_SSLv23, None)
  359. try:
  360. sock.getpeername()
  361. except:
  362. # no, no connection yet
  363. pass
  364. else:
  365. # yes, do the handshake
  366. ssl_sock.do_handshake()
  367. return ssl_sock