/src/main/scala/edu/umass/cs/iesl/entizer/FieldCollection.scala

https://github.com/kedarbellare/entizer · Scala · 95 lines · 68 code · 22 blank · 5 comment · 5 complexity · a37005e1b8eafbda3ba24ddd22b1f597 MD5 · raw file

  1. package edu.umass.cs.iesl.entizer
  2. import com.mongodb.casbah.Imports._
  3. import scala.util.Random
  4. import collection.mutable.{HashSet, ArrayBuffer, HashMap}
  5. /**
  6. * @author kedar
  7. */
  8. trait FieldCollection extends Field {
  9. val fieldNameIndexer = new HashMap[String, Int]
  10. val fields = new ArrayBuffer[Field]
  11. def addField(field: Field) = {
  12. if (!fieldNameIndexer.contains(field.name)) {
  13. fieldNameIndexer(field.name) = fields.size
  14. fields += field
  15. }
  16. this
  17. }
  18. def indexOf(name: String) = fieldNameIndexer(name)
  19. def numFields = fields.size
  20. def getField(index: Int) = fields(index)
  21. def getField(name: String) = fields(indexOf(name))
  22. def getFields: Seq[Field] = fields
  23. }
  24. case class SimpleRecord(name: String) extends FieldCollection {
  25. val isKey = false
  26. val useFullSegment = true
  27. val useAllRecordSegments = false
  28. val useOracle = false
  29. def setMaxSegmentLength(maxSegLen: Int = Short.MaxValue.toInt) = {
  30. maxSegmentLength = maxSegLen
  31. this
  32. }
  33. def init() = this
  34. def getPossibleValues(mentionId: ObjectId, begin: Int, end: Int) = Seq(FieldValue(this, None))
  35. def getMentionValues(mentionId: Option[ObjectId]) = Seq(FieldValue(this, None))
  36. def getValuePhrase(valueId: Option[ObjectId]) = Seq.empty[String]
  37. def getValueMention(valueId: Option[ObjectId]) = None
  38. def getValueMentionSegment(valueId: Option[ObjectId]) = None
  39. }
  40. class SimpleEntityRecord(val name: String, val repository: MongoRepository,
  41. val useOracle: Boolean = false) extends EntityField with FieldCollection {
  42. val isKey = false
  43. val useFullSegment = true
  44. val useAllRecordSegments = false
  45. }
  46. class ClusterEntityRecord(val name: String, val repository: MongoRepository) extends EntityField {
  47. val isKey = false
  48. val useFullSegment = true
  49. val useAllRecordSegments = false
  50. val useOracle = true
  51. def getRecordClusters = {
  52. val rnd = new Random
  53. val allValues = entityColl.find(MongoDBObject(), MongoDBObject()).map(dbo => FieldValue(this, Some(dbo._id.get))).toSeq
  54. ccPivot(rnd, allValues)
  55. }
  56. private def ccPivot(rnd: Random, values: Seq[FieldValue]): Seq[HashSet[FieldValue]] = {
  57. if (values.length == 0) Seq.empty[HashSet[FieldValue]]
  58. else {
  59. // pick random value id
  60. val pivotIndex = rnd.nextInt(values.length)
  61. val pivotValueId = values(pivotIndex).valueId
  62. // get all segments matching pivot mention segment
  63. val pivotMentionSegment = getValueMentionSegment(pivotValueId)
  64. val currCluster = new HashSet[FieldValue]
  65. currCluster += values(pivotIndex)
  66. for (segment <- pivotMentionSegment) {
  67. currCluster ++= getPossibleValues(segment.mentionId, segment.begin, segment.end).filter(_.valueId.isDefined)
  68. }
  69. val otherValues = values.filter(!currCluster(_))
  70. ccPivot(rnd, otherValues) ++ Seq(currCluster)
  71. }
  72. }
  73. }