/max_ent.py

https://bitbucket.org/Ayuei/viterbi
Python | 186 lines | 130 code | 40 blank | 16 comment | 34 complexity | fef2584acbbcc8a696e27f076c347ab0 MD5 | raw file
  1. import sys
  2. from sklearn.linear_model import LogisticRegression
  3. from sklearn.feature_extraction import DictVectorizer
  4. from sklearn.pipeline import Pipeline
  5. from sklearn.decomposition import TruncatedSVD
  6. from itertools import combinations
  7. from copy import deepcopy
  8. #Feature selection.
  9. def features(sentence, index, tags, syns, trths, keys_to_drop=None):
  10. d = {
  11. 'word': sentence[index],
  12. 'is_first': index == 0,
  13. 'is_last': index == len(sentence) - 1,
  14. 'is_capitalized': sentence[index][0].upper() == sentence[index][0],
  15. 'is_all_caps': sentence[index].upper() == sentence[index],
  16. 'is_all_lower': sentence[index].lower() == sentence[index],
  17. 'prefix-1': sentence[index][0],
  18. 'prefix-2': sentence[index][:2],
  19. 'prefix-3': sentence[index][:3],
  20. 'suffix-1': sentence[index][-1],
  21. 'suffix-2': sentence[index][-2:],
  22. 'suffix-3': sentence[index][-3:],
  23. 'prev_word': '' if index == 0 else sentence[index - 1],
  24. 'next_word': '' if index == len(sentence) - 1 else sentence[index + 1],
  25. 'has_hyphen': '-' in sentence[index],
  26. 'is_numeric': sentence[index].isdigit(),
  27. 'capitals_inside': sentence[index][1:].lower() != sentence[index][1:],
  28. 'pos': tags[index],
  29. 'prev_pos': '' if index == 0 else tags[index-1],
  30. 'next_pos': '' if index == len(tags)-1 else tags[index+1],
  31. 'prev_pos_2': '' if index <= 1 else tags[index - 2],
  32. 'next_pos_2': '' if index >= len(tags) - 2 else tags[index + 2],
  33. 'syn': syns[index],
  34. 'prev_syn': '' if index == 0 else syns[index-1],
  35. 'next_syn': '' if index == len(syns)-1 else syns[index+1],
  36. 'prev_syn_2': '' if index <= 1 else syns[index - 2],
  37. 'next_syn_2': '' if index >= len(syns) - 2 else syns[index + 2],
  38. 'previous_class': '' if index == 0 else trths[index-1]
  39. }
  40. if keys_to_drop is not None:
  41. for key in keys_to_drop:
  42. d.pop(key)
  43. return d
  44. def get_features_from_file(f, keys_to_drop=None):
  45. lines = f.readlines()
  46. del lines[0]
  47. X_train = []
  48. y_train = []
  49. sentence = []
  50. tags = []
  51. syn_tags = []
  52. NER_tag = []
  53. sentences = []
  54. for line in lines:
  55. if line == "\n":
  56. continue
  57. line = line.strip().split()
  58. sentence.append(line[0])
  59. tags.append(line[1])
  60. syn_tags.append(line[2])
  61. NER_tag.append(line[3])
  62. if line[1] == '.':
  63. TEMP = []
  64. for i in range(len(sentence)):
  65. X_train.append(features(sentence, i, tags, syn_tags, NER_tag, keys_to_drop))
  66. TEMP.append(features(sentence, i, tags, syn_tags, NER_tag, keys_to_drop))
  67. y_train.append(NER_tag[i])
  68. sentences.append(TEMP)
  69. sentence = []
  70. tags = []
  71. syn_tags = []
  72. NER_tag = []
  73. return X_train, y_train, sentences
  74. def main(language="deu", clf_class=LogisticRegression, keys_to_drop=None, ret_score=False):
  75. X_train, y_train, _ = get_features_from_file(open('data/'+language+'.train', 'r'), keys_to_drop)
  76. addit_X, addit_y, _ = get_features_from_file(open('data/'+language+'.testa', 'r'), keys_to_drop)
  77. X_train.extend(addit_X)
  78. y_train.extend(addit_y)
  79. clf = Pipeline([
  80. ('vectoriser', DictVectorizer()),
  81. ('classifier', clf_class())
  82. ])
  83. #print("Training!")
  84. clf.fit(X_train, y_train)
  85. #print('Training done!')
  86. '''
  87. X_test, y_test, _ = get_features_from_file(open('data/'+language+'.testa', 'r'))
  88. print("Accuracy:", clf.score(X_test, y_test))
  89. '''
  90. X_test, y_test, _ = get_features_from_file(open('data/' + language + '.testb', 'r'))
  91. #print("Accuracy:", clf.score(X_test, y_test))
  92. if ret_score:
  93. return float(clf.score(X_test, y_test))
  94. return clf
  95. if __name__ == '__main__':
  96. keys = ['is_first', 'is_last', 'is_capitalized', 'is_all_caps', 'is_all_lower',
  97. 'prefix-1', 'prefix-2', 'prefix-3', 'suffix-1', 'suffix-2', 'suffix-3', 'prev_word',
  98. 'next_word', 'has_hyphen', 'is_numeric', 'capitals_inside', 'pos', 'prev_pos',
  99. 'next_pos', 'prev_pos_2', 'next_pos_2', 'syn', 'prev_syn', 'next_syn',
  100. 'prev_syn_2', 'next_syn_2', 'previous_class']
  101. temp = deepcopy(keys)
  102. baseline_score = main(keys_to_drop=temp, ret_score=True)
  103. baseline_best = main(ret_score=True)
  104. print('baseline', baseline_score, 'all_fets', baseline_best)
  105. delta = 0.001
  106. best_keys_key_spef = []
  107. best_keys_key_rem = []
  108. print('From top')
  109. for key in keys:
  110. score = main(keys_to_drop=[key], ret_score=True)
  111. print(key, score)
  112. print('From bottom')
  113. for key in keys:
  114. temp = deepcopy(keys)
  115. temp.remove(key)
  116. score = main(keys_to_drop=temp, ret_score=True)
  117. print(key, score)
  118. print('Ground analysis')
  119. position = ['is_first', 'is_last', 'prev_word', 'next_word']
  120. word_fets = ['is_capitalized', 'is_all_caps', 'is_numeric']
  121. char_fets = ['prefix-1', 'prefix-2', 'prefix-3', 'suffix-1', 'suffix-2', 'suffix-3', 'has_hyphen',
  122. 'capitals_inside']
  123. pos_fet = ['pos', 'prev_pos', 'next_pos', 'prev_pos_2', 'next_pos_2']
  124. syn_fet = ['syn', 'prev_syn', 'next_syn', 'prev_syn_2', 'next_syn_2', 'previous_class']
  125. groups = [position, word_fets, char_fets, pos_fet, syn_fet]
  126. for group in groups:
  127. temp = deepcopy(keys) # Runs clf for key specified
  128. for key in group:
  129. temp.remove(key)
  130. score = main(keys_to_drop=temp, ret_score=True)
  131. print(group, score)
  132. print('Group removed, not run')
  133. score = main(keys_to_drop=group, ret_score=True)
  134. print(group, score)
  135. '''
  136. for i in range(2, len(keys)):
  137. for comb_key in combinations(keys, i):
  138. temp = deepcopy(keys)
  139. for k in comb_key:
  140. temp.remove(k)
  141. score = main(keys_to_drop=temp, ret_score=True)
  142. print(list(comb_key), score)
  143. '''
  144. #print(best_keys_key_spef)
  145. #print(best_keys_key_rem)