/wsgiauth0/__init__.py

https://gitlab.com/dialogue/wsgiauth0 · Python · 252 lines · 183 code · 57 blank · 12 comment · 31 complexity · 047d35739c5229c9eb7ca11618a8b3b0 MD5 · raw file

  1. """WSGI Auth0 middleware that check for HS256 and RS256 JWT.
  2. Encoded JWT are expected in `Authorizaion` http header, ex::
  3. Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.....
  4. """
  5. import logging
  6. import sys
  7. from collections import namedtuple
  8. from jose import jws, jwt
  9. from jose.exceptions import JOSEError, JWTError
  10. from jose.utils import base64url_decode
  11. from .exception import Error
  12. PY2 = sys.version_info[0] == 2
  13. log = logging.getLogger(__name__)
  14. Client = namedtuple('Client', 'label id audience secret')
  15. Secret = namedtuple('Secret', 'type value')
  16. def factory(application, config=None, **kwargs):
  17. monkeypatch_jws_get_keys()
  18. config = config.copy() if config else {}
  19. config.update(kwargs)
  20. app = auth0_middleware(application, config)
  21. return app
  22. def auth0_middleware(application, config):
  23. log.info('Setup auth0_middleware')
  24. log.debug('application=%s config=%s', application, config)
  25. clients = read_clients(config)
  26. def app(environ, start_response):
  27. authorization = environ.get('HTTP_AUTHORIZATION')
  28. jwt_environ = validate_jwt_claims(clients, authorization)
  29. environ.update(jwt_environ)
  30. return application(environ, start_response)
  31. return app
  32. def validate_jwt_claims(clients, authorization):
  33. claims = None
  34. jwt_environ = {}
  35. if not authorization:
  36. jwt_environ['wsgiauth0.jwt_token'] = None
  37. jwt_environ['wsgiauth0.jwt_claims'] = None
  38. jwt_environ['wsgiauth0.jwt_auth0_client'] = None
  39. jwt_environ['wsgiauth0.jwt_error'] = {
  40. 'code': 'no_authorization',
  41. 'description': 'No authorization in headers.',
  42. 'origin': None,
  43. }
  44. else:
  45. client = None
  46. try:
  47. token = extract_token(authorization)
  48. jwt_environ['wsgiauth0.jwt_token'] = token
  49. client = extract_client(clients, token)
  50. claims = jwt.decode(
  51. token,
  52. client.secret.value,
  53. audience=client.audience,
  54. )
  55. except JOSEError as jose_error:
  56. log.warn(
  57. 'Fail decoding authorization=%r error=%r',
  58. authorization,
  59. jose_error,
  60. exc_info=True,
  61. )
  62. jwt_environ['wsgiauth0.jwt_claims'] = None
  63. jwt_environ['wsgiauth0.jwt_error'] = {
  64. 'code': 'invalid_token',
  65. 'description': repr(jose_error),
  66. 'origin': jose_error,
  67. }
  68. except Error as jwt_error:
  69. log.warn(
  70. 'Fail extracting info from authorization=%r error=%r',
  71. authorization,
  72. jwt_error,
  73. exc_info=True,
  74. )
  75. jwt_environ['wsgiauth0.jwt_claims'] = None
  76. jwt_environ['wsgiauth0.jwt_error'] = jwt_error.to_dict()
  77. else:
  78. jwt_environ['wsgiauth0.jwt_claims'] = claims
  79. jwt_environ['wsgiauth0.jwt_error'] = None
  80. if client is not None:
  81. client_dict = client._asdict()
  82. client_dict.pop('secret')
  83. jwt_environ['wsgiauth0.jwt_auth0_client'] = client_dict
  84. else:
  85. jwt_environ['wsgiauth0.jwt_auth0_client'] = None
  86. if claims is not None and 'sub' in claims:
  87. jwt_environ['REMOTE_USER'] = claims['sub']
  88. log.debug(
  89. 'wsgiauth0.jwt_error=%s wsgiauth0.jwt_claims=%s '
  90. 'wsgiauth0.jwt_auth0_client=%s',
  91. jwt_environ['wsgiauth0.jwt_error'],
  92. jwt_environ['wsgiauth0.jwt_claims'],
  93. jwt_environ['wsgiauth0.jwt_auth0_client'],
  94. )
  95. return jwt_environ
  96. def read_clients(config):
  97. """Load client configuration from all possible sources."""
  98. client_settings = read_clients_settings(config)
  99. log.debug('client_settings: %s', client_settings)
  100. verify_client_settings(client_settings)
  101. return parse_clients(client_settings)
  102. def read_clients_settings(config):
  103. """Load client configuration from all possible sources."""
  104. from .config_yaml import config_yaml
  105. from .config_dynamodb import config_dynamodb
  106. client_settings = config.get('clients', {})
  107. clients = config_yaml(config)
  108. client_settings.update(clients)
  109. clients = config_dynamodb(config)
  110. client_settings.update(clients)
  111. return client_settings
  112. def verify_client_settings(client_settings):
  113. if not client_settings:
  114. raise Error(
  115. 'missing_config',
  116. "No auth0 clients configured",
  117. )
  118. def parse_clients(client_settings):
  119. """Convert input client specs to clients map used for lookup."""
  120. return {client.id: client
  121. for client in map(parse_client, client_settings.items())}
  122. def parse_client(item):
  123. """Convert input client specs to clients map used for lookup."""
  124. label, client_dict = item
  125. try:
  126. secret_type = client_dict['secret']['type']
  127. secret_value = client_dict['secret']['value']
  128. client_id = client_dict['id']
  129. audience = client_dict['audience']
  130. except (TypeError, KeyError):
  131. raise Error(
  132. 'missing_config_key',
  133. 'Client config missing key client_dict.',
  134. )
  135. if secret_type == 'base64_url_encoded':
  136. if PY2:
  137. secret_value = secret_value.encode('utf-8')
  138. secret_value = base64url_decode(secret_value)
  139. return Client(
  140. label=label,
  141. id=client_id,
  142. secret=Secret(type=secret_type, value=secret_value),
  143. audience=audience,
  144. )
  145. def extract_token(authorization):
  146. parts = authorization.split()
  147. if len(parts) != 2:
  148. raise Error(
  149. 'invalid_header',
  150. 'Authorization header must be "Bearer token".',
  151. )
  152. if parts[0].lower() != 'bearer':
  153. raise Error(
  154. 'invalid_header',
  155. 'Authorization header must start with "Bearer".',
  156. )
  157. return parts[1]
  158. def extract_client(clients, token):
  159. try:
  160. claims = jwt.get_unverified_claims(token)
  161. except JWTError:
  162. raise Error('invalid_token', 'Error decoding token claims.')
  163. try:
  164. audience = claims['aud']
  165. except KeyError:
  166. raise Error('invalid_claims', 'No key aud in claims.')
  167. if audience in clients:
  168. return clients[audience]
  169. try:
  170. subject = claims['sub']
  171. except KeyError:
  172. raise Error('invalid_claims', 'No key sub in claims.')
  173. try:
  174. return clients[subject]
  175. except KeyError:
  176. log.debug(
  177. 'No client found for: audience %s, subject %s',
  178. audience,
  179. subject,
  180. )
  181. raise Error('invalid_client', 'No config found for this client.')
  182. original_get_keys = None
  183. def monkeypatch_jws_get_keys(): # pragma: no cover
  184. # Monkey patch jws._get_keys to avoid failing with a base64 decoded secret
  185. global original_get_keys
  186. if original_get_keys is None:
  187. original_get_keys = jws._get_keys
  188. def jws_get_keys(key):
  189. if isinstance(key, bytes):
  190. return (key, )
  191. return original_get_keys(key)
  192. jws._get_keys = jws_get_keys