PageRenderTime 65ms CodeModel.GetById 33ms RepoModel.GetById 0ms app.codeStats 1ms

/nltk/tag/crf.py

https://github.com/BrucePHill/nltk
Python | 755 lines | 707 code | 8 blank | 40 comment | 17 complexity | f3b0770ecff534d7129dd0427525916a MD5 | raw file
Possible License(s): Apache-2.0
  1. # Natural Language Toolkit: Conditional Random Fields
  2. #
  3. # Copyright (C) 2001-2013 NLTK Project
  4. # Author: Edward Loper <edloper@gradient.cis.upenn.edu>
  5. # URL: <http://www.nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """
  8. An interface to Mallet <http://mallet.cs.umass.edu/>'s Linear Chain
  9. Conditional Random Field (LC-CRF) implementation.
  10. A user-supplied feature detector function is used to convert each
  11. token to a featureset. Each feature/value pair is then encoded as a
  12. single binary feature for Mallet.
  13. """
  14. from __future__ import print_function, unicode_literals
  15. import os
  16. import pickle
  17. import re
  18. import subprocess
  19. import codecs
  20. from tempfile import mkstemp
  21. import textwrap
  22. import time
  23. import zipfile
  24. from xml.etree import ElementTree
  25. import nltk
  26. from nltk import compat
  27. from nltk.classify import call_mallet
  28. from nltk.tag.api import FeaturesetTaggerI
  29. @compat.python_2_unicode_compatible
  30. class MalletCRF(FeaturesetTaggerI):
  31. """
  32. A conditional random field tagger, which is trained and run by
  33. making external calls to Mallet. Tokens are converted to
  34. featuresets using a feature detector function::
  35. feature_detector(tokens, index) -> featureset
  36. These featuresets are then encoded into feature vectors by
  37. converting each feature (name, value) pair to a unique binary
  38. feature.
  39. Ecah MalletCRF object is backed by a crf model file. This
  40. model file is actually a zip file, and it contains one file for
  41. the serialized model ``crf-model.ser`` and one file for
  42. information about the structure of the CRF ``crf-info.xml``.
  43. Create a new MalletCRF.
  44. :param filename: The filename of the model file that backs this CRF.
  45. :param feature_detector: The feature detector function that is
  46. used to convert tokens to featuresets. This parameter
  47. only needs to be given if the model file does not contain
  48. a pickled pointer to the feature detector (e.g., if the
  49. feature detector was a lambda function).
  50. """
  51. def __init__(self, filename, feature_detector=None):
  52. # Read the CRFInfo from the model file.
  53. zf = zipfile.ZipFile(filename)
  54. crf_info = CRFInfo.fromstring(zf.read('crf-info.xml'))
  55. zf.close()
  56. self.crf_info = crf_info
  57. """A CRFInfo object describing this CRF."""
  58. # Ensure that our crf_info object has a feature detector.
  59. if crf_info.feature_detector is not None:
  60. if (feature_detector is not None and
  61. self.crf_info.feature_detector != feature_detector):
  62. raise ValueError('Feature detector mismatch: %r vs %r' %
  63. (feature_detector, self.crf_info.feature_detector))
  64. elif feature_detector is None:
  65. raise ValueError('Feature detector not found; supply it manually.')
  66. elif feature_detector.__name__ != crf_info.feature_detector_name:
  67. raise ValueError('Feature detector name mismatch: %r vs %r' %
  68. (feature_detector.__name__,
  69. crf_info.feature_detector_name))
  70. else:
  71. self.crf_info.feature_detector = feature_detector
  72. #/////////////////////////////////////////////////////////////////
  73. # Convenience accessors (info also available via self.crf_info)
  74. #/////////////////////////////////////////////////////////////////
  75. @property
  76. def filename(self):
  77. """
  78. The filename of the crf model file that backs this
  79. MalletCRF. The crf model file is actually a zip file, and
  80. it contains one file for the serialized model
  81. ``crf-model.ser`` and one file for information about the
  82. structure of the CRF ``crf-info.xml``).
  83. """
  84. return self.crf_info.model_filename
  85. @property
  86. def feature_detector(self):
  87. """
  88. The feature detector function that is used to convert tokens
  89. to featuresets. This function has the signature::
  90. feature_detector(tokens, index) -> featureset
  91. """
  92. return self.crf_info.model_feature_detector
  93. #/////////////////////////////////////////////////////////////////
  94. # Tagging
  95. #/////////////////////////////////////////////////////////////////
  96. #: The name of the java script used to run MalletCRFs.
  97. _RUN_CRF = "org.nltk.mallet.RunCRF"
  98. def batch_tag(self, sentences):
  99. # Write the test corpus to a temporary file
  100. (fd, test_file) = mkstemp('.txt', 'test')
  101. self.write_test_corpus(sentences, os.fdopen(fd, 'w'))
  102. try:
  103. # Run mallet on the test file.
  104. stdout, stderr = call_mallet([self._RUN_CRF,
  105. '--model-file', os.path.abspath(self.crf_info.model_filename),
  106. '--test-file', test_file], stdout='pipe')
  107. # Decode the output
  108. labels = self.parse_mallet_output(stdout)
  109. # strip __start__ and __end__
  110. if self.crf_info.add_start_state and self.crf_info.add_end_state:
  111. labels = [labs[1:-1] for labs in labels]
  112. elif self.crf_info.add_start_state:
  113. labels = [labs[1:] for labs in labels]
  114. elif self.crf_info.add_end_state:
  115. labels = [labs[:-1] for labs in labels]
  116. # Combine the labels and the original sentences.
  117. return [zip(sent, label) for (sent,label) in
  118. zip(sentences, labels)]
  119. finally:
  120. os.remove(test_file)
  121. #/////////////////////////////////////////////////////////////////
  122. # Training
  123. #/////////////////////////////////////////////////////////////////
  124. #: The name of the java script used to train MalletCRFs.
  125. _TRAIN_CRF = "org.nltk.mallet.TrainCRF"
  126. @classmethod
  127. def train(cls, feature_detector, corpus, filename=None,
  128. weight_groups=None, gaussian_variance=1, default_label='O',
  129. transduction_type='VITERBI', max_iterations=500,
  130. add_start_state=True, add_end_state=True, trace=1):
  131. """
  132. Train a new linear chain CRF tagger based on the given corpus
  133. of training sequences. This tagger will be backed by a crf
  134. model file, containing both a serialized Mallet model and
  135. information about the CRF's structure. This crf model file
  136. will not be automatically deleted -- if you wish to delete
  137. it, you must delete it manually. The filename of the model
  138. file for a MalletCRF crf is available as ``crf.filename()``.
  139. :type corpus: list(tuple(str, str))
  140. :param corpus: Training data, represented as a list of
  141. sentences, where each sentence is a list of (token, tag) tuples.
  142. :type filename: str
  143. :param filename: The filename that should be used for the crf
  144. model file that backs the new MalletCRF. If no
  145. filename is given, then a new filename will be chosen
  146. automatically.
  147. :type weight_groups: list(CRFInfo.WeightGroup)
  148. :param weight_groups: Specifies how input-features should
  149. be mapped to joint-features. See CRFInfo.WeightGroup
  150. for more information.
  151. :type gaussian_variance: float
  152. :param gaussian_variance: The gaussian variance of the prior
  153. that should be used to train the new CRF.
  154. :type default_label: str
  155. :param default_label: The "label for initial context and
  156. uninteresting tokens" (from Mallet's SimpleTagger.java.)
  157. It's unclear whether this currently has any effect.
  158. :type transduction_type: str
  159. :param transduction_type: The type of transduction used by
  160. the CRF. Can be VITERBI, VITERBI_FBEAM, VITERBI_BBEAM,
  161. VITERBI_FBBEAM, or VITERBI_FBEAMKL.
  162. :type max_iterations: int
  163. :param max_iterations: The maximum number of iterations that
  164. should be used for training the CRF.
  165. :type add_start_state: bool
  166. :param add_start_state: If true, then NLTK will add a special
  167. start state, named '__start__'. The initial cost for
  168. the start state will be set to 0; and the initial cost for
  169. all other states will be set to +inf.
  170. :type add_end_state: bool
  171. :param add_end_state: If true, then NLTK will add a special
  172. end state, named '__end__'. The final cost for the end
  173. state will be set to 0; and the final cost for all other
  174. states will be set to +inf.
  175. :type trace: int
  176. :param trace: Controls the verbosity of trace output generated
  177. while training the CRF. Higher numbers generate more verbose
  178. output.
  179. """
  180. t0 = time.time() # Record starting time.
  181. # If they did not supply a model filename, then choose one.
  182. if filename is None:
  183. (fd, filename) = mkstemp('.crf', 'model')
  184. os.fdopen(fd).close()
  185. # Ensure that the filename ends with '.zip'
  186. if not filename.endswith('.crf'):
  187. filename += '.crf'
  188. if trace >= 1:
  189. print('[MalletCRF] Training a new CRF: %s' % filename)
  190. # Create crf-info object describing the new CRF.
  191. crf_info = MalletCRF._build_crf_info(
  192. corpus, gaussian_variance, default_label, max_iterations,
  193. transduction_type, weight_groups, add_start_state,
  194. add_end_state, filename, feature_detector)
  195. # Create a zipfile, and write crf-info to it.
  196. if trace >= 2:
  197. print('[MalletCRF] Adding crf-info.xml to %s' % filename)
  198. zf = zipfile.ZipFile(filename, mode='w')
  199. zf.writestr('crf-info.xml', crf_info.toxml()+'\n')
  200. zf.close()
  201. # Create the CRF object.
  202. crf = MalletCRF(filename, feature_detector)
  203. # Write the Training corpus to a temporary file.
  204. if trace >= 2:
  205. print('[MalletCRF] Writing training corpus...')
  206. (fd, train_file) = mkstemp('.txt', 'train')
  207. crf.write_training_corpus(corpus, os.fdopen(fd, 'w'))
  208. try:
  209. if trace >= 1:
  210. print('[MalletCRF] Calling mallet to train CRF...')
  211. cmd = [MalletCRF._TRAIN_CRF,
  212. '--model-file', os.path.abspath(filename),
  213. '--train-file', train_file]
  214. if trace > 3:
  215. call_mallet(cmd)
  216. else:
  217. p = call_mallet(cmd, stdout=subprocess.PIPE,
  218. stderr=subprocess.STDOUT,
  219. blocking=False)
  220. MalletCRF._filter_training_output(p, trace)
  221. finally:
  222. # Delete the temp file containing the training corpus.
  223. os.remove(train_file)
  224. if trace >= 1:
  225. print('[MalletCRF] Training complete.')
  226. print('[MalletCRF] Model stored in: %s' % filename)
  227. if trace >= 2:
  228. dt = time.time()-t0
  229. print('[MalletCRF] Total training time: %d seconds' % dt)
  230. # Return the completed CRF.
  231. return crf
  232. @staticmethod
  233. def _build_crf_info(corpus, gaussian_variance, default_label,
  234. max_iterations, transduction_type, weight_groups,
  235. add_start_state, add_end_state,
  236. model_filename, feature_detector):
  237. """
  238. Construct a CRFInfo object describing a CRF with a given
  239. set of configuration parameters, and based on the contents of
  240. a given corpus.
  241. """
  242. state_info_list = []
  243. labels = set()
  244. if add_start_state:
  245. labels.add('__start__')
  246. if add_end_state:
  247. labels.add('__end__')
  248. transitions = set() # not necessary to find this?
  249. for sent in corpus:
  250. prevtag = default_label
  251. for (tok,tag) in sent:
  252. labels.add(tag)
  253. transitions.add( (prevtag, tag) )
  254. prevtag = tag
  255. if add_start_state:
  256. transitions.add( ('__start__', sent[0][1]) )
  257. if add_end_state:
  258. transitions.add( (sent[-1][1], '__end__') )
  259. labels = sorted(labels)
  260. # 0th order default:
  261. if weight_groups is None:
  262. weight_groups = [CRFInfo.WeightGroup(name=l, src='.*',
  263. dst=re.escape(l))
  264. for l in labels]
  265. # Check that weight group names are unique
  266. if len(weight_groups) != len(set(wg.name for wg in weight_groups)):
  267. raise ValueError("Weight group names must be unique")
  268. # Construct a list of state descriptions. Currently, we make
  269. # these states fully-connected, with one parameter per
  270. # transition.
  271. for src in labels:
  272. if add_start_state:
  273. initial_cost = (0 if src == '__start__' else '+inf')
  274. if add_end_state:
  275. final_cost = (0 if src == '__end__' else '+inf')
  276. state_info = CRFInfo.State(src, initial_cost, final_cost, [])
  277. for dst in labels:
  278. state_weight_groups = [wg.name for wg in weight_groups
  279. if wg.match(src, dst)]
  280. state_info.transitions.append(
  281. CRFInfo.Transition(dst, dst, state_weight_groups))
  282. state_info_list.append(state_info)
  283. return CRFInfo(state_info_list, gaussian_variance,
  284. default_label, max_iterations,
  285. transduction_type, weight_groups,
  286. add_start_state, add_end_state,
  287. model_filename, feature_detector)
  288. #: A table used to filter the output that mallet generates during
  289. #: training. By default, mallet generates very verbose output.
  290. #: This table is used to select which lines of output are actually
  291. #: worth displaying to the user, based on the level of the *trace*
  292. #: parameter. Each entry of this table is a tuple
  293. #: (min_trace_level, regexp). A line will be displayed only if
  294. #: trace>=min_trace_level and the line matches regexp for at
  295. #: least one table entry.
  296. _FILTER_TRAINING_OUTPUT = [
  297. (1, r'DEBUG:.*'),
  298. (1, r'Number of weights.*'),
  299. (1, r'CRF about to train.*'),
  300. (1, r'CRF finished.*'),
  301. (1, r'CRF training has converged.*'),
  302. (2, r'CRF weights.*'),
  303. (2, r'getValue\(\) \(loglikelihood\) .*'),
  304. ]
  305. @staticmethod
  306. def _filter_training_output(p, trace):
  307. """
  308. Filter the (very verbose) output that is generated by mallet,
  309. and only display the interesting lines. The lines that are
  310. selected for display are determined by _FILTER_TRAINING_OUTPUT.
  311. """
  312. out = []
  313. while p.poll() is None:
  314. while True:
  315. line = p.stdout.readline()
  316. if not line: break
  317. out.append(line)
  318. for (t, regexp) in MalletCRF._FILTER_TRAINING_OUTPUT:
  319. if t <= trace and re.match(regexp, line):
  320. indent = ' '*t
  321. print('[MalletCRF] %s%s' % (indent, line.rstrip()))
  322. break
  323. if p.returncode != 0:
  324. print("\nError encountered! Mallet's most recent output:")
  325. print(''.join(out[-100:]))
  326. raise OSError('Mallet command failed')
  327. #/////////////////////////////////////////////////////////////////
  328. # Communication w/ mallet
  329. #/////////////////////////////////////////////////////////////////
  330. def write_training_corpus(self, corpus, stream, close_stream=True):
  331. """
  332. Write a given training corpus to a given stream, in a format that
  333. can be read by the java script org.nltk.mallet.TrainCRF.
  334. """
  335. feature_detector = self.crf_info.feature_detector
  336. for sentence in corpus:
  337. if self.crf_info.add_start_state:
  338. stream.write('__start__ __start__\n')
  339. unlabeled_sent = [tok for (tok,tag) in sentence]
  340. for index in range(len(unlabeled_sent)):
  341. featureset = feature_detector(unlabeled_sent, index)
  342. for (fname, fval) in featureset.items():
  343. stream.write(self._format_feature(fname, fval)+" ")
  344. stream.write(sentence[index][1]+'\n')
  345. if self.crf_info.add_end_state:
  346. stream.write('__end__ __end__\n')
  347. stream.write('\n')
  348. if close_stream: stream.close()
  349. def write_test_corpus(self, corpus, stream, close_stream=True):
  350. """
  351. Write a given test corpus to a given stream, in a format that
  352. can be read by the java script org.nltk.mallet.TestCRF.
  353. """
  354. feature_detector = self.crf_info.feature_detector
  355. for sentence in corpus:
  356. if self.crf_info.add_start_state:
  357. stream.write('__start__ __start__\n')
  358. for index in range(len(sentence)):
  359. featureset = feature_detector(sentence, index)
  360. for (fname, fval) in featureset.items():
  361. stream.write(self._format_feature(fname, fval)+" ")
  362. stream.write('\n')
  363. if self.crf_info.add_end_state:
  364. stream.write('__end__ __end__\n')
  365. stream.write('\n')
  366. if close_stream: stream.close()
  367. def parse_mallet_output(self, s):
  368. """
  369. Parse the output that is generated by the java script
  370. org.nltk.mallet.TestCRF, and convert it to a labeled
  371. corpus.
  372. """
  373. if re.match(r'\s*<<start>>', s):
  374. assert 0, 'its a lattice'
  375. corpus = [[]]
  376. for line in s.split('\n'):
  377. line = line.strip()
  378. # Label with augmentations?
  379. if line:
  380. corpus[-1].append(line.strip())
  381. # Start of new instance?
  382. elif corpus[-1] != []:
  383. corpus.append([])
  384. if corpus[-1] == []: corpus.pop()
  385. return corpus
  386. _ESCAPE_RE = re.compile('[^a-zA-Z0-9]')
  387. @staticmethod
  388. def _escape_sub(m):
  389. return '%' + ('%02x' % ord(m.group()))
  390. @staticmethod
  391. def _format_feature(fname, fval):
  392. """
  393. Return a string name for a given feature (name, value) pair,
  394. appropriate for consumption by mallet. We escape every
  395. character in fname or fval that's not a letter or a number,
  396. just to be conservative.
  397. """
  398. fname = MalletCRF._ESCAPE_RE.sub(MalletCRF._escape_sub, fname)
  399. if isinstance(fval, compat.string_types):
  400. fval = "'%s'" % MalletCRF._ESCAPE_RE.sub(
  401. MalletCRF._escape_sub, fval)
  402. else:
  403. fval = MalletCRF._ESCAPE_RE.sub(MalletCRF._escape_sub, '%r'%fval)
  404. return fname+'='+fval
  405. #/////////////////////////////////////////////////////////////////
  406. # String Representation
  407. #/////////////////////////////////////////////////////////////////
  408. def __repr__(self):
  409. return 'MalletCRF(%r)' % self.crf_info.model_filename
  410. ###########################################################################
  411. ## Serializable CRF Information Object
  412. ###########################################################################
  413. class CRFInfo(object):
  414. """
  415. An object used to record configuration information about a
  416. MalletCRF object. This configuration information can be
  417. serialized to an XML file, which can then be read by NLTK's custom
  418. interface to Mallet's CRF.
  419. CRFInfo objects are typically created by the ``MalletCRF.train()``
  420. method.
  421. Advanced users may wish to directly create custom
  422. CRFInfo.WeightGroup objects and pass them to the
  423. ``MalletCRF.train()`` function. See CRFInfo.WeightGroup for
  424. more information.
  425. """
  426. def __init__(self, states, gaussian_variance, default_label,
  427. max_iterations, transduction_type, weight_groups,
  428. add_start_state, add_end_state, model_filename,
  429. feature_detector):
  430. self.gaussian_variance = float(gaussian_variance)
  431. self.default_label = default_label
  432. self.states = states
  433. self.max_iterations = max_iterations
  434. self.transduction_type = transduction_type
  435. self.weight_groups = weight_groups
  436. self.add_start_state = add_start_state
  437. self.add_end_state = add_end_state
  438. self.model_filename = model_filename
  439. if isinstance(feature_detector, compat.string_types):
  440. self.feature_detector_name = feature_detector
  441. self.feature_detector = None
  442. else:
  443. self.feature_detector_name = feature_detector.__name__
  444. self.feature_detector = feature_detector
  445. _XML_TEMPLATE = (
  446. '<crf>\n'
  447. ' <modelFile>%(model_filename)s</modelFile>\n'
  448. ' <gaussianVariance>%(gaussian_variance)d</gaussianVariance>\n'
  449. ' <defaultLabel>%(default_label)s</defaultLabel>\n'
  450. ' <maxIterations>%(max_iterations)s</maxIterations>\n'
  451. ' <transductionType>%(transduction_type)s</transductionType>\n'
  452. ' <featureDetector name="%(feature_detector_name)s">\n'
  453. ' %(feature_detector)s\n'
  454. ' </featureDetector>\n'
  455. ' <addStartState>%(add_start_state)s</addStartState>\n'
  456. ' <addEndState>%(add_end_state)s</addEndState>\n'
  457. ' <states>\n'
  458. '%(states)s\n'
  459. ' </states>\n'
  460. ' <weightGroups>\n'
  461. '%(w_groups)s\n'
  462. ' </weightGroups>\n'
  463. '</crf>\n')
  464. def toxml(self):
  465. info = self.__dict__.copy()
  466. info['states'] = '\n'.join(state.toxml() for state in self.states)
  467. info['w_groups'] = '\n'.join(wg.toxml() for wg in self.weight_groups)
  468. info['feature_detector_name'] = (info['feature_detector_name']
  469. .replace('&', '&amp;')
  470. .replace('<', '&lt;'))
  471. try:
  472. fd = pickle.dumps(self.feature_detector)
  473. fd = fd.replace('&', '&amp;').replace('<', '&lt;')
  474. fd = fd.replace('\n', '&#10;') # put pickle data all on 1 line.
  475. info['feature_detector'] = '<pickle>%s</pickle>' % fd
  476. except pickle.PicklingError:
  477. info['feature_detector'] = ''
  478. return self._XML_TEMPLATE % info
  479. @staticmethod
  480. def fromstring(s):
  481. return CRFInfo._read(ElementTree.fromstring(s))
  482. @staticmethod
  483. def _read(etree):
  484. states = [CRFInfo.State._read(et) for et in
  485. etree.findall('states/state')]
  486. weight_groups = [CRFInfo.WeightGroup._read(et) for et in
  487. etree.findall('weightGroups/weightGroup')]
  488. fd = etree.find('featureDetector')
  489. feature_detector = fd.get('name')
  490. if fd.find('pickle') is not None:
  491. try: feature_detector = pickle.loads(fd.find('pickle').text)
  492. except pickle.PicklingError as e: pass # unable to unpickle it.
  493. return CRFInfo(states,
  494. float(etree.find('gaussianVariance').text),
  495. etree.find('defaultLabel').text,
  496. int(etree.find('maxIterations').text),
  497. etree.find('transductionType').text,
  498. weight_groups,
  499. bool(etree.find('addStartState').text),
  500. bool(etree.find('addEndState').text),
  501. etree.find('modelFile').text,
  502. feature_detector)
  503. def write(self, filename, encoding='utf8'):
  504. with codecs.open(filename, 'w', encoding) as out:
  505. out.write(self.toxml())
  506. out.write('\n')
  507. class State(object):
  508. """
  509. A description of a single CRF state.
  510. """
  511. def __init__(self, name, initial_cost, final_cost, transitions):
  512. if initial_cost != '+inf': initial_cost = float(initial_cost)
  513. if final_cost != '+inf': final_cost = float(final_cost)
  514. self.name = name
  515. self.initial_cost = initial_cost
  516. self.final_cost = final_cost
  517. self.transitions = transitions
  518. _XML_TEMPLATE = (
  519. ' <state name="%(name)s" initialCost="%(initial_cost)s" '
  520. 'finalCost="%(final_cost)s">\n'
  521. ' <transitions>\n'
  522. '%(transitions)s\n'
  523. ' </transitions>\n'
  524. ' </state>\n')
  525. def toxml(self):
  526. info = self.__dict__.copy()
  527. info['transitions'] = '\n'.join(transition.toxml()
  528. for transition in self.transitions)
  529. return self._XML_TEMPLATE % info
  530. @staticmethod
  531. def _read(etree):
  532. transitions = [CRFInfo.Transition._read(et)
  533. for et in etree.findall('transitions/transition')]
  534. return CRFInfo.State(etree.get('name'),
  535. etree.get('initialCost'),
  536. etree.get('finalCost'),
  537. transitions)
  538. class Transition(object):
  539. """
  540. A description of a single CRF transition.
  541. """
  542. def __init__(self, destination, label, weightgroups):
  543. """
  544. :param destination: The name of the state that this transition
  545. connects to.
  546. :param label: The tag that is generated when traversing this
  547. transition.
  548. :param weightgroups: A list of WeightGroup names, indicating
  549. which weight groups should be used to calculate the cost
  550. of traversing this transition.
  551. """
  552. self.destination = destination
  553. self.label = label
  554. self.weightgroups = weightgroups
  555. _XML_TEMPLATE = (' <transition label="%(label)s" '
  556. 'destination="%(destination)s" '
  557. 'weightGroups="%(w_groups)s"/>')
  558. def toxml(self):
  559. info = self.__dict__
  560. info['w_groups'] = ' '.join(wg for wg in self.weightgroups)
  561. return self._XML_TEMPLATE % info
  562. @staticmethod
  563. def _read(etree):
  564. return CRFInfo.Transition(etree.get('destination'),
  565. etree.get('label'),
  566. etree.get('weightGroups').split())
  567. class WeightGroup(object):
  568. """
  569. A configuration object used by MalletCRF to specify how
  570. input-features (which are a function of only the input) should be
  571. mapped to joint-features (which are a function of both the input
  572. and the output tags).
  573. Each weight group specifies that a given set of input features
  574. should be paired with all transitions from a given set of source
  575. tags to a given set of destination tags.
  576. """
  577. def __init__(self, name, src, dst, features='.*'):
  578. """
  579. :param name: A unique name for this weight group.
  580. :param src: The set of source tags that should be used for
  581. this weight group, specified as either a list of state
  582. names or a regular expression.
  583. :param dst: The set of destination tags that should be used
  584. for this weight group, specified as either a list of state
  585. names or a regular expression.
  586. :param features: The set of input feature that should be used
  587. for this weight group, specified as either a list of
  588. feature names or a regular expression. WARNING: currently,
  589. this regexp is passed streight to java -- i.e., it must
  590. be a java-style regexp!
  591. """
  592. if re.search('\s', name):
  593. raise ValueError('weight group name may not '
  594. 'contain whitespace.')
  595. if re.search('"', name):
  596. raise ValueError('weight group name may not contain \'"\'.')
  597. self.name = name
  598. self.src = src
  599. self.dst = dst
  600. self.features = features
  601. self._src_match_cache = {}
  602. self._dst_match_cache = {}
  603. _XML_TEMPLATE = (' <weightGroup name="%(name)s" src="%(src)s" '
  604. 'dst="%(dst)s" features="%(features)s" />')
  605. def toxml(self):
  606. return self._XML_TEMPLATE % self.__dict__
  607. @staticmethod
  608. def _read(etree):
  609. return CRFInfo.WeightGroup(etree.get('name'),
  610. etree.get('src'),
  611. etree.get('dst'),
  612. etree.get('features'))
  613. # [xx] feature name????
  614. def match(self, src, dst):
  615. # Check if the source matches
  616. src_match = self._src_match_cache.get(src)
  617. if src_match is None:
  618. if isinstance(self.src, compat.string_types):
  619. src_match = bool(re.match(self.src+'\Z', src))
  620. else:
  621. src_match = src in self.src
  622. self._src_match_cache[src] = src_match
  623. # Check if the dest matches
  624. dst_match = self._dst_match_cache.get(dst)
  625. if dst_match is None:
  626. if isinstance(self.dst, compat.string_types):
  627. dst_match = bool(re.match(self.dst+'\Z', dst))
  628. else:
  629. dst_match = dst in self.dst
  630. self._dst_match_cache[dst] = dst_match
  631. # Return true if both matched.
  632. return src_match and dst_match
  633. ###########################################################################
  634. ## Demonstration code
  635. ###########################################################################
  636. def demo(train_size=100, test_size=100, java_home=None, mallet_home=None):
  637. from nltk.corpus import brown
  638. import textwrap
  639. # Define a very simple feature detector
  640. def fd(sentence, index):
  641. word = sentence[index]
  642. return dict(word=word, suffix=word[-2:], len=len(word))
  643. # Let nltk know where java & mallet are.
  644. nltk.internals.config_java(java_home)
  645. nltk.classify.mallet.config_mallet(mallet_home)
  646. # Get the training & test corpus. We simplify the tagset a little:
  647. # just the first 2 chars.
  648. def strip(corpus): return [[(w, t[:2]) for (w,t) in sent]
  649. for sent in corpus]
  650. brown_train = strip(brown.tagged_sents(categories='news')[:train_size])
  651. brown_test = strip(brown.tagged_sents(categories='editorial')[:test_size])
  652. crf = MalletCRF.train(fd, brown_train, #'/tmp/crf-model',
  653. transduction_type='VITERBI')
  654. sample_output = crf.tag([w for (w,t) in brown_test[5]])
  655. acc = nltk.tag.accuracy(crf, brown_test)
  656. print('\nAccuracy: %.1f%%' % (acc*100))
  657. print('Sample output:')
  658. print(textwrap.fill(' '.join('%s/%s' % w for w in sample_output),
  659. initial_indent=' ', subsequent_indent=' ')+'\n')
  660. # Clean up
  661. print('Clean-up: deleting', crf.filename)
  662. os.remove(crf.filename)
  663. return crf
  664. if __name__ == "__main__":
  665. demo()