]> Joshua Wise's Git repositories - snipe.git/blobdiff - type/typechecker.sml
Initial import of l3c
[snipe.git] / type / typechecker.sml
index 32f80a16e40b9cddb7abae4e82ee4332b457719a..63608bdb403628faaa4aed2d79e640a2e3476fee 100644 (file)
@@ -1,10 +1,9 @@
-(* L1 Compiler
+(* L3 Compiler
  * TypeChecker
  * Author: Alex Vaynberg <alv@andrew.cmu.edu>
  * Modified: Frank Pfenning <fp@cs.cmu.edu>
- *
- * Simple typechecker that is based on a unit Symbol.table
- * This is all that is needed since there is only an integer type present
+ * Modified: Joshua Wise <jwise>
+ * Modified: Chris Lu <czl>
  *) 
 
 signature TYPE_CHECK =
@@ -16,6 +15,8 @@ end;
 structure TypeChecker :> TYPE_CHECK = 
 struct
   structure A = Ast
+  
+  datatype asn = ASSIGNED | UNASSIGNED
 
   fun returns nil = false
     | returns (A.Assign _ :: stms) = returns stms
@@ -27,7 +28,7 @@ struct
     | returns (A.If (_, s1, SOME s2) :: stms) = (returns s1 andalso returns s2) orelse returns stms
     | returns (A.For _ :: stms) = returns stms
     | returns (A.While _ :: stms) = returns stms
-    | returns (A.MarkedStm m :: stms) = returns (Mark.data m :: stms)
+    | returns (A.MarkedStm m :: stms) = returns (Mark.kane m :: stms)
   
   fun breakcheck nil mark = ()
     | breakcheck (A.Break :: stms) mark = ( ErrorMsg.error mark ("Illegal break outside loop") ;
@@ -41,21 +42,41 @@ struct
         ( breakcheck s1 mark;
           breakcheck s2 mark;
           breakcheck stms mark)
-    | breakcheck (A.MarkedStm m :: stms) mark = (breakcheck [(Mark.data m)] (Mark.ext m); breakcheck stms mark)
+    | breakcheck (A.MarkedStm m :: stms) mark = (breakcheck [(Mark.kane m)] (Mark.ext m); breakcheck stms mark)
     | breakcheck (_ :: stms) mark = breakcheck stms mark
   
-  fun varcheck_exp env (A.Var v) mark =
+  fun varcheck_exp env fenv (A.Var v) mark : Ast.vtype =
         ( case Symbol.look env v
           of NONE => ( ErrorMsg.error mark ("undefined variable `" ^ Symbol.name v ^ "'") ;
                        raise ErrorMsg.Error )
-           | SOME _ => ())
-    | varcheck_exp env (A.ConstExp _) mark = ()
-    | varcheck_exp env (A.OpExp (_, l)) mark = List.app (fn znt => varcheck_exp env znt mark) l
-    | varcheck_exp env (A.Marked m) mark = varcheck_exp env (Mark.data m) (Mark.ext m)
+           | SOME (t, UNASSIGNED) => ( ErrorMsg.error mark ("usage of unassigned variable `" ^ Symbol.name v ^ "'") ;
+                                       raise ErrorMsg.Error )
+           | SOME (t, ASSIGNED) => t)
+    | varcheck_exp env fenv (A.ConstExp _) mark = (A.Int)
+    | varcheck_exp env fenv (A.OpExp (_, l)) mark = (List.app (fn znt => (varcheck_exp env fenv znt mark; ())) l; A.Int)
+    | varcheck_exp env fenv (A.FuncCall (f, l)) mark =
+      let
+        val types = map (fn znt => varcheck_exp env fenv znt mark) l
+        val func = case Symbol.look fenv f
+                     of NONE => ( ErrorMsg.error mark ("undefined function `" ^ Symbol.name f ^ "'") ;
+                                  raise ErrorMsg.Error )
+                      | SOME a => a
+        val (rtype, params) = case func
+                               of A.Extern (rtype, _, params) => (rtype, params)
+                                | A.Function (rtype, _, params, _, _) => (rtype, params)
+        val paramtypes = map (fn (i, t) => t) params
+        val () = if not (types = paramtypes)
+                 then ( ErrorMsg.error mark ("incorrect parameters for function `" ^ Symbol.name f ^ "'") ;
+                        raise ErrorMsg.Error )
+                 else ()
+      in
+        rtype
+      end
+    | varcheck_exp env fenv (A.Marked m) mark = varcheck_exp env fenv (Mark.kane m) (Mark.ext m)
   
   fun computeassigns env nil = env
     | computeassigns env (A.Assign (id,e) :: stms) =
-        computeassigns (Symbol.bind env (id, ())) stms
+        computeassigns (Symbol.bind env (id, (A.Int, ASSIGNED))) stms
     | computeassigns env (A.Return _ :: stms) = env
     | computeassigns env (A.Nop :: stms) = computeassigns env stms
     | computeassigns env (A.Break :: stms) = env
@@ -65,7 +86,11 @@ struct
         let
           val env1 = computeassigns env s1
           val env2 = computeassigns env s2
-          val env' = Symbol.intersect (env1, env2)
+          val env' =
+            Symbol.intersect
+              (fn ((t, ASSIGNED), (t', ASSIGNED)) => (t, ASSIGNED) (* XXX check types for equality *)
+                | ((t, _), (t', _)) => (t, UNASSIGNED))
+              (env1, env2)
           val env' =
             if (returns s1) then env2
             else if (returns s2) then env1
@@ -82,59 +107,137 @@ struct
        in
          computeassigns env' stms
        end
-    | computeassigns env (A.MarkedStm m :: stms) = computeassigns env ((Mark.data m) :: stms)
+    | computeassigns env (A.MarkedStm m :: stms) = computeassigns env ((Mark.kane m) :: stms)
   
-  fun varcheck env nil mark = nil
-    | varcheck env (A.Assign (id, e) :: stms) mark =
-        ( varcheck_exp env e mark ;
-          A.Assign (id, e) :: (varcheck (Symbol.bind env (id, ())) stms mark) )
-    | varcheck env (A.Return (e) :: stms) mark =
-        ( varcheck_exp env e mark;
+  fun varcheck env fenv nil mark = nil
+    | varcheck env fenv (A.Assign (id, e) :: stms) mark =
+        let
+          val sym = Symbol.look env id
+          val _ = if not (isSome sym)
+                  then (ErrorMsg.error mark ("assignment to undeclared variable " ^ (Symbol.name id)); raise ErrorMsg.Error)
+                  else ()
+          val (t, a) = valOf sym
+          val t' = varcheck_exp env fenv e mark
+        in 
+          A.Assign (id, e) :: (varcheck (Symbol.bind env (id, (t, ASSIGNED))) fenv stms mark)
+        end
+    | varcheck env fenv (A.Return (e) :: stms) mark =
+        ( varcheck_exp env fenv e mark;
           A.Return (e) :: nil )
-    | varcheck env (A.Nop :: stms) mark =
-        ( A.Nop :: (varcheck env stms mark))
-    | varcheck env (A.Break :: stms) mark =
+    | varcheck env fenv (A.Nop :: stms) mark =
+        ( A.Nop :: (varcheck env fenv stms mark))
+    | varcheck env fenv (A.Break :: stms) mark =
         ( A.Break :: nil )
-    | varcheck env (A.Continue :: stms) mark =
+    | varcheck env fenv (A.Continue :: stms) mark =
         ( A.Continue :: nil )
-    | varcheck env (A.If (e, s1, NONE) :: stms) mark =
-        ( varcheck_exp env e mark ;
-          varcheck env s1 mark ;
-          A.If (e, s1, NONE) :: (varcheck env stms mark) )
-    | varcheck env ((i as A.If (e, s1, SOME s2)) :: stms) mark =
-        ( varcheck_exp env e mark ;
-          varcheck env s1 mark ; 
-          varcheck env s2 mark ;
+    | varcheck env fenv (A.If (e, s1, NONE) :: stms) mark =
+        ( varcheck_exp env fenv e mark ;
+          varcheck env fenv s1 mark ;
+          A.If (e, s1, NONE) :: (varcheck env fenv stms mark) )
+    | varcheck env fenv ((i as A.If (e, s1, SOME s2)) :: stms) mark =
+        ( varcheck_exp env fenv e mark ;
+          varcheck env fenv s1 mark ; 
+          varcheck env fenv s2 mark ;
           A.If (e, s1, SOME s2) ::
             (if (returns [i])
              then nil
-             else varcheck (computeassigns env [i]) stms mark)  )
-    | varcheck env (A.While (e, s1) :: stms) mark =
-        ( varcheck_exp env e mark ;
-          varcheck env s1 mark ;
-          A.While (e, s1) :: (varcheck env stms mark) )
-    | varcheck env (A.For (sbegin, e, sloop, inner) :: stms) mark =
+             else varcheck (computeassigns env [i]) fenv stms mark)  )
+    | varcheck env fenv (A.While (e, s1) :: stms) mark =
+        ( varcheck_exp env fenv e mark ;
+          varcheck env fenv s1 mark ;
+          A.While (e, s1) :: (varcheck env fenv stms mark) )
+    | varcheck env fenv (A.For (sbegin, e, sloop, inner) :: stms) mark =
         let
           val sbegin = case sbegin
-                       of SOME(s) => SOME (hd (varcheck env [s] mark))
+                       of SOME(s) => SOME (hd (varcheck env fenv [s] mark))
                         | NONE => NONE
           val env' = case sbegin
                      of SOME(s) => computeassigns env [s]
                       | NONE => env
-          val _ = varcheck_exp env' e
-          val inner = varcheck env' inner mark
+          val _ = varcheck_exp env' fenv e
+          val inner = varcheck env' fenv inner mark
           val env'' = computeassigns env' inner
           val sloop = case sloop
-                  of SOME(s) => SOME (hd (varcheck env'' [s] mark))
+                  of SOME(s) => SOME (hd (varcheck env'' fenv [s] mark))
                    | NONE => NONE
         in
-          A.For (sbegin, e, sloop, inner) :: (varcheck env' stms mark)
+          A.For (sbegin, e, sloop, inner) :: (varcheck env' fenv stms mark)
         end
-    | varcheck env (A.MarkedStm m :: stms) mark = varcheck env ((Mark.data m) :: stms) (Mark.ext m)
+    | varcheck env fenv (A.MarkedStm m :: stms) mark = varcheck env fenv ((Mark.kane m) :: stms) (Mark.ext m)
 
-  fun typecheck prog =
-      ( breakcheck prog NONE ;
-        if not (returns prog)
-        then (ErrorMsg.error NONE ("program does not return in all cases"); raise ErrorMsg.Error)
-        else varcheck Symbol.empty prog NONE)
+  fun bindvars sym stat l = foldr (fn ((i,t), s) => Symbol.bind s (i,(t, stat))) sym l
+  fun bindfuns sym l =
+    foldr
+      (fn (a as (A.Function (_, id, _, _, _)), s) => Symbol.bind s (id, a)
+        | (a as (A.Extern (_, id, _)), s) => Symbol.bind s (id, a))
+      sym l
+
+  fun dupchk l =
+        List.app
+          (fn (n, _) =>
+            let
+              val name = Symbol.name n
+              val all = List.filter (fn (n', _) => name = (Symbol.name n')) l
+              val count = length all
+            in
+              if count = 1
+              then ()
+              else ( ErrorMsg.error NONE ("multiple definition of variable " ^ (Symbol.name n));
+                     raise ErrorMsg.Error )
+            end) l
+
+  fun typecheck_fn p (e as (A.Extern (t, id, al))) = (dupchk al; e)
+    | typecheck_fn p (A.Function (t, id, al, vl, sl)) =
+      let
+        val () = breakcheck sl NONE
+        val () = if not (returns sl)
+                 then ( ErrorMsg.error NONE ("function `"^ Symbol.name id ^ "' does not return in all cases");
+                        raise ErrorMsg.Error )
+                 else ()
+        val env = Symbol.empty
+        val env = bindvars env ASSIGNED al
+        val env = bindvars env UNASSIGNED vl
+        val fenv = bindfuns Symbol.empty p
+        val () = dupchk (al @ vl)
+      in
+        A.Function (t, id, al, vl, varcheck env fenv sl NONE)
+      end
+  
+  fun typecheck p =
+      let
+        fun getFun n =
+          List.find (fn A.Extern (_, id, _) => ((Symbol.name id) = n)
+                      | A.Function (_, id, _, _, _) => ((Symbol.name id) = n))
+                    p
+        val main = case (getFun "main")
+                   of NONE => ( ErrorMsg.error NONE ("no function named main");
+                                raise ErrorMsg.Error )
+                    | SOME m => m
+        val () = case main
+                 of A.Extern _ => ( ErrorMsg.error NONE ("you anus, main can't be an extern");
+                                    raise ErrorMsg.Error )
+                  | A.Function (A.Int, _, nil, _, _) => ()
+                  | A.Function (A.Int, _, _, _, _) => ( ErrorMsg.error NONE ("main should take no parameters");
+                                                        raise ErrorMsg.Error )
+        val () = List.app
+                   (fn a =>
+                      let
+                        val id = case a
+                          of A.Extern (_, id, _) => id
+                           | A.Function (_, id, _, _, _) => id
+                        val name = Symbol.name id
+                        val all = List.filter
+                          (fn A.Extern (_, id, _) => (Symbol.name id) = name
+                            | A.Function (_, id, _, _, _) => (Symbol.name id) = name)
+                          p
+                        val num = length all
+                      in
+                        if num = 1
+                        then ()
+                        else ( ErrorMsg.error NONE ("multiple definition of " ^ name);
+                               raise ErrorMsg.Error )
+                      end) p
+      in
+        List.map (typecheck_fn p) p
+      end
 end
This page took 0.032049 seconds and 4 git commands to generate.