PageRenderTime 26ms CodeModel.GetById 24ms RepoModel.GetById 1ms app.codeStats 0ms

/indexer/src/main/scala/S2CoveringAkkaWorkers.scala

https://gitlab.com/18runt88/twofishes
Scala | 192 lines | 159 code | 22 blank | 11 comment | 17 complexity | 3b027c22568c68aa151a7ace24d9d5cc MD5 | raw file
  1. package com.foursquare.twofishes
  2. import akka.actor.{Actor, ActorSystem, PoisonPill, Props}
  3. import akka.routing.{Broadcast, RoundRobinRouter}
  4. import com.mongodb.Bytes
  5. import com.foursquare.geo.shapes.ShapefileS2Util
  6. import com.foursquare.twofishes.util.{DurationUtils, GeometryCleanupUtils, GeometryUtils, RevGeoConstants, S2CoveringConstants}
  7. import com.foursquare.twofishes.mongo.{PolygonIndexDAO, RevGeoIndexDAO, RevGeoIndex, S2CoveringIndexDAO, S2CoveringIndex}
  8. import com.google.common.geometry.S2CellId
  9. import com.mongodb.casbah.Imports._
  10. import com.twitter.ostrich.stats.Stats
  11. import com.vividsolutions.jts.geom.{Point => JTSPoint, Geometry}
  12. import com.vividsolutions.jts.geom.prep.PreparedGeometryFactory
  13. import com.vividsolutions.jts.io.{WKBReader, WKBWriter}
  14. import com.weiglewilczek.slf4s.Logging
  15. import java.util.concurrent.CountDownLatch
  16. import org.bson.types.ObjectId
  17. import scalaj.collection.Implicits._
  18. import java.util.concurrent.atomic.AtomicInteger
  19. // ====================
  20. // ===== Messages =====
  21. // ====================
  22. case class CoverOptions(forS2CoveringIndex: Boolean = true, forRevGeoIndex: Boolean = true)
  23. sealed trait CoverMessage
  24. case class Done() extends CoverMessage
  25. case class CalculateCoverFromMongo(polyIds: List[ObjectId], options: CoverOptions) extends CoverMessage
  26. case class CalculateCover(polyId: ObjectId, geomBytes: Array[Byte], options: CoverOptions) extends CoverMessage
  27. case class FinishedCover() extends CoverMessage
  28. class NullActor extends Actor {
  29. def receive = {
  30. case x =>
  31. }
  32. }
  33. object GlobalCounter {
  34. val count = new AtomicInteger
  35. }
  36. class S2CoveringWorker extends Actor with DurationUtils with RevGeoConstants with S2CoveringConstants with Logging {
  37. val wkbReader = new WKBReader()
  38. val wkbWriter = new WKBWriter()
  39. def calculateCoverFromMongo(msg: CalculateCoverFromMongo) {
  40. val records = PolygonIndexDAO.find(MongoDBObject("_id" -> MongoDBObject("$in" -> msg.polyIds)))
  41. records.option = Bytes.QUERYOPTION_NOTIMEOUT
  42. records.foreach(p =>
  43. calculateCover(p._id, p.polygon, msg.options)
  44. )
  45. }
  46. def calculateCover(msg: CalculateCover) {
  47. calculateCover(msg.polyId, msg.geomBytes, msg.options)
  48. }
  49. def calculateCover(polyId: ObjectId, geomBytes: Array[Byte], options: CoverOptions) {
  50. logDuration("totalCovering", "generated cover for %s".format(polyId)) {
  51. val currentCount = GlobalCounter.count.getAndIncrement()
  52. if (currentCount % 1000 == 0) {
  53. logger.info("processed about %s polygons for s2 coverage".format(currentCount))
  54. }
  55. val geom = wkbReader.read(geomBytes)
  56. if (options.forS2CoveringIndex) {
  57. // println("generating cover for %s for s2covering index".format(polyId))
  58. val cells = logDuration("s2CoveringForS2CoveringIndex", "generated cover for %s for s2covering index".format(polyId)) {
  59. GeometryUtils.s2PolygonCovering(
  60. geom, minS2LevelForS2Covering, maxS2LevelForS2Covering,
  61. levelMod = Some(defaultLevelModForS2Covering),
  62. maxCellsHintWhichMightBeIgnored = Some(defaultMaxCellsHintForS2Covering)
  63. ).toList
  64. }
  65. val record = S2CoveringIndex(polyId, cells.map(_.id()))
  66. S2CoveringIndexDAO.insert(record)
  67. }
  68. if (options.forRevGeoIndex) {
  69. // println("generating cover for %s for revgeo index".format(polyId))
  70. val cells = logDuration("s2CoveringForRevGeoIndex", "generated cover for %s for revgeo index".format(polyId)) {
  71. GeometryUtils.s2PolygonCovering(
  72. geom, minS2LevelForRevGeo, maxS2LevelForRevGeo,
  73. levelMod = Some(defaultLevelModForRevGeo),
  74. maxCellsHintWhichMightBeIgnored = Some(defaultMaxCellsHintForRevGeo)
  75. )
  76. }
  77. logDuration("coverClippingForRevGeoIndex", "clipped and outputted cover for %d cells (%s) for revgeo index".format(cells.size, polyId)) {
  78. val records = cells.map((cellid: S2CellId) => {
  79. if (geom.isInstanceOf[JTSPoint]) {
  80. RevGeoIndex(
  81. cellid.id(), polyId,
  82. full = false,
  83. geom = Some(wkbWriter.write(geom))
  84. )
  85. } else {
  86. val recordShape = geom.buffer(0)
  87. val preparedRecordShape = PreparedGeometryFactory.prepare(recordShape)
  88. val s2shape = ShapefileS2Util.fullGeometryForCell(cellid)
  89. if (preparedRecordShape.contains(s2shape)) {
  90. RevGeoIndex(cellid.id(), polyId, full = true, geom = None)
  91. } else {
  92. val intersection = s2shape.intersection(recordShape)
  93. val geomToIndex = if (intersection.getGeometryType == "GeometryCollection") {
  94. GeometryCleanupUtils.cleanupGeometryCollection(intersection)
  95. } else {
  96. intersection
  97. }
  98. RevGeoIndex(
  99. cellid.id(), polyId,
  100. full = false,
  101. geom = Some(wkbWriter.write(geomToIndex))
  102. )
  103. }
  104. }
  105. })
  106. RevGeoIndexDAO.insert(records)
  107. }
  108. }
  109. }
  110. }
  111. def receive = {
  112. case msg: CalculateCover =>
  113. calculateCover(msg)
  114. sender ! FinishedCover()
  115. case msg: CalculateCoverFromMongo =>
  116. calculateCoverFromMongo(msg)
  117. sender ! FinishedCover()
  118. }
  119. }
  120. // ==================
  121. // ===== Master =====
  122. // ==================
  123. class S2CoveringMaster(val latch: CountDownLatch) extends Actor with Logging {
  124. var start: Long = 0
  125. val _system = ActorSystem("RoundRobinRouterExample")
  126. val router = _system.actorOf(Props[S2CoveringWorker].withRouter(RoundRobinRouter(8)), name = "myRoundRobinRouterActor")
  127. var inFlight = 0
  128. var seenDone = false
  129. // message handler
  130. def receive = {
  131. case msg: FinishedCover =>
  132. inFlight -= 1
  133. if (inFlight == 0 && seenDone) {
  134. shutdownWithMessage("finished all s2 covers, shutting down system")
  135. }
  136. if (inFlight < 0) {
  137. logger.error("inFlight < 0 ... we're bad at a counting")
  138. }
  139. case msg: CalculateCover =>
  140. Stats.incr("s2.akkaWorkers.CalculateCover")
  141. inFlight += 1
  142. router ! msg
  143. case msg: CalculateCoverFromMongo =>
  144. inFlight += 1
  145. router ! msg
  146. case msg: Done =>
  147. logger.info("all done with s2 cover indexing, sending poison pills")
  148. // send a PoisonPill to all workers telling them to shut down themselves
  149. router ! Broadcast(PoisonPill)
  150. seenDone = true
  151. if (inFlight == 0) {
  152. shutdownWithMessage("had already finished all s2 covers, shutting down system")
  153. }
  154. }
  155. private def shutdownWithMessage(message: String): Unit = {
  156. logger.info(message)
  157. latch.countDown()
  158. self ! PoisonPill
  159. }
  160. override def preStart() {
  161. start = System.currentTimeMillis
  162. }
  163. override def postStop() {
  164. // tell the world that the calculation is complete
  165. logger.info(
  166. "s2 covering calculation time: \t%s millis"
  167. .format((System.currentTimeMillis - start))
  168. )
  169. }
  170. }