/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala

https://gitlab.com/KiaraGrouwstra/spark · Scala · 106 lines · 45 code · 16 blank · 45 comment · 2 complexity · 9edb211fa95cbd620797d63414bb646e MD5 · raw file

  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package org.apache.spark.util.random
  18. import java.nio.ByteBuffer
  19. import java.util.{Random => JavaRandom}
  20. import scala.util.hashing.MurmurHash3
  21. import org.apache.spark.util.Utils.timeIt
  22. /**
  23. * This class implements a XORShift random number generator algorithm
  24. * Source:
  25. * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
  26. * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a>
  27. * This implementation is approximately 3.5 times faster than
  28. * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
  29. * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
  30. * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
  31. * for each thread.
  32. */
  33. private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
  34. def this() = this(System.nanoTime)
  35. private var seed = XORShiftRandom.hashSeed(init)
  36. // we need to just override next - this will be called by nextInt, nextDouble,
  37. // nextGaussian, nextLong, etc.
  38. override protected def next(bits: Int): Int = {
  39. var nextSeed = seed ^ (seed << 21)
  40. nextSeed ^= (nextSeed >>> 35)
  41. nextSeed ^= (nextSeed << 4)
  42. seed = nextSeed
  43. (nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
  44. }
  45. override def setSeed(s: Long) {
  46. seed = XORShiftRandom.hashSeed(s)
  47. }
  48. }
  49. /** Contains benchmark method and main method to run benchmark of the RNG */
  50. private[spark] object XORShiftRandom {
  51. /** Hash seeds to have 0/1 bits throughout. */
  52. private def hashSeed(seed: Long): Long = {
  53. val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array()
  54. MurmurHash3.bytesHash(bytes)
  55. }
  56. /**
  57. * Main method for running benchmark
  58. * @param args takes one argument - the number of random numbers to generate
  59. */
  60. def main(args: Array[String]): Unit = {
  61. // scalastyle:off println
  62. if (args.length != 1) {
  63. println("Benchmark of XORShiftRandom vis-a-vis java.util.Random")
  64. println("Usage: XORShiftRandom number_of_random_numbers_to_generate")
  65. System.exit(1)
  66. }
  67. println(benchmark(args(0).toInt))
  68. // scalastyle:on println
  69. }
  70. /**
  71. * @param numIters Number of random numbers to generate while running the benchmark
  72. * @return Map of execution times for {@link java.util.Random java.util.Random}
  73. * and XORShift
  74. */
  75. def benchmark(numIters: Int): Map[String, Long] = {
  76. val seed = 1L
  77. val million = 1e6.toInt
  78. val javaRand = new JavaRandom(seed)
  79. val xorRand = new XORShiftRandom(seed)
  80. // this is just to warm up the JIT - we're not timing anything
  81. timeIt(million) {
  82. javaRand.nextInt()
  83. xorRand.nextInt()
  84. }
  85. /* Return results as a map instead of just printing to screen
  86. in case the user wants to do something with them */
  87. Map("javaTime" -> timeIt(numIters) { javaRand.nextInt() },
  88. "xorTime" -> timeIt(numIters) { xorRand.nextInt() })
  89. }
  90. }