PageRenderTime 48ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 0ms

/Statistics/Resampling.hs

http://github.com/bos/statistics
Haskell | 279 lines | 175 code | 34 blank | 70 comment | 4 complexity | a1d1e621f834da2b3d35230aecbaedbe MD5 | raw file
Possible License(s): BSD-2-Clause
  1. {-# LANGUAGE BangPatterns #-}
  2. {-# LANGUAGE CPP #-}
  3. {-# LANGUAGE DeriveDataTypeable #-}
  4. {-# LANGUAGE DeriveFoldable #-}
  5. {-# LANGUAGE DeriveFunctor #-}
  6. {-# LANGUAGE DeriveGeneric #-}
  7. {-# LANGUAGE DeriveTraversable #-}
  8. {-# LANGUAGE FlexibleContexts #-}
  9. {-# LANGUAGE TypeFamilies #-}
  10. -- |
  11. -- Module : Statistics.Resampling
  12. -- Copyright : (c) 2009, 2010 Bryan O'Sullivan
  13. -- License : BSD3
  14. --
  15. -- Maintainer : bos@serpentine.com
  16. -- Stability : experimental
  17. -- Portability : portable
  18. --
  19. -- Resampling statistics.
  20. module Statistics.Resampling
  21. ( -- * Data types
  22. Resample(..)
  23. , Bootstrap(..)
  24. , Estimator(..)
  25. , estimate
  26. -- * Resampling
  27. , resampleST
  28. , resample
  29. , resampleVector
  30. -- * Jackknife
  31. , jackknife
  32. , jackknifeMean
  33. , jackknifeVariance
  34. , jackknifeVarianceUnb
  35. , jackknifeStdDev
  36. -- * Helper functions
  37. , splitGen
  38. ) where
  39. import Data.Aeson (FromJSON, ToJSON)
  40. import Control.Concurrent.Async (forConcurrently_)
  41. import Control.Monad (forM_, forM, replicateM, liftM2)
  42. import Control.Monad.Primitive (PrimMonad(..))
  43. import Data.Binary (Binary(..))
  44. import Data.Data (Data, Typeable)
  45. import Data.Vector.Algorithms.Intro (sort)
  46. import Data.Vector.Binary ()
  47. import Data.Vector.Generic (unsafeFreeze,unsafeThaw)
  48. import Data.Word (Word32)
  49. import qualified Data.Foldable as T
  50. import qualified Data.Traversable as T
  51. import qualified Data.Vector.Generic as G
  52. import qualified Data.Vector.Unboxed as U
  53. import qualified Data.Vector.Unboxed.Mutable as MU
  54. import GHC.Conc (numCapabilities)
  55. import GHC.Generics (Generic)
  56. import Numeric.Sum (Summation(..), kbn)
  57. import Statistics.Function (indices)
  58. import Statistics.Sample (mean, stdDev, variance, varianceUnbiased)
  59. import Statistics.Types (Sample)
  60. import System.Random.MWC (Gen, GenIO, initialize, uniformR, uniformVector)
  61. ----------------------------------------------------------------
  62. -- Data types
  63. ----------------------------------------------------------------
  64. -- | A resample drawn randomly, with replacement, from a set of data
  65. -- points. Distinct from a normal array to make it harder for your
  66. -- humble author's brain to go wrong.
  67. newtype Resample = Resample {
  68. fromResample :: U.Vector Double
  69. } deriving (Eq, Read, Show, Typeable, Data, Generic)
  70. instance FromJSON Resample
  71. instance ToJSON Resample
  72. instance Binary Resample where
  73. put = put . fromResample
  74. get = fmap Resample get
  75. data Bootstrap v a = Bootstrap
  76. { fullSample :: !a
  77. , resamples :: v a
  78. }
  79. deriving (Eq, Read, Show , Generic, Functor, T.Foldable, T.Traversable
  80. #if __GLASGOW_HASKELL__ >= 708
  81. , Typeable, Data
  82. #endif
  83. )
  84. instance (Binary a, Binary (v a)) => Binary (Bootstrap v a) where
  85. get = liftM2 Bootstrap get get
  86. put (Bootstrap fs rs) = put fs >> put rs
  87. instance (FromJSON a, FromJSON (v a)) => FromJSON (Bootstrap v a)
  88. instance (ToJSON a, ToJSON (v a)) => ToJSON (Bootstrap v a)
  89. -- | An estimator of a property of a sample, such as its 'mean'.
  90. --
  91. -- The use of an algebraic data type here allows functions such as
  92. -- 'jackknife' and 'bootstrapBCA' to use more efficient algorithms
  93. -- when possible.
  94. data Estimator = Mean
  95. | Variance
  96. | VarianceUnbiased
  97. | StdDev
  98. | Function (Sample -> Double)
  99. -- | Run an 'Estimator' over a sample.
  100. estimate :: Estimator -> Sample -> Double
  101. estimate Mean = mean
  102. estimate Variance = variance
  103. estimate VarianceUnbiased = varianceUnbiased
  104. estimate StdDev = stdDev
  105. estimate (Function est) = est
  106. ----------------------------------------------------------------
  107. -- Resampling
  108. ----------------------------------------------------------------
  109. -- | Single threaded and deterministic version of resample.
  110. resampleST :: PrimMonad m
  111. => Gen (PrimState m)
  112. -> [Estimator] -- ^ Estimation functions.
  113. -> Int -- ^ Number of resamples to compute.
  114. -> U.Vector Double -- ^ Original sample.
  115. -> m [Bootstrap U.Vector Double]
  116. resampleST gen ests numResamples sample = do
  117. -- Generate resamples
  118. res <- forM ests $ \e -> U.replicateM numResamples $ do
  119. v <- resampleVector gen sample
  120. return $! estimate e v
  121. -- Sort resamples
  122. resM <- mapM unsafeThaw res
  123. mapM_ sort resM
  124. resSorted <- mapM unsafeFreeze resM
  125. return $ zipWith Bootstrap [estimate e sample | e <- ests]
  126. resSorted
  127. -- | /O(e*r*s)/ Resample a data set repeatedly, with replacement,
  128. -- computing each estimate over the resampled data.
  129. --
  130. -- This function is expensive; it has to do work proportional to
  131. -- /e*r*s/, where /e/ is the number of estimation functions, /r/ is
  132. -- the number of resamples to compute, and /s/ is the number of
  133. -- original samples.
  134. --
  135. -- To improve performance, this function will make use of all
  136. -- available CPUs. At least with GHC 7.0, parallel performance seems
  137. -- best if the parallel garbage collector is disabled (RTS option
  138. -- @-qg@).
  139. resample :: GenIO
  140. -> [Estimator] -- ^ Estimation functions.
  141. -> Int -- ^ Number of resamples to compute.
  142. -> U.Vector Double -- ^ Original sample.
  143. -> IO [(Estimator, Bootstrap U.Vector Double)]
  144. resample gen ests numResamples samples = do
  145. let ixs = scanl (+) 0 $
  146. zipWith (+) (replicate numCapabilities q)
  147. (replicate r 1 ++ repeat 0)
  148. where (q,r) = numResamples `quotRem` numCapabilities
  149. results <- mapM (const (MU.new numResamples)) ests
  150. gens <- splitGen numCapabilities gen
  151. forConcurrently_ (zip3 ixs (tail ixs) gens) $ \ (start,!end,gen') -> do
  152. -- on GHCJS it doesn't make sense to do any forking.
  153. -- JavaScript runtime has only single capability.
  154. let loop k ers | k >= end = return ()
  155. | otherwise = do
  156. re <- resampleVector gen' samples
  157. forM_ ers $ \(est,arr) ->
  158. MU.write arr k . est $ re
  159. loop (k+1) ers
  160. loop start (zip ests' results)
  161. mapM_ sort results
  162. -- Build resamples
  163. res <- mapM unsafeFreeze results
  164. return $ zip ests
  165. $ zipWith Bootstrap [estimate e samples | e <- ests]
  166. res
  167. where
  168. ests' = map estimate ests
  169. -- | Create vector using resamples
  170. resampleVector :: (PrimMonad m, G.Vector v a)
  171. => Gen (PrimState m) -> v a -> m (v a)
  172. resampleVector gen v
  173. = G.replicateM n $ do i <- uniformR (0,n-1) gen
  174. return $! G.unsafeIndex v i
  175. where
  176. n = G.length v
  177. ----------------------------------------------------------------
  178. -- Jackknife
  179. ----------------------------------------------------------------
  180. -- | /O(n) or O(n^2)/ Compute a statistical estimate repeatedly over a
  181. -- sample, each time omitting a successive element.
  182. jackknife :: Estimator -> Sample -> U.Vector Double
  183. jackknife Mean sample = jackknifeMean sample
  184. jackknife Variance sample = jackknifeVariance sample
  185. jackknife VarianceUnbiased sample = jackknifeVarianceUnb sample
  186. jackknife StdDev sample = jackknifeStdDev sample
  187. jackknife (Function est) sample
  188. | G.length sample == 1 = singletonErr "jackknife"
  189. | otherwise = U.map f . indices $ sample
  190. where f i = est (dropAt i sample)
  191. -- | /O(n)/ Compute the jackknife mean of a sample.
  192. jackknifeMean :: Sample -> U.Vector Double
  193. jackknifeMean samp
  194. | len == 1 = singletonErr "jackknifeMean"
  195. | otherwise = G.map (/l) $ G.zipWith (+) (pfxSumL samp) (pfxSumR samp)
  196. where
  197. l = fromIntegral (len - 1)
  198. len = G.length samp
  199. -- | /O(n)/ Compute the jackknife variance of a sample with a
  200. -- correction factor @c@, so we can get either the regular or
  201. -- \"unbiased\" variance.
  202. jackknifeVariance_ :: Double -> Sample -> U.Vector Double
  203. jackknifeVariance_ c samp
  204. | len == 1 = singletonErr "jackknifeVariance"
  205. | otherwise = G.zipWith4 go als ars bls brs
  206. where
  207. als = pfxSumL . G.map goa $ samp
  208. ars = pfxSumR . G.map goa $ samp
  209. goa x = v * v where v = x - m
  210. bls = pfxSumL . G.map (subtract m) $ samp
  211. brs = pfxSumR . G.map (subtract m) $ samp
  212. m = mean samp
  213. n = fromIntegral len
  214. go al ar bl br = (al + ar - (b * b) / q) / (q - c)
  215. where b = bl + br
  216. q = n - 1
  217. len = G.length samp
  218. -- | /O(n)/ Compute the unbiased jackknife variance of a sample.
  219. jackknifeVarianceUnb :: Sample -> U.Vector Double
  220. jackknifeVarianceUnb samp
  221. | G.length samp == 2 = singletonErr "jackknifeVariance"
  222. | otherwise = jackknifeVariance_ 1 samp
  223. -- | /O(n)/ Compute the jackknife variance of a sample.
  224. jackknifeVariance :: Sample -> U.Vector Double
  225. jackknifeVariance = jackknifeVariance_ 0
  226. -- | /O(n)/ Compute the jackknife standard deviation of a sample.
  227. jackknifeStdDev :: Sample -> U.Vector Double
  228. jackknifeStdDev = G.map sqrt . jackknifeVarianceUnb
  229. pfxSumL :: U.Vector Double -> U.Vector Double
  230. pfxSumL = G.map kbn . G.scanl add zero
  231. pfxSumR :: U.Vector Double -> U.Vector Double
  232. pfxSumR = G.tail . G.map kbn . G.scanr (flip add) zero
  233. -- | Drop the /k/th element of a vector.
  234. dropAt :: U.Unbox e => Int -> U.Vector e -> U.Vector e
  235. dropAt n v = U.slice 0 n v U.++ U.slice (n+1) (U.length v - n - 1) v
  236. singletonErr :: String -> a
  237. singletonErr func = error $
  238. "Statistics.Resampling." ++ func ++ ": not enough elements in sample"
  239. -- | Split a generator into several that can run independently.
  240. splitGen :: Int -> GenIO -> IO [GenIO]
  241. splitGen n gen
  242. | n <= 0 = return []
  243. | otherwise =
  244. fmap (gen:) . replicateM (n-1) $
  245. initialize =<< (uniformVector gen 256 :: IO (U.Vector Word32))