Implement conversion from Ast to TAst
This commit is contained in:
parent
fe335fa16e
commit
d53362f882
@ -5,3 +5,4 @@ import Windows12.Lexer
|
|||||||
import Windows12.Parser
|
import Windows12.Parser
|
||||||
import Windows12.CodeGen
|
import Windows12.CodeGen
|
||||||
import Windows12.TAst
|
import Windows12.TAst
|
||||||
|
import Windows12.Semant
|
||||||
|
263
src/Windows12/Semant.hs
Normal file
263
src/Windows12/Semant.hs
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module Windows12.Semant where
|
||||||
|
|
||||||
|
import Data.Text (Text)
|
||||||
|
|
||||||
|
import Control.Monad.State
|
||||||
|
|
||||||
|
import Data.List (find)
|
||||||
|
|
||||||
|
import Windows12.Ast as Ast
|
||||||
|
import Windows12.TAst as TAst
|
||||||
|
|
||||||
|
suppliedFuncs :: [Text]
|
||||||
|
suppliedFuncs = ["printf"]
|
||||||
|
|
||||||
|
-- Convert an Ast to a TAst
|
||||||
|
-- Performs type inference and type checking
|
||||||
|
|
||||||
|
data Ctx = Ctx { structs :: [TLStruct],
|
||||||
|
enums :: [TLEnum],
|
||||||
|
funcs :: [TTLFunc],
|
||||||
|
vars :: [(Text, Type)] }
|
||||||
|
deriving (Eq, Show)
|
||||||
|
|
||||||
|
-- Main conversion function. May return an error message if the program
|
||||||
|
-- is not well-typed.
|
||||||
|
convert :: Ast.Program -> Either String TAst.TProgram
|
||||||
|
convert (Ast.Program structs enums funcs) = do
|
||||||
|
let ctx = Ctx structs enums [] []
|
||||||
|
let (funcs', _) = runState (mapM convertFunc funcs) ctx
|
||||||
|
return $ TAst.TProgram structs enums funcs'
|
||||||
|
|
||||||
|
-- Convert a TLFunc (Top Level Function) to a TTLFunc (Typed Top Level Function)
|
||||||
|
-- Note that the function must be added to the context before converting statements
|
||||||
|
-- of the function. This is because the function may call itself recursively.
|
||||||
|
-- After converting the function, the function's statements are converted
|
||||||
|
-- and added to the context.
|
||||||
|
convertFunc :: MonadState Ctx m => Ast.TLFunc -> m TAst.TTLFunc
|
||||||
|
convertFunc (Ast.Func name args retType body) = do
|
||||||
|
args' <- mapM (\(Bind name t) -> return (name, t)) args
|
||||||
|
oldFuncs <- gets funcs
|
||||||
|
modify (\ctx -> ctx { funcs = funcs ctx ++ [TTLFunc name args retType []], vars = args' })
|
||||||
|
body' <- mapM convertStmt body
|
||||||
|
ctx <- get
|
||||||
|
let func = (last $ funcs ctx) { TAst.funcBody = body' }
|
||||||
|
put $ ctx { funcs = oldFuncs ++ [func] }
|
||||||
|
return func
|
||||||
|
|
||||||
|
-- Convert a statement
|
||||||
|
convertStmt :: MonadState Ctx m => Ast.Stmt -> m TAst.TStmt
|
||||||
|
|
||||||
|
convertStmt (Ast.Expr expr) = do
|
||||||
|
expr' <- convertExpr expr
|
||||||
|
return $ TAst.TExprStmt expr'
|
||||||
|
|
||||||
|
convertStmt (Ast.Return expr) = do
|
||||||
|
expr' <- convertExpr expr
|
||||||
|
return $ TAst.TReturn expr'
|
||||||
|
|
||||||
|
convertStmt (Ast.If cond thenStmts elseStmts) = do
|
||||||
|
thenStmts' <- mapM convertStmt thenStmts
|
||||||
|
elseStmts' <- mapM convertStmt $ maybe [] id elseStmts
|
||||||
|
cond' <- convertExpr cond
|
||||||
|
return $ TAst.TIf cond' thenStmts' (Just elseStmts')
|
||||||
|
|
||||||
|
convertStmt (Ast.While cond stmts) = do
|
||||||
|
stmts' <- mapM convertStmt stmts
|
||||||
|
cond' <- convertExpr cond
|
||||||
|
return $ TAst.TWhile cond' stmts'
|
||||||
|
|
||||||
|
convertStmt (Ast.Assign op lval expr) = do
|
||||||
|
lval' <- convertLVal lval
|
||||||
|
expr' <- convertExpr expr
|
||||||
|
return $ TAst.TAssign op lval' expr'
|
||||||
|
|
||||||
|
convertStmt (Ast.Block stmts) = do
|
||||||
|
stmts' <- mapM convertStmt stmts
|
||||||
|
return $ TAst.TBlock stmts'
|
||||||
|
|
||||||
|
convertStmt (Ast.Var name (Just t) maybeExpr) = do
|
||||||
|
expr' <- maybe (return Nothing) (fmap Just . convertExpr) maybeExpr
|
||||||
|
modify (\ctx -> ctx { vars = (name, t) : vars ctx })
|
||||||
|
return $ TAst.TDeclVar name t expr'
|
||||||
|
|
||||||
|
-- TODO
|
||||||
|
convertStmt (Ast.Var name Nothing maybeExpr) = error "Type inference not implemented"
|
||||||
|
|
||||||
|
-- Convert an expression to an LValue
|
||||||
|
-- Only certain expressions are allowed as LValues
|
||||||
|
convertLVal :: MonadState Ctx m => Ast.Expr -> m TAst.TLVal
|
||||||
|
convertLVal (Ast.Id name) = do
|
||||||
|
ctx <- get
|
||||||
|
case lookup name (vars ctx) of
|
||||||
|
Just t -> return (t, TAst.TId name)
|
||||||
|
Nothing -> error $ "Variable " ++ show name ++ " not in scope"
|
||||||
|
|
||||||
|
convertLVal (Ast.Index arr idx) = do
|
||||||
|
arr' <- convertLVal arr
|
||||||
|
idx' <- convertExpr idx
|
||||||
|
return (fst arr', TAst.LTIndex arr' idx')
|
||||||
|
|
||||||
|
convertLVal (Ast.Member e (Id m)) = do
|
||||||
|
e' <- convertLVal e
|
||||||
|
return (fst e', TAst.LTMember e' m)
|
||||||
|
|
||||||
|
convertLVal (Ast.Member e m) = do error $ "Invalid member access " ++ show m ++ " on " ++ show e
|
||||||
|
|
||||||
|
convertLVal (Ast.UnOp Ast.Deref e) = error "Dereferencing not implemented"
|
||||||
|
|
||||||
|
convertLVal e = do error $ "Invalid or unimplemented LValue " ++ show e
|
||||||
|
|
||||||
|
-- Convert an expression
|
||||||
|
convertExpr :: MonadState Ctx m => Ast.Expr -> m TAst.TExpr
|
||||||
|
convertExpr (Ast.Id name) = do
|
||||||
|
ctx <- get
|
||||||
|
case lookup name (vars ctx) of
|
||||||
|
Just t -> return (t, TAst.TVar name)
|
||||||
|
Nothing -> error $ "Variable " ++ show name ++ " not in scope"
|
||||||
|
|
||||||
|
convertExpr (Ast.IntLit x) = return (IntType, TAst.TIntLit x)
|
||||||
|
convertExpr (Ast.UIntLit x) = return (UIntType, TAst.TUIntLit x)
|
||||||
|
convertExpr (Ast.FloatLit x) = return (FloatType, TAst.TFloatLit x)
|
||||||
|
convertExpr (Ast.StrLit x) = return (StrType, TAst.TStrLit x)
|
||||||
|
convertExpr (Ast.BoolLit x) = return (BoolType, TAst.TBoolLit x)
|
||||||
|
convertExpr (Ast.CharLit x) = return (CharType, TAst.TCharLit x)
|
||||||
|
|
||||||
|
convertExpr (Ast.BinOp Add l r) = arithOp Add l r
|
||||||
|
convertExpr (Ast.BinOp Sub l r) = arithOp Sub l r
|
||||||
|
convertExpr (Ast.BinOp Mul l r) = arithOp Mul l r
|
||||||
|
convertExpr (Ast.BinOp Div l r) = arithOp Div l r
|
||||||
|
convertExpr (Ast.BinOp Mod l r) = arithOp Mod l r
|
||||||
|
|
||||||
|
convertExpr (Ast.BinOp Eq l r) = compOp Eq l r
|
||||||
|
convertExpr (Ast.BinOp Ne l r) = compOp Ne l r
|
||||||
|
convertExpr (Ast.BinOp Lt l r) = compOp Lt l r
|
||||||
|
convertExpr (Ast.BinOp Gt l r) = compOp Gt l r
|
||||||
|
convertExpr (Ast.BinOp Le l r) = compOp Le l r
|
||||||
|
convertExpr (Ast.BinOp Ge l r) = compOp Ge l r
|
||||||
|
|
||||||
|
convertExpr (Ast.BinOp And l r) = boolOp And l r
|
||||||
|
convertExpr (Ast.BinOp Or l r) = boolOp Or l r
|
||||||
|
|
||||||
|
convertExpr (Ast.BinOp BitAnd l r) = bitOp BitAnd l r
|
||||||
|
convertExpr (Ast.BinOp BitOr l r) = bitOp BitOr l r
|
||||||
|
convertExpr (Ast.BinOp BitXor l r) = bitOp BitXor l r
|
||||||
|
|
||||||
|
convertExpr (Ast.BinOp ShiftL l r) = shiftOp ShiftL l r
|
||||||
|
convertExpr (Ast.BinOp ShiftR l r) = shiftOp ShiftR l r
|
||||||
|
|
||||||
|
convertExpr (Ast.UnOp Neg e) = do
|
||||||
|
e' <- convertExpr e
|
||||||
|
if fst e' `elem` [IntType, UIntType, FloatType]
|
||||||
|
then return (fst e', TAst.TUnOp Neg e')
|
||||||
|
else error $ "Type mismatch: " ++ show e
|
||||||
|
|
||||||
|
convertExpr (Ast.UnOp Not e) = do
|
||||||
|
e' <- convertExpr e
|
||||||
|
if fst e' == BoolType
|
||||||
|
then return (BoolType, TAst.TUnOp Not e')
|
||||||
|
else error $ "Type mismatch: " ++ show e
|
||||||
|
|
||||||
|
convertExpr (Ast.UnOp BitNot e) = undefined
|
||||||
|
convertExpr (Ast.UnOp Deref e) = undefined
|
||||||
|
convertExpr (Ast.UnOp AddrOf e) = undefined
|
||||||
|
|
||||||
|
-- TODO type check function return
|
||||||
|
-- TODO ensure returns on all paths
|
||||||
|
-- Lower priority since LLVM checks this also
|
||||||
|
convertExpr (Ast.Call (Id f) args) = do
|
||||||
|
ctx <- get
|
||||||
|
if f == "printf"
|
||||||
|
then do
|
||||||
|
args' <- mapM convertExpr args
|
||||||
|
return (IntType, TAst.TCall "printf" args')
|
||||||
|
else case find (\(TTLFunc n a r _) -> n == f) (funcs ctx) of
|
||||||
|
Just t -> do
|
||||||
|
args' <- mapM convertExpr args
|
||||||
|
if length args' == length (TAst.funcArgs t) && all (\(t1, t2) -> t1 == t2) (zip (map fst args') (map bindType (TAst.funcArgs t)))
|
||||||
|
then return (TAst.funcRetType t, TAst.TCall f args')
|
||||||
|
else error $ "Type mismatch in call to " ++ show f
|
||||||
|
Nothing -> error $ "Function " ++ show f ++ " not in scope. Available functions: " ++ show (map TAst.funcName (funcs ctx))
|
||||||
|
|
||||||
|
convertExpr (Ast.Index arr idx) = do
|
||||||
|
arr' <- convertExpr arr
|
||||||
|
idx' <- convertExpr idx
|
||||||
|
case fst arr' of
|
||||||
|
ArrayType t -> if fst idx' == IntType
|
||||||
|
then return (t, TAst.TIndex arr' idx')
|
||||||
|
else error $ "Index must be an integer: " ++ show idx
|
||||||
|
_ -> error $ "Indexing non-array: " ++ show arr
|
||||||
|
|
||||||
|
convertExpr (Ast.Cast t e) = do
|
||||||
|
e' <- convertExpr e
|
||||||
|
return (t, TAst.TCast t e')
|
||||||
|
|
||||||
|
convertExpr (Ast.Sizeof t) = return (IntType, TAst.TSizeof t)
|
||||||
|
|
||||||
|
convertExpr (Ast.Member e (Id m)) = do
|
||||||
|
e' <- convertExpr e
|
||||||
|
case fst e' of
|
||||||
|
StructType name -> do
|
||||||
|
ctx <- get
|
||||||
|
case find (\(Struct n _) -> n == name) (structs ctx) of
|
||||||
|
Just (Struct _ binds) -> case find (\(Bind n t) -> n == m) binds of
|
||||||
|
Just (Bind _ t) -> return (t, TAst.TMember e' m)
|
||||||
|
Nothing -> error $ "Field " ++ show m ++ " not in struct " ++ show name
|
||||||
|
Nothing -> error $ "Struct " ++ show name ++ " not in scope"
|
||||||
|
_ -> error $ "Member access on non-struct " ++ show e
|
||||||
|
|
||||||
|
convertExpr (Ast.StructInit name fields) = do
|
||||||
|
ctx <- get
|
||||||
|
case find (\(Struct n _) -> n == name) (structs ctx) of
|
||||||
|
Just (Struct _ binds) -> do
|
||||||
|
fields' <- mapM (\(n, e) -> do
|
||||||
|
e' <- convertExpr e
|
||||||
|
case find (\(Bind n' t) -> n == n') binds of
|
||||||
|
Just (Bind _ t) -> if fst e' == t
|
||||||
|
then return (n, e')
|
||||||
|
else error $ "Type mismatch in struct initialization: " ++ show e
|
||||||
|
Nothing -> error $ "Field " ++ show n ++ " not in struct " ++ show name) fields
|
||||||
|
|
||||||
|
return (StructType name, TAst.TStructInit name fields')
|
||||||
|
Nothing -> error $ "Struct " ++ show name ++ " not in scope"
|
||||||
|
|
||||||
|
convertExpr e = error $ "Invalid or Unimplemented conversion for expression " ++ show e
|
||||||
|
|
||||||
|
-- Ensure that the types of the left and right expressions are the same
|
||||||
|
-- and return the type of the result
|
||||||
|
arithOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
|
||||||
|
arithOp o l r = do
|
||||||
|
l' <- convertExpr l
|
||||||
|
r' <- convertExpr r
|
||||||
|
if fst l' == fst r'
|
||||||
|
then return (fst l', TAst.TBinOp o l' r')
|
||||||
|
else error $ "Type mismatch: " ++ show l ++ " and " ++ show r
|
||||||
|
|
||||||
|
-- Ensure that the types of the left and right expressions are the same
|
||||||
|
-- and return a boolean type
|
||||||
|
compOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
|
||||||
|
compOp o l r = do
|
||||||
|
l' <- convertExpr l
|
||||||
|
r' <- convertExpr r
|
||||||
|
if fst l' == fst r'
|
||||||
|
then return (BoolType, TAst.TBinOp o l' r')
|
||||||
|
else error $ "Type mismatch: " ++ show l ++ " and " ++ show r
|
||||||
|
|
||||||
|
-- Ensure that the types of both expressions are boolean
|
||||||
|
-- and return a boolean type
|
||||||
|
boolOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
|
||||||
|
boolOp o l r = do
|
||||||
|
l' <- convertExpr l
|
||||||
|
r' <- convertExpr r
|
||||||
|
if fst l' == fst r' && fst l' == BoolType
|
||||||
|
then return (BoolType, TAst.TBinOp o l' r')
|
||||||
|
else error $ "Type mismatch: " ++ show l ++ " and " ++ show r
|
||||||
|
|
||||||
|
bitOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
|
||||||
|
bitOp o l r = do error $ "Bit operations not implemented"
|
||||||
|
|
||||||
|
shiftOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
|
||||||
|
shiftOp o l r = do error $ "Shift operations not implemented"
|
@ -70,6 +70,7 @@ executable windows12
|
|||||||
Windows12.Parser
|
Windows12.Parser
|
||||||
Windows12.CodeGen
|
Windows12.CodeGen
|
||||||
Windows12.TAst
|
Windows12.TAst
|
||||||
|
Windows12.Semant
|
||||||
|
|
||||||
-- LANGUAGE extensions used by modules in this package.
|
-- LANGUAGE extensions used by modules in this package.
|
||||||
-- other-extensions:
|
-- other-extensions:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user