{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} module Windows12.Semant where import Data.Text (Text) import Control.Monad.State import Data.List (find) import Windows12.Ast as Ast import Windows12.TAst as TAst suppliedFuncs :: [Text] suppliedFuncs = ["printf"] -- Convert an Ast to a TAst -- Performs type inference and type checking data Ctx = Ctx { structs :: [TLStruct], enums :: [TLEnum], funcs :: [TTLFunc], vars :: [(Text, Type)] } deriving (Eq, Show) -- Main conversion function. May return an error message if the program -- is not well-typed. convert :: Ast.Program -> Either String TAst.TProgram convert (Ast.Program structs enums funcs) = do let ctx = Ctx structs enums [] [] let (funcs', _) = runState (mapM convertFunc funcs) ctx return $ TAst.TProgram structs enums funcs' -- Convert a TLFunc (Top Level Function) to a TTLFunc (Typed Top Level Function) -- Note that the function must be added to the context before converting statements -- of the function. This is because the function may call itself recursively. -- After converting the function, the function's statements are converted -- and added to the context. convertFunc :: MonadState Ctx m => Ast.TLFunc -> m TAst.TTLFunc convertFunc (Ast.Func name args retType body) = do args' <- mapM (\(Bind name t) -> return (name, t)) args oldFuncs <- gets funcs modify (\ctx -> ctx { funcs = funcs ctx ++ [TTLFunc name args retType []], vars = args' }) body' <- mapM convertStmt body ctx <- get let func = (last $ funcs ctx) { TAst.funcBody = body' } put $ ctx { funcs = oldFuncs ++ [func] } return func -- Convert a statement convertStmt :: MonadState Ctx m => Ast.Stmt -> m TAst.TStmt convertStmt (Ast.Expr expr) = do expr' <- convertExpr expr return $ TAst.TExprStmt expr' convertStmt (Ast.Return expr) = do expr' <- convertExpr expr return $ TAst.TReturn expr' convertStmt (Ast.If cond thenStmts elseStmts) = do thenStmts' <- mapM convertStmt thenStmts elseStmts' <- mapM convertStmt $ maybe [] id elseStmts cond' <- convertExpr cond return $ TAst.TIf cond' thenStmts' (Just elseStmts') convertStmt (Ast.While cond stmts) = do stmts' <- mapM convertStmt stmts cond' <- convertExpr cond return $ TAst.TWhile cond' stmts' convertStmt (Ast.Assign op lval expr) = do lval' <- convertLVal lval expr' <- convertExpr expr return $ TAst.TAssign op lval' expr' convertStmt (Ast.Block stmts) = do stmts' <- mapM convertStmt stmts return $ TAst.TBlock stmts' convertStmt (Ast.Var name (Just t) maybeExpr) = do expr' <- maybe (return Nothing) (fmap Just . convertExpr) maybeExpr modify (\ctx -> ctx { vars = (name, t) : vars ctx }) return $ TAst.TDeclVar name t expr' -- TODO convertStmt (Ast.Var name Nothing maybeExpr) = error "Type inference not implemented" -- Convert an expression to an LValue -- Only certain expressions are allowed as LValues convertLVal :: MonadState Ctx m => Ast.Expr -> m TAst.TLVal convertLVal (Ast.Id name) = do ctx <- get case lookup name (vars ctx) of Just t -> return (t, TAst.TId name) Nothing -> error $ "Variable " ++ show name ++ " not in scope" convertLVal (Ast.Index arr idx) = do arr' <- convertLVal arr idx' <- convertExpr idx return (fst arr', TAst.LTIndex arr' idx') convertLVal (Ast.Member e (Id m)) = do e' <- convertLVal e return (fst e', TAst.LTMember e' m) convertLVal (Ast.Member e m) = do error $ "Invalid member access " ++ show m ++ " on " ++ show e convertLVal (Ast.UnOp Ast.Deref e) = error "Dereferencing not implemented" convertLVal e = do error $ "Invalid or unimplemented LValue " ++ show e -- Convert an expression convertExpr :: MonadState Ctx m => Ast.Expr -> m TAst.TExpr convertExpr (Ast.Id name) = do ctx <- get case lookup name (vars ctx) of Just t -> return (t, TAst.TVar name) Nothing -> error $ "Variable " ++ show name ++ " not in scope" convertExpr (Ast.IntLit x) = return (IntType, TAst.TIntLit x) convertExpr (Ast.UIntLit x) = return (UIntType, TAst.TUIntLit x) convertExpr (Ast.FloatLit x) = return (FloatType, TAst.TFloatLit x) convertExpr (Ast.StrLit x) = return (StrType, TAst.TStrLit x) convertExpr (Ast.BoolLit x) = return (BoolType, TAst.TBoolLit x) convertExpr (Ast.CharLit x) = return (CharType, TAst.TCharLit x) convertExpr (Ast.BinOp Add l r) = arithOp Add l r convertExpr (Ast.BinOp Sub l r) = arithOp Sub l r convertExpr (Ast.BinOp Mul l r) = arithOp Mul l r convertExpr (Ast.BinOp Div l r) = arithOp Div l r convertExpr (Ast.BinOp Mod l r) = arithOp Mod l r convertExpr (Ast.BinOp Eq l r) = compOp Eq l r convertExpr (Ast.BinOp Ne l r) = compOp Ne l r convertExpr (Ast.BinOp Lt l r) = compOp Lt l r convertExpr (Ast.BinOp Gt l r) = compOp Gt l r convertExpr (Ast.BinOp Le l r) = compOp Le l r convertExpr (Ast.BinOp Ge l r) = compOp Ge l r convertExpr (Ast.BinOp And l r) = boolOp And l r convertExpr (Ast.BinOp Or l r) = boolOp Or l r convertExpr (Ast.BinOp BitAnd l r) = bitOp BitAnd l r convertExpr (Ast.BinOp BitOr l r) = bitOp BitOr l r convertExpr (Ast.BinOp BitXor l r) = bitOp BitXor l r convertExpr (Ast.BinOp ShiftL l r) = shiftOp ShiftL l r convertExpr (Ast.BinOp ShiftR l r) = shiftOp ShiftR l r convertExpr (Ast.UnOp Neg e) = do e' <- convertExpr e if fst e' `elem` [IntType, UIntType, FloatType] then return (fst e', TAst.TUnOp Neg e') else error $ "Type mismatch: " ++ show e convertExpr (Ast.UnOp Not e) = do e' <- convertExpr e if fst e' == BoolType then return (BoolType, TAst.TUnOp Not e') else error $ "Type mismatch: " ++ show e convertExpr (Ast.UnOp BitNot e) = undefined convertExpr (Ast.UnOp Deref e) = undefined convertExpr (Ast.UnOp AddrOf e) = undefined -- TODO type check function return -- TODO ensure returns on all paths -- Lower priority since LLVM checks this also convertExpr (Ast.Call (Id f) args) = do ctx <- get if f == "printf" then do args' <- mapM convertExpr args return (IntType, TAst.TCall "printf" args') else case find (\(TTLFunc n a r _) -> n == f) (funcs ctx) of Just t -> do args' <- mapM convertExpr args if length args' == length (TAst.funcArgs t) && all (\(t1, t2) -> t1 == t2) (zip (map fst args') (map bindType (TAst.funcArgs t))) then return (TAst.funcRetType t, TAst.TCall f args') else error $ "Type mismatch in call to " ++ show f Nothing -> error $ "Function " ++ show f ++ " not in scope. Available functions: " ++ show (map TAst.funcName (funcs ctx)) convertExpr (Ast.Index arr idx) = do arr' <- convertExpr arr idx' <- convertExpr idx case fst arr' of ArrayType t -> if fst idx' == IntType then return (t, TAst.TIndex arr' idx') else error $ "Index must be an integer: " ++ show idx _ -> error $ "Indexing non-array: " ++ show arr convertExpr (Ast.Cast t e) = do e' <- convertExpr e return (t, TAst.TCast t e') convertExpr (Ast.Sizeof t) = return (IntType, TAst.TSizeof t) convertExpr (Ast.Member e (Id m)) = do e' <- convertExpr e case fst e' of StructType name -> do ctx <- get case find (\(Struct n _) -> n == name) (structs ctx) of Just (Struct _ binds) -> case find (\(Bind n t) -> n == m) binds of Just (Bind _ t) -> return (t, TAst.TMember e' m) Nothing -> error $ "Field " ++ show m ++ " not in struct " ++ show name Nothing -> error $ "Struct " ++ show name ++ " not in scope" _ -> error $ "Member access on non-struct " ++ show e convertExpr (Ast.StructInit name fields) = do ctx <- get case find (\(Struct n _) -> n == name) (structs ctx) of Just (Struct _ binds) -> do fields' <- mapM (\(n, e) -> do e' <- convertExpr e case find (\(Bind n' t) -> n == n') binds of Just (Bind _ t) -> if fst e' == t then return (n, e') else error $ "Type mismatch in struct initialization: " ++ show e Nothing -> error $ "Field " ++ show n ++ " not in struct " ++ show name) fields return (StructType name, TAst.TStructInit name fields') Nothing -> error $ "Struct " ++ show name ++ " not in scope" convertExpr e = error $ "Invalid or Unimplemented conversion for expression " ++ show e -- Ensure that the types of the left and right expressions are the same -- and return the type of the result arithOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr arithOp o l r = do l' <- convertExpr l r' <- convertExpr r if fst l' == fst r' then return (fst l', TAst.TBinOp o l' r') else error $ "Type mismatch: " ++ show l ++ " and " ++ show r -- Ensure that the types of the left and right expressions are the same -- and return a boolean type compOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr compOp o l r = do l' <- convertExpr l r' <- convertExpr r if fst l' == fst r' then return (BoolType, TAst.TBinOp o l' r') else error $ "Type mismatch: " ++ show l ++ " and " ++ show r -- Ensure that the types of both expressions are boolean -- and return a boolean type boolOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr boolOp o l r = do l' <- convertExpr l r' <- convertExpr r if fst l' == fst r' && fst l' == BoolType then return (BoolType, TAst.TBinOp o l' r') else error $ "Type mismatch: " ++ show l ++ " and " ++ show r bitOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr bitOp o l r = do error $ "Bit operations not implemented" shiftOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr shiftOp o l r = do error $ "Shift operations not implemented"