Deep Definition removal
[darcs-mirror-polyfix.git] / Expr.hs
1 {-# LANGUAGE PatternGuards, DeriveDataTypeable   #-}
2 module Expr where
3
4 import Data.List
5 import Data.Maybe
6 import ParseType
7
8 import Data.Generics hiding (typeOf)
9 import Data.Generics.Schemes
10
11 data TypedExpr = TypedExpr
12         { unTypeExpr    :: Expr
13         , typeOf        :: Typ
14         } deriving (Eq, Typeable, Data)
15
16 typedLeft, typedRight :: Expr -> Typ -> TypedExpr
17 typedLeft  e t = TypedExpr e (instType False t)
18 typedRight e t = TypedExpr e (instType True t)
19
20 data Expr
21         = Var String
22         | App Expr Expr
23         | Conc [Expr] -- Conc [] is Id
24         | Lambda TypedExpr Expr
25         | Pair Expr Expr
26         | Map
27             deriving (Eq, Typeable, Data)
28
29 data BoolExpr 
30         = Equal Expr Expr
31         | And [BoolExpr] -- And [] is True
32         | AllZipWith TypedExpr TypedExpr BoolExpr Expr Expr
33         | Condition [TypedExpr] BoolExpr BoolExpr
34         | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
35         | TypeVarInst Int BoolExpr
36             deriving (Eq, Typeable, Data)
37
38 -- Smart constructors
39
40 equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
41               | otherwise                = Equal (unTypeExpr te1) (unTypeExpr te2)
42
43 unpackPair = UnpackPair
44
45 allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
46 allZipWith v1 v2 rel e1 e2 | Just (v1', _) <- defFor v1 rel =
47                                 e1 `equal` amap (lambda v2 v1') e2
48                            | Just (v2', _) <- defFor v2 rel =
49                                 amap (lambda v1 v2') e1 `equal` e2
50                            | otherwise =
51                                 AllZipWith v1 v2 rel (unTypeExpr e1) (unTypeExpr e2)
52
53 amap tf tl | Arrow t1 t2 <- typeOf tf
54            , List t      <- typeOf tl
55            , t1 == t
56            = let tMap = TypedExpr Map (Arrow (List t1) (List t2))
57              in app (app tMap tf) tl
58 amap tf tl | otherwise = error "Type error in map"
59
60 aand (And xs) (And ys) = And (xs  ++ ys)
61 aand (And xs) y        = And (xs  ++ [y])
62 aand x        (And ys) = And ([x] ++ ys)
63 aand x        y        = And ([x,y])
64
65 beTrue = And []
66
67 -- | Is any var (or part of var) defined in cond, and can be replaced in concl?
68 condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
69 condition vars cond concl | ((vars',cond',concl'):_) <- mapMaybe try vars
70                           = condition vars' cond' concl'
71                           | otherwise
72                           = Condition vars cond concl
73   where try v = do (def,del) <- defFor v cond --Maybe Monad
74                    return (delete v vars, del cond, del concl)
75
76 -- | Replaces a Term in a BoolExpr
77 replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
78 replaceTermBE d r = go
79   where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
80                            | d == e2 && r == e1 = beTrue
81                            | otherwise          = go' e1 `Equal` go' e2
82         go (And es)        = foldr aand beTrue (map go es)
83         go (AllZipWith v1 v2 be e1 e2) 
84                            = AllZipWith v1 v2 (go be) (go' e1) (go' e2)
85         go (Condition vs cond concl)
86                            = condition vs (go cond) (go concl)
87         go (UnpackPair v1 v2 e be)
88                            = unpackPair v1 v2 (goT e) (go be)
89         go (TypeVarInst _ _) = error "TypeVarInst not expected here"
90         goT = replaceTypedExpr d r
91         go' = replaceExpr d r
92
93 replaceExpr :: Expr -> Expr -> Expr -> Expr
94 replaceExpr d r = everywhere (mkT go)
95   where go e | e == d    = r 
96              | otherwise = e
97
98 replaceTypedExpr :: Expr -> Expr -> TypedExpr -> TypedExpr
99 replaceTypedExpr d r = everywhere (mkT go)
100   where go e | unTypeExpr e == d = e { unTypeExpr = r }
101              | otherwise         = e
102
103 -- | Is inside the term a definition for the variable?
104 defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr, BoolExpr -> BoolExpr)
105 defFor tv be | Just (e', delDef) <- defFor' (unTypeExpr tv) be
106                          = Just (TypedExpr e' (typeOf tv), delDef)
107              | otherwise = Nothing
108         
109 -- | Find a definition, and return it along the definition remover
110 defFor' :: Expr -> BoolExpr -> Maybe (Expr, BoolExpr -> BoolExpr)
111 defFor' (Pair x y) e | Just (dx, delX) <- defFor' x e
112                      , Just (dy, delY) <- defFor' y e
113                     = Just ((Pair dx dy), delX . delY)
114 defFor' e (e1 `Equal` e2) | e == e1                 = Just (e2, replaceTermBE e e2)
115                           | e == e2                 = Just (e1, replaceTermBE e e1)
116 defFor' e (And es)        | [d]  <- mapMaybe (defFor' e) es -- exactly one definition
117                                                     = Just d
118 defFor' _ _                                         = Nothing
119
120 app te1 te2 | Arrow t1 t2 <- typeOf te1
121             , t3          <- typeOf te2 
122             , t1 == t3 
123             = TypedExpr (app' (unTypeExpr te1) (unTypeExpr te2)) t2
124  where app' Map (Conc []) = Conc []
125        app' (Conc []) v   = v
126        app' f v           = App f v
127 app te1 te2 | otherwise                          = error $ "Type mismatch in app: " ++
128                                                            show te1 ++ " " ++ show te2
129
130 unCond v (Equal l r) | (Just l') <- isApplOn (unTypeExpr v) l 
131                      , (Just r') <- isApplOn (unTypeExpr v) r = 
132         if v `occursIn` l' || v `occursIn` r'
133         then Condition [v] beTrue (Equal l' r')
134         else (Equal l' r')
135 unCond v e = Condition [v] beTrue e
136
137 lambda tv e = TypedExpr inner (Arrow (typeOf tv) (typeOf e))
138   where inner | (Just e') <- isApplOn (unTypeExpr tv) (unTypeExpr e)
139               , not (unTypeExpr tv `occursIn` e')
140                           = e'
141               | tv == e   = Conc []
142               | otherwise = Lambda tv (unTypeExpr e)
143
144 conc f (Conc fs) = Conc (f:fs)
145
146 -- Helpers
147
148 isApplOn e e'         | e == e'                       = Nothing
149 isApplOn e (App f e') | e == e'                       = Just (Conc [f])
150 isApplOn e (App f e') | (Just inner) <- isApplOn e e' = Just (conc f inner)
151 isApplOn _ _                                          = Nothing
152
153 hasVar v (Var v')     = v == v'
154 hasVar v (App e1 e2)  = hasVar v e1 && hasVar v e2
155 hasVar v (Conc es)    = any (hasVar v) es
156 hasVar v (Lambda _ e) = hasVar v e
157 hasVar v Map          = False
158
159 e `occursIn` e'       = not (null (listify (==e) e'))
160
161 isTuple (TPair _ _) = True
162 isTuple _           = False
163
164
165 -- showing
166
167 -- Precedences:
168 -- 10 fun app
169 --  9 (.)
170 --  8 ==
171 --  7 ==>
172 --  6 forall
173
174 instance Show Expr where
175         showsPrec d (Var s)     = showString s
176         showsPrec d (App e1 e2) = showParen (d>10) $
177                 showsPrec 10 e1 . showChar ' ' . showsPrec 11 e2
178         showsPrec d (Conc [])   = showString "id"
179         showsPrec d (Conc [e])  = showsPrec d e
180         showsPrec d (Conc es)   = showParen (d>9) $
181                 showIntercalate (showString " . ") (map (showsPrec 10) es)
182         showsPrec d (Lambda tv e) = showParen True $ 
183                                     showString "\\" .
184                                     showsPrec 0 tv .
185                                     showString " -> ".
186                                     showsPrec 0 e 
187         showsPrec _ (Pair e1 e2) = showParen True $ 
188                                    showsPrec 0 e1 .
189                                    showString "," .
190                                    showsPrec 0 e2
191         showsPrec _ Map           = showString "map"
192
193 showIntercalate i []  = id
194 showIntercalate i [x] = x
195 showIntercalate i (x:xs) = x . i . showIntercalate i xs
196
197 instance Show TypedExpr where
198         showsPrec d (TypedExpr e t) = 
199                 showParen (d>10) $
200                         showsPrec 0 e .
201                         showString " :: " .
202                         showString (showTypePrec 0 t)
203
204 instance Show BoolExpr where
205         show (Equal e1 e2) = showsPrec 9 e1 $
206                              showString " == " $
207                              showsPrec 9 e2 ""
208         show (And [])      = show "True"
209         show (And bes)     = intercalate " && " $ map show bes
210         show (AllZipWith v1 v2 be e1 e2) =
211                         "allZipWith " ++
212                         "( " ++
213                         "\\" ++
214                         showsPrec 11 v1 "" ++
215                         " " ++
216                         showsPrec 11 v2 "" ++
217                         " -> " ++
218                         show be ++
219                         ")" ++
220                         " " ++
221                         showsPrec 11 e1 "" ++
222                         " " ++
223                         showsPrec 11 e2 ""
224         show (Condition tvars be1 be2) = 
225                         "forall " ++
226                         intercalate ", " (map show tvars) ++
227                         ".\n" ++
228                         (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
229                         indent 2 (show be2)
230         show (UnpackPair v1 v2 e be) = 
231                         "let (" ++
232                         showsPrec 0 v1 "" ++
233                         "," ++
234                         showsPrec 0 v2 "" ++
235                         ") = " ++
236                         showsPrec 0 e "" ++
237                         " in\n" ++
238                         indent 2 (show be)
239         show (TypeVarInst i be) = 
240                         "forall types t" ++
241                         show (2*i-1) ++
242                         ", t" ++
243                         show (2*i) ++
244                         ", function g" ++
245                         show i ++
246                         " :: t" ++
247                         show (2*i-1) ++
248                         " -> t" ++
249                         show (2*i) ++ 
250                         ".\n" ++
251                         indent 2 (show be)
252
253 indent n = unlines . map (replicate n ' ' ++) . lines
254
255 showTypePrec :: Int -> Typ -> String
256 showTypePrec _ Int                          = "Int" 
257 showTypePrec _ (TVar (TypVar i))            = "a"++show i
258 showTypePrec _ (TVar (TypInst i b)) | not b = "t" ++  show (2*i-1)
259                                     |     b = "t" ++  show (2*i)
260 showTypePrec d (Arrow t1 t2)                = paren (d>9) $ 
261                                   showTypePrec 10 t1 ++ " -> " ++ showTypePrec 9 t2 
262 showTypePrec d (List t)                     = "[" ++ showTypePrec 0 t ++ "]"
263 showTypePrec d (TEither t1 t2)              = "Either " ++ showTypePrec 11 t1 ++ 
264                                                     " " ++ showTypePrec 11 t2
265 showTypePrec d (TPair t1 t2)                = "(" ++ showTypePrec 0 t1 ++
266                                               "," ++ showTypePrec 0 t2 ++ ")"
267
268 paren b p   =  if b then "(" ++ p ++ ")" else p