diff --git a/src/Windows12.hs b/src/Windows12.hs index 185b4a3..8ccebaf 100644 --- a/src/Windows12.hs +++ b/src/Windows12.hs @@ -5,3 +5,4 @@ import Windows12.Lexer import Windows12.Parser import Windows12.CodeGen import Windows12.TAst +import Windows12.Semant diff --git a/src/Windows12/Semant.hs b/src/Windows12/Semant.hs new file mode 100644 index 0000000..fb36cd2 --- /dev/null +++ b/src/Windows12/Semant.hs @@ -0,0 +1,263 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} + +module Windows12.Semant where + +import Data.Text (Text) + +import Control.Monad.State + +import Data.List (find) + +import Windows12.Ast as Ast +import Windows12.TAst as TAst + +suppliedFuncs :: [Text] +suppliedFuncs = ["printf"] + +-- Convert an Ast to a TAst +-- Performs type inference and type checking + +data Ctx = Ctx { structs :: [TLStruct], + enums :: [TLEnum], + funcs :: [TTLFunc], + vars :: [(Text, Type)] } + deriving (Eq, Show) + +-- Main conversion function. May return an error message if the program +-- is not well-typed. +convert :: Ast.Program -> Either String TAst.TProgram +convert (Ast.Program structs enums funcs) = do + let ctx = Ctx structs enums [] [] + let (funcs', _) = runState (mapM convertFunc funcs) ctx + return $ TAst.TProgram structs enums funcs' + +-- Convert a TLFunc (Top Level Function) to a TTLFunc (Typed Top Level Function) +-- Note that the function must be added to the context before converting statements +-- of the function. This is because the function may call itself recursively. +-- After converting the function, the function's statements are converted +-- and added to the context. +convertFunc :: MonadState Ctx m => Ast.TLFunc -> m TAst.TTLFunc +convertFunc (Ast.Func name args retType body) = do + args' <- mapM (\(Bind name t) -> return (name, t)) args + oldFuncs <- gets funcs + modify (\ctx -> ctx { funcs = funcs ctx ++ [TTLFunc name args retType []], vars = args' }) + body' <- mapM convertStmt body + ctx <- get + let func = (last $ funcs ctx) { TAst.funcBody = body' } + put $ ctx { funcs = oldFuncs ++ [func] } + return func + +-- Convert a statement +convertStmt :: MonadState Ctx m => Ast.Stmt -> m TAst.TStmt + +convertStmt (Ast.Expr expr) = do + expr' <- convertExpr expr + return $ TAst.TExprStmt expr' + +convertStmt (Ast.Return expr) = do + expr' <- convertExpr expr + return $ TAst.TReturn expr' + +convertStmt (Ast.If cond thenStmts elseStmts) = do + thenStmts' <- mapM convertStmt thenStmts + elseStmts' <- mapM convertStmt $ maybe [] id elseStmts + cond' <- convertExpr cond + return $ TAst.TIf cond' thenStmts' (Just elseStmts') + +convertStmt (Ast.While cond stmts) = do + stmts' <- mapM convertStmt stmts + cond' <- convertExpr cond + return $ TAst.TWhile cond' stmts' + +convertStmt (Ast.Assign op lval expr) = do + lval' <- convertLVal lval + expr' <- convertExpr expr + return $ TAst.TAssign op lval' expr' + +convertStmt (Ast.Block stmts) = do + stmts' <- mapM convertStmt stmts + return $ TAst.TBlock stmts' + +convertStmt (Ast.Var name (Just t) maybeExpr) = do + expr' <- maybe (return Nothing) (fmap Just . convertExpr) maybeExpr + modify (\ctx -> ctx { vars = (name, t) : vars ctx }) + return $ TAst.TDeclVar name t expr' + +-- TODO +convertStmt (Ast.Var name Nothing maybeExpr) = error "Type inference not implemented" + +-- Convert an expression to an LValue +-- Only certain expressions are allowed as LValues +convertLVal :: MonadState Ctx m => Ast.Expr -> m TAst.TLVal +convertLVal (Ast.Id name) = do + ctx <- get + case lookup name (vars ctx) of + Just t -> return (t, TAst.TId name) + Nothing -> error $ "Variable " ++ show name ++ " not in scope" + +convertLVal (Ast.Index arr idx) = do + arr' <- convertLVal arr + idx' <- convertExpr idx + return (fst arr', TAst.LTIndex arr' idx') + +convertLVal (Ast.Member e (Id m)) = do + e' <- convertLVal e + return (fst e', TAst.LTMember e' m) + +convertLVal (Ast.Member e m) = do error $ "Invalid member access " ++ show m ++ " on " ++ show e + +convertLVal (Ast.UnOp Ast.Deref e) = error "Dereferencing not implemented" + +convertLVal e = do error $ "Invalid or unimplemented LValue " ++ show e + +-- Convert an expression +convertExpr :: MonadState Ctx m => Ast.Expr -> m TAst.TExpr +convertExpr (Ast.Id name) = do + ctx <- get + case lookup name (vars ctx) of + Just t -> return (t, TAst.TVar name) + Nothing -> error $ "Variable " ++ show name ++ " not in scope" + +convertExpr (Ast.IntLit x) = return (IntType, TAst.TIntLit x) +convertExpr (Ast.UIntLit x) = return (UIntType, TAst.TUIntLit x) +convertExpr (Ast.FloatLit x) = return (FloatType, TAst.TFloatLit x) +convertExpr (Ast.StrLit x) = return (StrType, TAst.TStrLit x) +convertExpr (Ast.BoolLit x) = return (BoolType, TAst.TBoolLit x) +convertExpr (Ast.CharLit x) = return (CharType, TAst.TCharLit x) + +convertExpr (Ast.BinOp Add l r) = arithOp Add l r +convertExpr (Ast.BinOp Sub l r) = arithOp Sub l r +convertExpr (Ast.BinOp Mul l r) = arithOp Mul l r +convertExpr (Ast.BinOp Div l r) = arithOp Div l r +convertExpr (Ast.BinOp Mod l r) = arithOp Mod l r + +convertExpr (Ast.BinOp Eq l r) = compOp Eq l r +convertExpr (Ast.BinOp Ne l r) = compOp Ne l r +convertExpr (Ast.BinOp Lt l r) = compOp Lt l r +convertExpr (Ast.BinOp Gt l r) = compOp Gt l r +convertExpr (Ast.BinOp Le l r) = compOp Le l r +convertExpr (Ast.BinOp Ge l r) = compOp Ge l r + +convertExpr (Ast.BinOp And l r) = boolOp And l r +convertExpr (Ast.BinOp Or l r) = boolOp Or l r + +convertExpr (Ast.BinOp BitAnd l r) = bitOp BitAnd l r +convertExpr (Ast.BinOp BitOr l r) = bitOp BitOr l r +convertExpr (Ast.BinOp BitXor l r) = bitOp BitXor l r + +convertExpr (Ast.BinOp ShiftL l r) = shiftOp ShiftL l r +convertExpr (Ast.BinOp ShiftR l r) = shiftOp ShiftR l r + +convertExpr (Ast.UnOp Neg e) = do + e' <- convertExpr e + if fst e' `elem` [IntType, UIntType, FloatType] + then return (fst e', TAst.TUnOp Neg e') + else error $ "Type mismatch: " ++ show e + +convertExpr (Ast.UnOp Not e) = do + e' <- convertExpr e + if fst e' == BoolType + then return (BoolType, TAst.TUnOp Not e') + else error $ "Type mismatch: " ++ show e + +convertExpr (Ast.UnOp BitNot e) = undefined +convertExpr (Ast.UnOp Deref e) = undefined +convertExpr (Ast.UnOp AddrOf e) = undefined + +-- TODO type check function return +-- TODO ensure returns on all paths +-- Lower priority since LLVM checks this also +convertExpr (Ast.Call (Id f) args) = do + ctx <- get + if f == "printf" + then do + args' <- mapM convertExpr args + return (IntType, TAst.TCall "printf" args') + else case find (\(TTLFunc n a r _) -> n == f) (funcs ctx) of + Just t -> do + args' <- mapM convertExpr args + if length args' == length (TAst.funcArgs t) && all (\(t1, t2) -> t1 == t2) (zip (map fst args') (map bindType (TAst.funcArgs t))) + then return (TAst.funcRetType t, TAst.TCall f args') + else error $ "Type mismatch in call to " ++ show f + Nothing -> error $ "Function " ++ show f ++ " not in scope. Available functions: " ++ show (map TAst.funcName (funcs ctx)) + +convertExpr (Ast.Index arr idx) = do + arr' <- convertExpr arr + idx' <- convertExpr idx + case fst arr' of + ArrayType t -> if fst idx' == IntType + then return (t, TAst.TIndex arr' idx') + else error $ "Index must be an integer: " ++ show idx + _ -> error $ "Indexing non-array: " ++ show arr + +convertExpr (Ast.Cast t e) = do + e' <- convertExpr e + return (t, TAst.TCast t e') + +convertExpr (Ast.Sizeof t) = return (IntType, TAst.TSizeof t) + +convertExpr (Ast.Member e (Id m)) = do + e' <- convertExpr e + case fst e' of + StructType name -> do + ctx <- get + case find (\(Struct n _) -> n == name) (structs ctx) of + Just (Struct _ binds) -> case find (\(Bind n t) -> n == m) binds of + Just (Bind _ t) -> return (t, TAst.TMember e' m) + Nothing -> error $ "Field " ++ show m ++ " not in struct " ++ show name + Nothing -> error $ "Struct " ++ show name ++ " not in scope" + _ -> error $ "Member access on non-struct " ++ show e + +convertExpr (Ast.StructInit name fields) = do + ctx <- get + case find (\(Struct n _) -> n == name) (structs ctx) of + Just (Struct _ binds) -> do + fields' <- mapM (\(n, e) -> do + e' <- convertExpr e + case find (\(Bind n' t) -> n == n') binds of + Just (Bind _ t) -> if fst e' == t + then return (n, e') + else error $ "Type mismatch in struct initialization: " ++ show e + Nothing -> error $ "Field " ++ show n ++ " not in struct " ++ show name) fields + + return (StructType name, TAst.TStructInit name fields') + Nothing -> error $ "Struct " ++ show name ++ " not in scope" + +convertExpr e = error $ "Invalid or Unimplemented conversion for expression " ++ show e + +-- Ensure that the types of the left and right expressions are the same +-- and return the type of the result +arithOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr +arithOp o l r = do + l' <- convertExpr l + r' <- convertExpr r + if fst l' == fst r' + then return (fst l', TAst.TBinOp o l' r') + else error $ "Type mismatch: " ++ show l ++ " and " ++ show r + +-- Ensure that the types of the left and right expressions are the same +-- and return a boolean type +compOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr +compOp o l r = do + l' <- convertExpr l + r' <- convertExpr r + if fst l' == fst r' + then return (BoolType, TAst.TBinOp o l' r') + else error $ "Type mismatch: " ++ show l ++ " and " ++ show r + +-- Ensure that the types of both expressions are boolean +-- and return a boolean type +boolOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr +boolOp o l r = do + l' <- convertExpr l + r' <- convertExpr r + if fst l' == fst r' && fst l' == BoolType + then return (BoolType, TAst.TBinOp o l' r') + else error $ "Type mismatch: " ++ show l ++ " and " ++ show r + +bitOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr +bitOp o l r = do error $ "Bit operations not implemented" + +shiftOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr +shiftOp o l r = do error $ "Shift operations not implemented" diff --git a/windows12.cabal b/windows12.cabal index f274cd5..35720ce 100644 --- a/windows12.cabal +++ b/windows12.cabal @@ -70,6 +70,7 @@ executable windows12 Windows12.Parser Windows12.CodeGen Windows12.TAst + Windows12.Semant -- LANGUAGE extensions used by modules in this package. -- other-extensions: