/java/edu/berkeley/nlp/parser/EnglishPennTreebankParseEvaluator.java
Java | 520 lines | 408 code | 92 blank | 20 comment | 81 complexity | 10b0ada3dd64924769d0c81d3822512d MD5 | raw file
- package edu.berkeley.nlp.parser;
- import edu.berkeley.nlp.syntax.Tree;
- import edu.berkeley.nlp.syntax.Trees;
- import java.util.*;
- import java.io.PrintWriter;
- import java.io.StringReader;
- /**
- * Evaluates precision and recall for English Penn Treebank parse trees. NOTE:
- * Unlike the standard evaluation, multiplicity over each span is ignored. Also,
- * punction is NOT currently deleted properly (approximate hack), and other
- * normalizations (like AVDP ~ PRT) are NOT done.
- *
- * @author Dan Klein
- */
- public class EnglishPennTreebankParseEvaluator<L> {
- static class UnlabeledConstituent<L> {
- int start;
- int end;
- public int getStart() {
- return start;
- }
- public int getEnd() {
- return end;
- }
- public boolean equals(Object o) {
- if (this == o)
- return true;
- if (!(o instanceof UnlabeledConstituent))
- return false;
- final UnlabeledConstituent unlabeledConstituent = (UnlabeledConstituent) o;
- if (end != unlabeledConstituent.end)
- return false;
- if (start != unlabeledConstituent.start)
- return false;
- return true;
- }
- public int hashCode() {
- int result;
- result = start;
- result = 29 * result + end;
- return result;
- }
- public String toString() {
- return "[" + start + "," + end + "]";
- }
- public UnlabeledConstituent(int start, int end) {
- this.start = start;
- this.end = end;
- }
- }
- abstract static class AbstractEval<L> {
- protected String str = "";
- private int exact = 0;
- private int total = 0;
- private int correctEvents = 0;
- private int guessedEvents = 0;
- private int goldEvents = 0;
- abstract Set<Object> makeObjects(Tree<L> tree);
- public double evaluate(Tree<L> guess, Tree<L> gold) {
- return evaluate(guess, gold, new PrintWriter(System.out, true));
- }
- public double evaluate(Tree<L> guess, Tree<L> gold, boolean b) {
- return evaluate(guess, gold, null);
- }
- /*
- * evaluates precision and recall by calling makeObjects() to make a set
- * of structures for guess Tree and gold Tree, and compares them with
- * each other.
- */
- public double evaluate(Tree<L> guess, Tree<L> gold, PrintWriter pw) {
- Set<Object> guessedSet = makeObjects(guess);
- Set<Object> goldSet = makeObjects(gold);
- Set<Object> correctSet = new HashSet<Object>();
- correctSet.addAll(goldSet);
- correctSet.retainAll(guessedSet);
- correctEvents += correctSet.size();
- guessedEvents += guessedSet.size();
- goldEvents += goldSet.size();
- int currentExact = 0;
- if (correctSet.size() == guessedSet.size()
- && correctSet.size() == goldSet.size()) {
- exact++;
- currentExact = 1;
- }
- total++;
- // guess.pennPrint(pw);
- // gold.pennPrint(pw);
- double f1 = displayPRF(str + " [Current] ", correctSet.size(),
- guessedSet.size(), goldSet.size(), currentExact, 1, pw);
- return f1;
- }
- public double evaluateMultiple(List<Tree<L>> guesses,
- List<Tree<L>> golds, PrintWriter pw) {
- assert (guesses.size() == golds.size());
- int correctCount = 0;
- int guessedCount = 0;
- int goldCount = 0;
- for (int i = 0; i < guesses.size(); i++) {
- Tree<L> guess = guesses.get(i);
- Tree<L> gold = golds.get(i);
- Set<Object> guessedSet = makeObjects(guess);
- Set<Object> goldSet = makeObjects(gold);
- Set<Object> correctSet = new HashSet<Object>();
- correctSet.addAll(goldSet);
- correctSet.retainAll(guessedSet);
- correctCount += correctSet.size();
- guessedCount += guessedSet.size();
- goldCount += goldSet.size();
- }
- correctEvents += correctCount;
- guessedEvents += guessedCount;
- goldEvents += goldCount;
- int currentExact = 0;
- if (correctCount == guessedCount && correctCount == goldCount) {
- exact++;
- currentExact = 1;
- }
- total++;
- // guess.pennPrint(pw);
- // gold.pennPrint(pw);
- double f1 = displayPRF(str + " [Current] ", correctCount,
- guessedCount, goldCount, currentExact, 1, pw);
- return f1;
- }
- public double[] massEvaluate(Tree<L> guess, Tree<L>[] goldTrees) {
- Set<Object> guessedSet = makeObjects(guess);
- double cEvents = 0;
- double guEvents = 0;
- double goEvents = 0;
- double exactM = 0, precision = 0, recall = 0, f1 = 0;
- for (int treeI = 0; treeI < goldTrees.length; treeI++) {
- Tree<L> gold = goldTrees[treeI];
- Set<Object> goldSet = makeObjects(gold);
- Set<Object> correctSet = new HashSet<Object>();
- correctSet.addAll(goldSet);
- correctSet.retainAll(guessedSet);
- cEvents = correctSet.size();
- guEvents = guessedSet.size();
- goEvents = goldSet.size();
- double p = cEvents / guEvents;
- double r = cEvents / goEvents;
- double f = (p > 0.0 && r > 0.0 ? 2.0 / (1.0 / p + 1.0 / r)
- : 0.0);
- precision += p;
- recall += r;
- f1 += f;
- if (cEvents == guEvents && cEvents == goEvents) {
- exactM++;
- }
- }
- double ex = exactM / goldTrees.length;
- double[] results = { precision, recall, f1, ex };
- return results;
- }
- private double displayPRF(String prefixStr, int correct, int guessed,
- int gold, int exact, int total, PrintWriter pw) {
- double precision = (guessed > 0 ? correct / (double) guessed : 1.0);
- double recall = (gold > 0 ? correct / (double) gold : 1.0);
- double f1 = (precision > 0.0 && recall > 0.0 ? 2.0 / (1.0 / precision + 1.0 / recall)
- : 0.0);
- double exactMatch = exact / (double) total;
- String displayStr = " P: " + ((int) (precision * 10000)) / 100.0
- + " R: " + ((int) (recall * 10000)) / 100.0 + " F1: "
- + ((int) (f1 * 10000)) / 100.0 + " EX: "
- + ((int) (exactMatch * 10000)) / 100.0;
- if (pw != null)
- pw.println(prefixStr + displayStr);
- return f1;
- }
- public double display(boolean verbose) {
- return display(verbose, new PrintWriter(System.out, true));
- }
- public double display(boolean verbose, PrintWriter pw) {
- return displayPRF(str + " [Average] ", correctEvents,
- guessedEvents, goldEvents, exact, total, pw);
- }
- }
- static class LabeledConstituent<L> {
- L label;
- int start;
- int end;
- public L getLabel() {
- return label;
- }
- public int getStart() {
- return start;
- }
- public int getEnd() {
- return end;
- }
- public boolean equals(Object o) {
- if (this == o)
- return true;
- if (!(o instanceof LabeledConstituent))
- return false;
- final LabeledConstituent labeledConstituent = (LabeledConstituent) o;
- if (end != labeledConstituent.end)
- return false;
- if (start != labeledConstituent.start)
- return false;
- if (label != null ? !label.equals(labeledConstituent.label)
- : labeledConstituent.label != null)
- return false;
- return true;
- }
- public int hashCode() {
- int result;
- result = (label != null ? label.hashCode() : 0);
- result = 29 * result + start;
- result = 29 * result + end;
- return result;
- }
- public String toString() {
- return label + "[" + start + "," + end + "]";
- }
- public LabeledConstituent(L label, int start, int end) {
- this.label = label;
- this.start = start;
- this.end = end;
- }
- }
- public static class UnlabeledConstituentEval<L> extends AbstractEval<L> {
- public UnlabeledConstituentEval() {
- }
- @Override
- Set<Object> makeObjects(Tree<L> tree) {
- Tree<L> noLeafTree = LabeledConstituentEval.stripLeaves(tree);
- Set<Object> set = new HashSet<Object>();
- addConstituents(noLeafTree, set, 0);
- return set;
- }
- private int addConstituents(Tree<L> tree, Set<Object> set, int start) {
- if (tree == null)
- return 0;
- if (tree.getYield().size() == 1) {
- return 1;
- }
- int end = start;
- for (Tree<L> child : tree.getChildren()) {
- int childSpan = addConstituents(child, set, end);
- end += childSpan;
- }
- set.add(new UnlabeledConstituent<L>(start, end));
- return end - start;
- }
- }
- public static class LabeledConstituentEval<L> extends AbstractEval<L> {
- Set<L> labelsToIgnore;
- Set<L> punctuationTags;
- static <L> Tree<L> stripLeaves(Tree<L> tree) {
- if (tree.isLeaf())
- return null;
- if (tree.isPreTerminal())
- return new Tree<L>(tree.getLabel());
- List<Tree<L>> children = new ArrayList<Tree<L>>();
- for (Tree<L> child : tree.getChildren()) {
- children.add(stripLeaves(child));
- }
- return new Tree<L>(tree.getLabel(), children);
- }
- Set<Object> makeObjects(Tree<L> tree) {
- Tree<L> noLeafTree = stripLeaves(tree);
- Set<Object> set = new HashSet<Object>();
- addConstituents(noLeafTree, set, 0);
- return set;
- }
- private int addConstituents(Tree<L> tree, Set<Object> set, int start) {
- if (tree == null)
- return 0;
- if (tree.isLeaf()) {
- if (punctuationTags.contains(tree.getLabel()))
- return 0;
- else
- return 1;
- }
- int end = start;
- for (Tree<L> child : tree.getChildren()) {
- int childSpan = addConstituents(child, set, end);
- end += childSpan;
- }
- L label = tree.getLabel();
- if (!labelsToIgnore.contains(label)) {
- set.add(new LabeledConstituent<L>(label, start, end));
- }
- return end - start;
- }
- public LabeledConstituentEval(Set<L> labelsToIgnore,
- Set<L> punctuationTags) {
- this.labelsToIgnore = labelsToIgnore;
- this.punctuationTags = punctuationTags;
- }
- public int getHammingDistance(Tree<L> guess, Tree<L> gold) {
- Set<Object> guessedSet = makeObjects(guess);
- Set<Object> goldSet = makeObjects(gold);
- Set<Object> correctSet = new HashSet<Object>();
- correctSet.addAll(goldSet);
- correctSet.retainAll(guessedSet);
- return (guessedSet.size() - correctSet.size())
- + (goldSet.size() - correctSet.size());
- }
- }
- public static void main(String[] args) throws Throwable {
- Tree<String> goldTree = (new Trees.PennTreeReader(new StringReader(
- "(ROOT (S (NP (DT the) (NN can)) (VP (VBD fell))))"))).next();
- Tree<String> guessedTree = (new Trees.PennTreeReader(new StringReader(
- "(ROOT (S (NP (DT the)) (VP (MB can) (VP (VBD fell)))))")))
- .next();
- LabeledConstituentEval<String> eval = new LabeledConstituentEval<String>(
- Collections.singleton("ROOT"), new HashSet<String>());
- RuleEval<String> rule_eval = new RuleEval<String>(Collections
- .singleton("ROOT"), new HashSet<String>());
- System.out.println("Gold tree:\n"
- + Trees.PennTreeRenderer.render(goldTree));
- System.out.println("Guessed tree:\n"
- + Trees.PennTreeRenderer.render(guessedTree));
- eval.evaluate(guessedTree, goldTree);
- eval.display(true);
- rule_eval.evaluate(guessedTree, goldTree);
- rule_eval.display(true);
- }
- public static class RuleEval<L> extends AbstractEval<L> {
- Set<L> labelsToIgnore;
- Set<L> punctuationTags;
- static <L> Tree<L> stripLeaves(Tree<L> tree) {
- if (tree.isLeaf())
- return null;
- if (tree.isPreTerminal())
- return new Tree<L>(tree.getLabel());
- List<Tree<L>> children = new ArrayList<Tree<L>>();
- for (Tree<L> child : tree.getChildren()) {
- children.add(stripLeaves(child));
- }
- return new Tree<L>(tree.getLabel(), children);
- }
- Set<Object> makeObjects(Tree<L> tree) {
- Tree<L> noLeafTree = stripLeaves(tree);
- Set<Object> set = new HashSet<Object>();
- addConstituents(noLeafTree, set, 0);
- return set;
- }
- private int addConstituents(Tree<L> tree, Set<Object> set, int start) {
- if (tree == null)
- return 0;
- if (tree.isLeaf()) {
- /*
- * if (punctuationTags.contains(tree.getLabel())) return 0; else
- */
- return 1;
- }
- int end = start, i = 0;
- L lC = null, rC = null;
- for (Tree<L> child : tree.getChildren()) {
- int childSpan = addConstituents(child, set, end);
- if (i == 0)
- lC = child.getLabel();
- else
- /* i==1 */rC = child.getLabel();
- i++;
- end += childSpan;
- }
- L label = tree.getLabel();
- if (!labelsToIgnore.contains(label)) {
- set.add(new RuleConstituent<L>(label, lC, rC, start, end));
- }
- return end - start;
- }
- public RuleEval(Set<L> labelsToIgnore, Set<L> punctuationTags) {
- this.labelsToIgnore = labelsToIgnore;
- this.punctuationTags = punctuationTags;
- }
- }
- static class RuleConstituent<L> {
- L label, lChild, rChild;
- int start;
- int end;
- public L getLabel() {
- return label;
- }
- public int getStart() {
- return start;
- }
- public int getEnd() {
- return end;
- }
- public boolean equals(Object o) {
- if (this == o)
- return true;
- if (!(o instanceof RuleConstituent))
- return false;
- final RuleConstituent labeledConstituent = (RuleConstituent) o;
- if (end != labeledConstituent.end)
- return false;
- if (start != labeledConstituent.start)
- return false;
- if (label != null ? !label.equals(labeledConstituent.label)
- : labeledConstituent.label != null)
- return false;
- if (lChild != null ? !lChild.equals(labeledConstituent.lChild)
- : labeledConstituent.lChild != null)
- return false;
- if (rChild != null ? !rChild.equals(labeledConstituent.rChild)
- : labeledConstituent.rChild != null)
- return false;
- return true;
- }
- public int hashCode() {
- int result;
- result = (label != null ? label.hashCode() : 0) + 17
- * (lChild != null ? lChild.hashCode() : 0) - 7
- * (rChild != null ? rChild.hashCode() : 0);
- result = 29 * result + start;
- result = 29 * result + end;
- return result;
- }
- public String toString() {
- String rChildStr = (rChild == null) ? "" : rChild.toString();
- return label + "->" + lChild + " " + rChildStr + "[" + start + ","
- + end + "]";
- }
- public RuleConstituent(L label, L lChild, L rChild, int start, int end) {
- this.label = label;
- this.lChild = lChild;
- this.rChild = rChild;
- this.start = start;
- this.end = end;
- }
- }
- }