PageRenderTime 43ms CodeModel.GetById 14ms RepoModel.GetById 0ms app.codeStats 0ms

/compiler/vectorise/Vectorise/Type/Env.hs

https://github.com/kgardas/ghc
Haskell | 190 lines | 137 code | 35 blank | 18 comment | 0 complexity | c31f42f91279697c6ed3dd64716044de MD5 | raw file
  1. {-# OPTIONS_GHC -XNoMonoLocalBinds #-}
  2. module Vectorise.Type.Env (
  3. vectTypeEnv,
  4. ) where
  5. import Vectorise.Env
  6. import Vectorise.Vect
  7. import Vectorise.Monad
  8. import Vectorise.Builtins
  9. import Vectorise.Type.TyConDecl
  10. import Vectorise.Type.Classify
  11. import Vectorise.Type.PADict
  12. import Vectorise.Type.PData
  13. import Vectorise.Type.PRepr
  14. import Vectorise.Type.Repr
  15. import Vectorise.Utils
  16. import HscTypes
  17. import CoreSyn
  18. import CoreUtils
  19. import CoreUnfold
  20. import DataCon
  21. import TyCon
  22. import Type
  23. import FamInstEnv
  24. import OccName
  25. import Id
  26. import MkId
  27. import NameEnv
  28. import Unique
  29. import UniqFM
  30. import Util
  31. import Outputable
  32. import FastString
  33. import MonadUtils
  34. import Control.Monad
  35. import Data.List
  36. -- | Vectorise a type environment.
  37. -- The type environment contains all the type things defined in a module.
  38. --
  39. vectTypeEnv :: TypeEnv
  40. -> VM ( TypeEnv -- Vectorised type environment.
  41. , [FamInst] -- New type family instances.
  42. , [(Var, CoreExpr)]) -- New top level bindings.
  43. vectTypeEnv env
  44. = do
  45. traceVt "** vectTypeEnv" $ ppr env
  46. cs <- readGEnv $ mk_map . global_tycons
  47. -- Split the list of TyCons into the ones we have to vectorise vs the
  48. -- ones we can pass through unchanged. We also pass through algebraic
  49. -- types that use non Haskell98 features, as we don't handle those.
  50. let tycons = typeEnvTyCons env
  51. groups = tyConGroups tycons
  52. let (conv_tcs, keep_tcs) = classifyTyCons cs groups
  53. orig_tcs = keep_tcs ++ conv_tcs
  54. keep_dcs = concatMap tyConDataCons keep_tcs
  55. -- Just use the unvectorised versions of these constructors in vectorised code.
  56. zipWithM_ defTyCon keep_tcs keep_tcs
  57. zipWithM_ defDataCon keep_dcs keep_dcs
  58. -- Vectorise all the declarations.
  59. new_tcs <- vectTyConDecls conv_tcs
  60. -- We don't need to make new representation types for dictionary
  61. -- constructors. The constructors are always fully applied, and we don't
  62. -- need to lift them to arrays as a dictionary of a particular type
  63. -- always has the same value.
  64. let vect_tcs = filter (not . isClassTyCon)
  65. $ keep_tcs ++ new_tcs
  66. reprs <- mapM tyConRepr vect_tcs
  67. repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
  68. pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
  69. updGEnv $ extendFamEnv
  70. $ map mkLocalFamInst
  71. $ repr_tcs ++ pdata_tcs
  72. -- Create PRepr and PData instances for the vectorised types.
  73. -- We get back the binds for the instance functions,
  74. -- and some new type constructors for the representation types.
  75. (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
  76. do
  77. defTyConPAs (zipLazy vect_tcs dfuns')
  78. reprs <- mapM tyConRepr vect_tcs
  79. dfuns <- sequence
  80. $ zipWith5 buildTyConBindings
  81. orig_tcs
  82. vect_tcs
  83. repr_tcs
  84. pdata_tcs
  85. reprs
  86. binds <- takeHoisted
  87. return (dfuns, binds, repr_tcs ++ pdata_tcs)
  88. -- The new type constructors are the vectorised versions of the originals,
  89. -- plus the new type constructors that we use for the representations.
  90. let all_new_tcs = new_tcs ++ inst_tcs
  91. let new_env = extendTypeEnvList env
  92. $ map ATyCon all_new_tcs
  93. ++ [ADataCon dc | tc <- all_new_tcs
  94. , dc <- tyConDataCons tc]
  95. return (new_env, map mkLocalFamInst inst_tcs, binds)
  96. where
  97. mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
  98. buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
  99. buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
  100. = do vectDataConWorkers orig_tc vect_tc pdata_tc
  101. buildPADict vect_tc prepr_tc pdata_tc repr
  102. vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
  103. vectDataConWorkers orig_tc vect_tc arr_tc
  104. = do bs <- sequence
  105. . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
  106. $ zipWith4 mk_data_con (tyConDataCons vect_tc)
  107. rep_tys
  108. (inits rep_tys)
  109. (tail $ tails rep_tys)
  110. mapM_ (uncurry hoistBinding) bs
  111. where
  112. tyvars = tyConTyVars vect_tc
  113. var_tys = mkTyVarTys tyvars
  114. ty_args = map Type var_tys
  115. res_ty = mkTyConApp vect_tc var_tys
  116. cons = tyConDataCons vect_tc
  117. arity = length cons
  118. [arr_dc] = tyConDataCons arr_tc
  119. rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
  120. mk_data_con con tys pre post
  121. = liftM2 (,) (vect_data_con con)
  122. (lift_data_con tys pre post (mkDataConTag con))
  123. sel_replicate len tag
  124. | arity > 1 = do
  125. rep <- builtin (selReplicate arity)
  126. return [rep `mkApps` [len, tag]]
  127. | otherwise = return []
  128. vect_data_con con = return $ mkConApp con ty_args
  129. lift_data_con tys pre_tys post_tys tag
  130. = do
  131. len <- builtin liftingContext
  132. args <- mapM (newLocalVar (fsLit "xs"))
  133. =<< mapM mkPDataType tys
  134. sel <- sel_replicate (Var len) tag
  135. pre <- mapM emptyPD (concat pre_tys)
  136. post <- mapM emptyPD (concat post_tys)
  137. return . mkLams (len : args)
  138. . wrapFamInstBody arr_tc var_tys
  139. . mkConApp arr_dc
  140. $ ty_args ++ sel ++ pre ++ map Var args ++ post
  141. def_worker data_con arg_tys mk_body
  142. = do
  143. arity <- polyArity tyvars
  144. body <- closedV
  145. . inBind orig_worker
  146. . polyAbstract tyvars $ \args ->
  147. liftM (mkLams (tyvars ++ args) . vectorised)
  148. $ buildClosures tyvars [] arg_tys res_ty mk_body
  149. raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
  150. let vect_worker = raw_worker `setIdUnfolding`
  151. mkInlineUnfolding (Just arity) body
  152. defGlobalVar orig_worker vect_worker
  153. return (vect_worker, body)
  154. where
  155. orig_worker = dataConWorkId data_con