/Lambdabot/Pointful.hs
Haskell | 356 lines | 253 code | 45 blank | 58 comment | 10 complexity | 7c21e3ec03509ae9c7204d7d8bbc067d MD5 | raw file
Possible License(s): BSD-3-Clause
- {-# LANGUAGE LambdaCase #-}
- {-# LANGUAGE PatternSynonyms #-}
- {-# LANGUAGE ScopedTypeVariables #-}
- {-# LANGUAGE ViewPatterns #-}
- -- Undo pointfree transformations. Plugin code derived from Pl.hs.
- module Lambdabot.Pointful (pointful) where
- import Lambdabot.Parser (withParsed)
- import Prelude hiding (sum, exp)
- import Control.Monad.Reader
- import Control.Monad.State
- import Data.Generics
- import qualified Data.Set as S
- import qualified Data.Map as M
- import Data.List hiding (sum)
- import Data.Maybe
- import Language.Haskell.Exts.Simple as Hs hiding (alt, name, var)
- ---- Utilities ----
- stabilize :: Eq a => (a -> a) -> a -> a
- stabilize f x = let x' = f x in if x' == x then x else stabilize f x'
- -- varsBoundHere returns variables bound by top patterns or binders
- varsBoundHere :: Data d => d -> S.Set Name
- varsBoundHere (cast -> Just (PVar name)) = S.singleton name
- varsBoundHere (cast -> Just (Match name _ _ _)) = S.singleton name
- varsBoundHere (cast -> Just (PatBind pat _ _)) = varsBoundHere pat
- varsBoundHere (cast -> Just (_ :: Exp)) = S.empty
- varsBoundHere d = S.unions
- (gmapQ varsBoundHere d)
- -- note: the tempting idea of using a pattern synonym for the frequent
- -- (cast -> Just _) patterns causes compiler crashes with ghc before
- -- version 8; cf. https://ghc.haskell.org/trac/ghc/ticket/11336
- foldFreeVars
- :: forall a d. Data d
- => (Name -> S.Set Name -> a) -> ([a] -> a) -> d -> a
- foldFreeVars var sum e = runReader (go e) S.empty where
- go :: forall d'. Data d' => d' -> Reader (S.Set Name) a
- go (cast -> Just (Var (UnQual name))) =
- asks (var name)
- go (cast -> Just (Lambda ps exp)) =
- bind [varsBoundHere ps] $ go exp
- go (cast -> Just (Let bs exp)) =
- bind [varsBoundHere bs] $ collect [go bs, go exp]
- go (cast -> Just (Alt pat exp bs)) =
- bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
- go (cast -> Just (PatBind pat exp bs)) =
- bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs]
- go (cast -> Just (Match _ ps exp bs)) =
- bind [varsBoundHere ps, varsBoundHere bs] $ collect [go exp, go bs]
- go d = collect (gmapQ go d)
- collect :: forall m. Monad m => [m a] -> m a
- collect ms = sum `liftM` sequence ms
- bind
- :: forall a' b. Ord a'
- => [S.Set a'] -> Reader (S.Set a') b -> Reader (S.Set a') b
- bind ss = local (S.unions ss `S.union`)
- -- return free variables
- freeVars :: Data d => d -> S.Set Name
- freeVars =
- foldFreeVars (\name bv -> S.singleton name `S.difference` bv) S.unions
- -- return number of free occurrences of a variable
- countOcc :: Data d => Name -> d -> Int
- countOcc name = foldFreeVars var sum where
- sum = foldl' (+) 0
- var name' bv = if name /= name' || name' `S.member` bv then 0 else 1
- -- variable capture avoiding substitution
- substAvoiding :: Data d => M.Map Name Exp -> S.Set Name -> d -> d
- substAvoiding subst bv =
- base `extT` exp `extT` alt `extT` decl `extT` match
- where
- base :: Data d => d -> d
- base = gmapT (substAvoiding subst bv)
- exp e@(Var (UnQual name)) =
- fromMaybe e (M.lookup name subst)
- exp (Lambda ps exp') =
- let (subst', bv', ps') = renameBinds subst bv ps
- in Lambda ps' (substAvoiding subst' bv' exp')
- exp (Let bs exp') =
- let (subst', bv', bs') = renameBinds subst bv bs
- in Let (substAvoiding subst' bv' bs') (substAvoiding subst' bv' exp')
- exp d = base d
- alt (Alt pat exp' bs) =
- let (subst1, bv1, pat') = renameBinds subst bv pat
- (subst', bv', bs') = renameBinds subst1 bv1 bs
- in Alt pat'
- (substAvoiding subst' bv' exp') (substAvoiding subst' bv' bs')
- alt _ = error "unexpected"
- decl (PatBind pat exp' bs) =
- let (subst', bv', bs') = renameBinds subst bv bs
- in PatBind
- pat (substAvoiding subst' bv' exp') (substAvoiding subst' bv' bs')
- decl d = base d
- match (Match name ps exp' bs) =
- let (subst1, bv1, ps') = renameBinds subst bv ps
- (subst', bv', bs') = renameBinds subst1 bv1 bs
- in Match name ps'
- (substAvoiding subst' bv' exp') (substAvoiding subst' bv' bs')
- match _ = error "unexpected"
- -- rename local binders (but not the nested expressions)
- renameBinds
- :: Data d
- => M.Map Name Exp -> S.Set Name -> d
- -> (M.Map Name Exp, S.Set Name, d)
- renameBinds subst bv d = (subst', bv', d') where
- (d', (subst', bv', _)) = runState (go d) (subst, bv, M.empty)
- go, base
- :: Data d
- => d -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) d
- go = base `extM` pat `extM` match `extM` decl `extM` exp
- base d'' = gmapM go d''
- pat (PVar name) = PVar `fmap` rename name
- pat d'' = base d''
- match (Match name ps exp' bs) = do
- name' <- rename name
- return $ Match name' ps exp' bs
- match _ = error "unexpected"
- decl (PatBind pat' exp' bs) = do
- pat'' <- go pat'
- return $ PatBind pat'' exp' bs
- decl d'' = base d''
- exp (e :: Exp) = return e
- rename :: Name -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) Name
- rename name = do
- (subst'', bv'', ass) <- get
- case (name `M.lookup` ass, name `S.member` bv'') of
- (Just name', _) -> do
- return name'
- (_, False) -> do
- put (M.delete name subst'', S.insert name bv'', ass)
- return name
- _ -> do
- let name' = freshNameAvoiding name bv''
- put (M.insert name (Var (UnQual name')) subst'',
- S.insert name' bv'', M.insert name name' ass)
- return name'
- -- generate fresh names
- freshNameAvoiding :: Name -> S.Set Name -> Name
- freshNameAvoiding name forbidden = con (pre ++ suf) where
- (con, nm, cs) = case name of
- Ident n -> (Ident, n, "0123456789")
- Symbol n -> (Symbol, n, "?#")
- _ -> error "unexpected"
- pre = reverse . dropWhile (`elem` cs) . reverse $ nm
- sufs = [1..] >>= flip replicateM cs
- suf = head $ dropWhile (\suff -> con (pre ++ suff) `S.member` forbidden)
- sufs
- ---- Optimization (removing explicit lambdas) and restoration of infix ops ----
- -- move lambda patterns into LHS
- optimizeD :: Decl -> Decl
- optimizeD (PatBind (PVar fname) (UnGuardedRhs (Lambda pats rhs)) Nothing) =
- let (subst, bv, pats') = renameBinds M.empty (S.singleton fname) pats
- rhs' = substAvoiding subst bv rhs
- in FunBind [Match fname pats' (UnGuardedRhs rhs') Nothing]
- ---- combine function binding and lambda
- optimizeD (FunBind
- [Match fname pats1 (UnGuardedRhs (Lambda pats2 rhs)) Nothing]) =
- let (subst, bv, pats2') = renameBinds M.empty (varsBoundHere pats1) pats2
- rhs' = substAvoiding subst bv rhs
- in FunBind [Match fname (pats1 ++ pats2') (UnGuardedRhs rhs') Nothing]
- optimizeD x = x
- -- remove parens
- optimizeRhs :: Rhs -> Rhs
- optimizeRhs (UnGuardedRhs (Paren x)) = UnGuardedRhs x
- optimizeRhs x = x
- optimizeE :: Exp -> Exp
- -- apply ((\x z -> ...x...) y) yielding (\z -> ...y...) if there is
- -- only one x or y is simple
- optimizeE (App (Lambda (PVar ident : pats) body) arg) | single || simple arg =
- let (subst, bv, pats') =
- renameBinds (M.singleton ident arg) (freeVars arg) pats
- in Paren (Lambda pats' (substAvoiding subst bv body))
- where
- single = countOcc ident body <= 1
- simple = \case Var _ -> True
- Lit _ -> True
- Paren e' -> simple e'
- _ -> False
- -- apply ((\_ z -> ...) y) yielding (\z -> ...)
- optimizeE (App (Lambda (PWildCard : pats) body) _) =
- Paren (Lambda pats body)
- -- remove 0-arg lambdas resulting from application rules
- optimizeE (Lambda [] b) =
- b
- -- replace (\x -> \y -> z) with (\x y -> z)
- optimizeE (Lambda p1 (Lambda p2 body)) =
- let (subst, bv, p2') = renameBinds M.empty (varsBoundHere p1) p2
- body' = substAvoiding subst bv body
- in Lambda (p1 ++ p2') body'
- -- remove double parens
- optimizeE (Paren (Paren x)) =
- Paren x
- -- remove parens around applied lambdas (the pretty printer restores them)
- optimizeE (App (Paren (x@Lambda{})) y) =
- App x y
- -- remove lambda body parens
- optimizeE (Lambda p (Paren x)) =
- Lambda p x
- -- remove var, lit parens
- optimizeE (Paren x@(Var _)) =
- x
- optimizeE (Paren x@(Lit _)) =
- x
- -- remove infix+lambda parens
- optimizeE (InfixApp a o (Paren l@(Lambda _ _))) =
- InfixApp a o l
- -- remove infix+app aprens
- optimizeE (InfixApp (Paren a@App{}) o l) =
- InfixApp a o l
- optimizeE (InfixApp a o (Paren l@App{})) =
- InfixApp a o l
- -- remove left-assoc application parens
- optimizeE (App (Paren (App a b)) c) =
- App (App a b) c
- -- restore infix
- optimizeE (App (App (Var name'@(UnQual (Symbol _))) l) r) =
- (InfixApp l (QVarOp name') r)
- -- eta reduce
- optimizeE (Lambda ps@(_:_) (App e (Var (UnQual v))))
- | free && last ps == PVar v = Lambda (init ps) e
- where free = countOcc v e == 0
- -- fail
- optimizeE x = x
- ---- Decombinatorization ----
- uncomb' :: Exp -> Exp
- uncomb' (Paren (Paren e)) = Paren e
- -- eliminate sections
- uncomb' (RightSection op' arg) =
- let a = freshNameAvoiding (Ident "a") (freeVars arg)
- in (Paren (Lambda [PVar a] (InfixApp (Var (UnQual a)) op' arg)))
- uncomb' (LeftSection arg op') =
- let a = freshNameAvoiding (Ident "a") (freeVars arg)
- in (Paren (Lambda [PVar a] (InfixApp arg op' (Var (UnQual a)))))
- -- infix to prefix for canonicality
- uncomb' (InfixApp lf (QVarOp name') rf) =
- (Paren (App (App (Var name') (Paren lf)) (Paren rf)))
- -- Expand (>>=) when it is obviously the reader monad:
- -- rewrite: (>>=) (\x -> e)
- -- to: (\ a b -> a ((\ x -> e) b) b)
- uncomb' (App (Var (UnQual (Symbol ">>="))) (Paren lam@Lambda{})) =
- let a = freshNameAvoiding (Ident "a") (freeVars lam)
- b = freshNameAvoiding (Ident "b") (freeVars lam)
- in (Paren (Lambda [PVar a, PVar b]
- (App (App (Var (UnQual a))
- (Paren (App lam (Var (UnQual b))))) (Var (UnQual b)))))
- -- rewrite: ((>>=) e1) (\x y -> e2)
- -- to: (\a -> (\x y -> e2) (e1 a) a)
- uncomb' (App (App (Var (UnQual (Symbol ">>="))) e1)
- (Paren lam@(Lambda (_:_:_) _))) =
- let a = freshNameAvoiding (Ident "a") (freeVars [e1,lam])
- in (Paren (Lambda [PVar a]
- (App (App lam (App e1 (Var (UnQual a)))) (Var (UnQual a)))))
- -- fail
- uncomb' expr = expr
- ---- Simple combinator definitions ---
- combinators :: M.Map Name Exp
- combinators = M.fromList $ map declToTuple defs
- where defs = case parseModule combinatorModule of
- ParseOk (Hs.Module _ _ _ d) -> d
- ParseOk _ -> error "unexpected"
- f@(ParseFailed _ _) -> error
- ("Combinator loading: " ++ show f)
- declToTuple (PatBind (PVar fname) (UnGuardedRhs body) Nothing)
- = (fname, Paren body)
- declToTuple _
- = error
- "Pointful Plugin error: can't convert declaration to tuple"
- combinatorModule :: String
- combinatorModule = unlines [
- "(.) = \\f g x -> f (g x) ",
- "($) = \\f x -> f x ",
- "flip = \\f x y -> f y x ",
- "const = \\x _ -> x ",
- "id = \\x -> x ",
- "(=<<) = flip (>>=) ",
- "liftM2 = \\f m1 m2 -> m1 >>= \\x1 -> m2 >>= \\x2 -> return (f x1 x2) ",
- "join = (>>= id) ",
- "ap = liftM2 id ",
- "(>=>) = flip (<=<) ",
- "(<=<) = \\f g x -> f >>= g x ",
- " ",
- "-- ASSUMED reader monad ",
- "-- (>>=) = (\\f k r -> k (f r) r) ",
- "-- return = const ",
- ""]
- ---- Top level ----
- unfoldCombinators :: (Data a) => a -> a
- unfoldCombinators = substAvoiding combinators (freeVars combinators)
- uncombOnce :: (Data a) => a -> a
- uncombOnce x = everywhere (mkT uncomb') x
- uncomb :: (Eq a, Data a) => a -> a
- uncomb = stabilize uncombOnce
- optimizeOnce :: (Data a) => a -> a
- optimizeOnce x = everywhere
- (mkT optimizeD `extT` optimizeRhs `extT` optimizeE) x
- optimize :: (Eq a, Data a) => a -> a
- optimize = stabilize optimizeOnce
- pointful :: String -> Either String String
- pointful =
- withParsed
- (stabilize (optimize . uncomb) . stabilize (unfoldCombinators . uncomb))
- -- TODO: merge this into a proper test suite once one exists
- -- test s = case parseModule s of
- -- f@(ParseFailed _ _) -> fail (show f)
- -- ParseOk (Hs.Module _ _ _ _ _ _ defs) ->
- -- flip mapM_ defs $ \def -> do
- -- putStrLn . prettyPrintInLine $ def
- -- putStrLn . prettyPrintInLine . uncomb $ def
- -- putStrLn . prettyPrintInLine . optimize . uncomb $ def
- -- putStrLn . prettyPrintInLine . stabilize (optimize . uncomb) $ def
- -- putStrLn ""
- --
- -- main = test "f = tail . head; g = head . tail; h = tail + tail; three = g . h . i; dontSub = (\\x -> x + x) 1; ofHead f = f . head; fm = flip mapM_ xs (\\x -> g x); po = (+1); op = (1+); g = (. f); stabilize = fix (ap . flip (ap . (flip =<< (if' .) . (==))) =<<)"
- --