/institution_test_suite.py

https://github.com/adsabs/scripts-affiliation-disambiguation · Python · 187 lines · 140 code · 45 blank · 2 comment · 48 complexity · 24b55f8c024b0bcdc901eecc5b3ed5f9 MD5 · raw file

  1. from collections import defaultdict
  2. import Levenshtein
  3. import marshal
  4. import os
  5. import re
  6. import time
  7. import institution_searcher as s
  8. RE_SPACES = re.compile('\s+')
  9. def get_icns(reextract=False):
  10. if reextract:
  11. os.chdir('desy_affiliations')
  12. import desy_affs
  13. icns = desy_affs.get_icns()
  14. os.chdir('..')
  15. # Find the match file between old and new ICNs.
  16. match = dict(line.strip().split('\t') for line in open('old_new.txt').readlines())
  17. for icn in icns.keys():
  18. if '; 'in icn or icn.endswith(' to be removed'):
  19. # Delete pairs of affiliations.
  20. del icns[icn]
  21. else:
  22. if icn in match:
  23. new_icn = RE_SPACES.sub(' ', match[icn].strip())
  24. else:
  25. new_icn = icn
  26. try:
  27. new_icn = new_icn.decode('utf-8')
  28. except:
  29. pass
  30. new_icn = RE_SPACES.sub(' ', new_icn.strip())
  31. if new_icn != icn:
  32. icns[new_icn] = sorted(list(set(icns.get(new_icn, []) + icns.pop(icn, []))))
  33. return icns
  34. else:
  35. return marshal.load(open('icns.marshal'))
  36. PROCESS_NUMBER = 20
  37. def extend_icns(icns):
  38. out = []
  39. for icn, institutions in icns.items():
  40. for institution in institutions:
  41. out.append((icn, institution))
  42. return out
  43. def test_icns_only(icns):
  44. return test([(icn, icn) for icn in icns])
  45. def analyse_icns(res, icns):
  46. out = []
  47. for original, _, matched in res:
  48. if matched is not None:
  49. if not isinstance(original, str):
  50. original = original.encode('utf-8')
  51. if not isinstance(matched, str):
  52. matched = matched.encode('utf-8')
  53. if original != matched:
  54. out.append((
  55. Levenshtein.distance(original, matched),
  56. 1. / len(icns[original]),
  57. original,
  58. matched
  59. ))
  60. return sorted(out)
  61. def test(icns):
  62. if isinstance(icns, dict):
  63. icns = extend_icns(icns)
  64. results = []
  65. chunk_size = len(icns) / PROCESS_NUMBER + 1
  66. while icns:
  67. chunk = icns[:chunk_size]
  68. icns = icns[chunk_size:]
  69. results.append(s.match_institutions.delay(chunk))
  70. while not all([r.ready() for r in results]):
  71. time.sleep(0.1)
  72. out = []
  73. for r in results:
  74. out += r.result
  75. print_statistics(out)
  76. return out
  77. def test_ratio(icns):
  78. res = test(icns)
  79. correct_matches, incorrect_matches = separate_results(res)
  80. correct_ratios = compute_ratios(correct_matches)
  81. incorrect_ratios = compute_ratios(incorrect_matches)
  82. return correct_ratios, incorrect_ratios
  83. def compute_ratios(matches):
  84. out = []
  85. for inst, results in get_two_first_results(matches):
  86. if results and len(results) >= 2:
  87. out.append((inst, float(int(results[1][0] / results[0][0] * 20)) / 20))
  88. return out
  89. def display_ratios(ratios):
  90. clustered = defaultdict(list)
  91. for inst, ratio in ratios:
  92. clustered[ratio].append(clustered)
  93. for i in [float(i) / 20 for i in range(0, 21)]:
  94. print len(clustered[i])
  95. def get_two_first_results(institutions):
  96. results = []
  97. chunk_size = len(institutions) / PROCESS_NUMBER + 1
  98. first = []
  99. while institutions:
  100. chunk = institutions[:chunk_size]
  101. institutions = institutions[chunk_size:]
  102. results.append(s.get_match_ratio.delay(chunk))
  103. while not all([r.ready() for r in results]):
  104. time.sleep(0.1)
  105. out = []
  106. for r in results:
  107. out += r.result
  108. return out
  109. def separate_results(res):
  110. correct, error = [], []
  111. for r in res:
  112. if r[0] == r[2]:
  113. correct.append(r[1])
  114. else:
  115. error.append(r[1])
  116. return correct, error
  117. def get_sorted_errors(res):
  118. errors = [r for r in res if r[0] != r[2]]
  119. sorted_errors = defaultdict(list)
  120. for r in errors:
  121. sorted_errors[r[0]].append((r[1], r[2]))
  122. return sorted(((len(v), k) for k, v in sorted_errors.items()), reverse=True)
  123. PREVIOUS_SCORE = 0
  124. def print_statistics(results):
  125. correct = [r for r in results if r[0] == r[2]]
  126. score = float(len(correct)) / len(results) * 100
  127. print '%s/%s (%.2f%%)' % (len(correct), len(results), score)
  128. global PREVIOUS_SCORE
  129. print 'Previous score: %.2f%%' % PREVIOUS_SCORE
  130. PREVIOUS_SCORE = score
  131. def format_results(results):
  132. from BeautifulSoup import UnicodeDammit
  133. new_results = []
  134. for line in results:
  135. new_line = []
  136. for elem in line:
  137. if elem is None:
  138. new_line.append('')
  139. else:
  140. new_line.append(UnicodeDammit(elem).unicode)
  141. new_results.append('\t'.join(new_line))
  142. return '\n'.join(new_results)