/src/main/scala/edu/umass/cs/iesl/entizer/Env.scala

https://github.com/kedarbellare/entizer · Scala · 159 lines · 142 code · 11 blank · 6 comment · 48 complexity · c83987e11adb0852f9baff1c62e2f708 MD5 · raw file

  1. package edu.umass.cs.iesl.entizer
  2. import com.mongodb.casbah.Imports._
  3. import org.riedelcastro.nurupo.HasLogger
  4. import collection.mutable.HashMap
  5. import uk.ac.shef.wit.simmetrics.similaritymetrics.JaroWinkler
  6. /**
  7. * @author kedar
  8. */
  9. trait Env extends HasLogger {
  10. val JWINK = new JaroWinkler
  11. def repo: MongoRepository
  12. def mentions: MongoCollection
  13. def removeSubsegmentAligns(alignSegPred: AlignSegmentPredicate) = {
  14. val segmentToValues = new HashMap[MentionSegment, Seq[FieldValueMentionSegment]]
  15. for (value <- alignSegPred) {
  16. segmentToValues(value.mentionSegment) =
  17. segmentToValues.getOrElse(value.mentionSegment, Seq.empty[FieldValueMentionSegment]) ++ Seq(value)
  18. }
  19. var numRemoved = 0
  20. for (segment <- segmentToValues.keys) {
  21. val begin = segment.begin
  22. val end = segment.end
  23. for (i <- begin until end; j <- (i + 1) to end if (j - i) < (end - begin)) {
  24. val subsegment = MentionSegment(segment.mentionId, i, j)
  25. if (segmentToValues.contains(subsegment)) {
  26. for (subvalue <- segmentToValues(subsegment)) {
  27. if (alignSegPred.remove(subvalue)) {
  28. logger.debug("Removing subsegment[" + subsegment + "] of segment[" + segment + "] value: " + subvalue)
  29. numRemoved += 1
  30. }
  31. }
  32. }
  33. }
  34. }
  35. if (numRemoved > 0)
  36. logger.debug("Deleted " + numRemoved + " (overlapping) alignments from " + alignSegPred.predicateName)
  37. alignSegPred
  38. }
  39. def removeSupersegmentAligns(alignSegPred: AlignSegmentPredicate) = {
  40. val segmentToValues = new HashMap[MentionSegment, Seq[FieldValueMentionSegment]]
  41. for (value <- alignSegPred) {
  42. segmentToValues(value.mentionSegment) =
  43. segmentToValues.getOrElse(value.mentionSegment, Seq.empty[FieldValueMentionSegment]) ++ Seq(value)
  44. }
  45. var numRemoved = 0
  46. for (segment <- segmentToValues.keys) {
  47. val begin = segment.begin
  48. val end = segment.end
  49. for (i <- begin until end; j <- (i + 1) to end if (j - i) < (end - begin)) {
  50. val subsegment = MentionSegment(segment.mentionId, i, j)
  51. if (segmentToValues.contains(subsegment)) {
  52. // remove the segment corresponding to larger one
  53. for (supervalue <- segmentToValues(segment)) {
  54. if (alignSegPred.remove(supervalue)) {
  55. logger.debug("Removing supersegment[" + segment + "] of segment[" + subsegment + "] value: " + supervalue)
  56. numRemoved += 1
  57. }
  58. }
  59. }
  60. }
  61. }
  62. if (numRemoved > 0)
  63. logger.debug("Deleted " + numRemoved + " (overlapping) alignments from " + alignSegPred.predicateName)
  64. alignSegPred
  65. }
  66. def isMentionPhraseApproxContainedInValue(fv: FieldValue, m: Mention, begin: Int, end: Int,
  67. transforms: Seq[(Seq[String], Seq[String])],
  68. simThreshold: Double = 0.9): Boolean = {
  69. if (!fv.valueId.isDefined) false
  70. else {
  71. def rmPunct(seq: Seq[String]) = seq.map(_.toLowerCase.replaceAll("[^a-z0-9]+", "")).filter(_.length() > 0)
  72. val mentionPhraseClean = rmPunct(m.words.slice(begin, end))
  73. val valuePhrase = fv.field.getValuePhrase(fv.valueId)
  74. def isContained(phrFrom: Seq[String], phrTo: Seq[String]): Boolean = {
  75. if (phrTo.length == 0) false
  76. else {
  77. val numTokens = phrTo.length
  78. val fromVec = new HashMap[String, Double]
  79. val toVec = new HashMap[String, Double]
  80. for (w <- phrFrom) fromVec(w) = fromVec.getOrElse(w, 0.0) + 1
  81. for (w <- phrTo) toVec(w) = toVec.getOrElse(w, 0.0) + 1
  82. var numIntersection = 0.0
  83. for (w <- toVec.keys) {
  84. if (fromVec.contains(w)) {
  85. numIntersection += math.min(fromVec(w), toVec(w))
  86. fromVec.remove(w)
  87. } else {
  88. var bestScore = 0.0
  89. var bestKey: String = null
  90. for (ow <- fromVec.keys) {
  91. val score = JWINK.getSimilarity(w, ow)
  92. if (score > bestScore) {
  93. bestScore = score
  94. bestKey = ow
  95. }
  96. }
  97. if (bestScore >= 0.9 && bestKey != null) {
  98. numIntersection += bestScore * math.min(toVec(w), fromVec(bestKey))
  99. fromVec.remove(bestKey)
  100. }
  101. }
  102. }
  103. if (numTokens == 0) false
  104. else (numIntersection / numTokens) >= simThreshold
  105. }
  106. }
  107. for (transformValuePhrase <- PhraseHash.transformedPhrases(valuePhrase, transforms)) {
  108. val transformValuePhraseClean = rmPunct(transformValuePhrase)
  109. if (isContained(transformValuePhraseClean, mentionPhraseClean)) {
  110. // logger.info("phrase: " + m.words.slice(begin, end) + " value: " + valuePhrase)
  111. return true
  112. }
  113. }
  114. false
  115. }
  116. }
  117. def isMentionPhraseContainedInValue(fv: FieldValue, m: Mention, begin: Int, end: Int,
  118. transforms: Seq[(Seq[String], Seq[String])]): Boolean = {
  119. if (!fv.valueId.isDefined) false
  120. else {
  121. def rmPunct(seq: Seq[String]) = seq.map(_.toLowerCase.replaceAll("[^a-z0-9]+", "")).filter(_.length() > 0)
  122. val mentionPhraseClean = rmPunct(m.words.slice(begin, end))
  123. val valuePhrase = fv.field.getValuePhrase(fv.valueId)
  124. def isContained(phrFrom: Seq[String], phrTo: Seq[String]): Boolean = {
  125. if (phrTo.length == 0) false
  126. else if (phrFrom.length < phrTo.length) false
  127. else {
  128. val phrFromSet = phrFrom.toSet
  129. for (wto <- phrTo) {
  130. if (!phrFromSet(wto) && !phrFromSet.exists(wfrom => JWINK.getSimilarity(wfrom, wto) >= 0.9))
  131. return false
  132. }
  133. true
  134. }
  135. }
  136. for (transformValuePhrase <- PhraseHash.transformedPhrases(valuePhrase, transforms)) {
  137. val transformValuePhraseClean = rmPunct(transformValuePhrase)
  138. if (isContained(transformValuePhraseClean, mentionPhraseClean)) {
  139. // logger.info("phrase: " + m.words.slice(begin, end) + " value: " + valuePhrase)
  140. return true
  141. }
  142. }
  143. false
  144. }
  145. }
  146. def getMentionIds(query: DBObject = MongoDBObject()) =
  147. mentions.find(query, MongoDBObject()).map(_._id.get).toSeq
  148. }