PageRenderTime 52ms CodeModel.GetById 23ms RepoModel.GetById 0ms app.codeStats 0ms

/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala

https://github.com/shenyanls/spark
Scala | 299 lines | 216 code | 26 blank | 57 comment | 50 complexity | 1dae7110080e5c25c903e1282d98fbe0 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package org.apache.spark.util.collection
  18. import java.util.{Arrays, Comparator}
  19. import com.google.common.hash.Hashing
  20. import org.apache.spark.annotation.DeveloperApi
  21. /**
  22. * :: DeveloperApi ::
  23. * A simple open hash table optimized for the append-only use case, where keys
  24. * are never removed, but the value for each key may be changed.
  25. *
  26. * This implementation uses quadratic probing with a power-of-2 hash table
  27. * size, which is guaranteed to explore all spaces for each key (see
  28. * http://en.wikipedia.org/wiki/Quadratic_probing).
  29. *
  30. * TODO: Cache the hash values of each key? java.util.HashMap does that.
  31. */
  32. @DeveloperApi
  33. class AppendOnlyMap[K, V](initialCapacity: Int = 64)
  34. extends Iterable[(K, V)] with Serializable {
  35. require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
  36. require(initialCapacity >= 1, "Invalid initial capacity")
  37. private val LOAD_FACTOR = 0.7
  38. private var capacity = nextPowerOf2(initialCapacity)
  39. private var mask = capacity - 1
  40. private var curSize = 0
  41. private var growThreshold = (LOAD_FACTOR * capacity).toInt
  42. // Holds keys and values in the same array for memory locality; specifically, the order of
  43. // elements is key0, value0, key1, value1, key2, value2, etc.
  44. private var data = new Array[AnyRef](2 * capacity)
  45. // Treat the null key differently so we can use nulls in "data" to represent empty items.
  46. private var haveNullValue = false
  47. private var nullValue: V = null.asInstanceOf[V]
  48. // Triggered by destructiveSortedIterator; the underlying data array may no longer be used
  49. private var destroyed = false
  50. private val destructionMessage = "Map state is invalid from destructive sorting!"
  51. /** Get the value for a given key */
  52. def apply(key: K): V = {
  53. assert(!destroyed, destructionMessage)
  54. val k = key.asInstanceOf[AnyRef]
  55. if (k.eq(null)) {
  56. return nullValue
  57. }
  58. var pos = rehash(k.hashCode) & mask
  59. var i = 1
  60. while (true) {
  61. val curKey = data(2 * pos)
  62. if (k.eq(curKey) || k.equals(curKey)) {
  63. return data(2 * pos + 1).asInstanceOf[V]
  64. } else if (curKey.eq(null)) {
  65. return null.asInstanceOf[V]
  66. } else {
  67. val delta = i
  68. pos = (pos + delta) & mask
  69. i += 1
  70. }
  71. }
  72. null.asInstanceOf[V]
  73. }
  74. /** Set the value for a key */
  75. def update(key: K, value: V): Unit = {
  76. assert(!destroyed, destructionMessage)
  77. val k = key.asInstanceOf[AnyRef]
  78. if (k.eq(null)) {
  79. if (!haveNullValue) {
  80. incrementSize()
  81. }
  82. nullValue = value
  83. haveNullValue = true
  84. return
  85. }
  86. var pos = rehash(key.hashCode) & mask
  87. var i = 1
  88. while (true) {
  89. val curKey = data(2 * pos)
  90. if (curKey.eq(null)) {
  91. data(2 * pos) = k
  92. data(2 * pos + 1) = value.asInstanceOf[AnyRef]
  93. incrementSize() // Since we added a new key
  94. return
  95. } else if (k.eq(curKey) || k.equals(curKey)) {
  96. data(2 * pos + 1) = value.asInstanceOf[AnyRef]
  97. return
  98. } else {
  99. val delta = i
  100. pos = (pos + delta) & mask
  101. i += 1
  102. }
  103. }
  104. }
  105. /**
  106. * Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value
  107. * for key, if any, or null otherwise. Returns the newly updated value.
  108. */
  109. def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
  110. assert(!destroyed, destructionMessage)
  111. val k = key.asInstanceOf[AnyRef]
  112. if (k.eq(null)) {
  113. if (!haveNullValue) {
  114. incrementSize()
  115. }
  116. nullValue = updateFunc(haveNullValue, nullValue)
  117. haveNullValue = true
  118. return nullValue
  119. }
  120. var pos = rehash(k.hashCode) & mask
  121. var i = 1
  122. while (true) {
  123. val curKey = data(2 * pos)
  124. if (k.eq(curKey) || k.equals(curKey)) {
  125. val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
  126. data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
  127. return newValue
  128. } else if (curKey.eq(null)) {
  129. val newValue = updateFunc(false, null.asInstanceOf[V])
  130. data(2 * pos) = k
  131. data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
  132. incrementSize()
  133. return newValue
  134. } else {
  135. val delta = i
  136. pos = (pos + delta) & mask
  137. i += 1
  138. }
  139. }
  140. null.asInstanceOf[V] // Never reached but needed to keep compiler happy
  141. }
  142. /** Iterator method from Iterable */
  143. override def iterator: Iterator[(K, V)] = {
  144. assert(!destroyed, destructionMessage)
  145. new Iterator[(K, V)] {
  146. var pos = -1
  147. /** Get the next value we should return from next(), or null if we're finished iterating */
  148. def nextValue(): (K, V) = {
  149. if (pos == -1) { // Treat position -1 as looking at the null value
  150. if (haveNullValue) {
  151. return (null.asInstanceOf[K], nullValue)
  152. }
  153. pos += 1
  154. }
  155. while (pos < capacity) {
  156. if (!data(2 * pos).eq(null)) {
  157. return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
  158. }
  159. pos += 1
  160. }
  161. null
  162. }
  163. override def hasNext: Boolean = nextValue() != null
  164. override def next(): (K, V) = {
  165. val value = nextValue()
  166. if (value == null) {
  167. throw new NoSuchElementException("End of iterator")
  168. }
  169. pos += 1
  170. value
  171. }
  172. }
  173. }
  174. override def size: Int = curSize
  175. /** Increase table size by 1, rehashing if necessary */
  176. private def incrementSize() {
  177. curSize += 1
  178. if (curSize > growThreshold) {
  179. growTable()
  180. }
  181. }
  182. /**
  183. * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
  184. */
  185. private def rehash(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt()
  186. /** Double the table's size and re-hash everything */
  187. protected def growTable() {
  188. val newCapacity = capacity * 2
  189. if (newCapacity >= (1 << 30)) {
  190. // We can't make the table this big because we want an array of 2x
  191. // that size for our data, but array sizes are at most Int.MaxValue
  192. throw new Exception("Can't make capacity bigger than 2^29 elements")
  193. }
  194. val newData = new Array[AnyRef](2 * newCapacity)
  195. val newMask = newCapacity - 1
  196. // Insert all our old values into the new array. Note that because our old keys are
  197. // unique, there's no need to check for equality here when we insert.
  198. var oldPos = 0
  199. while (oldPos < capacity) {
  200. if (!data(2 * oldPos).eq(null)) {
  201. val key = data(2 * oldPos)
  202. val value = data(2 * oldPos + 1)
  203. var newPos = rehash(key.hashCode) & newMask
  204. var i = 1
  205. var keepGoing = true
  206. while (keepGoing) {
  207. val curKey = newData(2 * newPos)
  208. if (curKey.eq(null)) {
  209. newData(2 * newPos) = key
  210. newData(2 * newPos + 1) = value
  211. keepGoing = false
  212. } else {
  213. val delta = i
  214. newPos = (newPos + delta) & newMask
  215. i += 1
  216. }
  217. }
  218. }
  219. oldPos += 1
  220. }
  221. data = newData
  222. capacity = newCapacity
  223. mask = newMask
  224. growThreshold = (LOAD_FACTOR * newCapacity).toInt
  225. }
  226. private def nextPowerOf2(n: Int): Int = {
  227. val highBit = Integer.highestOneBit(n)
  228. if (highBit == n) n else highBit << 1
  229. }
  230. /**
  231. * Return an iterator of the map in sorted order. This provides a way to sort the map without
  232. * using additional memory, at the expense of destroying the validity of the map.
  233. */
  234. def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = {
  235. destroyed = true
  236. // Pack KV pairs into the front of the underlying array
  237. var keyIndex, newIndex = 0
  238. while (keyIndex < capacity) {
  239. if (data(2 * keyIndex) != null) {
  240. data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1))
  241. newIndex += 1
  242. }
  243. keyIndex += 1
  244. }
  245. assert(curSize == newIndex + (if (haveNullValue) 1 else 0))
  246. // Sort by the given ordering
  247. val rawOrdering = new Comparator[AnyRef] {
  248. def compare(x: AnyRef, y: AnyRef): Int = {
  249. cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)])
  250. }
  251. }
  252. Arrays.sort(data, 0, newIndex, rawOrdering)
  253. new Iterator[(K, V)] {
  254. var i = 0
  255. var nullValueReady = haveNullValue
  256. def hasNext: Boolean = (i < newIndex || nullValueReady)
  257. def next(): (K, V) = {
  258. if (nullValueReady) {
  259. nullValueReady = false
  260. (null.asInstanceOf[K], nullValue)
  261. } else {
  262. val item = data(i).asInstanceOf[(K, V)]
  263. i += 1
  264. item
  265. }
  266. }
  267. }
  268. }
  269. /**
  270. * Return whether the next insert will cause the map to grow
  271. */
  272. def atGrowThreshold: Boolean = curSize == growThreshold
  273. }