PageRenderTime 8243ms CodeModel.GetById 22ms RepoModel.GetById 1ms app.codeStats 0ms

/src/saml2/config.py

https://github.com/daryllstrauss/pysaml2
Python | 545 lines | 527 code | 17 blank | 1 comment | 0 complexity | c89e6c9a0de79e04e0aa76edcc9ff416 MD5 | raw file
  1. #!/usr/bin/env python
  2. __author__ = 'rolandh'
  3. import copy
  4. import sys
  5. import os
  6. import re
  7. import logging
  8. import logging.handlers
  9. from importlib import import_module
  10. from saml2 import root_logger, BINDING_URI, SAMLError
  11. from saml2 import BINDING_SOAP
  12. from saml2 import BINDING_HTTP_REDIRECT
  13. from saml2 import BINDING_HTTP_POST
  14. from saml2 import BINDING_HTTP_ARTIFACT
  15. from saml2.attribute_converter import ac_factory
  16. from saml2.assertion import Policy
  17. from saml2.mdstore import MetadataStore
  18. from saml2.virtual_org import VirtualOrg
  19. logger = logging.getLogger(__name__)
  20. from saml2 import md
  21. from saml2 import saml
  22. from saml2.extension import mdui
  23. from saml2.extension import idpdisc
  24. from saml2.extension import dri
  25. from saml2.extension import mdattr
  26. from saml2.extension import ui
  27. import xmldsig
  28. import xmlenc
  29. ONTS = {
  30. saml.NAMESPACE: saml,
  31. mdui.NAMESPACE: mdui,
  32. mdattr.NAMESPACE: mdattr,
  33. dri.NAMESPACE: dri,
  34. ui.NAMESPACE: ui,
  35. idpdisc.NAMESPACE: idpdisc,
  36. md.NAMESPACE: md,
  37. xmldsig.NAMESPACE: xmldsig,
  38. xmlenc.NAMESPACE: xmlenc
  39. }
  40. COMMON_ARGS = [
  41. "entityid", "xmlsec_binary", "debug", "key_file", "cert_file",
  42. "encryption_type", "secret", "accepted_time_diff", "name", "ca_certs",
  43. "description", "valid_for", "verify_ssl_cert",
  44. "organization",
  45. "contact_person",
  46. "name_form",
  47. "virtual_organization",
  48. "logger",
  49. "only_use_keys_in_metadata",
  50. "logout_requests_signed",
  51. "disable_ssl_certificate_validation",
  52. "referred_binding",
  53. "session_storage",
  54. "entity_category",
  55. "xmlsec_path",
  56. "extension_schemas",
  57. "cert_handler_extra_class",
  58. "generate_cert_func",
  59. "generate_cert_info",
  60. "verify_encrypt_cert",
  61. "tmp_cert_file",
  62. "tmp_key_file",
  63. "validate_certificate",
  64. "extensions"
  65. ]
  66. SP_ARGS = [
  67. "required_attributes",
  68. "optional_attributes",
  69. "idp",
  70. "aa",
  71. "subject_data",
  72. "want_response_signed",
  73. "want_assertions_signed",
  74. "authn_requests_signed",
  75. "name_form",
  76. "endpoints",
  77. "ui_info",
  78. "discovery_response",
  79. "allow_unsolicited",
  80. "ecp",
  81. "name_id_format",
  82. "allow_unknown_attributes"
  83. ]
  84. AA_IDP_ARGS = [
  85. "sign_assertion",
  86. "sign_response",
  87. "encrypt_assertion",
  88. "want_authn_requests_signed",
  89. "want_authn_requests_only_with_valid_cert",
  90. "provided_attributes",
  91. "subject_data",
  92. "sp",
  93. "scope",
  94. "endpoints",
  95. "metadata",
  96. "ui_info",
  97. "name_id_format",
  98. "domain",
  99. "name_qualifier",
  100. "edu_person_targeted_id",
  101. ]
  102. PDP_ARGS = ["endpoints", "name_form", "name_id_format"]
  103. AQ_ARGS = ["endpoints"]
  104. AA_ARGS = ["attribute", "attribute_profile"]
  105. COMPLEX_ARGS = ["attribute_converters", "metadata", "policy"]
  106. ALL = set(COMMON_ARGS + SP_ARGS + AA_IDP_ARGS + PDP_ARGS + COMPLEX_ARGS +
  107. AA_ARGS)
  108. SPEC = {
  109. "": COMMON_ARGS + COMPLEX_ARGS,
  110. "sp": COMMON_ARGS + COMPLEX_ARGS + SP_ARGS,
  111. "idp": COMMON_ARGS + COMPLEX_ARGS + AA_IDP_ARGS,
  112. "aa": COMMON_ARGS + COMPLEX_ARGS + AA_IDP_ARGS + AA_ARGS,
  113. "pdp": COMMON_ARGS + COMPLEX_ARGS + PDP_ARGS,
  114. "aq": COMMON_ARGS + COMPLEX_ARGS + AQ_ARGS,
  115. }
  116. # --------------- Logging stuff ---------------
  117. LOG_LEVEL = {
  118. 'debug': logging.DEBUG,
  119. 'info': logging.INFO,
  120. 'warning': logging.WARNING,
  121. 'error': logging.ERROR,
  122. 'critical': logging.CRITICAL}
  123. LOG_HANDLER = {
  124. "rotating": logging.handlers.RotatingFileHandler,
  125. "syslog": logging.handlers.SysLogHandler,
  126. "timerotate": logging.handlers.TimedRotatingFileHandler,
  127. "memory": logging.handlers.MemoryHandler,
  128. }
  129. LOG_FORMAT = "%(asctime)s %(name)s:%(levelname)s %(message)s"
  130. _RPA = [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST, BINDING_HTTP_ARTIFACT]
  131. _PRA = [BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, BINDING_HTTP_ARTIFACT]
  132. _SRPA = [BINDING_SOAP, BINDING_HTTP_REDIRECT, BINDING_HTTP_POST,
  133. BINDING_HTTP_ARTIFACT]
  134. PREFERRED_BINDING = {
  135. "single_logout_service": _SRPA,
  136. "manage_name_id_service": _SRPA,
  137. "assertion_consumer_service": _PRA,
  138. "single_sign_on_service": _RPA,
  139. "name_id_mapping_service": [BINDING_SOAP],
  140. "authn_query_service": [BINDING_SOAP],
  141. "attribute_service": [BINDING_SOAP],
  142. "authz_service": [BINDING_SOAP],
  143. "assertion_id_request_service": [BINDING_URI],
  144. "artifact_resolution_service": [BINDING_SOAP],
  145. "attribute_consuming_service": _RPA
  146. }
  147. class ConfigurationError(SAMLError):
  148. pass
  149. # -----------------------------------------------------------------
  150. class Config(object):
  151. def_context = ""
  152. def __init__(self, homedir="."):
  153. self._homedir = homedir
  154. self.entityid = None
  155. self.xmlsec_binary = None
  156. self.xmlsec_path = []
  157. self.debug = False
  158. self.key_file = None
  159. self.cert_file = None
  160. self.encryption_type = 'both'
  161. self.secret = None
  162. self.accepted_time_diff = None
  163. self.name = None
  164. self.ca_certs = None
  165. self.verify_ssl_cert = False
  166. self.description = None
  167. self.valid_for = None
  168. self.organization = None
  169. self.contact_person = None
  170. self.name_form = None
  171. self.name_id_format = None
  172. self.virtual_organization = None
  173. self.logger = None
  174. self.only_use_keys_in_metadata = True
  175. self.logout_requests_signed = None
  176. self.disable_ssl_certificate_validation = None
  177. self.context = ""
  178. self.attribute_converters = None
  179. self.metadata = None
  180. self.policy = None
  181. self.serves = []
  182. self.vorg = {}
  183. self.preferred_binding = PREFERRED_BINDING
  184. self.domain = ""
  185. self.name_qualifier = ""
  186. self.entity_category = ""
  187. self.crypto_backend = 'xmlsec1'
  188. self.scope = ""
  189. self.allow_unknown_attributes = False
  190. self.extension_schema = {}
  191. self.cert_handler_extra_class = None
  192. self.verify_encrypt_cert = None
  193. self.generate_cert_func = None
  194. self.generate_cert_info = None
  195. self.tmp_cert_file = None
  196. self.tmp_key_file = None
  197. self.validate_certificate = None
  198. self.extensions = {}
  199. self.attribute = []
  200. self.attribute_profile = []
  201. def setattr(self, context, attr, val):
  202. if context == "":
  203. setattr(self, attr, val)
  204. else:
  205. setattr(self, "_%s_%s" % (context, attr), val)
  206. def getattr(self, attr, context=None):
  207. if context is None:
  208. context = self.context
  209. if context == "":
  210. return getattr(self, attr, None)
  211. else:
  212. return getattr(self, "_%s_%s" % (context, attr), None)
  213. def load_special(self, cnf, typ, metadata_construction=False):
  214. for arg in SPEC[typ]:
  215. try:
  216. self.setattr(typ, arg, cnf[arg])
  217. except KeyError:
  218. pass
  219. self.context = typ
  220. self.load_complex(cnf, typ, metadata_construction=metadata_construction)
  221. self.context = self.def_context
  222. def load_complex(self, cnf, typ="", metadata_construction=False):
  223. try:
  224. self.setattr(typ, "policy", Policy(cnf["policy"]))
  225. except KeyError:
  226. pass
  227. # for srv, spec in cnf["service"].items():
  228. # try:
  229. # self.setattr(srv, "policy",
  230. # Policy(cnf["service"][srv]["policy"]))
  231. # except KeyError:
  232. # pass
  233. try:
  234. try:
  235. acs = ac_factory(cnf["attribute_map_dir"])
  236. except KeyError:
  237. acs = ac_factory()
  238. if not acs:
  239. raise ConfigurationError(
  240. "No attribute converters, something is wrong!!")
  241. _acs = self.getattr("attribute_converters", typ)
  242. if _acs:
  243. _acs.extend(acs)
  244. else:
  245. self.setattr(typ, "attribute_converters", acs)
  246. except KeyError:
  247. pass
  248. if not metadata_construction:
  249. try:
  250. self.setattr(typ, "metadata",
  251. self.load_metadata(cnf["metadata"]))
  252. except KeyError:
  253. pass
  254. def unicode_convert(self, item):
  255. try:
  256. return unicode(item, "utf-8")
  257. except TypeError:
  258. _uc = self.unicode_convert
  259. if isinstance(item, dict):
  260. return dict([(key, _uc(val)) for key, val in item.items()])
  261. elif isinstance(item, list):
  262. return [_uc(v) for v in item]
  263. elif isinstance(item, tuple):
  264. return tuple([_uc(v) for v in item])
  265. else:
  266. return item
  267. def load(self, cnf, metadata_construction=False):
  268. """ The base load method, loads the configuration
  269. :param cnf: The configuration as a dictionary
  270. :param metadata_construction: Is this only to be able to construct
  271. metadata. If so some things can be left out.
  272. :return: The Configuration instance
  273. """
  274. _uc = self.unicode_convert
  275. for arg in COMMON_ARGS:
  276. if arg == "virtual_organization":
  277. if "virtual_organization" in cnf:
  278. for key, val in cnf["virtual_organization"].items():
  279. self.vorg[key] = VirtualOrg(None, key, val)
  280. continue
  281. elif arg == "extension_schemas":
  282. # List of filename of modules representing the schemas
  283. if "extension_schemas" in cnf:
  284. for mod_file in cnf["extension_schemas"]:
  285. _mod = self._load(mod_file)
  286. self.extension_schema[_mod.NAMESPACE] = _mod
  287. try:
  288. setattr(self, arg, _uc(cnf[arg]))
  289. except KeyError:
  290. pass
  291. except TypeError: # Something that can't be a string
  292. setattr(self, arg, cnf[arg])
  293. if "service" in cnf:
  294. for typ in ["aa", "idp", "sp", "pdp", "aq"]:
  295. try:
  296. self.load_special(
  297. cnf["service"][typ], typ,
  298. metadata_construction=metadata_construction)
  299. self.serves.append(typ)
  300. except KeyError:
  301. pass
  302. if "extensions" in cnf:
  303. self.do_extensions(cnf["extensions"])
  304. self.load_complex(cnf, metadata_construction=metadata_construction)
  305. self.context = self.def_context
  306. return self
  307. def _load(self, fil):
  308. head, tail = os.path.split(fil)
  309. if head == "":
  310. if sys.path[0] != ".":
  311. sys.path.insert(0, ".")
  312. else:
  313. sys.path.insert(0, head)
  314. return import_module(tail)
  315. def load_file(self, config_file, metadata_construction=False):
  316. if config_file.endswith(".py"):
  317. config_file = config_file[:-3]
  318. mod = self._load(config_file)
  319. #return self.load(eval(open(config_file).read()))
  320. return self.load(copy.deepcopy(mod.CONFIG), metadata_construction)
  321. def load_metadata(self, metadata_conf):
  322. """ Loads metadata into an internal structure """
  323. acs = self.attribute_converters
  324. if acs is None:
  325. raise ConfigurationError(
  326. "Missing attribute converter specification")
  327. try:
  328. ca_certs = self.ca_certs
  329. except:
  330. ca_certs = None
  331. try:
  332. disable_validation = self.disable_ssl_certificate_validation
  333. except:
  334. disable_validation = False
  335. mds = MetadataStore(
  336. ONTS.values(), acs, self, ca_certs,
  337. disable_ssl_certificate_validation=disable_validation)
  338. mds.imp(metadata_conf)
  339. return mds
  340. def endpoint(self, service, binding=None, context=None):
  341. """ Goes through the list of endpoint specifications for the
  342. given type of service and returnes the first endpoint that matches
  343. the given binding. If no binding is given any endpoint for that
  344. service will be returned.
  345. :param service: The service the endpoint should support
  346. :param binding: The expected binding
  347. :return: All the endpoints that matches the given restrictions
  348. """
  349. spec = []
  350. unspec = []
  351. endps = self.getattr("endpoints", context)
  352. if endps and service in endps:
  353. for endpspec in endps[service]:
  354. try:
  355. endp, bind = endpspec
  356. if binding is None or bind == binding:
  357. spec.append(endp)
  358. except ValueError:
  359. unspec.append(endpspec)
  360. if spec:
  361. return spec
  362. else:
  363. return unspec
  364. def log_handler(self):
  365. try:
  366. _logconf = self.logger
  367. except KeyError:
  368. return None
  369. handler = None
  370. for htyp in LOG_HANDLER:
  371. if htyp in _logconf:
  372. if htyp == "syslog":
  373. args = _logconf[htyp]
  374. if "socktype" in args:
  375. import socket
  376. if args["socktype"] == "dgram":
  377. args["socktype"] = socket.SOCK_DGRAM
  378. elif args["socktype"] == "stream":
  379. args["socktype"] = socket.SOCK_STREAM
  380. else:
  381. raise ConfigurationError("Unknown socktype!")
  382. try:
  383. handler = LOG_HANDLER[htyp](**args)
  384. except TypeError: # difference between 2.6 and 2.7
  385. del args["socktype"]
  386. handler = LOG_HANDLER[htyp](**args)
  387. else:
  388. handler = LOG_HANDLER[htyp](**_logconf[htyp])
  389. break
  390. if handler is None:
  391. # default if rotating logger
  392. handler = LOG_HANDLER["rotating"]()
  393. if "format" in _logconf:
  394. formatter = logging.Formatter(_logconf["format"])
  395. else:
  396. formatter = logging.Formatter(LOG_FORMAT)
  397. handler.setFormatter(formatter)
  398. return handler
  399. def setup_logger(self):
  400. if root_logger.level != logging.NOTSET: # Someone got there before me
  401. return root_logger
  402. _logconf = self.logger
  403. if _logconf is None:
  404. return root_logger
  405. try:
  406. root_logger.setLevel(LOG_LEVEL[_logconf["loglevel"].lower()])
  407. except KeyError: # reasonable default
  408. root_logger.setLevel(logging.INFO)
  409. root_logger.addHandler(self.log_handler())
  410. root_logger.info("Logging started")
  411. return root_logger
  412. def endpoint2service(self, endpoint, context=None):
  413. endps = self.getattr("endpoints", context)
  414. for service, specs in endps.items():
  415. for endp, binding in specs:
  416. if endp == endpoint:
  417. return service, binding
  418. return None, None
  419. def do_extensions(self, extensions):
  420. for key, val in extensions.items():
  421. self.extensions[key] = val
  422. class SPConfig(Config):
  423. def_context = "sp"
  424. def __init__(self):
  425. Config.__init__(self)
  426. def vo_conf(self, vo_name):
  427. try:
  428. return self.virtual_organization[vo_name]
  429. except KeyError:
  430. return None
  431. def ecp_endpoint(self, ipaddress):
  432. """
  433. Returns the entity ID of the IdP which the ECP client should talk to
  434. :param ipaddress: The IP address of the user client
  435. :return: IdP entity ID or None
  436. """
  437. _ecp = self.getattr("ecp")
  438. if _ecp:
  439. for key, eid in _ecp.items():
  440. if re.match(key, ipaddress):
  441. return eid
  442. return None
  443. class IdPConfig(Config):
  444. def_context = "idp"
  445. def __init__(self):
  446. Config.__init__(self)
  447. def config_factory(typ, filename):
  448. if typ == "sp":
  449. conf = SPConfig().load_file(filename)
  450. conf.context = typ
  451. elif typ in ["aa", "idp", "pdp", "aq"]:
  452. conf = IdPConfig().load_file(filename)
  453. conf.context = typ
  454. else:
  455. conf = Config().load_file(filename)
  456. conf.context = typ
  457. return conf