PageRenderTime 28ms CodeModel.GetById 17ms RepoModel.GetById 1ms app.codeStats 0ms

/Lib/site-packages/gensim/test/test_corpora.py

https://gitlab.com/pierreEffiScience/TwitterClustering
Python | 330 lines | 269 code | 45 blank | 16 comment | 16 complexity | d8f2dee3ffeb4c9616f4b773a8e1df91 MD5 | raw file
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
  5. # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
  6. """
  7. Automated tests for checking corpus I/O formats (the corpora package).
  8. """
  9. import logging
  10. import os.path
  11. import unittest
  12. import tempfile
  13. import itertools
  14. import numpy
  15. from gensim.utils import to_unicode, smart_extension
  16. from gensim.interfaces import TransformedCorpus
  17. from gensim.corpora import (bleicorpus, mmcorpus, lowcorpus, svmlightcorpus,
  18. ucicorpus, malletcorpus, textcorpus, indexedcorpus)
  19. # needed because sample data files are located in the same folder
  20. module_path = os.path.dirname(__file__)
  21. datapath = lambda fname: os.path.join(module_path, 'test_data', fname)
  22. def testfile():
  23. # temporary data will be stored to this file
  24. return os.path.join(tempfile.gettempdir(), 'gensim_corpus.tst')
  25. class DummyTransformer(object):
  26. def __getitem__(self, bow):
  27. if len(next(iter(bow))) == 2:
  28. # single bag of words
  29. transformed = [(termid, count + 1) for termid, count in bow]
  30. else:
  31. # sliced corpus
  32. transformed = [[(termid, count + 1) for termid, count in doc] for doc in bow]
  33. return transformed
  34. class CorpusTestCase(unittest.TestCase):
  35. TEST_CORPUS = [[(1, 1.0)], [], [(0, 0.5), (2, 1.0)], []]
  36. def run(self, result=None):
  37. if type(self) is not CorpusTestCase:
  38. super(CorpusTestCase, self).run(result)
  39. def tearDown(self):
  40. # remove all temporary test files
  41. fname = testfile()
  42. extensions = ['', '', '.bz2', '.gz', '.index', '.vocab']
  43. for ext in itertools.permutations(extensions, 2):
  44. try:
  45. os.remove(fname + ext[0] + ext[1])
  46. except OSError:
  47. pass
  48. def test_load(self):
  49. fname = datapath('testcorpus.' + self.file_extension.lstrip('.'))
  50. corpus = self.corpus_class(fname)
  51. docs = list(corpus)
  52. # the deerwester corpus always has nine documents
  53. self.assertEqual(len(docs), 9)
  54. def test_len(self):
  55. fname = datapath('testcorpus.' + self.file_extension.lstrip('.'))
  56. corpus = self.corpus_class(fname)
  57. # make sure corpus.index works, too
  58. corpus = self.corpus_class(fname)
  59. self.assertEqual(len(corpus), 9)
  60. # for subclasses of IndexedCorpus, we need to nuke this so we don't
  61. # test length on the index, but just testcorpus contents
  62. if hasattr(corpus, 'index'):
  63. corpus.index = None
  64. self.assertEqual(len(corpus), 9)
  65. def test_empty_input(self):
  66. with open(testfile(), 'w') as f:
  67. f.write('')
  68. with open(testfile() + '.vocab', 'w') as f:
  69. f.write('')
  70. corpus = self.corpus_class(testfile())
  71. self.assertEqual(len(corpus), 0)
  72. docs = list(corpus)
  73. self.assertEqual(len(docs), 0)
  74. def test_save(self):
  75. corpus = self.TEST_CORPUS
  76. # make sure the corpus can be saved
  77. self.corpus_class.save_corpus(testfile(), corpus)
  78. # and loaded back, resulting in exactly the same corpus
  79. corpus2 = list(self.corpus_class(testfile()))
  80. self.assertEqual(corpus, corpus2)
  81. def test_serialize(self):
  82. corpus = self.TEST_CORPUS
  83. # make sure the corpus can be saved
  84. self.corpus_class.serialize(testfile(), corpus)
  85. # and loaded back, resulting in exactly the same corpus
  86. corpus2 = self.corpus_class(testfile())
  87. self.assertEqual(corpus, list(corpus2))
  88. # make sure the indexing corpus[i] works
  89. for i in range(len(corpus)):
  90. self.assertEqual(corpus[i], corpus2[i])
  91. # make sure that subclasses of IndexedCorpus support fancy indexing
  92. # after deserialisation
  93. if isinstance(corpus, indexedcorpus.IndexedCorpus):
  94. idx = [1, 3, 5, 7]
  95. self.assertEquals(corpus[idx], corpus2[idx])
  96. def test_serialize_compressed(self):
  97. corpus = self.TEST_CORPUS
  98. for extension in ['.gz', '.bz2']:
  99. fname = testfile() + extension
  100. # make sure the corpus can be saved
  101. self.corpus_class.serialize(fname, corpus)
  102. # and loaded back, resulting in exactly the same corpus
  103. corpus2 = self.corpus_class(fname)
  104. self.assertEqual(corpus, list(corpus2))
  105. # make sure the indexing `corpus[i]` syntax works
  106. for i in range(len(corpus)):
  107. self.assertEqual(corpus[i], corpus2[i])
  108. def test_switch_id2word(self):
  109. fname = datapath('testcorpus.' + self.file_extension.lstrip('.'))
  110. corpus = self.corpus_class(fname)
  111. if hasattr(corpus, 'id2word'):
  112. firstdoc = next(iter(corpus))
  113. testdoc = set((to_unicode(corpus.id2word[x]), y) for x, y in firstdoc)
  114. self.assertEqual(testdoc, set([('computer', 1), ('human', 1), ('interface', 1)]))
  115. d = corpus.id2word
  116. d[0], d[1] = d[1], d[0]
  117. corpus.id2word = d
  118. firstdoc2 = next(iter(corpus))
  119. testdoc2 = set((to_unicode(corpus.id2word[x]), y) for x, y in firstdoc2)
  120. self.assertEqual(testdoc2, set([('computer', 1), ('human', 1), ('interface', 1)]))
  121. def test_indexing(self):
  122. fname = datapath('testcorpus.' + self.file_extension.lstrip('.'))
  123. corpus = self.corpus_class(fname)
  124. docs = list(corpus)
  125. for idx, doc in enumerate(docs):
  126. self.assertEqual(doc, corpus[idx])
  127. self.assertEqual(docs, list(corpus[:]))
  128. self.assertEqual(docs[0:], list(corpus[0:]))
  129. self.assertEqual(docs[0:-1], list(corpus[0:-1]))
  130. self.assertEqual(docs[2:4], list(corpus[2:4]))
  131. self.assertEqual(docs[::2], list(corpus[::2]))
  132. self.assertEqual(docs[::-1], list(corpus[::-1]))
  133. # make sure sliced corpora can be iterated over multiple times
  134. c = corpus[:]
  135. self.assertEqual(docs, list(c))
  136. self.assertEqual(docs, list(c))
  137. self.assertEqual(len(docs), len(corpus))
  138. self.assertEqual(len(docs), len(corpus[:]))
  139. self.assertEqual(len(docs[::2]), len(corpus[::2]))
  140. def _get_slice(corpus, slice_):
  141. # assertRaises for python 2.6 takes a callable
  142. return corpus[slice_]
  143. # make sure proper input validation for sliced corpora is done
  144. self.assertRaises(ValueError, _get_slice, corpus, set([1]))
  145. self.assertRaises(ValueError, _get_slice, corpus, 1.0)
  146. # check sliced corpora that use fancy indexing
  147. c = corpus[[1, 3, 4]]
  148. self.assertEquals([d for i, d in enumerate(docs) if i in [1, 3, 4]], list(c))
  149. self.assertEquals([d for i, d in enumerate(docs) if i in [1, 3, 4]], list(c))
  150. self.assertEquals(len(corpus[[0, 1, -1]]), 3)
  151. self.assertEquals(len(corpus[numpy.asarray([0, 1, -1])]), 3)
  152. # check that TransformedCorpus supports indexing when the underlying
  153. # corpus does, and throws an error otherwise
  154. if hasattr(corpus, 'index') and corpus.index is not None:
  155. corpus_ = TransformedCorpus(DummyTransformer(), corpus)
  156. self.assertEqual(corpus_[0][0][1], docs[0][0][1]+1)
  157. self.assertRaises(ValueError, _get_slice, corpus_, set([1]))
  158. transformed_docs = [val+1 for i, d in enumerate(docs) for _, val in d if i in [1, 3, 4]]
  159. self.assertEquals(transformed_docs, list(v for doc in corpus_[[1, 3, 4]] for _, v in doc))
  160. self.assertEqual(3, len(corpus_[[1, 3, 4]]))
  161. else:
  162. self.assertRaises(RuntimeError, _get_slice, corpus_, [1, 3, 4])
  163. self.assertRaises(RuntimeError, _get_slice, corpus_, set([1]))
  164. self.assertRaises(RuntimeError, _get_slice, corpus_, 1.0)
  165. class TestMmCorpus(CorpusTestCase):
  166. def setUp(self):
  167. self.corpus_class = mmcorpus.MmCorpus
  168. self.file_extension = '.mm'
  169. def test_serialize_compressed(self):
  170. # MmCorpus needs file write with seek => doesn't support compressed output (only input)
  171. pass
  172. class TestSvmLightCorpus(CorpusTestCase):
  173. def setUp(self):
  174. self.corpus_class = svmlightcorpus.SvmLightCorpus
  175. self.file_extension = '.svmlight'
  176. class TestBleiCorpus(CorpusTestCase):
  177. def setUp(self):
  178. self.corpus_class = bleicorpus.BleiCorpus
  179. self.file_extension = '.blei'
  180. def test_save_format_for_dtm(self):
  181. corpus = [[(1, 1.0)], [], [(0, 5.0), (2, 1.0)], []]
  182. test_file = testfile()
  183. self.corpus_class.save_corpus(test_file, corpus)
  184. with open(test_file) as f:
  185. for line in f:
  186. # unique_word_count index1:count1 index2:count2 ... indexn:counnt
  187. tokens = line.split()
  188. words_len = int(tokens[0])
  189. if words_len > 0:
  190. tokens = tokens[1:]
  191. else:
  192. tokens = []
  193. self.assertEqual(words_len, len(tokens))
  194. for token in tokens:
  195. word, count = token.split(':')
  196. self.assertEqual(count, str(int(count)))
  197. class TestLowCorpus(CorpusTestCase):
  198. TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []]
  199. def setUp(self):
  200. self.corpus_class = lowcorpus.LowCorpus
  201. self.file_extension = '.low'
  202. class TestUciCorpus(CorpusTestCase):
  203. TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []]
  204. def setUp(self):
  205. self.corpus_class = ucicorpus.UciCorpus
  206. self.file_extension = '.uci'
  207. def test_serialize_compressed(self):
  208. # UciCorpus needs file write with seek => doesn't support compressed output (only input)
  209. pass
  210. class TestMalletCorpus(CorpusTestCase):
  211. TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []]
  212. def setUp(self):
  213. self.corpus_class = malletcorpus.MalletCorpus
  214. self.file_extension = '.mallet'
  215. def test_load_with_metadata(self):
  216. fname = datapath('testcorpus.' + self.file_extension.lstrip('.'))
  217. corpus = self.corpus_class(fname)
  218. corpus.metadata = True
  219. self.assertEqual(len(corpus), 9)
  220. docs = list(corpus)
  221. self.assertEqual(len(docs), 9)
  222. for i, docmeta in enumerate(docs):
  223. doc, metadata = docmeta
  224. self.assertEqual(metadata[0], str(i + 1))
  225. self.assertEqual(metadata[1], 'en')
  226. class TestTextCorpus(CorpusTestCase):
  227. def setUp(self):
  228. self.corpus_class = textcorpus.TextCorpus
  229. self.file_extension = '.txt'
  230. def test_load_with_metadata(self):
  231. fname = datapath('testcorpus.' + self.file_extension.lstrip('.'))
  232. corpus = self.corpus_class(fname)
  233. corpus.metadata = True
  234. self.assertEqual(len(corpus), 9)
  235. docs = list(corpus)
  236. self.assertEqual(len(docs), 9)
  237. for i, docmeta in enumerate(docs):
  238. doc, metadata = docmeta
  239. self.assertEqual(metadata[0], i)
  240. def test_save(self):
  241. pass
  242. def test_serialize(self):
  243. pass
  244. def test_serialize_compressed(self):
  245. pass
  246. def test_indexing(self):
  247. pass
  248. if __name__ == '__main__':
  249. logging.basicConfig(level=logging.DEBUG)
  250. unittest.main()