PageRenderTime 72ms CodeModel.GetById 27ms RepoModel.GetById 1ms app.codeStats 1ms

/Lambdabot/Pointful.hs

https://github.com/23Skidoo/pointful
Haskell | 356 lines | 253 code | 45 blank | 58 comment | 10 complexity | 7c21e3ec03509ae9c7204d7d8bbc067d MD5 | raw file
Possible License(s): BSD-3-Clause
  1. {-# LANGUAGE LambdaCase #-}
  2. {-# LANGUAGE PatternSynonyms #-}
  3. {-# LANGUAGE ScopedTypeVariables #-}
  4. {-# LANGUAGE ViewPatterns #-}
  5. -- Undo pointfree transformations. Plugin code derived from Pl.hs.
  6. module Lambdabot.Pointful (pointful) where
  7. import Lambdabot.Parser (withParsed)
  8. import Prelude hiding (sum, exp)
  9. import Control.Monad.Reader
  10. import Control.Monad.State
  11. import Data.Generics
  12. import qualified Data.Set as S
  13. import qualified Data.Map as M
  14. import Data.List hiding (sum)
  15. import Data.Maybe
  16. import Language.Haskell.Exts.Simple as Hs hiding (alt, name, var)
  17. ---- Utilities ----
  18. stabilize :: Eq a => (a -> a) -> a -> a
  19. stabilize f x = let x' = f x in if x' == x then x else stabilize f x'
  20. -- varsBoundHere returns variables bound by top patterns or binders
  21. varsBoundHere :: Data d => d -> S.Set Name
  22. varsBoundHere (cast -> Just (PVar name)) = S.singleton name
  23. varsBoundHere (cast -> Just (Match name _ _ _)) = S.singleton name
  24. varsBoundHere (cast -> Just (PatBind pat _ _)) = varsBoundHere pat
  25. varsBoundHere (cast -> Just (_ :: Exp)) = S.empty
  26. varsBoundHere d = S.unions
  27. (gmapQ varsBoundHere d)
  28. -- note: the tempting idea of using a pattern synonym for the frequent
  29. -- (cast -> Just _) patterns causes compiler crashes with ghc before
  30. -- version 8; cf. https://ghc.haskell.org/trac/ghc/ticket/11336
  31. foldFreeVars
  32. :: forall a d. Data d
  33. => (Name -> S.Set Name -> a) -> ([a] -> a) -> d -> a
  34. foldFreeVars var sum e = runReader (go e) S.empty where
  35. go :: forall d'. Data d' => d' -> Reader (S.Set Name) a
  36. go (cast -> Just (Var (UnQual name))) =
  37. asks (var name)
  38. go (cast -> Just (Lambda ps exp)) =
  39. bind [varsBoundHere ps] $ go exp
  40. go (cast -> Just (Let bs exp)) =
  41. bind [varsBoundHere bs] $ collect [go bs, go exp]
  42. go (cast -> Just (Alt pat exp bs)) =
  43. bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
  44. go (cast -> Just (PatBind pat exp bs)) =
  45. bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
  46. go (cast -> Just (Match _ ps exp bs)) =
  47. bind [varsBoundHere ps, varsBoundHere bs] $ collect [go exp, go bs]
  48. go d = collect (gmapQ go d)
  49. collect :: forall m. Monad m => [m a] -> m a
  50. collect ms = sum `liftM` sequence ms
  51. bind
  52. :: forall a' b. Ord a'
  53. => [S.Set a'] -> Reader (S.Set a') b -> Reader (S.Set a') b
  54. bind ss = local (S.unions ss `S.union`)
  55. -- return free variables
  56. freeVars :: Data d => d -> S.Set Name
  57. freeVars =
  58. foldFreeVars (\name bv -> S.singleton name `S.difference` bv) S.unions
  59. -- return number of free occurrences of a variable
  60. countOcc :: Data d => Name -> d -> Int
  61. countOcc name = foldFreeVars var sum where
  62. sum = foldl' (+) 0
  63. var name' bv = if name /= name' || name' `S.member` bv then 0 else 1
  64. -- variable capture avoiding substitution
  65. substAvoiding :: Data d => M.Map Name Exp -> S.Set Name -> d -> d
  66. substAvoiding subst bv =
  67. base `extT` exp `extT` alt `extT` decl `extT` match
  68. where
  69. base :: Data d => d -> d
  70. base = gmapT (substAvoiding subst bv)
  71. exp e@(Var (UnQual name)) =
  72. fromMaybe e (M.lookup name subst)
  73. exp (Lambda ps exp') =
  74. let (subst', bv', ps') = renameBinds subst bv ps
  75. in Lambda ps' (substAvoiding subst' bv' exp')
  76. exp (Let bs exp') =
  77. let (subst', bv', bs') = renameBinds subst bv bs
  78. in Let (substAvoiding subst' bv' bs') (substAvoiding subst' bv' exp')
  79. exp d = base d
  80. alt (Alt pat exp' bs) =
  81. let (subst1, bv1, pat') = renameBinds subst bv pat
  82. (subst', bv', bs') = renameBinds subst1 bv1 bs
  83. in Alt pat'
  84. (substAvoiding subst' bv' exp') (substAvoiding subst' bv' bs')
  85. alt _ = error "unexpected"
  86. decl (PatBind pat exp' bs) =
  87. let (subst', bv', bs') = renameBinds subst bv bs
  88. in PatBind
  89. pat (substAvoiding subst' bv' exp') (substAvoiding subst' bv' bs')
  90. decl d = base d
  91. match (Match name ps exp' bs) =
  92. let (subst1, bv1, ps') = renameBinds subst bv ps
  93. (subst', bv', bs') = renameBinds subst1 bv1 bs
  94. in Match name ps'
  95. (substAvoiding subst' bv' exp') (substAvoiding subst' bv' bs')
  96. match _ = error "unexpected"
  97. -- rename local binders (but not the nested expressions)
  98. renameBinds
  99. :: Data d
  100. => M.Map Name Exp -> S.Set Name -> d
  101. -> (M.Map Name Exp, S.Set Name, d)
  102. renameBinds subst bv d = (subst', bv', d') where
  103. (d', (subst', bv', _)) = runState (go d) (subst, bv, M.empty)
  104. go, base
  105. :: Data d
  106. => d -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) d
  107. go = base `extM` pat `extM` match `extM` decl `extM` exp
  108. base d'' = gmapM go d''
  109. pat (PVar name) = PVar `fmap` rename name
  110. pat d'' = base d''
  111. match (Match name ps exp' bs) = do
  112. name' <- rename name
  113. return $ Match name' ps exp' bs
  114. match _ = error "unexpected"
  115. decl (PatBind pat' exp' bs) = do
  116. pat'' <- go pat'
  117. return $ PatBind pat'' exp' bs
  118. decl d'' = base d''
  119. exp (e :: Exp) = return e
  120. rename :: Name -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) Name
  121. rename name = do
  122. (subst'', bv'', ass) <- get
  123. case (name `M.lookup` ass, name `S.member` bv'') of
  124. (Just name', _) -> do
  125. return name'
  126. (_, False) -> do
  127. put (M.delete name subst'', S.insert name bv'', ass)
  128. return name
  129. _ -> do
  130. let name' = freshNameAvoiding name bv''
  131. put (M.insert name (Var (UnQual name')) subst'',
  132. S.insert name' bv'', M.insert name name' ass)
  133. return name'
  134. -- generate fresh names
  135. freshNameAvoiding :: Name -> S.Set Name -> Name
  136. freshNameAvoiding name forbidden = con (pre ++ suf) where
  137. (con, nm, cs) = case name of
  138. Ident n -> (Ident, n, "0123456789")
  139. Symbol n -> (Symbol, n, "?#")
  140. _ -> error "unexpected"
  141. pre = reverse . dropWhile (`elem` cs) . reverse $ nm
  142. sufs = [1..] >>= flip replicateM cs
  143. suf = head $ dropWhile (\suff -> con (pre ++ suff) `S.member` forbidden)
  144. sufs
  145. ---- Optimization (removing explicit lambdas) and restoration of infix ops ----
  146. -- move lambda patterns into LHS
  147. optimizeD :: Decl -> Decl
  148. optimizeD (PatBind (PVar fname) (UnGuardedRhs (Lambda pats rhs)) Nothing) =
  149. let (subst, bv, pats') = renameBinds M.empty (S.singleton fname) pats
  150. rhs' = substAvoiding subst bv rhs
  151. in FunBind [Match fname pats' (UnGuardedRhs rhs') Nothing]
  152. ---- combine function binding and lambda
  153. optimizeD (FunBind
  154. [Match fname pats1 (UnGuardedRhs (Lambda pats2 rhs)) Nothing]) =
  155. let (subst, bv, pats2') = renameBinds M.empty (varsBoundHere pats1) pats2
  156. rhs' = substAvoiding subst bv rhs
  157. in FunBind [Match fname (pats1 ++ pats2') (UnGuardedRhs rhs') Nothing]
  158. optimizeD x = x
  159. -- remove parens
  160. optimizeRhs :: Rhs -> Rhs
  161. optimizeRhs (UnGuardedRhs (Paren x)) = UnGuardedRhs x
  162. optimizeRhs x = x
  163. optimizeE :: Exp -> Exp
  164. -- apply ((\x z -> ...x...) y) yielding (\z -> ...y...) if there is
  165. -- only one x or y is simple
  166. optimizeE (App (Lambda (PVar ident : pats) body) arg) | single || simple arg =
  167. let (subst, bv, pats') =
  168. renameBinds (M.singleton ident arg) (freeVars arg) pats
  169. in Paren (Lambda pats' (substAvoiding subst bv body))
  170. where
  171. single = countOcc ident body <= 1
  172. simple = \case Var _ -> True
  173. Lit _ -> True
  174. Paren e' -> simple e'
  175. _ -> False
  176. -- apply ((\_ z -> ...) y) yielding (\z -> ...)
  177. optimizeE (App (Lambda (PWildCard : pats) body) _) =
  178. Paren (Lambda pats body)
  179. -- remove 0-arg lambdas resulting from application rules
  180. optimizeE (Lambda [] b) =
  181. b
  182. -- replace (\x -> \y -> z) with (\x y -> z)
  183. optimizeE (Lambda p1 (Lambda p2 body)) =
  184. let (subst, bv, p2') = renameBinds M.empty (varsBoundHere p1) p2
  185. body' = substAvoiding subst bv body
  186. in Lambda (p1 ++ p2') body'
  187. -- remove double parens
  188. optimizeE (Paren (Paren x)) =
  189. Paren x
  190. -- remove parens around applied lambdas (the pretty printer restores them)
  191. optimizeE (App (Paren (x@Lambda{})) y) =
  192. App x y
  193. -- remove lambda body parens
  194. optimizeE (Lambda p (Paren x)) =
  195. Lambda p x
  196. -- remove var, lit parens
  197. optimizeE (Paren x@(Var _)) =
  198. x
  199. optimizeE (Paren x@(Lit _)) =
  200. x
  201. -- remove infix+lambda parens
  202. optimizeE (InfixApp a o (Paren l@(Lambda _ _))) =
  203. InfixApp a o l
  204. -- remove infix+app aprens
  205. optimizeE (InfixApp (Paren a@App{}) o l) =
  206. InfixApp a o l
  207. optimizeE (InfixApp a o (Paren l@App{})) =
  208. InfixApp a o l
  209. -- remove left-assoc application parens
  210. optimizeE (App (Paren (App a b)) c) =
  211. App (App a b) c
  212. -- restore infix
  213. optimizeE (App (App (Var name'@(UnQual (Symbol _))) l) r) =
  214. (InfixApp l (QVarOp name') r)
  215. -- eta reduce
  216. optimizeE (Lambda ps@(_:_) (App e (Var (UnQual v))))
  217. | free && last ps == PVar v = Lambda (init ps) e
  218. where free = countOcc v e == 0
  219. -- fail
  220. optimizeE x = x
  221. ---- Decombinatorization ----
  222. uncomb' :: Exp -> Exp
  223. uncomb' (Paren (Paren e)) = Paren e
  224. -- eliminate sections
  225. uncomb' (RightSection op' arg) =
  226. let a = freshNameAvoiding (Ident "a") (freeVars arg)
  227. in (Paren (Lambda [PVar a] (InfixApp (Var (UnQual a)) op' arg)))
  228. uncomb' (LeftSection arg op') =
  229. let a = freshNameAvoiding (Ident "a") (freeVars arg)
  230. in (Paren (Lambda [PVar a] (InfixApp arg op' (Var (UnQual a)))))
  231. -- infix to prefix for canonicality
  232. uncomb' (InfixApp lf (QVarOp name') rf) =
  233. (Paren (App (App (Var name') (Paren lf)) (Paren rf)))
  234. -- Expand (>>=) when it is obviously the reader monad:
  235. -- rewrite: (>>=) (\x -> e)
  236. -- to: (\ a b -> a ((\ x -> e) b) b)
  237. uncomb' (App (Var (UnQual (Symbol ">>="))) (Paren lam@Lambda{})) =
  238. let a = freshNameAvoiding (Ident "a") (freeVars lam)
  239. b = freshNameAvoiding (Ident "b") (freeVars lam)
  240. in (Paren (Lambda [PVar a, PVar b]
  241. (App (App (Var (UnQual a))
  242. (Paren (App lam (Var (UnQual b))))) (Var (UnQual b)))))
  243. -- rewrite: ((>>=) e1) (\x y -> e2)
  244. -- to: (\a -> (\x y -> e2) (e1 a) a)
  245. uncomb' (App (App (Var (UnQual (Symbol ">>="))) e1)
  246. (Paren lam@(Lambda (_:_:_) _))) =
  247. let a = freshNameAvoiding (Ident "a") (freeVars [e1,lam])
  248. in (Paren (Lambda [PVar a]
  249. (App (App lam (App e1 (Var (UnQual a)))) (Var (UnQual a)))))
  250. -- fail
  251. uncomb' expr = expr
  252. ---- Simple combinator definitions ---
  253. combinators :: M.Map Name Exp
  254. combinators = M.fromList $ map declToTuple defs
  255. where defs = case parseModule combinatorModule of
  256. ParseOk (Hs.Module _ _ _ d) -> d
  257. ParseOk _ -> error "unexpected"
  258. f@(ParseFailed _ _) -> error
  259. ("Combinator loading: " ++ show f)
  260. declToTuple (PatBind (PVar fname) (UnGuardedRhs body) Nothing)
  261. = (fname, Paren body)
  262. declToTuple _
  263. = error
  264. "Pointful Plugin error: can't convert declaration to tuple"
  265. combinatorModule :: String
  266. combinatorModule = unlines [
  267. "(.) = \\f g x -> f (g x) ",
  268. "($) = \\f x -> f x ",
  269. "flip = \\f x y -> f y x ",
  270. "const = \\x _ -> x ",
  271. "id = \\x -> x ",
  272. "(=<<) = flip (>>=) ",
  273. "liftM2 = \\f m1 m2 -> m1 >>= \\x1 -> m2 >>= \\x2 -> return (f x1 x2) ",
  274. "join = (>>= id) ",
  275. "ap = liftM2 id ",
  276. "(>=>) = flip (<=<) ",
  277. "(<=<) = \\f g x -> f >>= g x ",
  278. " ",
  279. "-- ASSUMED reader monad ",
  280. "-- (>>=) = (\\f k r -> k (f r) r) ",
  281. "-- return = const ",
  282. ""]
  283. ---- Top level ----
  284. unfoldCombinators :: (Data a) => a -> a
  285. unfoldCombinators = substAvoiding combinators (freeVars combinators)
  286. uncombOnce :: (Data a) => a -> a
  287. uncombOnce x = everywhere (mkT uncomb') x
  288. uncomb :: (Eq a, Data a) => a -> a
  289. uncomb = stabilize uncombOnce
  290. optimizeOnce :: (Data a) => a -> a
  291. optimizeOnce x = everywhere
  292. (mkT optimizeD `extT` optimizeRhs `extT` optimizeE) x
  293. optimize :: (Eq a, Data a) => a -> a
  294. optimize = stabilize optimizeOnce
  295. pointful :: String -> Either String String
  296. pointful =
  297. withParsed
  298. (stabilize (optimize . uncomb) . stabilize (unfoldCombinators . uncomb))
  299. -- TODO: merge this into a proper test suite once one exists
  300. -- test s = case parseModule s of
  301. -- f@(ParseFailed _ _) -> fail (show f)
  302. -- ParseOk (Hs.Module _ _ _ _ _ _ defs) ->
  303. -- flip mapM_ defs $ \def -> do
  304. -- putStrLn . prettyPrintInLine $ def
  305. -- putStrLn . prettyPrintInLine . uncomb $ def
  306. -- putStrLn . prettyPrintInLine . optimize . uncomb $ def
  307. -- putStrLn . prettyPrintInLine . stabilize (optimize . uncomb) $ def
  308. -- putStrLn ""
  309. --
  310. -- main = test "f = tail . head; g = head . tail; h = tail + tail; three = g . h . i; dontSub = (\\x -> x + x) 1; ofHead f = f . head; fm = flip mapM_ xs (\\x -> g x); po = (+1); op = (1+); g = (. f); stabilize = fix (ap . flip (ap . (flip =<< (if' .) . (==))) =<<)"
  311. --