PageRenderTime 45ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 1ms

/vectorology.py

https://gitlab.com/solstag/abstractology
Python | 259 lines | 218 code | 18 blank | 23 comment | 31 complexity | 34f8b7274e0f8b3dad8eb32533366bc9 MD5 | raw file
  1. # coding: utf-8
  2. # Corporalogister
  3. #
  4. # Author(s):
  5. # * Ale Abdo <abdo@member.fsf.org>
  6. #
  7. # License:
  8. # [GNU-GPLv3+](https://www.gnu.org/licenses/gpl-3.0.html)
  9. #
  10. # Project:
  11. # <https://en.wikiversity.org/wiki/The_dynamics_and_social_organization_of_innovation_in
  12. # _the_field_of_oncology>
  13. #
  14. # Reference repository for this file:
  15. # <https://gitlab.com/solstag/abstractology>
  16. #
  17. # Contributions are welcome, get in touch with the author(s).
  18. try:
  19. import gensim
  20. except ImportError:
  21. print("Warning: Failed to import `gensim`, some functions may not be available")
  22. import multiprocessing, numpy, os, pandas
  23. from os import path
  24. from pathlib import Path
  25. from collections import OrderedDict
  26. from itertools import combinations
  27. from copy import deepcopy
  28. from matplotlib import pyplot as plt, colors as colors
  29. from .ioio import ioio
  30. from .corporalogy import Corporalogy
  31. from .scorology import Scorology
  32. class Vectorology(Scorology, Corporalogy):
  33. """
  34. Word embedding model training and scoring
  35. """
  36. def __init__(self, *args, **kwargs):
  37. super().__init__(*args, **kwargs)
  38. os.makedirs(self.embeddings_path, exist_ok=True)
  39. self.embeddings = OrderedDict()
  40. self.shuffled_index = None
  41. if self._to_load.get("models", []):
  42. for model in self._to_load["models"]:
  43. self.load_models_from_store(model)
  44. if len(self.loaded["data"]) == len(
  45. self.loaded["models"]
  46. ) == 1 and path.basename(self.loaded["models"][0]).startswith(
  47. path.splitext(path.basename(self.loaded["data"][0]))[0] + "-"
  48. ):
  49. index_name = self.loaded["models"][0] + ".index.pickle.xz"
  50. self.shuffled_index = ioio.load(Path(index_name))
  51. print("Index loaded: {}".format(index_name))
  52. else:
  53. print("Index not loaded")
  54. models_path = property(lambda self: path.join(self.storage_dir, "models"))
  55. def set_storage_dir(self, storage_dir):
  56. storage_dir = path.normpath(storage_dir)
  57. os.makedirs(self.embeddings_path, exist_ok=True)
  58. super().set_storage_dir(storage_dir)
  59. def clear_models(self):
  60. self.embeddings = OrderedDict()
  61. self.loaded["models"] = []
  62. def load_models_from_store(self, fdir):
  63. models = OrderedDict()
  64. for fname in sorted(
  65. os.listdir(fdir), key=lambda x: path.getctime(path.join(fdir, x))
  66. ):
  67. if fname.isdigit():
  68. fname = int(fname)
  69. models[fname] = ioio.load(Path(fdir, str(fname)))
  70. print("Models loaded: {}".format(path.join(fdir, str(fname))))
  71. self.update_models(models, fdir)
  72. def update_models(self, models, fullname):
  73. if not self.embeddings.keys().isdisjoint(models):
  74. print(
  75. "Warning: models were overwritten upon loading: {}".format(
  76. set(self.embeddings).intersection(models)
  77. )
  78. )
  79. self.embeddings.update(models)
  80. self.loaded["models"].append(path.normpath(fullname))
  81. def load_models(
  82. self,
  83. name,
  84. balance,
  85. iterations,
  86. window,
  87. dimensions,
  88. mode="document",
  89. sg=False,
  90. hs=True,
  91. groupby=None,
  92. load=True,
  93. store=True,
  94. localvocab=False,
  95. ):
  96. if name == None:
  97. if len(self.loaded["data"]) == 1:
  98. name = path.splitext(path.basename(self.loaded["data"][0]))[0]
  99. else:
  100. raise Exception(
  101. 'Must provide "name" when more than one dataset is loaded'
  102. )
  103. else:
  104. name = path.normpath(name)
  105. name = (
  106. "-".join(
  107. [
  108. name.replace("/", "++").replace(".", "+"),
  109. str("sg" if sg else "cb"),
  110. str("hs" if hs else "ns"),
  111. str(mode),
  112. str(iterations),
  113. str(window),
  114. str(dimensions),
  115. str(balance),
  116. str(localvocab),
  117. str(groupby),
  118. str(self.column),
  119. ]
  120. )
  121. + ".vectors"
  122. )
  123. fullname = path.join(self.embeddings_path, name)
  124. if load:
  125. try:
  126. return self.load_models_from_store(fullname)
  127. except FileNotFoundError:
  128. pass
  129. if window == "full":
  130. window = max(len(d) for d in list(self.itersentences("document")))
  131. print("Window set to {}".format(window))
  132. models = OrderedDict()
  133. # Get data shuffled to reduce training bias
  134. sdata = self.shuffled_data()
  135. # Create the base model, hs=1 and negative=0 are required by .score()
  136. basemodel = gensim.models.Word2Vec(
  137. workers=multiprocessing.cpu_count(),
  138. iter=iterations,
  139. window=window,
  140. size=dimensions,
  141. sg=sg,
  142. hs=hs,
  143. negative=0 if hs else 5,
  144. )
  145. if not localvocab:
  146. basemodel.build_vocab(self.itersentences(mode, sdata[self.column]))
  147. # Train a model for each group of documents
  148. grouped_data = sdata.groupby((lambda x: 0) if groupby is None else groupby)
  149. print("Training these models:", list(grouped_data.groups))
  150. for gname, gdata in self.balance_groups(grouped_data, balance):
  151. print("\rTraining {:<42}".format(gname), end="")
  152. models[gname] = deepcopy(basemodel)
  153. trainlist = list(self.itersentences(mode, gdata[self.column]))
  154. if localvocab:
  155. models[gname].build_vocab(trainlist)
  156. models[gname].train(
  157. trainlist, total_examples=len(trainlist), epochs=models[gname].iter
  158. )
  159. if store:
  160. ioio.store(models[gname], Path(fullname, f"{gname}.pickle.xz"))
  161. print("\nModels stored: {}".format(gname))
  162. if store:
  163. ioio.store(
  164. self.shuffled_index,
  165. Path(self.embeddings_path, f"{name}.index.pickle.xz"),
  166. )
  167. print("Model training index stored: {}".format(fullname + ".index"))
  168. self.update_models(models, fullname)
  169. def calc_scores(self, mode="document", lenfix=True):
  170. allscores = pandas.DataFrame()
  171. print("Calculating scores for: {}".format(list(self.embeddings.keys())))
  172. for name, model in self.embeddings.items():
  173. print("\rCalculating {:<42}".format(name), end="")
  174. # Get sentences, indexes and length of documents to correct likelihoods
  175. sentencelist = list(self.itersentences(mode=mode))
  176. indexlist = list(self.indexsentences(mode=mode))
  177. lenabs = pandas.Series(
  178. (
  179. len([w for w in sentence if w in model.wv.vocab])
  180. for sentence in self.itersentences(mode=mode)
  181. ),
  182. name="lenabs",
  183. )
  184. assert len(sentencelist) == len(indexlist) == len(lenabs)
  185. # the score (log likelihood) of each sentence for the model
  186. scores = pandas.Series(model.score(sentencelist, len(sentencelist)))
  187. if lenfix:
  188. if model.sg:
  189. w = model.window
  190. sgfix = lenabs.apply(
  191. lambda l: max(0, l - 2 * w) * 2 * w
  192. + min(l, 2 * w) * min(l - 1, w)
  193. + sum([int(i / 2) for i in range(min(l, 2 * w))])
  194. )
  195. scores = scores.div(sgfix)
  196. else:
  197. scores = scores.div(lenabs) # abstract-size correction
  198. scorenans = scores[scores.isnull()]
  199. if not scorenans.empty:
  200. print("NaN found for model {}: {}".format(name, list(scorenans.index)))
  201. allscores[name] = scores.groupby(indexlist).mean().loc[self.data.index]
  202. print()
  203. return allscores
  204. def load_scores(self, mode="document"):
  205. print("Loading scores for {}".format(self.column))
  206. fname = f"scores-{self.column}.pickle.xz"
  207. try:
  208. self.scores = ioio.load(Path(self.analysis_dir, fname))
  209. except FileNotFoundError:
  210. self.scores = self.calc_scores(mode)
  211. ioio.store(self.scores, Path(self.analysis_dir, fname))
  212. def plot_wordpair_similarity_matrix(
  213. self, words, name="", scale="linear", upper=False, diagonal=True
  214. ):
  215. functions = OrderedDict(
  216. (mname, getattr(model, "similarity"))
  217. for mname, model in self.embeddings.items()
  218. )
  219. return self.plot_wordpair_matrix(
  220. words,
  221. functions,
  222. funcname="similarity",
  223. name=name,
  224. scale=scale,
  225. upper=upper,
  226. diagonal=diagonal,
  227. )
  228. def plot_wordpair_similarity_profile(self, words, name=""):
  229. functions = OrderedDict(
  230. (mname, getattr(model, "similarity"))
  231. for mname, model in self.embeddings.items()
  232. )
  233. return self.plot_wordpair_profile(
  234. words, functions, funcname="similarity", name=name
  235. )