PageRenderTime 23ms CodeModel.GetById 36ms RepoModel.GetById 0ms app.codeStats 1ms

/src/CLUtil/Buffer.hs

http://github.com/acowley/CLUtil
Haskell | 219 lines | 126 code | 20 blank | 73 comment | 0 complexity | b88ae83b6109f46529f1f604941ae8ab MD5 | raw file
Possible License(s): BSD-3-Clause
  1. {-# LANGUAGE ConstraintKinds, FlexibleContexts, ScopedTypeVariables,
  2. TupleSections, RankNTypes #-}
  3. -- | Typed monadic interface for working with OpenCL buffers.
  4. module CLUtil.Buffer where
  5. import Control.Applicative ((<$>))
  6. import Control.Concurrent (forkIO)
  7. import Control.Concurrent.MVar (newEmptyMVar, putMVar, takeMVar)
  8. import Control.Exception (evaluate)
  9. import Control.Monad (when)
  10. import Control.Monad.ST (ST)
  11. import Control.Monad.ST.Unsafe (unsafeSTToIO)
  12. import qualified Data.Vector.Storable as V
  13. import qualified Data.Vector.Storable.Mutable as VM
  14. import Foreign.Marshal.Utils (copyBytes)
  15. import Foreign.ForeignPtr (newForeignPtr_)
  16. import Foreign.Ptr (castPtr, nullPtr)
  17. import Foreign.Storable (Storable(..))
  18. import CLUtil.Async
  19. import CLUtil.CL
  20. import CLUtil.State (OpenCLState(clContext, clQueue))
  21. import Control.Parallel.OpenCL
  22. -- |Allocate a raw buffer whose contents are undefined.
  23. initOutputBuffer :: Integral a => OpenCLState -> [CLMemFlag] -> a -> IO CLMem
  24. initOutputBuffer s flags n = clCreateBuffer (clContext s) flags (n, nullPtr)
  25. -- |A @CLBuffer a@ is a buffer object whose elements are of type
  26. -- @a@. It is the caller's responsibility that the given type must
  27. -- naturally map to an OpenCL type (e.g. 'Word8', 'Int32', 'Float').
  28. data CLBuffer a = CLBuffer { bufferLength :: Int
  29. , bufferObject :: CLMem }
  30. -- | NOTE: This is an /EVIL/ 'Storable' instance that lets us treat a
  31. -- 'CLBuffer' as its underlying 'CLMem' value for the sake of
  32. -- interoperating with OpenCL. The 'Storable' instance does /not/ let
  33. -- you roundtrip a value using 'peek' and 'poke'.
  34. instance Storable (CLBuffer a) where
  35. sizeOf _ = sizeOf (undefined::CLMem)
  36. alignment _ = alignment (undefined::CLMem)
  37. peek = fmap (CLBuffer (error "Tried to peek a CLBuffer")) . peek . castPtr
  38. poke ptr (CLBuffer _ m) = poke (castPtr ptr) m
  39. instance HasCLMem (CLBuffer a) where
  40. getCLMem (CLBuffer _ m) = m
  41. -- | Allocate a new buffer object of the given number of elements.
  42. allocBuffer :: forall a m. (Storable a, HasCL m)
  43. => [CLMemFlag] -> Int -> m (CLBuffer a)
  44. allocBuffer flags n =
  45. do s <- ask
  46. fmap (CLBuffer n) . liftIO $ initOutputBuffer s flags numBytes
  47. where numBytes = n * sizeOf (undefined::a)
  48. -- | Allocate a new buffer object and write a 'Vector''s contents to
  49. -- it.
  50. initBuffer :: forall a m. (Storable a, HasCL m)
  51. => [CLMemFlag] -> V.Vector a -> m (CLBuffer a)
  52. initBuffer flags v =
  53. do c <- clContext <$> ask
  54. fmap (CLBuffer (V.length v)) . liftIO . V.unsafeWith v $
  55. clCreateBuffer c flags . (sz,) . castPtr
  56. where sz = V.length v * sizeOf (undefined::a)
  57. -- | @readBuffer' mem n events@ reads back a 'Vector' of @n@ elements
  58. -- from the buffer object @mem@ after waiting for @events@ to finish.
  59. readBufferAsync' :: forall a m. (Storable a, HasCL m)
  60. => CLBuffer a -> Int -> [CLEvent]
  61. -> m (CLAsync (V.Vector a))
  62. readBufferAsync' (CLBuffer n' mem) n waitForIt =
  63. do when (n > n') (throwError "Tried to read more elements than a buffer has")
  64. q <- clQueue <$> ask
  65. v <- liftIO $ VM.new n
  66. ev <- liftIO . VM.unsafeWith v $ \ptr ->
  67. do (_, src) <- clEnqueueMapBuffer q mem True [CL_MAP_READ] 0 sz
  68. waitForIt
  69. copyBytes (castPtr ptr) src sz
  70. clEnqueueUnmapMemObject q mem src []
  71. -- clEnqueueReadBuffer q mem True 0 sz (castPtr ptr) waitForIt
  72. return . clAsync ev $ liftIO $ V.unsafeFreeze v
  73. where sz = n * sizeOf (undefined::a)
  74. -- | @readBuffer' buf n events@ performs a blocking read of the first
  75. -- @n@ elements of a buffer after waiting for @events@.
  76. readBuffer' :: (Storable a, HasCL m)
  77. => CLBuffer a -> Int -> [CLEvent] -> m (V.Vector a)
  78. readBuffer' buf n waitForIt = readBufferAsync' buf n waitForIt >>= waitOne
  79. -- | @readBuffer mem@ reads back a 'Vector' containing all the data
  80. -- stored in a 'CLBuffer'.
  81. readBuffer :: (Storable a, HasCL m) => CLBuffer a -> m (V.Vector a)
  82. readBuffer b@(CLBuffer n _) = readBuffer' b n []
  83. -- | Perform a non-blocking read of an buffer's entire contents.
  84. readBufferAsync :: (Storable a, HasCL m) => CLBuffer a -> m (CLAsync (V.Vector a))
  85. readBufferAsync b@(CLBuffer n _) = readBufferAsync' b n []
  86. -- | Write a 'Vector''s contents to a buffer object. This operation
  87. -- is non-blocking.
  88. writeBufferAsync :: forall a m. (Storable a, HasCL m)
  89. => CLBuffer a -> V.Vector a -> m (CLAsync ())
  90. writeBufferAsync (CLBuffer n mem) v =
  91. do when (V.length v > n)
  92. (throwError "writeBuffer: Vector is bigger than the CLBuffer")
  93. q <- clQueue <$> ask
  94. ev <- liftIO . V.unsafeWith v $ \ptr ->
  95. clEnqueueWriteBuffer q mem True 0 sz (castPtr ptr) []
  96. return . clAsync ev $ return ()
  97. where sz = V.length v * sizeOf (undefined::a)
  98. -- | Perform a blocking write of a 'Vector's contents to a buffer object.
  99. writeBuffer :: (Storable a, HasCL m) => CLBuffer a -> V.Vector a -> m ()
  100. writeBuffer b v = writeBufferAsync b v >>= waitOne
  101. -- | Create a read-only 'CLBuffer' that shares an underlying pointer
  102. -- with a 'V.Vector', then apply a function to that buffer. This is
  103. -- typically used to have an OpenCL kernel directly read from a
  104. -- vector. If the OpenCL context can not directly use the pointer,
  105. -- this will raise a runtime error!
  106. withSharedVector :: forall a r m. (Storable a, HasCL m)
  107. => V.Vector a -> (CLBuffer a -> m r) -> m r
  108. withSharedVector v go =
  109. do ctx <- clContext <$> ask
  110. mem <- liftIO . V.unsafeWith v $ \ptr ->
  111. clCreateBuffer ctx [CL_MEM_READ_ONLY, CL_MEM_USE_HOST_PTR]
  112. (sz, castPtr ptr)
  113. r <- go (CLBuffer (V.length v) mem)
  114. _ <- liftIO $ clReleaseMemObject mem
  115. return r
  116. where sz = V.length v * sizeOf (undefined::a)
  117. -- | Create a read-write 'CLBuffer' that shares an underlying pointer
  118. -- with an 'VM.IOVector', then apply the given function to that
  119. -- buffer. This is typically used to have an OpenCL kernel write
  120. -- directly to a Haskell vector. If the OpenCL context can not
  121. -- directly use the pointer, this will raise a runtime error!
  122. withSharedMVector :: forall a r m. (Storable a, HasCL m)
  123. => VM.IOVector a -> (CLBuffer a -> m r) -> m r
  124. withSharedMVector v go =
  125. do ctx <- clContext <$> ask
  126. mem <- liftIO . VM.unsafeWith v $ \ptr ->
  127. clCreateBuffer ctx [CL_MEM_READ_WRITE, CL_MEM_USE_HOST_PTR]
  128. (sz, castPtr ptr)
  129. r <- go (CLBuffer (VM.length v) mem)
  130. _ <- liftIO $ clReleaseMemObject mem
  131. return r
  132. where sz = VM.length v * sizeOf (undefined::a)
  133. -- | Provides access to a memory-mapped 'VM.MVector' of a
  134. -- 'CLBuffer'. The result of applying the given function to the vector
  135. -- is evaluated to WHNF, but the caller should ensure that this is
  136. -- sufficient to not require hanging onto a reference to the vector
  137. -- data, as this reference will not be valid. Returning the vector
  138. -- itself is right out. The 'CLMapFlag's supplied determine if we have
  139. -- read-only, write-only, or read/write access to the 'VM.MVector'.
  140. withBufferAsync_ :: forall a r m. (Storable a, HasCL m)
  141. => [CLMapFlag] -> CLBuffer a
  142. -> (forall s. VM.MVector s a -> ST s r) -> m (m r)
  143. withBufferAsync_ flags (CLBuffer n mem) f =
  144. do q <- clQueue <$> ask
  145. liftIO $
  146. do done <- newEmptyMVar
  147. _ <- forkIO $ do
  148. (ev,ptr) <- clEnqueueMapBuffer q mem True flags 0 sz []
  149. fp <- newForeignPtr_ $ castPtr ptr
  150. x <- evaluate =<< (unsafeSTToIO . f
  151. $ VM.unsafeFromForeignPtr0 fp n)
  152. clEnqueueUnmapMemObject q mem ptr [ev] >>= waitReleaseEvent
  153. putMVar done x
  154. return $ (liftIO $ takeMVar done)
  155. -- liftIO $ do (ev, ptr) <- clEnqueueMapBuffer q mem False flags 0 sz []
  156. -- let go = do fp <- newForeignPtr_ $ castPtr ptr
  157. -- x <- evaluate =<<
  158. -- (unsafeSTToIO . f
  159. -- $ VM.unsafeFromForeignPtr0 fp n)
  160. -- ev' <- clEnqueueUnmapMemObject q mem ptr []
  161. -- _ <- clWaitForEvents [ev'] >> clReleaseEvent ev'
  162. -- return x
  163. -- return . clAsync ev $ liftIO go
  164. where sz = n * sizeOf (undefined::a)
  165. -- | Provides read/write access to a memory-mapped 'VM.MVector' of a
  166. -- 'CLImage'. The caller should ensure that this is sufficient to not
  167. -- require hanging onto a reference to the vector data, as this
  168. -- reference will not be valid. Returning the vector itself is right
  169. -- out.
  170. withBufferRWAsync :: (Storable a, HasCL m)
  171. => CLBuffer a -> (forall s. VM.MVector s a -> ST s r)
  172. -> m (m r)
  173. withBufferRWAsync = withBufferAsync_ [CL_MAP_READ, CL_MAP_WRITE]
  174. -- | Provides read/write access to a memory-mapped 'VM.MVector' of a
  175. -- 'CLBuffer'. The caller should ensure that this is sufficient to not
  176. -- require hanging onto a reference to the vector data, as this
  177. -- reference will not be valid. Returning the vector itself is right
  178. -- out.
  179. withBufferRW :: (Storable a, HasCL m)
  180. => CLBuffer a -> (forall s. VM.MVector s a -> ST s r) -> m r
  181. withBufferRW img f = withBufferRWAsync img f >>= id
  182. -- | Provides read-only access to a memory-mapped 'V.Vector' of a
  183. -- 'CLBuffer'. The result of applying the given function to the vector
  184. -- is evaluated to WHNF, but the caller should ensure that this is
  185. -- sufficient to not require hanging onto a reference to the vector
  186. -- data, as this reference will not be valid. Returning the vector
  187. -- itself is right out.
  188. withBufferAsync :: (Storable a, HasCL m)
  189. => CLBuffer a -> (V.Vector a -> r) -> m (m r)
  190. withBufferAsync img f =
  191. withBufferAsync_ [CL_MAP_READ] img (fmap f . V.unsafeFreeze)
  192. -- | Provides read/write access to a memory-mapped 'V.Vector' of a
  193. -- 'CLBuffer'. The result of applying the given function to the vector
  194. -- is evaluated to WHNF, but the caller should ensure that this is
  195. -- sufficient to not require hanging onto a reference to the vector
  196. -- data, as this reference will not be valid. Returning the vector
  197. -- itself is right out.
  198. withBuffer :: (Storable a, HasCL m) => CLBuffer a -> (V.Vector a -> r) -> m r
  199. withBuffer img f = withBufferAsync img f >>= id