PageRenderTime 74ms CodeModel.GetById 15ms RepoModel.GetById 1ms app.codeStats 0ms

/java/edu/berkeley/nlp/parser/EnglishPennTreebankParseEvaluator.java

https://bitbucket.org/yoavgo/blatt_spmrl
Java | 520 lines | 408 code | 92 blank | 20 comment | 81 complexity | 10b0ada3dd64924769d0c81d3822512d MD5 | raw file
  1. package edu.berkeley.nlp.parser;
  2. import edu.berkeley.nlp.syntax.Tree;
  3. import edu.berkeley.nlp.syntax.Trees;
  4. import java.util.*;
  5. import java.io.PrintWriter;
  6. import java.io.StringReader;
  7. /**
  8. * Evaluates precision and recall for English Penn Treebank parse trees. NOTE:
  9. * Unlike the standard evaluation, multiplicity over each span is ignored. Also,
  10. * punction is NOT currently deleted properly (approximate hack), and other
  11. * normalizations (like AVDP ~ PRT) are NOT done.
  12. *
  13. * @author Dan Klein
  14. */
  15. public class EnglishPennTreebankParseEvaluator<L> {
  16. static class UnlabeledConstituent<L> {
  17. int start;
  18. int end;
  19. public int getStart() {
  20. return start;
  21. }
  22. public int getEnd() {
  23. return end;
  24. }
  25. public boolean equals(Object o) {
  26. if (this == o)
  27. return true;
  28. if (!(o instanceof UnlabeledConstituent))
  29. return false;
  30. final UnlabeledConstituent unlabeledConstituent = (UnlabeledConstituent) o;
  31. if (end != unlabeledConstituent.end)
  32. return false;
  33. if (start != unlabeledConstituent.start)
  34. return false;
  35. return true;
  36. }
  37. public int hashCode() {
  38. int result;
  39. result = start;
  40. result = 29 * result + end;
  41. return result;
  42. }
  43. public String toString() {
  44. return "[" + start + "," + end + "]";
  45. }
  46. public UnlabeledConstituent(int start, int end) {
  47. this.start = start;
  48. this.end = end;
  49. }
  50. }
  51. abstract static class AbstractEval<L> {
  52. protected String str = "";
  53. private int exact = 0;
  54. private int total = 0;
  55. private int correctEvents = 0;
  56. private int guessedEvents = 0;
  57. private int goldEvents = 0;
  58. abstract Set<Object> makeObjects(Tree<L> tree);
  59. public double evaluate(Tree<L> guess, Tree<L> gold) {
  60. return evaluate(guess, gold, new PrintWriter(System.out, true));
  61. }
  62. public double evaluate(Tree<L> guess, Tree<L> gold, boolean b) {
  63. return evaluate(guess, gold, null);
  64. }
  65. /*
  66. * evaluates precision and recall by calling makeObjects() to make a set
  67. * of structures for guess Tree and gold Tree, and compares them with
  68. * each other.
  69. */
  70. public double evaluate(Tree<L> guess, Tree<L> gold, PrintWriter pw) {
  71. Set<Object> guessedSet = makeObjects(guess);
  72. Set<Object> goldSet = makeObjects(gold);
  73. Set<Object> correctSet = new HashSet<Object>();
  74. correctSet.addAll(goldSet);
  75. correctSet.retainAll(guessedSet);
  76. correctEvents += correctSet.size();
  77. guessedEvents += guessedSet.size();
  78. goldEvents += goldSet.size();
  79. int currentExact = 0;
  80. if (correctSet.size() == guessedSet.size()
  81. && correctSet.size() == goldSet.size()) {
  82. exact++;
  83. currentExact = 1;
  84. }
  85. total++;
  86. // guess.pennPrint(pw);
  87. // gold.pennPrint(pw);
  88. double f1 = displayPRF(str + " [Current] ", correctSet.size(),
  89. guessedSet.size(), goldSet.size(), currentExact, 1, pw);
  90. return f1;
  91. }
  92. public double evaluateMultiple(List<Tree<L>> guesses,
  93. List<Tree<L>> golds, PrintWriter pw) {
  94. assert (guesses.size() == golds.size());
  95. int correctCount = 0;
  96. int guessedCount = 0;
  97. int goldCount = 0;
  98. for (int i = 0; i < guesses.size(); i++) {
  99. Tree<L> guess = guesses.get(i);
  100. Tree<L> gold = golds.get(i);
  101. Set<Object> guessedSet = makeObjects(guess);
  102. Set<Object> goldSet = makeObjects(gold);
  103. Set<Object> correctSet = new HashSet<Object>();
  104. correctSet.addAll(goldSet);
  105. correctSet.retainAll(guessedSet);
  106. correctCount += correctSet.size();
  107. guessedCount += guessedSet.size();
  108. goldCount += goldSet.size();
  109. }
  110. correctEvents += correctCount;
  111. guessedEvents += guessedCount;
  112. goldEvents += goldCount;
  113. int currentExact = 0;
  114. if (correctCount == guessedCount && correctCount == goldCount) {
  115. exact++;
  116. currentExact = 1;
  117. }
  118. total++;
  119. // guess.pennPrint(pw);
  120. // gold.pennPrint(pw);
  121. double f1 = displayPRF(str + " [Current] ", correctCount,
  122. guessedCount, goldCount, currentExact, 1, pw);
  123. return f1;
  124. }
  125. public double[] massEvaluate(Tree<L> guess, Tree<L>[] goldTrees) {
  126. Set<Object> guessedSet = makeObjects(guess);
  127. double cEvents = 0;
  128. double guEvents = 0;
  129. double goEvents = 0;
  130. double exactM = 0, precision = 0, recall = 0, f1 = 0;
  131. for (int treeI = 0; treeI < goldTrees.length; treeI++) {
  132. Tree<L> gold = goldTrees[treeI];
  133. Set<Object> goldSet = makeObjects(gold);
  134. Set<Object> correctSet = new HashSet<Object>();
  135. correctSet.addAll(goldSet);
  136. correctSet.retainAll(guessedSet);
  137. cEvents = correctSet.size();
  138. guEvents = guessedSet.size();
  139. goEvents = goldSet.size();
  140. double p = cEvents / guEvents;
  141. double r = cEvents / goEvents;
  142. double f = (p > 0.0 && r > 0.0 ? 2.0 / (1.0 / p + 1.0 / r)
  143. : 0.0);
  144. precision += p;
  145. recall += r;
  146. f1 += f;
  147. if (cEvents == guEvents && cEvents == goEvents) {
  148. exactM++;
  149. }
  150. }
  151. double ex = exactM / goldTrees.length;
  152. double[] results = { precision, recall, f1, ex };
  153. return results;
  154. }
  155. private double displayPRF(String prefixStr, int correct, int guessed,
  156. int gold, int exact, int total, PrintWriter pw) {
  157. double precision = (guessed > 0 ? correct / (double) guessed : 1.0);
  158. double recall = (gold > 0 ? correct / (double) gold : 1.0);
  159. double f1 = (precision > 0.0 && recall > 0.0 ? 2.0 / (1.0 / precision + 1.0 / recall)
  160. : 0.0);
  161. double exactMatch = exact / (double) total;
  162. String displayStr = " P: " + ((int) (precision * 10000)) / 100.0
  163. + " R: " + ((int) (recall * 10000)) / 100.0 + " F1: "
  164. + ((int) (f1 * 10000)) / 100.0 + " EX: "
  165. + ((int) (exactMatch * 10000)) / 100.0;
  166. if (pw != null)
  167. pw.println(prefixStr + displayStr);
  168. return f1;
  169. }
  170. public double display(boolean verbose) {
  171. return display(verbose, new PrintWriter(System.out, true));
  172. }
  173. public double display(boolean verbose, PrintWriter pw) {
  174. return displayPRF(str + " [Average] ", correctEvents,
  175. guessedEvents, goldEvents, exact, total, pw);
  176. }
  177. }
  178. static class LabeledConstituent<L> {
  179. L label;
  180. int start;
  181. int end;
  182. public L getLabel() {
  183. return label;
  184. }
  185. public int getStart() {
  186. return start;
  187. }
  188. public int getEnd() {
  189. return end;
  190. }
  191. public boolean equals(Object o) {
  192. if (this == o)
  193. return true;
  194. if (!(o instanceof LabeledConstituent))
  195. return false;
  196. final LabeledConstituent labeledConstituent = (LabeledConstituent) o;
  197. if (end != labeledConstituent.end)
  198. return false;
  199. if (start != labeledConstituent.start)
  200. return false;
  201. if (label != null ? !label.equals(labeledConstituent.label)
  202. : labeledConstituent.label != null)
  203. return false;
  204. return true;
  205. }
  206. public int hashCode() {
  207. int result;
  208. result = (label != null ? label.hashCode() : 0);
  209. result = 29 * result + start;
  210. result = 29 * result + end;
  211. return result;
  212. }
  213. public String toString() {
  214. return label + "[" + start + "," + end + "]";
  215. }
  216. public LabeledConstituent(L label, int start, int end) {
  217. this.label = label;
  218. this.start = start;
  219. this.end = end;
  220. }
  221. }
  222. public static class UnlabeledConstituentEval<L> extends AbstractEval<L> {
  223. public UnlabeledConstituentEval() {
  224. }
  225. @Override
  226. Set<Object> makeObjects(Tree<L> tree) {
  227. Tree<L> noLeafTree = LabeledConstituentEval.stripLeaves(tree);
  228. Set<Object> set = new HashSet<Object>();
  229. addConstituents(noLeafTree, set, 0);
  230. return set;
  231. }
  232. private int addConstituents(Tree<L> tree, Set<Object> set, int start) {
  233. if (tree == null)
  234. return 0;
  235. if (tree.getYield().size() == 1) {
  236. return 1;
  237. }
  238. int end = start;
  239. for (Tree<L> child : tree.getChildren()) {
  240. int childSpan = addConstituents(child, set, end);
  241. end += childSpan;
  242. }
  243. set.add(new UnlabeledConstituent<L>(start, end));
  244. return end - start;
  245. }
  246. }
  247. public static class LabeledConstituentEval<L> extends AbstractEval<L> {
  248. Set<L> labelsToIgnore;
  249. Set<L> punctuationTags;
  250. static <L> Tree<L> stripLeaves(Tree<L> tree) {
  251. if (tree.isLeaf())
  252. return null;
  253. if (tree.isPreTerminal())
  254. return new Tree<L>(tree.getLabel());
  255. List<Tree<L>> children = new ArrayList<Tree<L>>();
  256. for (Tree<L> child : tree.getChildren()) {
  257. children.add(stripLeaves(child));
  258. }
  259. return new Tree<L>(tree.getLabel(), children);
  260. }
  261. Set<Object> makeObjects(Tree<L> tree) {
  262. Tree<L> noLeafTree = stripLeaves(tree);
  263. Set<Object> set = new HashSet<Object>();
  264. addConstituents(noLeafTree, set, 0);
  265. return set;
  266. }
  267. private int addConstituents(Tree<L> tree, Set<Object> set, int start) {
  268. if (tree == null)
  269. return 0;
  270. if (tree.isLeaf()) {
  271. if (punctuationTags.contains(tree.getLabel()))
  272. return 0;
  273. else
  274. return 1;
  275. }
  276. int end = start;
  277. for (Tree<L> child : tree.getChildren()) {
  278. int childSpan = addConstituents(child, set, end);
  279. end += childSpan;
  280. }
  281. L label = tree.getLabel();
  282. if (!labelsToIgnore.contains(label)) {
  283. set.add(new LabeledConstituent<L>(label, start, end));
  284. }
  285. return end - start;
  286. }
  287. public LabeledConstituentEval(Set<L> labelsToIgnore,
  288. Set<L> punctuationTags) {
  289. this.labelsToIgnore = labelsToIgnore;
  290. this.punctuationTags = punctuationTags;
  291. }
  292. public int getHammingDistance(Tree<L> guess, Tree<L> gold) {
  293. Set<Object> guessedSet = makeObjects(guess);
  294. Set<Object> goldSet = makeObjects(gold);
  295. Set<Object> correctSet = new HashSet<Object>();
  296. correctSet.addAll(goldSet);
  297. correctSet.retainAll(guessedSet);
  298. return (guessedSet.size() - correctSet.size())
  299. + (goldSet.size() - correctSet.size());
  300. }
  301. }
  302. public static void main(String[] args) throws Throwable {
  303. Tree<String> goldTree = (new Trees.PennTreeReader(new StringReader(
  304. "(ROOT (S (NP (DT the) (NN can)) (VP (VBD fell))))"))).next();
  305. Tree<String> guessedTree = (new Trees.PennTreeReader(new StringReader(
  306. "(ROOT (S (NP (DT the)) (VP (MB can) (VP (VBD fell)))))")))
  307. .next();
  308. LabeledConstituentEval<String> eval = new LabeledConstituentEval<String>(
  309. Collections.singleton("ROOT"), new HashSet<String>());
  310. RuleEval<String> rule_eval = new RuleEval<String>(Collections
  311. .singleton("ROOT"), new HashSet<String>());
  312. System.out.println("Gold tree:\n"
  313. + Trees.PennTreeRenderer.render(goldTree));
  314. System.out.println("Guessed tree:\n"
  315. + Trees.PennTreeRenderer.render(guessedTree));
  316. eval.evaluate(guessedTree, goldTree);
  317. eval.display(true);
  318. rule_eval.evaluate(guessedTree, goldTree);
  319. rule_eval.display(true);
  320. }
  321. public static class RuleEval<L> extends AbstractEval<L> {
  322. Set<L> labelsToIgnore;
  323. Set<L> punctuationTags;
  324. static <L> Tree<L> stripLeaves(Tree<L> tree) {
  325. if (tree.isLeaf())
  326. return null;
  327. if (tree.isPreTerminal())
  328. return new Tree<L>(tree.getLabel());
  329. List<Tree<L>> children = new ArrayList<Tree<L>>();
  330. for (Tree<L> child : tree.getChildren()) {
  331. children.add(stripLeaves(child));
  332. }
  333. return new Tree<L>(tree.getLabel(), children);
  334. }
  335. Set<Object> makeObjects(Tree<L> tree) {
  336. Tree<L> noLeafTree = stripLeaves(tree);
  337. Set<Object> set = new HashSet<Object>();
  338. addConstituents(noLeafTree, set, 0);
  339. return set;
  340. }
  341. private int addConstituents(Tree<L> tree, Set<Object> set, int start) {
  342. if (tree == null)
  343. return 0;
  344. if (tree.isLeaf()) {
  345. /*
  346. * if (punctuationTags.contains(tree.getLabel())) return 0; else
  347. */
  348. return 1;
  349. }
  350. int end = start, i = 0;
  351. L lC = null, rC = null;
  352. for (Tree<L> child : tree.getChildren()) {
  353. int childSpan = addConstituents(child, set, end);
  354. if (i == 0)
  355. lC = child.getLabel();
  356. else
  357. /* i==1 */rC = child.getLabel();
  358. i++;
  359. end += childSpan;
  360. }
  361. L label = tree.getLabel();
  362. if (!labelsToIgnore.contains(label)) {
  363. set.add(new RuleConstituent<L>(label, lC, rC, start, end));
  364. }
  365. return end - start;
  366. }
  367. public RuleEval(Set<L> labelsToIgnore, Set<L> punctuationTags) {
  368. this.labelsToIgnore = labelsToIgnore;
  369. this.punctuationTags = punctuationTags;
  370. }
  371. }
  372. static class RuleConstituent<L> {
  373. L label, lChild, rChild;
  374. int start;
  375. int end;
  376. public L getLabel() {
  377. return label;
  378. }
  379. public int getStart() {
  380. return start;
  381. }
  382. public int getEnd() {
  383. return end;
  384. }
  385. public boolean equals(Object o) {
  386. if (this == o)
  387. return true;
  388. if (!(o instanceof RuleConstituent))
  389. return false;
  390. final RuleConstituent labeledConstituent = (RuleConstituent) o;
  391. if (end != labeledConstituent.end)
  392. return false;
  393. if (start != labeledConstituent.start)
  394. return false;
  395. if (label != null ? !label.equals(labeledConstituent.label)
  396. : labeledConstituent.label != null)
  397. return false;
  398. if (lChild != null ? !lChild.equals(labeledConstituent.lChild)
  399. : labeledConstituent.lChild != null)
  400. return false;
  401. if (rChild != null ? !rChild.equals(labeledConstituent.rChild)
  402. : labeledConstituent.rChild != null)
  403. return false;
  404. return true;
  405. }
  406. public int hashCode() {
  407. int result;
  408. result = (label != null ? label.hashCode() : 0) + 17
  409. * (lChild != null ? lChild.hashCode() : 0) - 7
  410. * (rChild != null ? rChild.hashCode() : 0);
  411. result = 29 * result + start;
  412. result = 29 * result + end;
  413. return result;
  414. }
  415. public String toString() {
  416. String rChildStr = (rChild == null) ? "" : rChild.toString();
  417. return label + "->" + lChild + " " + rChildStr + "[" + start + ","
  418. + end + "]";
  419. }
  420. public RuleConstituent(L label, L lChild, L rChild, int start, int end) {
  421. this.label = label;
  422. this.lChild = lChild;
  423. this.rChild = rChild;
  424. this.start = start;
  425. this.end = end;
  426. }
  427. }
  428. }