PageRenderTime 47ms CodeModel.GetById 24ms RepoModel.GetById 1ms app.codeStats 0ms

/backend/geneaprove/sql/sources.py

https://github.com/briot/geneapro
Python | 158 lines | 129 code | 12 blank | 17 comment | 4 complexity | 07406e2471e8965f38be0d5247bb27e3 MD5 | raw file
Possible License(s): GPL-2.0
  1. import collections
  2. import django.db
  3. from django.db.models import F, IntegerField, TextField, Value, Count
  4. import logging
  5. from .. import models
  6. from .sqlsets import SQLSet
  7. from .asserts import AssertList
  8. logger = logging.getLogger(__name__)
  9. CitationDetails = collections.namedtuple(
  10. 'CitationDetails', 'name value fromHigh')
  11. class SourceSet(SQLSet):
  12. def __init__(self):
  13. self.sources = collections.OrderedDict() # id -> Source
  14. self.asserts = AssertList()
  15. self._higher = None # id -> list of higher source ids, recursively
  16. self._citations = None # id -> list of CitationDetails
  17. def add_ids(self, ids=None, offset=None, limit=None):
  18. """
  19. Fetch sources for all the given ids, along with related data like
  20. researcher and repository.
  21. """
  22. assert ids is None or isinstance(ids, collections.abc.Iterable)
  23. pm = models.Source.objects.select_related()
  24. pm = self.limit_offset(pm, offset=offset, limit=limit)
  25. for chunk in self.sqlin(pm, id__in=ids):
  26. for s in chunk:
  27. self.sources[s.id] = s
  28. self._higher = None
  29. self._citations = None
  30. self.asserts.add_known(
  31. sources=self.sources.values()) # Do not fetch them again
  32. def fetch_higher_sources(self):
  33. """
  34. Fetch the 'higher' source relationships. This only gets the ids
  35. """
  36. if self._higher is not None:
  37. return # already computed
  38. logger.debug('SourceSet.fetch_higher_sources')
  39. with django.db.connection.cursor() as cur:
  40. ids = ", ".join(str(k) for k in self.sources.keys())
  41. q = (
  42. "WITH RECURSIVE higher(source_id, parent) AS ("
  43. "SELECT id, higher_source_id FROM source "
  44. "WHERE higher_source_id IS NOT NULL "
  45. "UNION "
  46. "SELECT higher.source_id, source.higher_source_id "
  47. "FROM source, higher "
  48. "WHERE source.id=higher.parent "
  49. "AND source.higher_source_id IS NOT NULL"
  50. ") SELECT higher.source_id, higher.parent "
  51. "FROM higher "
  52. f"WHERE higher.source_id IN ({ids}) "
  53. )
  54. cur.execute(q)
  55. self._higher = collections.defaultdict(list)
  56. for s, parent in cur.fetchall():
  57. self._higher[s].append(parent)
  58. def count_asserts(self):
  59. """
  60. Count all asserts for the sources, but doesn't fetch them
  61. """
  62. assert len(self.sources) == 1
  63. sid = next(iter(self.sources))
  64. count = 0
  65. for table in (models.P2E, models.P2C, models.P2P, models.P2G):
  66. count += table.objects.filter(source=sid).count()
  67. return count
  68. def fetch_asserts(self, offset=None, limit=None):
  69. """
  70. Fetch all assertions for all sources
  71. """
  72. logger.debug('SourceSet.fetch_asserts')
  73. assert len(self.sources) == 1
  74. sid = next(iter(self.sources))
  75. self.asserts.fetch_asserts_subset(
  76. [models.P2E.objects.filter(source=sid),
  77. models.P2C.objects.filter(source=sid),
  78. models.P2P.objects.filter(source=sid),
  79. models.P2G.objects.filter(source=sid)],
  80. offset=offset,
  81. limit=limit)
  82. def fetch_citations(self):
  83. """
  84. Fetch all citation parts for all sources and their higher sources
  85. """
  86. if self._citations is not None:
  87. return # already computed
  88. self.fetch_higher_sources()
  89. logger.debug('SourceSet.fetch_citations')
  90. all_ids = set(self.sources.keys())
  91. all_ids.update(
  92. h
  93. for higher_list in self._higher.values()
  94. for h in higher_list)
  95. self._citations = collections.defaultdict(list)
  96. for p in models.Citation_Part.objects.filter(source__in=all_ids):
  97. self._citations[p.source_id].append(CitationDetails(
  98. name=p.type_id,
  99. value=p.value,
  100. fromHigh=False))
  101. def get_citations(self, source):
  102. """
  103. Return the citations for a given source, recursively looking at
  104. higher sources
  105. """
  106. if isinstance(source, models.Source):
  107. source = source.id
  108. assert isinstance(source, int)
  109. if source not in self.sources:
  110. self.add_ids([source])
  111. self.fetch_citations()
  112. result = list(self._citations[source])
  113. for h in self._higher[source]:
  114. for c in self._citations[h]:
  115. result.append(CitationDetails(
  116. name=c.name, value=c.value, fromHigh=True))
  117. return result
  118. def get_higher_sources(self, source):
  119. """
  120. Get the ids of higher sources for a specific source
  121. """
  122. if isinstance(source, models.Source):
  123. source = source.id
  124. assert isinstance(source, int)
  125. if source not in self.sources:
  126. self.add_ids([source])
  127. self.fetch_higher_sources()
  128. return self._higher[source]