PageRenderTime 54ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/commons/src/main/scala/io/prediction/commons/settings/mongodb/MongoAlgos.scala

https://github.com/eddieliu/PredictionIO
Scala | 121 lines | 92 code | 24 blank | 5 comment | 0 complexity | 5085dedcaadc1286e003f31c6226a5a2 MD5 | raw file
  1. package io.prediction.commons.settings.mongodb
  2. import io.prediction.commons.MongoUtils
  3. import io.prediction.commons.settings.{ Algo, Algos }
  4. import com.mongodb.casbah.Imports._
  5. import com.mongodb.casbah.commons.conversions.scala.RegisterJodaTimeConversionHelpers
  6. import com.github.nscala_time.time.Imports._
  7. /** MongoDB implementation of Algos. */
  8. class MongoAlgos(db: MongoDB) extends Algos {
  9. private val algoColl = db("algos")
  10. private val seq = new MongoSequences(db)
  11. RegisterJodaTimeConversionHelpers()
  12. private def dbObjToAlgo(dbObj: DBObject) = {
  13. Algo(
  14. id = dbObj.as[Int]("_id"),
  15. engineid = dbObj.as[Int]("engineid"),
  16. name = dbObj.as[String]("name"),
  17. infoid = dbObj.getAs[String]("infoid").getOrElse("pdio-knnitembased"), // TODO: tempararily default for backward compatiblity
  18. command = dbObj.as[String]("command"),
  19. params = MongoUtils.dbObjToMap(dbObj.as[DBObject]("params")),
  20. settings = MongoUtils.dbObjToMap(dbObj.as[DBObject]("settings")),
  21. modelset = dbObj.as[Boolean]("modelset"),
  22. createtime = dbObj.as[DateTime]("createtime"),
  23. updatetime = dbObj.as[DateTime]("updatetime"),
  24. status = dbObj.as[String]("status"),
  25. offlineevalid = dbObj.getAs[Int]("offlineevalid"),
  26. offlinetuneid = dbObj.getAs[Int]("offlinetuneid"),
  27. loop = dbObj.getAs[Int]("loop"),
  28. paramset = dbObj.getAs[Int]("paramset"),
  29. lasttraintime = dbObj.getAs[DateTime]("lasttraintime"))
  30. }
  31. def insert(algo: Algo) = {
  32. val id = seq.genNext("algoid")
  33. // required fields
  34. val obj = MongoDBObject(
  35. "_id" -> id,
  36. "engineid" -> algo.engineid,
  37. "name" -> algo.name,
  38. "infoid" -> algo.infoid,
  39. "command" -> algo.command,
  40. "params" -> algo.params,
  41. "settings" -> algo.settings,
  42. "modelset" -> algo.modelset,
  43. "createtime" -> algo.createtime,
  44. "updatetime" -> algo.updatetime,
  45. "status" -> algo.status)
  46. // optional fields
  47. val optObj = algo.offlineevalid.map(x => MongoDBObject("offlineevalid" -> x)).getOrElse(MongoUtils.emptyObj) ++
  48. algo.offlinetuneid.map(x => MongoDBObject("offlinetuneid" -> x)).getOrElse(MongoUtils.emptyObj) ++
  49. algo.loop.map(x => MongoDBObject("loop" -> x)).getOrElse(MongoUtils.emptyObj) ++
  50. algo.paramset.map(x => MongoDBObject("paramset" -> x)).getOrElse(MongoUtils.emptyObj) ++
  51. algo.lasttraintime.map(x => MongoDBObject("lasttraintime" -> x)).getOrElse(MongoUtils.emptyObj)
  52. algoColl.insert(obj ++ optObj)
  53. id
  54. }
  55. def get(id: Int) = algoColl.findOne(MongoDBObject("_id" -> id)) map { dbObjToAlgo(_) }
  56. def getAll() = new MongoAlgoIterator(algoColl.find())
  57. def getByEngineid(engineid: Int) = new MongoAlgoIterator(
  58. algoColl.find(MongoDBObject("engineid" -> engineid)).sort(MongoDBObject("name" -> 1))
  59. )
  60. def getDeployedByEngineid(engineid: Int) = new MongoAlgoIterator(
  61. algoColl.find(MongoDBObject("engineid" -> engineid, "status" -> "deployed")).sort(MongoDBObject("name" -> 1))
  62. )
  63. def getByOfflineEvalid(evalid: Int, loop: Option[Int] = None, paramset: Option[Int] = None) = {
  64. val q = MongoDBObject("offlineevalid" -> evalid) ++ loop.map(l => MongoDBObject("loop" -> l)).getOrElse(MongoUtils.emptyObj) ++ paramset.map(p => MongoDBObject("paramset" -> p)).getOrElse(MongoUtils.emptyObj)
  65. new MongoAlgoIterator(algoColl.find(q).sort(MongoDBObject("name" -> 1)))
  66. }
  67. def getTuneSubjectByOfflineTuneid(tuneid: Int) = algoColl.findOne(MongoDBObject("offlinetuneid" -> tuneid, "loop" -> null, "paramset" -> null)) map { dbObjToAlgo(_) }
  68. def getByIdAndEngineid(id: Int, engineid: Int): Option[Algo] = algoColl.findOne(MongoDBObject("_id" -> id, "engineid" -> engineid)) map { dbObjToAlgo(_) }
  69. def update(algo: Algo, upsert: Boolean = false) = {
  70. // required fields
  71. val obj = MongoDBObject(
  72. "_id" -> algo.id,
  73. "engineid" -> algo.engineid,
  74. "name" -> algo.name,
  75. "infoid" -> algo.infoid,
  76. "command" -> algo.command,
  77. "params" -> algo.params,
  78. "settings" -> algo.settings,
  79. "modelset" -> algo.modelset,
  80. "createtime" -> algo.createtime,
  81. "updatetime" -> algo.updatetime,
  82. "status" -> algo.status)
  83. // optional fields
  84. val optObj = algo.offlineevalid.map(x => MongoDBObject("offlineevalid" -> x)).getOrElse(MongoUtils.emptyObj) ++
  85. algo.offlinetuneid.map(x => MongoDBObject("offlinetuneid" -> x)).getOrElse(MongoUtils.emptyObj) ++
  86. algo.loop.map(x => MongoDBObject("loop" -> x)).getOrElse(MongoUtils.emptyObj) ++
  87. algo.paramset.map(x => MongoDBObject("paramset" -> x)).getOrElse(MongoUtils.emptyObj) ++
  88. algo.lasttraintime.map(x => MongoDBObject("lasttraintime" -> x)).getOrElse(MongoUtils.emptyObj)
  89. algoColl.update(MongoDBObject("_id" -> algo.id), obj ++ optObj, upsert)
  90. }
  91. def delete(id: Int) = algoColl.remove(MongoDBObject("_id" -> id))
  92. def existsByEngineidAndName(engineid: Int, name: String) = algoColl.findOne(MongoDBObject("name" -> name, "engineid" -> engineid, "offlineevalid" -> null)) map { _ => true } getOrElse false
  93. class MongoAlgoIterator(it: MongoCursor) extends Iterator[Algo] {
  94. def next = dbObjToAlgo(it.next)
  95. def hasNext = it.hasNext
  96. }
  97. }