/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala

https://github.com/Stratio/Spark-MongoDB · Scala · 139 lines · 81 code · 21 blank · 37 comment · 4 complexity · 85ee5c15529ed8ddc60b4cde0e0b7799 MD5 · raw file

  1. /*
  2. * Copyright (C) 2015 Stratio (http://stratio.com)
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package com.stratio.datasource.mongodb.reader
  17. import com.mongodb.casbah.Imports._
  18. import com.mongodb.casbah.MongoCursorBase
  19. import com.stratio.datasource.mongodb.client.MongodbClientFactory
  20. import com.stratio.datasource.mongodb.config.{MongodbConfig, MongodbCredentials, MongodbSSLOptions}
  21. import com.stratio.datasource.mongodb.partitioner.MongodbPartition
  22. import com.stratio.datasource.mongodb.query.FilterSection
  23. import com.stratio.datasource.util.Config
  24. import org.apache.spark.Partition
  25. import scala.util.Try
  26. /**
  27. *
  28. * @param config Configuration object.
  29. * @param requiredColumns Pruning fields
  30. * @param filters Added query filters
  31. */
  32. class MongodbReader(config: Config,
  33. requiredColumns: Array[String],
  34. filters: FilterSection) {
  35. private var mongoClient: Option[MongodbClientFactory.Client] = None
  36. private var mongoClientKey: Option[String] = None
  37. private var dbCursor: Option[MongoCursorBase] = None
  38. private val batchSize = config.getOrElse[Int](MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize)
  39. private val connectionsTime = config.get[String](MongodbConfig.ConnectionsTime).map(_.toLong)
  40. def close(): Unit = {
  41. dbCursor.fold(ifEmpty = ()) { cursor =>
  42. cursor.close()
  43. dbCursor = None
  44. }
  45. mongoClient.fold(ifEmpty = ()) { client =>
  46. mongoClientKey.fold({
  47. MongodbClientFactory.closeByClient(client)
  48. }) {key =>
  49. MongodbClientFactory.closeByKey(key)
  50. }
  51. mongoClient = None
  52. }
  53. }
  54. def hasNext: Boolean = {
  55. dbCursor.fold(ifEmpty = false)(cursor => cursor.hasNext)
  56. }
  57. def next(): DBObject = {
  58. dbCursor.fold(ifEmpty = throw new IllegalStateException("DbCursor is not initialized"))(cursor => cursor.next())
  59. }
  60. /**
  61. * Initialize MongoDB reader
  62. * @param partition Where to read from
  63. */
  64. def init(partition: Partition): Unit = {
  65. Try {
  66. val mongoPartition = partition.asInstanceOf[MongodbPartition]
  67. val hosts = mongoPartition.hosts.map(add => new ServerAddress(add)).toList
  68. val credentials = config.getOrElse[List[MongodbCredentials]](MongodbConfig.Credentials, MongodbConfig.DefaultCredentials).map {
  69. case MongodbCredentials(user, database, password) =>
  70. MongoCredential.createCredential(user, database, password)
  71. }
  72. val sslOptions = config.get[MongodbSSLOptions](MongodbConfig.SSLOptions)
  73. val clientOptions = config.properties.filterKeys(_.contains(MongodbConfig.ListMongoClientOptions))
  74. val mongoClientResponse = MongodbClientFactory.getClient(hosts, credentials, sslOptions, clientOptions)
  75. mongoClient = Option(mongoClientResponse.clientConnection)
  76. mongoClientKey = Option(mongoClientResponse.key)
  77. val emptyFilter = MongoDBObject(List())
  78. val filter = Try(queryPartition(filters)).getOrElse(emptyFilter)
  79. dbCursor = (for {
  80. client <- mongoClient
  81. collection <- Option(client(config(MongodbConfig.Database))(config(MongodbConfig.Collection)))
  82. dbCursor <- Option(collection.find(filter, selectFields(requiredColumns)))
  83. } yield {
  84. mongoPartition.partitionRange.minKey.foreach(min => dbCursor.addSpecial("$min", min))
  85. mongoPartition.partitionRange.maxKey.foreach(max => dbCursor.addSpecial("$max", max))
  86. dbCursor.batchSize(batchSize)
  87. }).headOption
  88. }.recover {
  89. case throwable =>
  90. throw MongodbReadException(throwable.getMessage, throwable)
  91. }
  92. }
  93. /**
  94. * Create query partition using given filters.
  95. *
  96. * @param filters the Spark filters to be converted to Mongo filters
  97. * @return the dB object
  98. */
  99. private def queryPartition(filters: FilterSection): DBObject = {
  100. implicit val c: Config = config
  101. filters.filtersToDBObject()
  102. }
  103. /**
  104. *
  105. * Prepared DBObject used to specify required fields in mongodb 'find'
  106. * @param fields Required fields
  107. * @return A mongodb object that represents required fields.
  108. */
  109. private def selectFields(fields: Array[String]): DBObject =
  110. MongoDBObject(
  111. if (fields.isEmpty) List()
  112. else fields.toList.filterNot(_ == "_id").map(_ -> 1) ::: {
  113. List("_id" -> fields.find(_ == "_id").fold(0)(_ => 1))
  114. })
  115. }
  116. case class MongodbReadException(
  117. msg: String,
  118. causedBy: Throwable) extends RuntimeException(msg, causedBy)