module Aux where import Lex import Parser ----------------------------------------------------- len :: [a] -> Int len [] = 0 len (x:xs) = 1 + len xs nth :: Int -> [a] -> a nth k [] = error "No n-th element in List!" nth k (x:xs) = if k==1 then x else nth (k-1) xs isIn :: String -> [String] -> Bool isIn _ [] = False isIn x (y:ys) = if (x==y) then True else isIn x ys ----------------------------------------------------- allVarT :: Type -> [String] allVarT Nat = [] allVarT (TypeVar _) = [] allVarT (Pred _ ps) = allVarP (PSum ps) allVarT (TArrow t1 t2) = (allVarT t1)++(allVarT t2) allVarT (T_And ts) = concat (map allVarT ts) allVarT (T_Or ts) = concat (map allVarT ts) allVarT (TUniv x t) = allVarT t allVarT (TExist x t) = allVarT t allVarT (T_Equal p q) = (allVarP p)++(allVarP q) allVarT (T_Less p q) = (allVarP p)++(allVarP q) allVarP :: Prog -> [String] --allVarP (Num _) = [] allVarP Zero = [] allVarP (Succ p) = allVarP p allVarP (Var v) = [v] allVarP (Abs x _ p) = allVarP p allVarP (App p q) = (allVarP p)++(allVarP q) allVarP (Fix f _ p) = allVarP p allVarP (PSum ps) = concat (map allVarP ps) allVarP (Proj k p) = allVarP p allVarP (Casz t m v n) = (allVarP t)++(allVarP m)++[v]++(allVarP n) allVarP (Plus p q) = (allVarP p)++(allVarP q) allVarP (Axiom _ _) = [] allVarP p = error ("allVarP is not defined for program: "++(printP p)) allVarG :: Gamma -> [String] allVarG g = map fst g ------------------------------------------------------------------------------- newVar :: [String] -> String newVar vs = newVarN vs 0 newVarN :: [String] -> Int -> String newVarN vs n = let nv = "_n"++(show n) in (if nv `isIn` vs then newVarN vs (n+1) else nv) ------------------------------------------------------------------------------- equalT :: Type -> Type -> Bool equalT Nat Nat = True equalT (TypeVar v1) (TypeVar v2) = (v1==v2) equalT (TArrow t1 t2) (TArrow r1 r2) = (equalT t1 r1)&&(equalT t2 r2) equalT (T_And ts) (T_And ws) = foldr (&&) True (zipWith equalT ts ws) equalT (T_Or ts) (T_Or ws) = foldr (&&) True (zipWith equalT ts ws) equalT (Pred p ps) (Pred q qs) = (p==q)&&(equalP (PSum ps) (PSum qs)) equalT (TUniv x t) (TUniv y s) = equalT (substT t x nv) (substT s y nv) where nv = Var (newVar ((allVarT t)++(allVarT s))) equalT (TExist x t) (TExist y s) = (TUniv x t) `equalT` (TUniv y s) equalT (T_Equal p1 p2) (T_Equal q1 q2) = (equalP p1 q1)&&(equalP p2 q2) equalT (T_Less p1 p2) (T_Less q1 q2) = (equalP p1 q1)&&(equalP p2 q2) equalT _ _ = False equalP :: Prog -> Prog -> Bool --equalP (Num n) (Num m) = n==m equalP Zero Zero = True equalP (Succ p) (Succ q) = equalP p q equalP (Var v) (Var u) = v==u equalP (Abs x tx px) (Abs y ty py) = (equalT tx ty)&&( (substP px x nv) `equalP` (substP py y nv)) where nv = Var (newVar ((allVarP px)++(allVarP py))) equalP (App a b) (App p q) = (equalP a p)&&(equalP b q) equalP (Fix f tf pf) (Fix g tg pg) = (Abs f tf pf) `equalP` (Abs g tg pg) equalP (PSum ps) (PSum qs) = foldr (&&) True (zipWith equalP ps qs) equalP (Proj k p) (Proj h q) = (k==h)&&(equalP p q) equalP (Casz t m v n) (Casz t' m' u n') = (equalP t t')&&(equalP m m')&& ((substP n v nv) `equalP` (substP n' u nv)) where nv = Var(newVar((allVarP n)++(allVarP n'))) equalP (Plus p q) (Plus p' q') = (equalP p p')&&(equalP q q') equalP (Axiom s1 t1) (Axiom s2 t2) = (s1==s2)&&(equalT t1 t2) equalP _ _ = False ------------------------------------------------------------------------------- inPar :: String -> String inPar s = "("++s++")" printT :: Type -> String printT Nat = "#" printT (TypeVar s) = s printT (Pred s ps) = s++"("++(printL ps)++")" printT (TArrow t1 t2) = inPar ((printT t1)++"->"++(printT t2)) printT (T_And [t]) = printT t printT (T_And (t:ts)) = (printT t)++"&"++(printT (T_And ts)) printT (T_Or [t]) = printT t printT (T_Or (t:ts)) = (printT t)++"|"++(printT (T_Or ts)) printT (TUniv s t) = "!"++s++"."++(printT t) printT (TExist s t) = "?"++s++"."++(printT t) printT (T_Equal p q) = (printP p)++"="++(printP q) printT (T_Less p q) = (printP p)++"<"++(printP q) printP :: Prog -> String --printP (Num n) = show n printP Zero = "0" printP (Succ p) = prnSucc p 1 printP (Var s) = s printP (Abs s t p) = inPar ("\\"++s++":"++(printT t)++"."++(printP p)) printP (App p1 p2) = inPar ((printP p1)++" "++(printP p2)) printP (Fix f t p) = "fix "++f++":"++(printT t)++"."++(printP p) printP (If s i p1 p2) = inPar ("if ("++s++"="++(show i)++") "++(printP p1)++" else "++(printP p2)) printP (PSum ps) = "{"++(printL ps)++"}" printP (Proj i p) = inPar ("proj "++(show i)++" "++(printP p)) printP (Inj i p) = inPar ("inj " ++(show i)++" "++(printP p)) printP (Case p ps) = "case "++(printP p)++" -> ("++(printL ps)++")" printP (Inx p1 p2) = inPar ("inx " ++(printP p1)++" -> "++(printP p2)) printP (Casx p1 p2) = "casx "++(printP p1)++" -> "++(printP p2) printP (Casz t m v n) = "case "++(printP t)++" of 0 -> "++(printP m)++ " | "++v++"' -> "++(printP n) printP (Plus p q) = inPar ((printP p)++"+"++(printP q)) printP (Axiom s _) = s printL :: [Prog] -> String printL ([]) = "" printL ([p]) = printP p printL (p:ps) = (printP p)++","++(printL ps) prnSucc :: Prog -> Int -> String prnSucc Zero n = (show n) prnSucc (Succ p) n = prnSucc p (n+1) prnSucc p 0 = printP p prnSucc p 1 = (printP p)++"'" prnSucc p 2 = (printP p)++"''" prnSucc p n = (printP p)++"+"++(show n) ------------------------------------------------------------------------------- type Gamma = [(String, Type)] varType :: Gamma -> String -> Type varType [] s = error ("Variable "++s++" not found in context!") varType ((v,t):g) s = if v==s then t else varType g s addTo :: (String, Type) -> Gamma -> Gamma addTo (v,t) g = if (checkType g t) then ((v,t):g) else error ("Bad Type for variable "++v++ "; All programs in predicates must be #") checkType :: Gamma -> Type -> Bool checkType _ Nat = True checkType _ (TypeVar _) = True checkType g (Pred _ ps) = foldr (&&) True (map (\t->equalT t Nat) (map (tp1 g) ps)) checkType g (TArrow t1 t2) = (checkType g t1)&&(checkType g t2) checkType g (T_And ts) = foldr (&&) True (map (checkType g) ts) checkType g (T_Or ts) = foldr (&&) True (map (checkType g) ts) checkType g (TUniv x t) = checkType ((x,Nat):g) t checkType g (TExist x t) = checkType ((x,Nat):g) t checkType g (T_Equal p q) = (equalT t1 Nat)&&(equalT t2 Nat) where (t1,t2) = (tp1 g p, tp1 g q) checkType g (T_Less p q) = checkType g (T_Equal p q) ------------------------------------------------------------------------------- tp :: Gamma -> Int -> String -> Prog -> (Type, String) tp g l s (Var v) = tt l g (varType g v) s "Ax" --tp g l s (Num n) = tt l g Nat s "Ax#" tp g l s (Zero) = tt l g Nat s "Ax0" tp g l s (Succ p) = let (t,s') = (tp g (l+1) s p) in if (equalT t Nat) then (tt l g Nat s' "Succ#") else error ("Type error: term "++(printP p)++" is not #") tp g l s (Abs v t p) = let (t0,s') = (tp ((v,t) `addTo` g) (l+1) s p) in if (equalT t Nat) then (tt l g (TUniv v t0) s' "Gen") else (tt l g (TArrow t t0) s' "->") tp g l s (App a b) = let (t1,s1) = (tp g (l+1) s a) in let (t2,s2) = (tp g (l+1) s1 b) in case t1 of TArrow p q -> if (equalT t2 p) then (tt l g q s2 "MP") else error ("Application Type Error: term " ++(printP b)++": type "++(printT t2) ++" does not match "++(printT p)) TUniv x tx -> if (equalT t2 Nat) then (tt l g (substT tx x b) s1 "Subst") else error ("Application Type Error: term " ++(printP b)++" is not of type #") _ -> error ("Application Type Error: term '" ++(printP a)++"': "++(printT t1) ++" is not a function type!") tp g l s (Fix f t p) = let (t1,s') = (tp ((f,t) `addTo` g) (l+1) s p) in if (equalT t t1) then (tt l g t s' "Fix") else error ("Fix Type Error: term " ++(printP (Fix f t p))) tp g l s (PSum ps) = let pairs = (map (tp g (l+1) "") ps) in let ts = map fst pairs in let ss = map snd pairs in (tt l g (T_And ts) (s++(concat ss)) "&") tp g l s (Proj k p) = let (t,s') = (tp g (l+1) s p) in case t of (T_And ts) -> if (k<1)||(k>len ts) then error ("Index out of range in: " ++(printP (Proj k p))) else (tt l g (nth k ts) s' "Proj") _ -> error ((printP p)++" is not of Record Type!") tp g l s (Casz t m v n) = let (e,_) = (tp g (l+1) "" t) in if (e `equalT` Nat) then tt l g (substT (synthC v x c0 c1) x t) s' "Case" else error ("Type error: "++(printP t)++" is not #") where (c0,s0) = tp g (l+1) s m (c1,s') = tp ((v,Nat) `addTo` g) (l+1) s0 n x = newVar(v:(allVarG g)++(allVarP t)++(allVarP m)++(allVarP n)) tp g l s (Plus p q) = let (t1,s1) = (tp g (l+1) s p) in let (t2,s2) = (tp g (l+1) s1 q) in if (equalT t1 Nat)&&(equalT t2 Nat) then (tt l g Nat s2 "Plus#") else error ("Type error: term "++(printP (Plus p q))) tp g l s (Axiom name t) = tt l g t s name tp g l s p = error ("Type Rule not defined for: "++(printP p)) ------------------------------------------------------------------------------- -- to synthesize C from m:C[0/x] and n:C[v'/x] synthC :: String -> String -> Type -> Type -> Type synthC v x Nat Nat = Nat synthC v x (TypeVar r1) (TypeVar r2) = if (r1==r2) then (TypeVar r1) else error ("synthC failed: TypeVar "++r1++"!="++r2) synthC v x (TArrow t r) (TArrow t' r') = TArrow (synthC v x t t') (synthC v x r r') synthC v x (T_And ts) (T_And ws) = T_And (zipWith (synthC v x) ts ws) synthC v x (T_Or ts) (T_Or ws) = T_Or (zipWith (synthC v x) ts ws) synthC v x (Pred p ps) (Pred q qs) = if p==q then Pred p (zipWith (synthP v x) ps qs) else error ("synthC: predicate "++p++"!="++q) synthC v x (TUniv z t) (TUniv y s) = TUniv nv (synthC v x (substT t z (Var nv)) (substT s y (Var nv))) where nv = newVar (v:x:z:y:((allVarT t)++(allVarT s))) synthC v x (T_Equal p q) (T_Equal p' q') = T_Equal (synthP v x p p') (synthP v x q q') synthC v x (T_Less p q) (T_Less p' q') = T_Less (synthP v x p p') (synthP v x q q') synthC v x t1 t2 = error ("synthC ["++v++"'/"++x++"] failed on: " ++(printT t1)++","++(printT t2)) synthP :: String -> String -> Prog -> Prog -> Prog --synthP v x (Num n) (Num m) = ... synthP _ _ Zero Zero = Zero synthP v x (Zero) (Succ p) = case p of Var w -> if (v==w) then (Var x) else error ("synthP failed on "++ "0 vs. Succ("++(printP p)++")") _ -> error ("synthP failed on "++ "0 vs. Succ("++(printP p)++")") synthP v x (Succ p) (Succ q) = Succ (synthP v x p q) synthP v x (Var w) (Var u) = if (w==u) then (Var w) else error ("synthP failed on Var: "++w++"!="++u) synthP v x (Abs z tz pz) (Abs y ty py) = Abs nv (synthC v x tz ty) (synthP v x (substP pz z (Var nv)) (substP py y (Var nv))) where nv = newVar (v:x:z:y:(allVarP pz)++(allVarP py)) synthP v x (App a b) (App p q) = App (synthP v x a p) (synthP v x b q) synthP v x (Fix f tf pf) (Fix g tg pg) = Fix nv (synthC v x tf tg) (synthP v x (substP pf f (Var nv)) (substP pg g (Var nv))) where nv = newVar ((allVarP pf)++(allVarP pg)) synthP v x (PSum ps) (PSum qs) = PSum (zipWith (synthP v x) ps qs) synthP v x (Proj k p) (Proj h q) = if (k==h) then Proj k (synthP v x p q) else error ("synthP proj index mismatch") synthP v x (Casz t m w n) (Casz t' m' u n') = Casz (synthP v x t t') (synthP v x m m') nv (synthP v x (substP n w (Var nv)) (substP n' u (Var nv))) where nv = newVar(w:u:allVarP(PSum [t,t',m,m',n,n'])) synthP v x (Plus p q) (Plus p' q') = Plus (synthP v x p p') (synthP v x q q') synthP _ _ (Axiom a1 t1) (Axiom a2 t2) = if (a1==a2)&&(equalT t1 t2) then (Axiom a1 t1) else error("Type Error within Case: "++ "synthP failed on "++a1++","++a2) synthP v x p1 p2 = error ("Type Error within Case: synthP ["++v++"'/"++x++"]" ++"failed on: "++(printP p1)++","++(printP p2)) ------------------------------------------------------------------------------- printG :: Gamma -> String printG [] = "" printG ([(v,vt)]) = v++":"++(printT vt) printG ((v,vt):g) = v++":"++(printT vt)++","++(printG g) printJ :: Int -> Gamma -> Type -> String -> String printJ level g t rule = (spaceN level)++(printG g)++"|- " ++(printT t)++" /"++rule++"/" spaceN :: Int -> String spaceN 0 = "" spaceN n = ' ':(spaceN (n-1)) ------------------------------------------------------------------------------- substP :: Prog -> String -> Prog -> Prog --substP (Num n) _ _ = (Num n) substP (Zero) _ _ = Zero substP (Succ q) s p = Succ (substP q s p) substP (Var v) s p = if (v==s) then p else (Var v) substP (App a b) s p = App (substP a s p) (substP b s p) substP (Abs x t b) s p = if (x==s) then (Abs x t b) else if (x `isIn` (allVarP p)) then (Abs nv nt (substP (substP b x (Var nv)) s p)) else (Abs x nt (substP b s p)) where nt = (substT t s p) nv = newVar ((allVarP b)++(allVarP p)) substP (Fix x t b) s p = if (x==s) then (Fix x t b) else if (x `isIn` (allVarP p)) then (Fix nv nt (substP (substP b x (Var nv)) s p)) else (Fix x nt (substP b s p)) where nt = (substT t s p) nv = newVar ((allVarP b)++(allVarP p)) substP (PSum qs) s p = PSum (map (\q -> substP q s p) qs) substP (Proj k q) s p = Proj k (substP q s p) substP (Casz t m v n) s p = if (v==s) then (Casz new_t new_m v n) else if (v `isIn` (allVarP p)) then (Casz new_t new_m nv (substP (substP n v (Var nv)) s p)) else (Casz new_t new_m v (substP n s p)) where nv = newVar ((allVarP n)++(allVarP p)) new_t = (substP t s p) new_m = (substP m s p) substP (Plus r q) s p = Plus (substP r s p) (substP q s p) substP (Axiom a t) _ _ = Axiom a t substP q s p = error ("Substitution not defined on term: "++(printP q)) substT :: Type -> String -> Prog -> Type substT Nat s p = Nat substT (TypeVar v) s p = TypeVar v substT (Pred v qs) s p = Pred v (map (\q -> substP q s p) qs) substT (TArrow t1 t2) s p = TArrow (substT t1 s p) (substT t2 s p) substT (T_And ts) s p = T_And (map (\t -> substT t s p) ts) substT (T_Or ts) s p = T_Or (map (\t -> substT t s p) ts) substT (TUniv x t) s p = if (x==s) then (TUniv x t) else if (x `isIn` (allVarP p)) then (TUniv nx (substT (substT t x (Var nx)) s p)) else TUniv x (substT t s p) where nx = newVar ((allVarT t)++(allVarP p)) substT (T_Equal r q) s p = T_Equal (substP r s p) (substP q s p) substT (T_Less r q) s p = T_Less (substP r s p) (substP q s p) substT t s p = error ("Substitution not defined on type: "++(printT t)) ------------------------------------------------------------------------------- t :: String -> IO() t s = putStr (printT (tp1 [] (sp s))) tt :: Int -> Gamma -> Type -> String -> String -> (Type, String) tt level g t s rule = (t, s++(printJ level g t rule)++"\n") tp1 :: Gamma -> Prog -> Type tp1 g p = fst (tp g 0 "" p) st :: String -> Type st s = getType(lex1 s) sp :: String -> Prog sp s = getProg(lex1 s) -------------------------------------------------------------------------------