PageRenderTime 23ms CodeModel.GetById 23ms RepoModel.GetById 0ms app.codeStats 0ms

/compiler/vectorise/Vectorise/Type/Env.hs

https://github.com/albertz/ghc
Haskell | 191 lines | 138 code | 35 blank | 18 comment | 0 complexity | 041c173867e51db6cadc9d3b37c258df MD5 | raw file
  1. {-# OPTIONS_GHC -XNoMonoLocalBinds #-}
  2. module Vectorise.Type.Env (
  3. vectTypeEnv,
  4. ) where
  5. import Vectorise.Env
  6. import Vectorise.Vect
  7. import Vectorise.Monad
  8. import Vectorise.Builtins
  9. import Vectorise.Type.TyConDecl
  10. import Vectorise.Type.Classify
  11. import Vectorise.Type.PADict
  12. import Vectorise.Type.PData
  13. import Vectorise.Type.PRepr
  14. import Vectorise.Type.Repr
  15. import Vectorise.Utils
  16. import HscTypes
  17. import CoreSyn
  18. import CoreUtils
  19. import CoreUnfold
  20. import DataCon
  21. import TyCon
  22. import Type
  23. import FamInstEnv
  24. import OccName
  25. import Id
  26. import MkId
  27. import Var
  28. import NameEnv
  29. import Unique
  30. import UniqFM
  31. import Util
  32. import Outputable
  33. import FastString
  34. import MonadUtils
  35. import Control.Monad
  36. import Data.List
  37. -- | Vectorise a type environment.
  38. -- The type environment contains all the type things defined in a module.
  39. --
  40. vectTypeEnv :: TypeEnv
  41. -> VM ( TypeEnv -- Vectorised type environment.
  42. , [FamInst] -- New type family instances.
  43. , [(Var, CoreExpr)]) -- New top level bindings.
  44. vectTypeEnv env
  45. = do
  46. traceVt "** vectTypeEnv" $ ppr env
  47. cs <- readGEnv $ mk_map . global_tycons
  48. -- Split the list of TyCons into the ones we have to vectorise vs the
  49. -- ones we can pass through unchanged. We also pass through algebraic
  50. -- types that use non Haskell98 features, as we don't handle those.
  51. let tycons = typeEnvTyCons env
  52. groups = tyConGroups tycons
  53. let (conv_tcs, keep_tcs) = classifyTyCons cs groups
  54. orig_tcs = keep_tcs ++ conv_tcs
  55. keep_dcs = concatMap tyConDataCons keep_tcs
  56. -- Just use the unvectorised versions of these constructors in vectorised code.
  57. zipWithM_ defTyCon keep_tcs keep_tcs
  58. zipWithM_ defDataCon keep_dcs keep_dcs
  59. -- Vectorise all the declarations.
  60. new_tcs <- vectTyConDecls conv_tcs
  61. -- We don't need to make new representation types for dictionary
  62. -- constructors. The constructors are always fully applied, and we don't
  63. -- need to lift them to arrays as a dictionary of a particular type
  64. -- always has the same value.
  65. let vect_tcs = filter (not . isClassTyCon)
  66. $ keep_tcs ++ new_tcs
  67. reprs <- mapM tyConRepr vect_tcs
  68. repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
  69. pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
  70. updGEnv $ extendFamEnv
  71. $ map mkLocalFamInst
  72. $ repr_tcs ++ pdata_tcs
  73. -- Create PRepr and PData instances for the vectorised types.
  74. -- We get back the binds for the instance functions,
  75. -- and some new type constructors for the representation types.
  76. (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
  77. do
  78. defTyConPAs (zipLazy vect_tcs dfuns')
  79. reprs <- mapM tyConRepr vect_tcs
  80. dfuns <- sequence
  81. $ zipWith5 buildTyConBindings
  82. orig_tcs
  83. vect_tcs
  84. repr_tcs
  85. pdata_tcs
  86. reprs
  87. binds <- takeHoisted
  88. return (dfuns, binds, repr_tcs ++ pdata_tcs)
  89. -- The new type constructors are the vectorised versions of the originals,
  90. -- plus the new type constructors that we use for the representations.
  91. let all_new_tcs = new_tcs ++ inst_tcs
  92. let new_env = extendTypeEnvList env
  93. $ map ATyCon all_new_tcs
  94. ++ [ADataCon dc | tc <- all_new_tcs
  95. , dc <- tyConDataCons tc]
  96. return (new_env, map mkLocalFamInst inst_tcs, binds)
  97. where
  98. mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
  99. buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
  100. buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
  101. = do vectDataConWorkers orig_tc vect_tc pdata_tc
  102. buildPADict vect_tc prepr_tc pdata_tc repr
  103. vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
  104. vectDataConWorkers orig_tc vect_tc arr_tc
  105. = do bs <- sequence
  106. . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
  107. $ zipWith4 mk_data_con (tyConDataCons vect_tc)
  108. rep_tys
  109. (inits rep_tys)
  110. (tail $ tails rep_tys)
  111. mapM_ (uncurry hoistBinding) bs
  112. where
  113. tyvars = tyConTyVars vect_tc
  114. var_tys = mkTyVarTys tyvars
  115. ty_args = map Type var_tys
  116. res_ty = mkTyConApp vect_tc var_tys
  117. cons = tyConDataCons vect_tc
  118. arity = length cons
  119. [arr_dc] = tyConDataCons arr_tc
  120. rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
  121. mk_data_con con tys pre post
  122. = liftM2 (,) (vect_data_con con)
  123. (lift_data_con tys pre post (mkDataConTag con))
  124. sel_replicate len tag
  125. | arity > 1 = do
  126. rep <- builtin (selReplicate arity)
  127. return [rep `mkApps` [len, tag]]
  128. | otherwise = return []
  129. vect_data_con con = return $ mkConApp con ty_args
  130. lift_data_con tys pre_tys post_tys tag
  131. = do
  132. len <- builtin liftingContext
  133. args <- mapM (newLocalVar (fsLit "xs"))
  134. =<< mapM mkPDataType tys
  135. sel <- sel_replicate (Var len) tag
  136. pre <- mapM emptyPD (concat pre_tys)
  137. post <- mapM emptyPD (concat post_tys)
  138. return . mkLams (len : args)
  139. . wrapFamInstBody arr_tc var_tys
  140. . mkConApp arr_dc
  141. $ ty_args ++ sel ++ pre ++ map Var args ++ post
  142. def_worker data_con arg_tys mk_body
  143. = do
  144. arity <- polyArity tyvars
  145. body <- closedV
  146. . inBind orig_worker
  147. . polyAbstract tyvars $ \args ->
  148. liftM (mkLams (tyvars ++ args) . vectorised)
  149. $ buildClosures tyvars [] arg_tys res_ty mk_body
  150. raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
  151. let vect_worker = raw_worker `setIdUnfolding`
  152. mkInlineUnfolding (Just arity) body
  153. defGlobalVar orig_worker vect_worker
  154. return (vect_worker, body)
  155. where
  156. orig_worker = dataConWorkId data_con