PageRenderTime 49ms CodeModel.GetById 23ms RepoModel.GetById 0ms app.codeStats 0ms

/Data/Array/Accelerate/CUDA/Compile.hs

https://github.com/robeverest/accelerate-cuda
Haskell | 492 lines | 312 code | 70 blank | 110 comment | 11 complexity | f7cdf4cb06ebffabfda6ce1f8d7c2618 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. {-# LANGUAGE CPP #-}
  2. {-# LANGUAGE GADTs #-}
  3. {-# LANGUAGE ScopedTypeVariables #-}
  4. {-# LANGUAGE TupleSections #-}
  5. {-# OPTIONS_GHC -fno-warn-orphans #-}
  6. -- |
  7. -- Module : Data.Array.Accelerate.CUDA.Compile
  8. -- Copyright : [2008..2010] Manuel M T Chakravarty, Gabriele Keller, Sean Lee
  9. -- [2009..2012] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell
  10. -- License : BSD3
  11. --
  12. -- Maintainer : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
  13. -- Stability : experimental
  14. -- Portability : non-portable (GHC extensions)
  15. --
  16. module Data.Array.Accelerate.CUDA.Compile (
  17. -- * generate and compile kernels to realise a computation
  18. compileAcc, compileAfun
  19. ) where
  20. #include "accelerate.h"
  21. -- friends
  22. import Data.Array.Accelerate.Tuple
  23. import Data.Array.Accelerate.CUDA.AST
  24. import Data.Array.Accelerate.CUDA.State
  25. import Data.Array.Accelerate.CUDA.CodeGen
  26. import Data.Array.Accelerate.CUDA.Array.Sugar
  27. import Data.Array.Accelerate.CUDA.Analysis.Launch
  28. import Data.Array.Accelerate.CUDA.Foreign (canExecute)
  29. import Data.Array.Accelerate.CUDA.Persistent as KT
  30. import qualified Data.Array.Accelerate.CUDA.FullList as FL
  31. import qualified Data.Array.Accelerate.CUDA.Debug as D
  32. -- libraries
  33. import Numeric
  34. import Prelude hiding ( exp, scanl, scanr )
  35. import Control.Applicative hiding ( Const )
  36. import Control.Exception
  37. import Control.Monad
  38. import Control.Monad.Reader ( asks )
  39. import Control.Monad.State ( gets )
  40. import Control.Monad.Trans ( liftIO, MonadIO )
  41. import Control.Concurrent
  42. import Crypto.Hash.MD5 ( hashlazy )
  43. import Data.List ( intercalate )
  44. import Data.Maybe
  45. import Data.Monoid
  46. import System.Directory
  47. import System.Exit ( ExitCode(..) )
  48. import System.FilePath
  49. import System.IO
  50. import System.IO.Error
  51. import System.IO.Unsafe
  52. import System.Process
  53. import Text.PrettyPrint.Mainland ( ppr, renderCompact, displayLazyText )
  54. import qualified Data.ByteString as B
  55. import qualified Data.Text.Lazy as T
  56. import qualified Data.Text.Lazy.IO as T
  57. import qualified Data.Text.Lazy.Encoding as T
  58. import qualified Control.Concurrent.MSem as Q
  59. import qualified Foreign.CUDA.Driver as CUDA
  60. import qualified Foreign.CUDA.Analysis as CUDA
  61. import GHC.Conc ( getNumProcessors )
  62. #ifdef VERSION_unix
  63. import System.Posix.Process
  64. #else
  65. import System.Win32.Process
  66. #endif
  67. #ifndef SIZEOF_HSINT
  68. import Foreign.Storable
  69. #endif
  70. import Paths_accelerate_cuda ( getDataDir )
  71. -- Keep track of which kernels have been linked into which contexts. We use the
  72. -- context as a lookup key, which requires equality.
  73. --
  74. instance Eq CUDA.Context where
  75. CUDA.Context c1 == CUDA.Context c2 = c1 == c2
  76. -- | Initiate code generation, compilation, and data transfer for an array
  77. -- expression. The returned array computation is annotated so to be suitable for
  78. -- execution in the CUDA environment. This includes:
  79. --
  80. -- * list of array variables embedded within scalar expressions
  81. --
  82. -- * kernel object(s) required to executed the kernel
  83. --
  84. compileAcc :: Acc a -> CIO (ExecAcc a)
  85. compileAcc = prepareOpenAcc
  86. compileAfun :: Afun f -> CIO (ExecAfun f)
  87. compileAfun = prepareOpenAfun
  88. prepareOpenAfun :: OpenAfun aenv f -> CIO (PreOpenAfun ExecOpenAcc aenv f)
  89. prepareOpenAfun (Alam l) = Alam <$> prepareOpenAfun l
  90. prepareOpenAfun (Abody b) = Abody <$> prepareOpenAcc b
  91. prepareOpenAcc :: OpenAcc aenv a -> CIO (ExecOpenAcc aenv a)
  92. prepareOpenAcc rootAcc = traverseAcc rootAcc
  93. where
  94. -- Traverse an open array expression in depth-first order
  95. --
  96. -- The applicative combinators are used to gloss over that we are passing
  97. -- around the AST nodes together with a set of free variable indices that
  98. -- are merged at every step.
  99. --
  100. traverseAcc :: forall aenv arrs. OpenAcc aenv arrs -> CIO (ExecOpenAcc aenv arrs)
  101. traverseAcc acc@(OpenAcc pacc) =
  102. case pacc of
  103. -- Environment and control flow
  104. Avar ix -> node $ pure (Avar ix)
  105. Alet a b -> node . pure =<< Alet <$> traverseAcc a <*> traverseAcc b
  106. Apply f a -> node . pure =<< Apply <$> compileAfun f <*> traverseAcc a
  107. Acond p t e -> node =<< liftA3 Acond <$> travE p <*> travA t <*> travA e
  108. Atuple tup -> node =<< liftA Atuple <$> travAtup tup
  109. Aprj ix tup -> node =<< liftA (Aprj ix) <$> travA tup
  110. -- Array injection
  111. Unit e -> node =<< liftA Unit <$> travE e
  112. Use arrs -> use (arrays (undefined::arrs)) arrs >> node (pure $ Use arrs)
  113. -- Index space transforms
  114. Reshape s a -> node =<< liftA2 Reshape <$> travE s <*> travA a
  115. Replicate slix e a -> exec =<< liftA2 (Replicate slix) <$> travE e <*> travA a
  116. Slice slix a e -> exec =<< liftA2 (Slice slix) <$> travA a <*> travE e
  117. Backpermute e f a -> exec =<< liftA3 Backpermute <$> travE e <*> travF f <*> travD a
  118. -- Producers
  119. Generate e f -> exec =<< liftA2 Generate <$> travE e <*> travF f
  120. Map f a -> exec =<< liftA2 Map <$> travF f <*> travD a
  121. ZipWith f a b -> exec =<< liftA3 ZipWith <$> travF f <*> travD a <*> travD b
  122. Transform e p f a -> exec =<< liftA4 Transform <$> travE e <*> travF p <*> travF f <*> travD a
  123. -- Consumers
  124. Fold f z a -> exec =<< liftA3 Fold <$> travF f <*> travE z <*> travD a
  125. Fold1 f a -> exec =<< liftA2 Fold1 <$> travF f <*> travD a
  126. FoldSeg f e a s -> exec =<< liftA4 FoldSeg <$> travF f <*> travE e <*> travD a <*> travD s
  127. Fold1Seg f a s -> exec =<< liftA3 Fold1Seg <$> travF f <*> travD a <*> travD s
  128. Scanl f e a -> exec =<< liftA3 Scanl <$> travF f <*> travE e <*> travD a
  129. Scanl' f e a -> exec =<< liftA3 Scanl' <$> travF f <*> travE e <*> travD a
  130. Scanl1 f a -> exec =<< liftA2 Scanl1 <$> travF f <*> travD a
  131. Scanr f e a -> exec =<< liftA3 Scanr <$> travF f <*> travE e <*> travD a
  132. Scanr' f e a -> exec =<< liftA3 Scanr' <$> travF f <*> travE e <*> travD a
  133. Scanr1 f a -> exec =<< liftA2 Scanr1 <$> travF f <*> travD a
  134. Permute f d g a -> exec =<< liftA4 Permute <$> travF f <*> travA d <*> travF g <*> travD a
  135. Stencil f b a -> exec =<< liftA2 (flip Stencil b) <$> travF f <*> travA a
  136. Stencil2 f b1 a1 b2 a2 -> exec =<< liftA3 stencil2 <$> travF f <*> travA a1 <*> travA a2
  137. where stencil2 f' a1' a2' = Stencil2 f' b1 a1' b2 a2'
  138. Foreign ff afun a -> case canExecute ff of
  139. -- If it's a foreign call for the CUDA backend don't bother compiling the pure version
  140. (Just _) -> node =<< liftA (Foreign ff foreignError) <$> travA a
  141. Nothing -> node . pure =<< Foreign ff <$> compileAfun afun <*> traverseAcc a
  142. where
  143. use :: ArraysR a -> a -> CIO ()
  144. use ArraysRunit () = return ()
  145. use ArraysRarray arr = useArray arr
  146. use (ArraysRpair r1 r2) (a1, a2) = use r1 a1 >> use r2 a2
  147. exec :: (Gamma aenv, PreOpenAcc ExecOpenAcc aenv arrs) -> CIO (ExecOpenAcc aenv arrs)
  148. exec (aenv, eacc) = do
  149. kernel <- build acc aenv
  150. return $! ExecAcc (fullOfList kernel) aenv eacc
  151. node :: (Gamma aenv', PreOpenAcc ExecOpenAcc aenv' arrs') -> CIO (ExecOpenAcc aenv' arrs')
  152. node = fmap snd . wrap
  153. wrap :: (Gamma aenv', PreOpenAcc ExecOpenAcc aenv' arrs') -> CIO (Gamma aenv', ExecOpenAcc aenv' arrs')
  154. wrap = return . liftA (ExecAcc noKernel mempty)
  155. travA :: OpenAcc aenv' a' -> CIO (Gamma aenv', ExecOpenAcc aenv' a')
  156. travA a = pure <$> traverseAcc a
  157. travD :: (Shape sh, Elt e) => OpenAcc aenv (Array sh e) -> CIO (Gamma aenv, ExecOpenAcc aenv (Array sh e))
  158. travD (OpenAcc delayed) =
  159. case delayed of
  160. Avar ix -> wrap (freevar ix, Avar ix)
  161. Map f a -> wrap =<< liftA2 Map <$> travF f <*> travD a
  162. Generate e f -> wrap =<< liftA2 Generate <$> travE e <*> travF f
  163. Backpermute e f a -> wrap =<< liftA3 Backpermute <$> travE e <*> travF f <*> travD a
  164. Transform e p f a -> wrap =<< liftA4 Transform <$> travE e <*> travF p <*> travF f <*> travD a
  165. _ -> INTERNAL_ERROR(error) "compile" "expected fused/delayable array"
  166. travAtup :: Atuple (OpenAcc aenv) a -> CIO (Gamma aenv, Atuple (ExecOpenAcc aenv) a)
  167. travAtup NilAtup = return (pure NilAtup)
  168. travAtup (SnocAtup t a) = liftA2 SnocAtup <$> travAtup t <*> travA a
  169. travF :: OpenFun env aenv t -> CIO (Gamma aenv, PreOpenFun ExecOpenAcc env aenv t)
  170. travF (Body b) = liftA Body <$> travE b
  171. travF (Lam f) = liftA Lam <$> travF f
  172. noKernel :: FL.FullList () (AccKernel a)
  173. noKernel = FL.FL () (INTERNAL_ERROR(error) "compile" "no kernel module for this node") FL.Nil
  174. fullOfList :: [a] -> FL.FullList () a
  175. fullOfList [] = INTERNAL_ERROR(error) "fullList" "empty list"
  176. fullOfList [x] = FL.singleton () x
  177. fullOfList (x:xs) = FL.cons () x (fullOfList xs)
  178. foreignError = INTERNAL_ERROR(error) "compile" $ "Didn't compile the pure version of a foreign function call but" ++
  179. " it looks like it's being executed anyway"
  180. -- Traverse a scalar expression
  181. --
  182. travE :: OpenExp env aenv e
  183. -> CIO (Gamma aenv, PreOpenExp ExecOpenAcc env aenv e)
  184. travE exp =
  185. case exp of
  186. Var ix -> return $ pure (Var ix)
  187. Const c -> return $ pure (Const c)
  188. PrimConst c -> return $ pure (PrimConst c)
  189. IndexAny -> return $ pure IndexAny
  190. IndexNil -> return $ pure IndexNil
  191. --
  192. Let a b -> liftA2 Let <$> travE a <*> travE b
  193. IndexCons t h -> liftA2 IndexCons <$> travE t <*> travE h
  194. IndexHead h -> liftA IndexHead <$> travE h
  195. IndexTail t -> liftA IndexTail <$> travE t
  196. IndexSlice slix x s -> liftA2 (IndexSlice slix) <$> travE x <*> travE s
  197. IndexFull slix x s -> liftA2 (IndexFull slix) <$> travE x <*> travE s
  198. ToIndex s i -> liftA2 ToIndex <$> travE s <*> travE i
  199. FromIndex s i -> liftA2 FromIndex <$> travE s <*> travE i
  200. Tuple t -> liftA Tuple <$> travT t
  201. Prj ix e -> liftA (Prj ix) <$> travE e
  202. Cond p t e -> liftA3 Cond <$> travE p <*> travE t <*> travE e
  203. Iterate n f x -> liftA3 Iterate <$> travE n <*> travE f <*> travE x
  204. -- While p f x -> liftA3 While <$> travE p <*> travE f <*> travE x
  205. PrimApp f e -> liftA (PrimApp f) <$> travE e
  206. Index a e -> liftA2 Index <$> travA a <*> travE e
  207. LinearIndex a e -> liftA2 LinearIndex <$> travA a <*> travE e
  208. Shape a -> liftA Shape <$> travA a
  209. ShapeSize e -> liftA ShapeSize <$> travE e
  210. Intersect x y -> liftA2 Intersect <$> travE x <*> travE y
  211. where
  212. travA :: (Shape sh, Elt e)
  213. => OpenAcc aenv (Array sh e) -> CIO (Gamma aenv, ExecOpenAcc aenv (Array sh e))
  214. travA a = do
  215. a' <- traverseAcc a
  216. return $ (bind a', a')
  217. travT :: Tuple (OpenExp env aenv) t
  218. -> CIO (Gamma aenv, Tuple (PreOpenExp ExecOpenAcc env aenv) t)
  219. travT NilTup = return (pure NilTup)
  220. travT (SnocTup t e) = liftA2 SnocTup <$> travT t <*> travE e
  221. bind :: (Shape sh, Elt e) => ExecOpenAcc aenv (Array sh e) -> Gamma aenv
  222. bind (ExecAcc _ _ (Avar ix)) = freevar ix
  223. bind _ = INTERNAL_ERROR(error) "bind" "expected array variable"
  224. -- Applicative
  225. -- -----------
  226. --
  227. liftA4 :: Applicative f => (a -> b -> c -> d -> e) -> f a -> f b -> f c -> f d -> f e
  228. liftA4 f a b c d = f <$> a <*> b <*> c <*> d
  229. -- Compilation
  230. -- -----------
  231. -- Generate, compile, and link code to evaluate an array computation. We use
  232. -- 'unsafePerformIO' here to leverage laziness, so that the 'link' function
  233. -- evaluates and blocks on the external compiler only once the compiled object
  234. -- is truly needed.
  235. --
  236. build :: OpenAcc aenv a -> Gamma aenv -> CIO [AccKernel a]
  237. build acc aenv = do
  238. dev <- asks deviceProps
  239. mapM (build1 acc) (codegenAcc dev acc aenv)
  240. build1 :: OpenAcc aenv a -> CUTranslSkel aenv a -> CIO (AccKernel a)
  241. build1 acc code = do
  242. dev <- asks deviceProps
  243. table <- gets kernelTable
  244. (entry,key) <- compile table dev code
  245. let (cta,blocks,smem) = launchConfig acc dev occ
  246. (mdl,fun,occ) = unsafePerformIO $ do
  247. m <- link table key
  248. f <- CUDA.getFun m entry
  249. l <- CUDA.requires f CUDA.MaxKernelThreadsPerBlock
  250. o <- determineOccupancy acc dev f l
  251. D.when D.dump_cc (stats entry f o)
  252. return (m,f,o)
  253. --
  254. return $ AccKernel entry fun mdl occ cta smem blocks
  255. where
  256. stats name fn occ = do
  257. regs <- CUDA.requires fn CUDA.NumRegs
  258. smem <- CUDA.requires fn CUDA.SharedSizeBytes
  259. cmem <- CUDA.requires fn CUDA.ConstSizeBytes
  260. lmem <- CUDA.requires fn CUDA.LocalSizeBytes
  261. let msg1 = "entry function '" ++ name ++ "' used "
  262. ++ shows regs " registers, " ++ shows smem " bytes smem, "
  263. ++ shows lmem " bytes lmem, " ++ shows cmem " bytes cmem"
  264. msg2 = "multiprocessor occupancy " ++ showFFloat (Just 1) (CUDA.occupancy100 occ) "% : "
  265. ++ shows (CUDA.activeThreads occ) " threads over "
  266. ++ shows (CUDA.activeWarps occ) " warps in "
  267. ++ shows (CUDA.activeThreadBlocks occ) " blocks"
  268. --
  269. -- make sure kernel/stats are printed together. Use 'intercalate' rather
  270. -- than 'unlines' to avoid a trailing newline.
  271. --
  272. message $ intercalate "\n ... " [msg1, msg2]
  273. -- Link a compiled binary and update the associated kernel entry in the hash
  274. -- table. This may entail waiting for the external compilation process to
  275. -- complete. If successful, the temporary files are removed.
  276. --
  277. link :: KernelTable -> KernelKey -> IO CUDA.Module
  278. link table key =
  279. let intErr = INTERNAL_ERROR(error) "link" "missing kernel entry"
  280. in do
  281. ctx <- CUDA.get
  282. entry <- fromMaybe intErr `fmap` KT.lookup table key
  283. case entry of
  284. CompileProcess cufile done -> do
  285. -- Wait for the compiler to finish and load the binary object into the
  286. -- current context.
  287. --
  288. -- A forked thread will fill the MVar once the external compilation
  289. -- process completes, but only the main thread executes kernels. Hence,
  290. -- only one thread will ever attempt to take the MVar in order to link
  291. -- the binary object.
  292. --
  293. message "waiting for nvcc..."
  294. takeMVar done
  295. let cubin = replaceExtension cufile ".cubin"
  296. bin <- B.readFile cubin
  297. mdl <- CUDA.loadData bin
  298. -- Update hash tables and stash the binary object into the persistent
  299. -- cache
  300. --
  301. KT.insert table key $! KernelObject bin (FL.singleton ctx mdl)
  302. KT.persist cubin key
  303. -- Remove temporary build products.
  304. -- If compiling kernels with debugging symbols, leave the source files
  305. -- in place so that they can be referenced by 'cuda-gdb'.
  306. --
  307. D.unless D.debug_cc $ do
  308. removeFile cufile
  309. removeDirectory (dropFileName cufile)
  310. `catchIOError` \_ -> return () -- directory not empty
  311. return mdl
  312. -- If we get a real object back, then this will already be in the
  313. -- persistent cache, since either it was just read in from there, or we
  314. -- had to generate new code and the link step above has added it.
  315. --
  316. KernelObject bin active
  317. | Just mdl <- FL.lookup ctx active -> return mdl
  318. | otherwise -> do
  319. message "re-linking module for current context"
  320. mdl <- CUDA.loadData bin
  321. KT.insert table key $! KernelObject bin (FL.cons ctx mdl active)
  322. return mdl
  323. -- Generate and compile code for a single open array expression
  324. --
  325. compile :: KernelTable -> CUDA.DeviceProperties -> CUTranslSkel aenv a -> CIO (String, KernelKey)
  326. compile table dev cunit = do
  327. exists <- isJust `fmap` liftIO (KT.lookup table key)
  328. unless exists $ do
  329. message $ unlines [ show key, T.unpack code ]
  330. nvcc <- fromMaybe (error "nvcc: command not found") <$> liftIO (findExecutable "nvcc")
  331. (file,hdl) <- openTemporaryFile "dragon.cu" -- rawr!
  332. flags <- compileFlags file
  333. done <- liftIO $ do
  334. message $ "execute: " ++ nvcc ++ " " ++ unwords flags
  335. T.hPutStr hdl code `finally` hClose hdl
  336. enqueueProcess (proc nvcc flags) `onException` removeFile file
  337. --
  338. liftIO $ KT.insert table key (CompileProcess file done)
  339. --
  340. return (entry, key)
  341. where
  342. entry = show cunit
  343. key = (CUDA.computeCapability dev, hashlazy (T.encodeUtf8 code) )
  344. code = displayLazyText . renderCompact $ ppr cunit
  345. -- Determine the appropriate command line flags to pass to the compiler process.
  346. -- This is dependent on the host architecture and device capabilities.
  347. --
  348. compileFlags :: FilePath -> CIO [String]
  349. compileFlags cufile = do
  350. CUDA.Compute m n <- CUDA.computeCapability `fmap` asks deviceProps
  351. ddir <- liftIO getDataDir
  352. return $ filter (not . null) $
  353. [ "-I", ddir </> "cubits"
  354. , "-arch=sm_" ++ show m ++ show n
  355. , "-cubin"
  356. , "-o", cufile `replaceExtension` "cubin"
  357. , if D.mode D.dump_cc then "" else "--disable-warnings"
  358. , if D.mode D.debug_cc then "-G" else "-O3"
  359. , machine
  360. , cufile ]
  361. where
  362. #if SIZEOF_HSINT == 4
  363. machine = "-m32"
  364. #elif SIZEOF_HSINT == 8
  365. machine = "-m64"
  366. #else
  367. machine = case sizeOf (undefined :: Int) of
  368. 4 -> "-m32"
  369. 8 -> "-m64"
  370. #endif
  371. -- Open a unique file in the temporary directory used for compilation
  372. -- by-products. The directory will be created if it does not exist.
  373. --
  374. openTemporaryFile :: String -> CIO (FilePath, Handle)
  375. openTemporaryFile template = liftIO $ do
  376. pid <- getProcessID
  377. dir <- (</>) <$> getTemporaryDirectory <*> pure ("accelerate-cuda-" ++ show pid)
  378. createDirectoryIfMissing True dir
  379. openTempFile dir template
  380. #ifndef VERSION_unix
  381. getProcessID :: ProcessHandle -> IO ProcessId
  382. getProcessID = getProcessId
  383. #endif
  384. -- Worker pool
  385. -- -----------
  386. {-# NOINLINE pool #-}
  387. pool :: Q.MSem Int
  388. pool = unsafePerformIO $ Q.new =<< getNumProcessors
  389. -- Queue a system process to be executed and return an MVar flag that will be
  390. -- filled once the process completes. The task will only be launched once there
  391. -- is a worker available from the pool. This ensures we don't run out of process
  392. -- handles or flood the IO bus, degrading performance.
  393. --
  394. enqueueProcess :: CreateProcess -> IO (MVar ())
  395. enqueueProcess cp = do
  396. mvar <- newEmptyMVar
  397. _ <- forkIO $ do
  398. -- wait for a worker to become available
  399. Q.wait pool
  400. (_,_,_,pid) <- createProcess cp
  401. -- asynchronously notify the queue when the compiler has completed
  402. _ <- forkIO $ do finally (waitFor pid) (Q.signal pool)
  403. putMVar mvar () -- never executed if the compilation fails.
  404. return ()
  405. --
  406. return mvar
  407. -- Wait for a (compilation) process to finish
  408. --
  409. waitFor :: ProcessHandle -> IO ()
  410. waitFor pid = do
  411. status <- waitForProcess pid
  412. case status of
  413. ExitSuccess -> return ()
  414. ExitFailure c -> error $ "nvcc terminated abnormally (" ++ show c ++ ")"
  415. -- Debug
  416. -- -----
  417. {-# INLINE message #-}
  418. message :: MonadIO m => String -> m ()
  419. message msg = trace msg $ return ()
  420. {-# INLINE trace #-}
  421. trace :: MonadIO m => String -> m a -> m a
  422. trace msg next = D.message D.dump_cc ("cc: " ++ msg) >> next