PageRenderTime 53ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/regurgitator.py

https://gitlab.com/crossref/hairball
Python | 376 lines | 249 code | 102 blank | 25 comment | 43 complexity | 8a3daa8857770283313a2379654c5700 MD5 | raw file
  1. import logging
  2. import os
  3. import sys
  4. import time
  5. import json
  6. import concurrent.futures
  7. from collections import OrderedDict
  8. import click
  9. import pathlib
  10. import requests
  11. from bs4 import BeautifulSoup as bs
  12. from jinja2 import Template
  13. from lxml import etree
  14. from rich.logging import RichHandler
  15. import cr_schema
  16. APP_NAME = "Regurgitator"
  17. CONTACT = "labs@crossref.org"
  18. USER_AGENT = {"UserAgent": f"{APP_NAME}; mailto:{CONTACT}"}
  19. CRAPI_URI = "https://api.crossref.org"
  20. XML_API = "http://doi.crossref.org/search/doi"
  21. DEPOSIT_TEMPLATE_PATH = "deposit_templates"
  22. SCHEMA_PATH = "schemas"
  23. logging.basicConfig(level=logging.WARNING, handlers=[RichHandler()])
  24. logger = logging.getLogger("rich")
  25. logger.info(f"starting: {APP_NAME}")
  26. def get_template(fn):
  27. template = pathlib.Path(fn).read_text()
  28. return Template(template)
  29. def get_header_name_element(schema_version):
  30. """given schema version, which element name do we use in
  31. header?"""
  32. version_changed = "4.3.4"
  33. cutoff = cr_schema.standard.index(version_changed)
  34. return (
  35. "name"
  36. if cr_schema.standard.index(schema_version) > cutoff
  37. else "depositor_name"
  38. )
  39. def standard_template():
  40. return get_template(
  41. os.path.join(DEPOSIT_TEMPLATE_PATH, "standard_deposit_template.xml")
  42. )
  43. def grant_template():
  44. return get_template(
  45. os.path.join(DEPOSIT_TEMPLATE_PATH, "grant_deposit_template.xml")
  46. )
  47. def cn(doi, accept="application/vnd.crossref.unixsd+xml"):
  48. return requests.get(
  49. f"{CRAPI_URI}/works/{doi}/transform/{accept}", headers=USER_AGENT
  50. ).text
  51. def xml_api(doi):
  52. return requests.get(
  53. f"{XML_API}?pid={CONTACT}&format=unixsd&doi={doi}",
  54. headers=USER_AGENT,
  55. ).text
  56. def get_unixsd(doi):
  57. # We can either use CN (preferred)
  58. # or go directly to xml api. Just
  59. # easy way to switch between the two.
  60. # return xml_api(doi)
  61. return cn(doi)
  62. def remove_uneeded_namespaces_from_elements(record):
  63. for e in record.find_all():
  64. if e.prefix not in cr_schema.crossref_namespaces:
  65. e.prefix = None
  66. def is_namespace(attr: str) -> bool:
  67. """
  68. is the attribute a namespace?
  69. """
  70. return attr.startswith("xmlns:")
  71. def ns_name(attr: str) -> bool:
  72. """
  73. given namspace attribute, return the name
  74. """
  75. _, ns = attr.split(":")
  76. return ns
  77. def remove_uneeded_namespaces_decalrations(record):
  78. content_type = detect_content_type(record)
  79. root = record.find(content_type)
  80. namespaces_to_keep = {
  81. attr: root.attrs[attr]
  82. for attr in root.attrs.keys()
  83. if is_namespace(attr) and ns_name(attr) in cr_schema.crossref_namespaces
  84. }
  85. non_namespaces_to_keep = {
  86. attr: root.attrs[attr] for attr in root.attrs.keys() if not is_namespace(attr)
  87. }
  88. root.attrs = namespaces_to_keep | non_namespaces_to_keep
  89. def remove_non_crossref_namespaces(record):
  90. remove_uneeded_namespaces_from_elements(record)
  91. remove_uneeded_namespaces_decalrations(record)
  92. def extract_doi_record(xml):
  93. # NB if you just specify ust "lxml", bs4 case-folds element names and your XML will no longer validate.
  94. bs_content = bs(xml, "lxml-xml")
  95. return bs_content.find("crossref")
  96. def pp_xml(xml):
  97. x = etree.fromstring(xml.encode(encoding="utf-8"))
  98. etree.indent(x, space=" ", level=0)
  99. return etree.tostring(x, pretty_print=True, encoding=str)
  100. def prettify_xml(xml):
  101. return str(bs(xml, features="xml"))
  102. #return bs(xml, features="lxml-xml").prettify()
  103. def detect_content_type(record):
  104. return record.find().name
  105. def validate(xml: str, xsd_path: str) -> bool:
  106. xmlschema_doc = etree.parse(xsd_path)
  107. xmlschema = etree.XMLSchema(xmlschema_doc)
  108. xml_doc = etree.fromstring(xml.encode(encoding="utf-8"))
  109. return xmlschema.validate(xml_doc)
  110. def schema_path(content_type, schema_version):
  111. schema_type = "grant_id" if content_type == "grant" else "crossref"
  112. return os.path.join(SCHEMA_PATH, f"{schema_type}{schema_version}.xsd")
  113. def move_element_after(record, element_to_move_name, target_element_name):
  114. if element_to_move := record.find(element_to_move_name):
  115. if target_element := record.find(target_element_name):
  116. target_element.insert_after(element_to_move)
  117. def remove_attribute_from_elements(record, attribute_name, element_name):
  118. for element in record.findAll(element_name):
  119. if attribute_name in element.attrs:
  120. del element.attrs[attribute_name]
  121. def rename_element(record, old_name, new_name):
  122. if element := record.find(old_name):
  123. element.name = new_name
  124. def copy_to_new_date(date_part, old_pub_date, new_pub_date):
  125. if old_date_part_element := old_pub_date.find(date_part):
  126. bs_content = bs(features="lxml-xml")
  127. new_month_tag = bs_content.new_tag(date_part)
  128. new_month_tag.string = old_date_part_element.text
  129. new_pub_date.append(new_month_tag)
  130. def canonicalize_date(old_pub_date):
  131. # bs_content = bs(features="lxml-xml")
  132. # new_pub_date = bs_content.new_tag("publication_date")
  133. new_pub_date = bs(features="lxml-xml").new_tag("publication_date")
  134. new_pub_date.attrs = old_pub_date.attrs
  135. copy_to_new_date("month", old_pub_date, new_pub_date)
  136. copy_to_new_date("day", old_pub_date, new_pub_date)
  137. copy_to_new_date("year", old_pub_date, new_pub_date)
  138. return new_pub_date
  139. def fix_dates(record):
  140. for publication_date in record.findAll("publication_date"):
  141. publication_date.replace_with(canonicalize_date(publication_date))
  142. def degunk_book_chapter(record):
  143. remove_attribute_from_elements(record, "provider", "doi")
  144. remove_attribute_from_elements(record, "provider", "rel:intra_work_relation")
  145. remove_attribute_from_elements(record, "setbyID", "collection")
  146. rename_element(record, "volume", "edition_number")
  147. fix_dates(record)
  148. def degunk_journal_article(record):
  149. remove_attribute_from_elements(record, "provider", "doi")
  150. remove_attribute_from_elements(record, "provider", "rel:intra_work_relation")
  151. move_element_after(record, "ai:program", "pages")
  152. move_element_after(record, "publisher_item", "pages")
  153. move_element_after(record, "abstract", "contributors")
  154. # move_element_after(record,'journal_volume','contributors')
  155. def degunk(record, content_type) -> None:
  156. """
  157. Remove output schema anomolies
  158. The Crossref output schema will often do things in a slightly different order
  159. to the crossref deposit schema. It may also incluude elements that are not in the
  160. deposit schema. This *needs to be fixed in the output schema*, but this is my current
  161. workaround.
  162. """
  163. remove_non_crossref_namespaces(record)
  164. if content_type == "book":
  165. degunk_book_chapter(record)
  166. elif content_type == "journal":
  167. degunk_journal_article(record)
  168. def validate_all(content_type: str, all_schemas: dict) -> OrderedDict:
  169. results = {}
  170. with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
  171. future_to_version = {
  172. executor.submit(validate, xml, schema_path(content_type, schema_version)): (
  173. xml,
  174. schema_version,
  175. )
  176. for schema_version, xml in all_schemas.items()
  177. }
  178. for future in concurrent.futures.as_completed(future_to_version):
  179. version = future_to_version[future][1]
  180. try:
  181. validation_result = future.result()
  182. results[version] = validation_result
  183. except Exception as exc:
  184. logger.error("%r generated an exception: %s" % (url, exc))
  185. return OrderedDict(sorted(results.items(), reverse=True))
  186. def try_schema_versions(
  187. record,
  188. doi,
  189. content_type,
  190. doi_batch_id,
  191. timestamp,
  192. depositor_name,
  193. email_address,
  194. registrant,
  195. ):
  196. new_xml = None
  197. template = grant_template() if content_type == "grant" else standard_template()
  198. schema_versions = (
  199. cr_schema.grants if content_type == "grant" else cr_schema.standard
  200. )
  201. all_schemas = {}
  202. for schema_version in schema_versions:
  203. header_name_element = (
  204. "depositor_name"
  205. if content_type == "grant"
  206. else get_header_name_element(schema_version)
  207. )
  208. new_xml = template.render(
  209. body=record,
  210. schema_version=schema_version,
  211. doi_batch_id=doi_batch_id,
  212. timestamp=timestamp,
  213. depositor_name=depositor_name,
  214. email_address=email_address,
  215. registrant=registrant,
  216. header_name_element=header_name_element,
  217. )
  218. all_schemas[schema_version] = new_xml
  219. validates_against = validate_all(content_type, all_schemas)
  220. if not any(all_valid := [key for key, value in validates_against.items() if value]):
  221. return None, all_schemas[schema_versions[0]]
  222. most_recent_to_validate = all_valid[0]
  223. return most_recent_to_validate, all_schemas[most_recent_to_validate]
  224. def regurgitate(
  225. doi, doi_batch_id, timestamp, depositor_name, email_address, registrant
  226. ):
  227. xml = get_unixsd(doi)
  228. record = extract_doi_record(xml)
  229. record.name = "body"
  230. record.attrs = {}
  231. content_type = detect_content_type(record)
  232. degunk(record, content_type)
  233. schema_version, xml = try_schema_versions(
  234. record=record,
  235. doi=doi,
  236. content_type=content_type,
  237. doi_batch_id=doi_batch_id,
  238. timestamp=timestamp,
  239. depositor_name=depositor_name,
  240. email_address=email_address,
  241. registrant=registrant,
  242. )
  243. return schema_version, pp_xml(xml)
  244. if __name__ == "__main__":
  245. @click.command()
  246. @click.argument("input", type=click.File("rb"), nargs=-1)
  247. @click.option("-v", "--verbose", default=False, show_default=True, is_flag=True)
  248. def cli(input, verbose):
  249. if verbose:
  250. logging.getLogger().setLevel(logging.INFO)
  251. logger.info("verbose mode")
  252. dois = []
  253. for f in input:
  254. dois += [line.decode("utf-8").rstrip() for line in f.readlines()]
  255. for index, doi in enumerate(dois):
  256. ts = int(time.time())
  257. schema_version, new_xml = regurgitate(
  258. doi=doi,
  259. timestamp=ts,
  260. doi_batch_id=ts,
  261. email_address="gbilder@crossref.org",
  262. registrant="crossref",
  263. depositor_name="gbilder",
  264. )
  265. if schema_version:
  266. fn = f"results/valid-{index}.xml"
  267. else:
  268. logger.error(f"failed to regurgitate {doi}")
  269. fn = f"results/invalid-{index}.xml"
  270. with open(fn, "w") as f:
  271. f.write(prettify_xml(new_xml))
  272. logger.info(f"saved: {fn}")
  273. cli()