/Retrie/Rewrites/Function.hs

https://github.com/facebookincubator/retrie · Haskell · 168 lines · 137 code · 17 blank · 14 comment · 1 complexity · 52a137ecc60968ea3e54a3af2c6ba251 MD5 · raw file

  1. -- Copyright (c) Facebook, Inc. and its affiliates.
  2. --
  3. -- This source code is licensed under the MIT license found in the
  4. -- LICENSE file in the root directory of this source tree.
  5. --
  6. {-# LANGUAGE CPP #-}
  7. {-# LANGUAGE TupleSections #-}
  8. module Retrie.Rewrites.Function
  9. ( dfnsToRewrites
  10. , getImports
  11. , matchToRewrites
  12. ) where
  13. import Control.Monad
  14. import Control.Monad.State.Lazy
  15. import Data.List
  16. import Data.Maybe
  17. import Data.Traversable
  18. import Retrie.ExactPrint
  19. import Retrie.Expr
  20. import Retrie.GHC
  21. import Retrie.Quantifiers
  22. import Retrie.Types
  23. dfnsToRewrites
  24. :: [(FastString, Direction)]
  25. -> AnnotatedModule
  26. -> IO (UniqFM [Rewrite (LHsExpr GhcPs)])
  27. dfnsToRewrites specs am = fmap astA $ transformA am $ \ (L _ m) -> do
  28. let
  29. fsMap = uniqBag specs
  30. rrs <- sequence
  31. [ do
  32. fe <- mkLocatedHsVar fRdrName
  33. imps <- getImports dir (hsmodName m)
  34. (fName,) . concat <$>
  35. forM (unLoc $ mg_alts $ fun_matches f) (matchToRewrites fe imps dir)
  36. #if __GLASGOW_HASKELL__ < 806
  37. | L _ (ValD f@FunBind{}) <- hsmodDecls m
  38. #else
  39. | L _ (ValD _ f@FunBind{}) <- hsmodDecls m
  40. #endif
  41. , let fRdrName = fun_id f
  42. , let fName = occNameFS (occName (unLoc fRdrName))
  43. , dir <- fromMaybe [] (lookupUFM fsMap fName)
  44. ]
  45. return $ listToUFM_C (++) rrs
  46. ------------------------------------------------------------------------
  47. getImports
  48. :: Direction -> Maybe (Located ModuleName) -> TransformT IO AnnotatedImports
  49. getImports RightToLeft (Just (L _ mn)) = -- See Note [fold only]
  50. lift $ liftIO $ parseImports ["import " ++ moduleNameString mn]
  51. getImports _ _ = return mempty
  52. matchToRewrites
  53. :: LHsExpr GhcPs
  54. -> AnnotatedImports
  55. -> Direction
  56. -> LMatch GhcPs (LHsExpr GhcPs)
  57. -> TransformT IO [Rewrite (LHsExpr GhcPs)]
  58. matchToRewrites e imps dir (L _ alt) = do
  59. let
  60. pats = m_pats alt
  61. grhss = m_grhss alt
  62. qss <- for (zip (inits pats) (tails pats)) $
  63. makeFunctionQuery e imps dir grhss mkApps
  64. qs <- backtickRules e imps dir grhss pats
  65. return $ qs ++ concat qss
  66. type AppBuilder =
  67. LHsExpr GhcPs -> [LHsExpr GhcPs] -> TransformT IO (LHsExpr GhcPs)
  68. irrefutablePat :: LPat GhcPs -> Bool
  69. irrefutablePat = go . unLoc
  70. where
  71. go WildPat{} = True
  72. go VarPat{} = True
  73. #if __GLASGOW_HASKELL__ < 806
  74. go (LazyPat p) = irrefutablePat p
  75. go (AsPat _ p) = irrefutablePat p
  76. go (ParPat p) = irrefutablePat p
  77. go (BangPat p) = irrefutablePat p
  78. #else
  79. go (LazyPat _ p) = irrefutablePat p
  80. go (AsPat _ _ p) = irrefutablePat p
  81. go (ParPat _ p) = irrefutablePat p
  82. go (BangPat _ p) = irrefutablePat p
  83. #endif
  84. go _ = False
  85. makeFunctionQuery
  86. :: LHsExpr GhcPs
  87. -> AnnotatedImports
  88. -> Direction
  89. -> GRHSs GhcPs (LHsExpr GhcPs)
  90. -> AppBuilder
  91. -> ([LPat GhcPs], [LPat GhcPs])
  92. -> TransformT IO [Rewrite (LHsExpr GhcPs)]
  93. makeFunctionQuery e imps dir grhss mkAppFn (argpats, bndpats)
  94. | any (not . irrefutablePat) bndpats = return []
  95. | otherwise = do
  96. let
  97. #if __GLASGOW_HASKELL__ < 806
  98. GRHSs rhss lbs = grhss
  99. #else
  100. GRHSs _ rhss lbs = grhss
  101. #endif
  102. bs = collectPatsBinders argpats
  103. -- See Note [Wildcards]
  104. (es,(_,bs')) <- runStateT (mapM patToExpr argpats) (wildSupply bs, bs)
  105. lhs <- mkAppFn e es
  106. for rhss $ \ grhs -> do
  107. le <- mkLet (unLoc lbs) (grhsToExpr grhs)
  108. rhs <- mkLams bndpats le
  109. let
  110. (pat, temp) =
  111. case dir of
  112. LeftToRight -> (lhs,rhs)
  113. RightToLeft -> (rhs,lhs)
  114. p <- pruneA pat
  115. t <- pruneA temp
  116. return $ addRewriteImports imps $ mkRewrite (mkQs bs') p t
  117. backtickRules
  118. :: LHsExpr GhcPs
  119. -> AnnotatedImports
  120. -> Direction
  121. -> GRHSs GhcPs (LHsExpr GhcPs)
  122. -> [LPat GhcPs]
  123. -> TransformT IO [Rewrite (LHsExpr GhcPs)]
  124. backtickRules e imps dir@LeftToRight grhss ps@[p1, p2] = do
  125. let
  126. both, left, right :: AppBuilder
  127. #if __GLASGOW_HASKELL__ < 806
  128. both op [l, r] = mkLoc (OpApp l op PlaceHolder r)
  129. both _ _ = fail "backtickRules - both: impossible!"
  130. left op [l] = mkLoc (SectionL l op)
  131. left _ _ = fail "backtickRules - left: impossible!"
  132. right op [r] = mkLoc (SectionR op r)
  133. right _ _ = fail "backtickRules - right: impossible!"
  134. #else
  135. both op [l, r] = mkLoc (OpApp noExtField l op r)
  136. both _ _ = fail "backtickRules - both: impossible!"
  137. left op [l] = mkLoc (SectionL noExtField l op)
  138. left _ _ = fail "backtickRules - left: impossible!"
  139. right op [r] = mkLoc (SectionR noExtField op r)
  140. right _ _ = fail "backtickRules - right: impossible!"
  141. #endif
  142. qs <- makeFunctionQuery e imps dir grhss both (ps, [])
  143. qsl <- makeFunctionQuery e imps dir grhss left ([p1], [p2])
  144. qsr <- makeFunctionQuery e imps dir grhss right ([p2], [p1])
  145. return $ qs ++ qsl ++ qsr
  146. backtickRules _ _ _ _ _ = return []
  147. -- Note [fold only]
  148. -- Currently we only generate imports for folds, because it is easy.
  149. -- (We only need to add an import for the module defining the folded
  150. -- function.) Generating the imports for unfolds will require some
  151. -- sort of analysis with haskell-names and is a TODO.