/SciFlow/src/Control/Workflow/Language/TH.hs

https://github.com/kaizhang/SciFlow · Haskell · 183 lines · 152 code · 12 blank · 19 comment · 5 complexity · 211f4b0fb201597fda71105617e29a8e MD5 · raw file

  1. {-# LANGUAGE TemplateHaskell #-}
  2. {-# LANGUAGE FlexibleContexts #-}
  3. {-# LANGUAGE OverloadedStrings #-}
  4. {-# LANGUAGE RecordWildCards #-}
  5. {-# LANGUAGE BangPatterns #-}
  6. module Control.Workflow.Language.TH (build) where
  7. import Control.Arrow.Free (mapA, effect)
  8. import Control.Arrow (arr, (>>>))
  9. import qualified Data.Text as T
  10. import Language.Haskell.TH
  11. import Instances.TH.Lift ()
  12. import qualified Data.HashMap.Strict as M
  13. import qualified Data.Graph.Inductive as G
  14. import Control.Monad.State.Lazy (execState)
  15. import Data.Hashable (hash)
  16. import Data.Maybe (fromJust)
  17. import Control.Workflow.Language
  18. import Control.Workflow.Types
  19. import Control.Workflow.Interpreter.FunctionTable (mkFunTable)
  20. import Control.Workflow.Language.TH.Internal
  21. -- | Generate template haskell codes to build the workflow.
  22. build :: String -- ^ The name of the compiled workflow.
  23. -> TypeQ -- ^ The workflow signature.
  24. -> Builder () -- ^ Worflow builder.
  25. -> Q [Dec]
  26. build name sig builder = compile name sig wf
  27. where
  28. wf = addSource $ execState builder $ Workflow M.empty M.empty
  29. {-# INLINE build #-}
  30. -- Generate codes from a DAG. This function will create functions defined in
  31. -- the builder. These pieces will be assembled to form a function that will
  32. -- execute each individual function in a correct order.
  33. compile :: String -- ^ The name of the compiled workflow
  34. -> TypeQ -- ^ The function signature
  35. -> Workflow
  36. -> Q [Dec]
  37. compile name sig wf = do
  38. d1 <- defFlow wfName
  39. d2 <- mkFunTable (name ++ "__Table") (name ++ "__Flow")
  40. -- the function signature
  41. wf_signature <- (mkName name) `sigD` sig
  42. d3 <- [d| $(varP $ mkName name) = SciFlow $(varE wfName) $(varE tableName) $ G.mkGraph nodes edges |]
  43. return $ d1 ++ d2 ++ (wf_signature:d3)
  44. where
  45. nodes =
  46. let mkNodeLabel k (UNode _) = NodeLabel k "" False True
  47. mkNodeLabel k Node{..} = NodeLabel k _node_doc _node_parallel False
  48. in flip map (M.toList $ _nodes wf) $ \(k, nd) -> (hash k, mkNodeLabel k nd)
  49. edges = flip concatMap (M.toList $ _parents wf) $ \(x, ps) ->
  50. flip map ps $ \p -> (hash p, hash x, ())
  51. tableName = mkName $ name ++ "__Table"
  52. wfName = mkName $ name ++ "__Flow"
  53. defFlow nm = do
  54. main <- compileWorkflow wf
  55. return [ValD (VarP nm) (NormalB main) []]
  56. {-# INLINE compile #-}
  57. mkJob :: T.Text -> Node -> ExpQ
  58. mkJob nid Node{..}
  59. | _node_parallel = [| step $ Job
  60. { _job_name = nid
  61. , _job_descr = _node_doc
  62. , _job_resource = _node_job_resource
  63. , _job_parallel = True
  64. , _job_action = mapA $ effect $ Action $_node_function
  65. } |]
  66. | otherwise = [| step $ Job
  67. { _job_name = nid
  68. , _job_descr = _node_doc
  69. , _job_resource = _node_job_resource
  70. , _job_parallel = False
  71. , _job_action = effect $ Action $_node_function
  72. } |]
  73. mkJob nid (UNode fun) = [| ustep nid $fun |]
  74. {-# INLINE mkJob #-}
  75. addSource :: Workflow -> Workflow
  76. addSource wf = execState builder wf
  77. where
  78. builder = do
  79. uNode name [| \() -> return () |]
  80. mapM_ (\x -> [name] ~> x) sources
  81. sources = filter (\x -> not $ x `M.member` _parents wf) $ M.keys $ _nodes wf
  82. name = "SciFlow_Source_Node_2xdj23"
  83. {-# INLINE addSource #-}
  84. adjustIdx :: [Int] -- ^ positions removed
  85. -> [Int] -- ^ Old position in the list
  86. -> Maybe [Int] -- ^ new position in the list
  87. adjustIdx pos old = case filter (>=0) (map f old) of
  88. [] -> Nothing
  89. x -> Just x
  90. where
  91. f x = go 0 pos
  92. where
  93. go !acc (p:ps) | x == p = -1
  94. | p < x = go (acc+1) ps
  95. | otherwise = go acc ps
  96. go !acc _ = x - acc
  97. {-# INLINE adjustIdx #-}
  98. compileWorkflow :: Workflow -> ExpQ
  99. compileWorkflow wf =
  100. let (functions, _, _) = foldl processNodeGroup ([], M.empty, 0) $ groupSortNodes wf
  101. sink = [| arr $ const () |]
  102. in linkFunctions $ reverse $ sink : functions
  103. where
  104. nodeToParents :: M.HashMap T.Text ([T.Text], Int)
  105. nodeToParents = M.fromList $ flip map nodes $ \(_, (x, _)) ->
  106. let i = M.lookupDefault undefined x nodeToId
  107. degree = case G.outdeg gr i of
  108. 0 -> 1
  109. d -> d
  110. in (x, (M.lookupDefault [] x $ _parents wf, degree))
  111. processNodeGroup (acc, nodeToPos, nVar) nodes =
  112. case (map (\x -> M.lookupDefault (error "Impossible") (fst x) nodeToParents) nodes) of
  113. -- source node
  114. [([], n)] ->
  115. let oIdx = [0 .. n - 1]
  116. (nid, f) = head nodes
  117. nodeToPos' = M.insert nid oIdx nodeToPos
  118. fun = [| $f >>> $(replicateOutput 1 n) |]
  119. in (fun:acc, nodeToPos', n)
  120. parents ->
  121. let inputPos =
  122. let computeIdx count (x:xs) = let (x', count') = go ([], count) x in x' : computeIdx count' xs
  123. where
  124. go (acc, m) (y:ys) = case M.lookup y m of
  125. Nothing -> go (acc ++ [(y, 0)], M.insert y 1 m) ys
  126. Just c -> go (acc ++ [(y, c)], M.insert y (c+1) m) ys
  127. go acc _ = acc
  128. computeIdx _ _ = []
  129. lookupP (p, i) = M.lookupDefault errMsg p nodeToPos !! i
  130. where
  131. errMsg = error $ unlines $
  132. ("Node not found: " <> show p) :
  133. show (map fst nodes) :
  134. map show (M.toList nodeToPos)
  135. --map show (M.toList nodeToParents)
  136. in map (map lookupP) $ computeIdx M.empty $ map fst parents
  137. nInput = map length inputPos
  138. nOutput = map snd parents
  139. nodeToPos' =
  140. let outputPos = let i = scanl1 (+) nOutput
  141. in zipWith (\a b -> [a .. b-1]) (0:i) i
  142. m = fmap (map (+(sum nOutput))) $
  143. M.mapMaybe (adjustIdx $ concat inputPos) nodeToPos
  144. in foldl (\x (k, v) -> M.insert k v x) m $ zip (map fst nodes) outputPos
  145. fun =
  146. let combinedF = combineArrows $
  147. flip map (zip3 nodes nInput nOutput) $ \((_, f), ni, no) ->
  148. (ni, [| $f >>> $(replicateOutput 1 no) |], no)
  149. in selectInput nVar (sum nOutput) (concat inputPos) combinedF
  150. nVar' = nVar - (sum nInput) + (sum nOutput)
  151. in (fun:acc, nodeToPos', nVar')
  152. gr = let edges = flip concatMap (M.toList $ _parents wf) $ \(x, ps) ->
  153. let x' = M.lookupDefault (error $ show x) x nodeToId
  154. in flip map ps $ \p -> (M.lookupDefault (error $ show p) p nodeToId, x', ())
  155. in G.mkGraph nodes edges :: G.Gr (T.Text, Node) ()
  156. nodes = zip [0..] $ M.toList $ _nodes wf
  157. nodeToId = M.fromList $ map (\(i, (x, _)) -> (x, i)) nodes
  158. {-# INLINE compileWorkflow #-}
  159. groupSortNodes :: Workflow -> [[(T.Text, ExpQ)]]
  160. groupSortNodes wf = go [] $ G.topsort' gr
  161. where
  162. go acc [] = [acc]
  163. go [] (x:xs) = go [x] xs
  164. go acc (x:xs) | any (x `isChildOf`) acc = acc : go [x] xs
  165. | otherwise = go (acc <> [x]) xs
  166. isChildOf x y = gr `G.hasEdge` (hash $ fst y, hash $ fst x)
  167. gr = let edges = flip concatMap (M.toList $ _parents wf) $ \(x, ps) ->
  168. flip map ps $ \p -> (hash p, hash x, ())
  169. in G.mkGraph nodes edges :: G.Gr (T.Text, ExpQ) ()
  170. nodes = map (\(k, x) -> (hash k, (k, mkJob k x))) $ M.toList $ _nodes wf
  171. {-# INLINE groupSortNodes #-}