/commons/src/main/scala/io/prediction/commons/settings/mongodb/MongoEngineInfos.scala

https://github.com/SNaaS/PredictionIO · Scala · 58 lines · 43 code · 12 blank · 3 comment · 0 complexity · 06a0d715f5ac293bb7dc7e91c660a049 MD5 · raw file

  1. package io.prediction.commons.settings.mongodb
  2. import io.prediction.commons.MongoUtils
  3. import io.prediction.commons.settings.{ EngineInfo, EngineInfos }
  4. import com.mongodb.casbah.Imports._
  5. /** MongoDB implementation of EngineInfos. */
  6. class MongoEngineInfos(db: MongoDB) extends EngineInfos {
  7. private val coll = db("engineInfos")
  8. private def dbObjToEngineInfo(dbObj: DBObject) = EngineInfo(
  9. id = dbObj.as[String]("_id"),
  10. name = dbObj.as[String]("name"),
  11. description = dbObj.getAs[String]("description"),
  12. params = (dbObj.as[DBObject]("params") map { p => (p._1, MongoParam.dbObjToParam(p._1, p._2.asInstanceOf[DBObject])) }).toMap,
  13. paramsections = dbObj.as[Seq[DBObject]]("paramsections") map { MongoParam.dbObjToParamSection(_) },
  14. defaultalgoinfoid = dbObj.as[String]("defaultalgoinfoid"),
  15. defaultofflineevalmetricinfoid = dbObj.as[String]("defaultofflineevalmetricinfoid"),
  16. defaultofflineevalsplitterinfoid = dbObj.as[String]("defaultofflineevalsplitterinfoid"))
  17. def insert(engineInfo: EngineInfo) = {
  18. // required fields
  19. val obj = MongoDBObject(
  20. "_id" -> engineInfo.id,
  21. "name" -> engineInfo.name,
  22. "params" -> (engineInfo.params mapValues { MongoParam.paramToDBObj(_) }),
  23. "paramsections" -> (engineInfo.paramsections map { MongoParam.paramSectionToDBObj(_) }),
  24. "defaultalgoinfoid" -> engineInfo.defaultalgoinfoid,
  25. "defaultofflineevalmetricinfoid" -> engineInfo.defaultofflineevalmetricinfoid,
  26. "defaultofflineevalsplitterinfoid" -> engineInfo.defaultofflineevalsplitterinfoid)
  27. // optional fields
  28. val optObj = engineInfo.description.map { d => MongoDBObject("description" -> d) } getOrElse MongoUtils.emptyObj
  29. coll.insert(obj ++ optObj)
  30. }
  31. def get(id: String) = coll.findOne(MongoDBObject("_id" -> id)) map { dbObjToEngineInfo(_) }
  32. def getAll() = coll.find().toSeq map { dbObjToEngineInfo(_) }
  33. def update(engineInfo: EngineInfo, upsert: Boolean = false) = {
  34. val idObj = MongoDBObject("_id" -> engineInfo.id)
  35. val requiredObj = MongoDBObject(
  36. "name" -> engineInfo.name,
  37. "params" -> (engineInfo.params mapValues { MongoParam.paramToDBObj(_) }),
  38. "paramsections" -> (engineInfo.paramsections map { MongoParam.paramSectionToDBObj(_) }),
  39. "defaultalgoinfoid" -> engineInfo.defaultalgoinfoid,
  40. "defaultofflineevalmetricinfoid" -> engineInfo.defaultofflineevalmetricinfoid,
  41. "defaultofflineevalsplitterinfoid" -> engineInfo.defaultofflineevalsplitterinfoid)
  42. val descriptionObj = engineInfo.description.map { d => MongoDBObject("description" -> d) } getOrElse MongoUtils.emptyObj
  43. coll.update(idObj, idObj ++ requiredObj ++ descriptionObj, upsert)
  44. }
  45. def delete(id: String) = coll.remove(MongoDBObject("_id" -> id))
  46. }