Remote TAst

This commit is contained in:
2025-08-31 20:56:31 -04:00
parent 78bfec0953
commit 1be6175120
6 changed files with 68 additions and 394 deletions

View File

@@ -12,7 +12,6 @@ import Windows12.Parser (programP)
import System.Environment (getArgs) import System.Environment (getArgs)
import LLVM.Pretty import LLVM.Pretty
import Windows12.Ast import Windows12.Ast
import Windows12.Semant (convert)
import Windows12.CodeGen (codegen) import Windows12.CodeGen (codegen)
@@ -27,7 +26,4 @@ main = do
test <- T.readFile inputFile test <- T.readFile inputFile
case parse programP inputFile test of case parse programP inputFile test of
Left err -> print err Left err -> print err
Right ast ->
case convert ast of
Left err -> putStrLn err
Right ast -> TL.writeFile outputFile (ppllvm (codegen (cs inputFile) ast)) Right ast -> TL.writeFile outputFile (ppllvm (codegen (cs inputFile) ast))

View File

@@ -4,5 +4,3 @@ import Windows12.Ast
import Windows12.Lexer import Windows12.Lexer
import Windows12.Parser import Windows12.Parser
import Windows12.CodeGen import Windows12.CodeGen
import Windows12.TAst
import Windows12.Semant

View File

@@ -6,9 +6,7 @@
module Windows12.CodeGen where module Windows12.CodeGen where
import Windows12.Ast (BinOp(..), UnOp(..), AssignOp(..), Type(..), import Windows12.Ast;
Bind(..), TLStruct(..), TLEnum(..))
import Windows12.TAst
import LLVM.AST hiding (ArrayType, VoidType, Call, function) import LLVM.AST hiding (ArrayType, VoidType, Call, function)
import LLVM.AST.Type (i32, i1, i8, double, ptr, void) import LLVM.AST.Type (i32, i1, i8, double, ptr, void)
@@ -27,7 +25,7 @@ import Data.String.Conversions
import Data.String import Data.String
-- Global program context, used to keep track of operands -- Global program context, used to keep track of operands
data Ctx = Ctx { operands :: [(Text, Operand)], data Ctx = Ctx { operands :: [(Text, (Maybe Windows12.Ast.Type, Operand))],
structs :: [TLStruct], structs :: [TLStruct],
enums :: [TLEnum], enums :: [TLEnum],
strings :: [(Text, Operand)] } strings :: [(Text, Operand)] }
@@ -41,19 +39,19 @@ instance ConvertibleStrings Text ShortByteString where
convertString = Data.String.fromString . Data.Text.unpack convertString = Data.String.fromString . Data.Text.unpack
-- Put an operand into the context with a name -- Put an operand into the context with a name
createOperand :: MonadState Ctx m => Text -> Operand -> m () createOperand :: MonadState Ctx m => Text -> Maybe Windows12.Ast.Type -> Operand -> m ()
createOperand name op = do createOperand name op_type op = do
ctx <- get ctx <- get
put $ ctx { operands = (name, op) : operands ctx } put $ ctx { operands = (name, (op_type, op)) : operands ctx }
-- Take in a source file name, the AST, and return the LLVM IR module -- Take in a source file name, the AST, and return the LLVM IR module
codegen :: Text -> TProgram -> Module codegen :: Text -> Program -> Module
codegen filename (TProgram structs enums funcs) = codegen filename (Program structs enums funcs) =
flip evalState (Ctx [] [] [] []) flip evalState (Ctx [] [] [] [])
$ buildModuleT (cs filename) $ buildModuleT (cs filename)
$ do $ do
printf <- externVarArgs (mkName "printf") [ptr i8] i32 printf <- externVarArgs (mkName "printf") [ptr i8] i32
createOperand "printf" printf createOperand "printf" Nothing printf
mapM_ emitTypeDef structs mapM_ emitTypeDef structs
mapM_ codegenFunc funcs mapM_ codegenFunc funcs
@@ -100,27 +98,26 @@ size (StructType name) = do
size (EnumType _) = return 8 size (EnumType _) = return 8
size VoidType = return 0 size VoidType = return 0
-- CodeGen for LValues -- CodeGen for LValues
codegenLVal :: TLVal -> IRBuilder Operand codegenLVal :: Expr -> IRBuilder Operand
codegenLVal (t, (TId name)) = do codegenLVal (Id name) = do
ctx <- get ctx <- get
case lookup name (operands ctx) of case lookup name (operands ctx) of
Just op -> return op Just (_type, op) -> return op
Nothing -> error $ "Variable " ++ show name ++ " not found" Nothing -> error $ "Variable " ++ show name ++ " not found"
-- TODO support members of members -- TODO support members of members
codegenLVal ((StructType t), (LTMember ((_, TId sName)) field)) = do codegenLVal (Member (Id sName) (Id field)) = do
ctx <- get ctx <- get
case lookup sName (operands ctx) of case lookup sName (operands ctx) of
Just struct -> do Just ((Just (StructType op_type)), struct) -> do
fields <- getStructFields t fields <- getStructFields op_type
offset <- structFieldOffset (Struct sName fields) field offset <- structFieldOffset (Struct sName fields) field
gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))] gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))]
Nothing -> error $ "Struct " ++ show sName ++ " not found" Nothing -> error $ "Struct " ++ show sName ++ " not found"
codeGenLVal (t, (TDeref e)) = codegenExpr e codeGenLVal _ = error "Unimplemented or invalid LValue"
codeGenLVal (t, _) = error "Unimplemented or invalid LValue"
-- Given a struct and a field name, return the offset of the field in the struct. -- Given a struct and a field name, return the offset of the field in the struct.
-- In LLVM each field is actually size 1 -- In LLVM each field is actually size 1
@@ -129,12 +126,12 @@ structFieldOffset (Struct name fields) field = do
return $ length $ takeWhile (\(Bind n _) -> n /= field) fields return $ length $ takeWhile (\(Bind n _) -> n /= field) fields
-- CodeGen for expressions -- CodeGen for expressions
codegenExpr :: TExpr -> IRBuilder Operand codegenExpr :: Expr -> IRBuilder Operand
codegenExpr (t, (TVar name)) = flip load 0 =<< codegenLVal (t, (TId name)) codegenExpr (Id name) = flip load 0 =<< codegenLVal (Id name)
codegenExpr (t, (TIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i)) codegenExpr (IntLit i) = return $ ConstantOperand (C.Int 32 (fromIntegral i))
codegenExpr (t, (TUIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i)) codegenExpr (UIntLit i) = return $ ConstantOperand (C.Int 32 (fromIntegral i))
codegenExpr (t, (TFloatLit f)) = undefined -- TODO floats codegenExpr (FloatLit f) = undefined -- TODO floats
codegenExpr (t, (TStrLit s)) = do codegenExpr (StrLit s) = do
strs <- gets strings strs <- gets strings
case lookup s strs of case lookup s strs of
-- If the string is already in the context, return it -- If the string is already in the context, return it
@@ -145,10 +142,11 @@ codegenExpr (t, (TStrLit s)) = do
op <- globalStringPtr (cs s) str_name op <- globalStringPtr (cs s) str_name
modify $ \ctx -> ctx { strings = (s, (ConstantOperand op)) : strs } modify $ \ctx -> ctx { strings = (s, (ConstantOperand op)) : strs }
return (ConstantOperand op) return (ConstantOperand op)
codegenExpr (t, (TBoolLit b)) = return $ ConstantOperand (C.Int 1 (if b then 1 else 0))
codegenExpr (t, (TCharLit c)) = return $ ConstantOperand (C.Int 8 (fromIntegral (fromEnum c)))
codegenExpr (t, (TBinOp op lhs rhs)) = do codegenExpr (BoolLit b) = return $ ConstantOperand (C.Int 1 (if b then 1 else 0))
codegenExpr (CharLit c) = return $ ConstantOperand (C.Int 8 (fromIntegral (fromEnum c)))
codegenExpr (BinOp op lhs rhs) = do
lhs' <- codegenExpr lhs lhs' <- codegenExpr lhs
rhs' <- codegenExpr rhs rhs' <- codegenExpr rhs
@@ -190,33 +188,33 @@ codegenExpr (t, (TBinOp op lhs rhs)) = do
other -> error $ "Operator " ++ show other ++ " not implemented" other -> error $ "Operator " ++ show other ++ " not implemented"
codegenExpr (t, (TUnOp op e)) = undefined -- TODO handle unary operators codegenExpr (UnOp op e) = undefined -- TODO handle unary operators
-- Function calls: look up the function in operands, then call it with the args -- Function calls: look up the function in operands, then call it with the args
codegenExpr (t, (TCall f args)) = do codegenExpr (Call (Id f) args) = do
ctx <- get ctx <- get
f <- case lookup f (operands ctx) of f <- case lookup f (operands ctx) of
Just f -> return f Just (_type, f) -> return f
Nothing -> error $ "Function " ++ show f ++ " not found" Nothing -> error $ "Function " ++ show f ++ " not found"
args <- mapM (fmap (, []) . codegenExpr) args args <- mapM (fmap (, []) . codegenExpr) args
call f args call f args
codegenExpr (t, (TIndex arr idx)) = undefined -- TODO arrays codegenExpr (Index arr idx) = undefined -- TODO arrays
-- Get the address of the struct field and load it -- Get the address of the struct field and load it
codegenExpr (t, (TMember ((StructType sName), (TVar sVarName)) m)) = do codegenExpr (Member (Id sVarName) (Id field)) = do
ctx <- get ctx <- get
case lookup sVarName (operands ctx) of case lookup sVarName (operands ctx) of
Just struct -> do Just ((Just (StructType op_type)), struct) -> do
fields <- getStructFields sName fields <- getStructFields op_type
offset <- structFieldOffset (Struct sVarName fields) m offset <- structFieldOffset (Struct op_type fields) field
addr <- gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))] addr <- gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))]
load addr 0 load addr 0
Nothing -> error $ "Struct operand " ++ show sVarName ++ " not found" Nothing -> error $ "Struct operand " ++ show sVarName ++ " not found"
codegenExpr (_, (TCast t e)) = undefined -- TODO casts codegenExpr (Cast t e) = undefined -- TODO casts
codegenExpr (_, (TSizeof t)) = ConstantOperand . C.Int 32 . fromIntegral <$> size t codegenExpr (Sizeof t) = ConstantOperand . C.Int 32 . fromIntegral <$> size t
mkTerminator :: IRBuilder () -> IRBuilder () mkTerminator :: IRBuilder () -> IRBuilder ()
mkTerminator instr = do mkTerminator instr = do
@@ -224,81 +222,81 @@ mkTerminator instr = do
unless check instr unless check instr
-- Codegen for statements -- Codegen for statements
codegenStmt :: TStmt -> IRBuilder () codegenStmt :: Stmt -> IRBuilder ()
-- For expression statements, just evaluate the expression and discard the result -- For expression statements, just evaluate the expression and discard the result
codegenStmt (TExprStmt e) = do codegenStmt (Expr e) = do
_expr <- codegenExpr e _expr <- codegenExpr e
return () return ()
codegenStmt (TReturn e) = ret =<< codegenExpr e codegenStmt (Return e) = ret =<< codegenExpr e
-- Generate if statements, with a merge block at the end -- Generate if statements, with a merge block at the end
codegenStmt (TIf cond t f) = mdo codegenStmt (If cond t f) = mdo
cond' <- codegenExpr cond cond' <- codegenExpr cond
condBr cond' then' else' condBr cond' then' else'
then' <- block `named` "then" then' <- block `named` "then"
codegenStmt (TBlock t) codegenStmt (Block t)
mkTerminator $ br merge mkTerminator $ br merge
else' <- block `named` "else" else' <- block `named` "else"
codegenStmt (case f of codegenStmt (case f of
Just f' -> TBlock f' Just f' -> Block f'
Nothing -> TBlock []) Nothing -> Block [])
mkTerminator $ br merge mkTerminator $ br merge
merge <- block `named` "merge" merge <- block `named` "merge"
return () return ()
-- Generate while loops, with a merge block at the end -- Generate while loops, with a merge block at the end
codegenStmt (TWhile cond body) = mdo codegenStmt (While cond body) = mdo
br condBlock br condBlock
condBlock <- block `named` "cond" condBlock <- block `named` "cond"
cond' <- codegenExpr cond cond' <- codegenExpr cond
condBr cond' loop end condBr cond' loop end
loop <- block `named` "loop" loop <- block `named` "loop"
codegenStmt (TBlock body) codegenStmt (Block body)
mkTerminator $ br condBlock mkTerminator $ br condBlock
end <- block `named` "end" end <- block `named` "end"
return () return ()
codegenStmt (TAssign BaseAssign l e) = do codegenStmt (Assign BaseAssign l e) = do
op <- codegenExpr e op <- codegenExpr e
var <- codegenLVal l var <- codegenLVal l
store var 0 op store var 0 op
codegenStmt (TAssign AddAssign l e) = do codegenStmt (Assign AddAssign l e) = do
op <- codegenExpr e op <- codegenExpr e
var <- codegenLVal l var <- codegenLVal l
val <- load var 0 val <- load var 0
store var 0 =<< add val op store var 0 =<< add val op
codegenStmt (TAssign SubAssign l e) = do codegenStmt (Assign SubAssign l e) = do
op <- codegenExpr e op <- codegenExpr e
var <- codegenLVal l var <- codegenLVal l
val <- load var 0 val <- load var 0
store var 0 =<< sub val op store var 0 =<< sub val op
-- A block is just a list of statements -- A block is just a list of statements
codegenStmt (TBlock stmts) = mapM_ codegenStmt stmts codegenStmt (Block stmts) = mapM_ codegenStmt stmts
-- Since the vars are already allocated by genBody, we just need to assign the value -- Since the vars are already allocated by genBody, we just need to assign the value
codegenStmt (TDeclVar name t (Just e)) = codegenStmt (TAssign BaseAssign (t, (TId name)) e) codegenStmt (Var name t (Just e)) = codegenStmt (Assign BaseAssign (Id name) e)
-- Do nothing with variable declaration if no expression is given -- Do nothing with variable declaration if no expression is given
-- This is because allocation is done already -- This is because allocation is done already
codegenStmt (TDeclVar name _ Nothing) = return () codegenStmt (Var name _ Nothing) = return ()
codegenStmt s = error $ "Unimplemented or invalid statement " ++ show s codegenStmt s = error $ "Unimplemented or invalid statement " ++ show s
-- Generate code for a function -- Generate code for a function
-- First create the function, then allocate space for the arguments and locals -- First create the function, then allocate space for the arguments and locals
codegenFunc :: TTLFunc -> ModuleBuilder () codegenFunc :: TLFunc -> ModuleBuilder ()
codegenFunc func@(TTLFunc name args retType body) = mdo codegenFunc func@(Func name args retType body) = mdo
createOperand name f createOperand name Nothing f
(f, strs) <- do (f, strs) <- do
params' <- mapM mkParam args params' <- mapM mkParam args
retType' <- convertType retType retType' <- convertType retType
@@ -311,31 +309,32 @@ codegenFunc func@(TTLFunc name args retType body) = mdo
genBody :: [Operand] -> IRBuilder () genBody :: [Operand] -> IRBuilder ()
genBody ops = do genBody ops = do
forM_ (zip ops args) $ \(op, (Bind name t)) -> do forM_ (zip ops args) $ \(op, Bind name t) -> do
addr <- alloca (typeOf op) Nothing 0 addr <- alloca (typeOf op) Nothing 0
store addr 0 op store addr 0 op
createOperand name addr createOperand name (Just t) addr
forM_ (getLocals func) $ \(Bind name t) -> do forM_ (getLocals func) $ \(Bind name t) -> do
ltype <- convertType t ltype <- convertType t
addr <- alloca ltype Nothing 0 addr <- alloca ltype Nothing 0
createOperand name addr createOperand name (Just t) addr
codegenStmt (TBlock body) codegenStmt (Block body)
-- Given a function, get all the local variables -- Given a function, get all the local variables
-- Used so allocation can be done before the function body -- Used so allocation can be done before the function body
getLocals :: TTLFunc -> [Bind] getLocals :: TLFunc -> [Bind]
getLocals (TTLFunc _ args _ body) = blockGetLocals body getLocals (Func _ args _ body) = blockGetLocals body
blockGetLocals :: [TStmt] -> [Bind] blockGetLocals :: [Stmt] -> [Bind]
blockGetLocals = concatMap stmtGetLocals blockGetLocals = concatMap stmtGetLocals
stmtGetLocals :: TStmt -> [Bind] stmtGetLocals :: Stmt -> [Bind]
stmtGetLocals (TDeclVar n t _) = [Bind n t] stmtGetLocals (Var n (Just t) _) = [Bind n t]
stmtGetLocals (TBlock stmts) = blockGetLocals stmts stmtGetLocals (Var n Nothing _) = error $ "Explicit typing required (var " ++ show n ++ ")"
stmtGetLocals (TIf _ t f) = blockGetLocals t ++ maybe [] blockGetLocals f stmtGetLocals (Block stmts) = blockGetLocals stmts
stmtGetLocals (TWhile _ body) = blockGetLocals body stmtGetLocals (If _ t f) = blockGetLocals t ++ maybe [] blockGetLocals f
stmtGetLocals (While _ body) = blockGetLocals body
stmtGetLocals _ = [] stmtGetLocals _ = []
-- Create structs -- Create structs

View File

@@ -1,263 +0,0 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module Windows12.Semant where
import Data.Text (Text)
import Control.Monad.State
import Data.List (find)
import Windows12.Ast as Ast
import Windows12.TAst as TAst
suppliedFuncs :: [Text]
suppliedFuncs = ["printf"]
-- Convert an Ast to a TAst
-- Performs type inference and type checking
data Ctx = Ctx { structs :: [TLStruct],
enums :: [TLEnum],
funcs :: [TTLFunc],
vars :: [(Text, Type)] }
deriving (Eq, Show)
-- Main conversion function. May return an error message if the program
-- is not well-typed.
convert :: Ast.Program -> Either String TAst.TProgram
convert (Ast.Program structs enums funcs) = do
let ctx = Ctx structs enums [] []
let (funcs', _) = runState (mapM convertFunc funcs) ctx
return $ TAst.TProgram structs enums funcs'
-- Convert a TLFunc (Top Level Function) to a TTLFunc (Typed Top Level Function)
-- Note that the function must be added to the context before converting statements
-- of the function. This is because the function may call itself recursively.
-- After converting the function, the function's statements are converted
-- and added to the context.
convertFunc :: MonadState Ctx m => Ast.TLFunc -> m TAst.TTLFunc
convertFunc (Ast.Func name args retType body) = do
args' <- mapM (\(Bind name t) -> return (name, t)) args
oldFuncs <- gets funcs
modify (\ctx -> ctx { funcs = funcs ctx ++ [TTLFunc name args retType []], vars = args' })
body' <- mapM convertStmt body
ctx <- get
let func = (last $ funcs ctx) { TAst.funcBody = body' }
put $ ctx { funcs = oldFuncs ++ [func] }
return func
-- Convert a statement
convertStmt :: MonadState Ctx m => Ast.Stmt -> m TAst.TStmt
convertStmt (Ast.Expr expr) = do
expr' <- convertExpr expr
return $ TAst.TExprStmt expr'
convertStmt (Ast.Return expr) = do
expr' <- convertExpr expr
return $ TAst.TReturn expr'
convertStmt (Ast.If cond thenStmts elseStmts) = do
thenStmts' <- mapM convertStmt thenStmts
elseStmts' <- mapM convertStmt $ maybe [] id elseStmts
cond' <- convertExpr cond
return $ TAst.TIf cond' thenStmts' (Just elseStmts')
convertStmt (Ast.While cond stmts) = do
stmts' <- mapM convertStmt stmts
cond' <- convertExpr cond
return $ TAst.TWhile cond' stmts'
convertStmt (Ast.Assign op lval expr) = do
lval' <- convertLVal lval
expr' <- convertExpr expr
return $ TAst.TAssign op lval' expr'
convertStmt (Ast.Block stmts) = do
stmts' <- mapM convertStmt stmts
return $ TAst.TBlock stmts'
convertStmt (Ast.Var name (Just t) maybeExpr) = do
expr' <- maybe (return Nothing) (fmap Just . convertExpr) maybeExpr
modify (\ctx -> ctx { vars = (name, t) : vars ctx })
return $ TAst.TDeclVar name t expr'
-- TODO
convertStmt (Ast.Var name Nothing maybeExpr) = error "Type inference not implemented"
-- Convert an expression to an LValue
-- Only certain expressions are allowed as LValues
convertLVal :: MonadState Ctx m => Ast.Expr -> m TAst.TLVal
convertLVal (Ast.Id name) = do
ctx <- get
case lookup name (vars ctx) of
Just t -> return (t, TAst.TId name)
Nothing -> error $ "Variable " ++ show name ++ " not in scope"
convertLVal (Ast.Index arr idx) = do
arr' <- convertLVal arr
idx' <- convertExpr idx
return (fst arr', TAst.LTIndex arr' idx')
convertLVal (Ast.Member e (Id m)) = do
e' <- convertLVal e
return (fst e', TAst.LTMember e' m)
convertLVal (Ast.Member e m) = do error $ "Invalid member access " ++ show m ++ " on " ++ show e
convertLVal (Ast.UnOp Ast.Deref e) = error "Dereferencing not implemented"
convertLVal e = do error $ "Invalid or unimplemented LValue " ++ show e
-- Convert an expression
convertExpr :: MonadState Ctx m => Ast.Expr -> m TAst.TExpr
convertExpr (Ast.Id name) = do
ctx <- get
case lookup name (vars ctx) of
Just t -> return (t, TAst.TVar name)
Nothing -> error $ "Variable " ++ show name ++ " not in scope"
convertExpr (Ast.IntLit x) = return (IntType, TAst.TIntLit x)
convertExpr (Ast.UIntLit x) = return (UIntType, TAst.TUIntLit x)
convertExpr (Ast.FloatLit x) = return (FloatType, TAst.TFloatLit x)
convertExpr (Ast.StrLit x) = return (StrType, TAst.TStrLit x)
convertExpr (Ast.BoolLit x) = return (BoolType, TAst.TBoolLit x)
convertExpr (Ast.CharLit x) = return (CharType, TAst.TCharLit x)
convertExpr (Ast.BinOp Add l r) = arithOp Add l r
convertExpr (Ast.BinOp Sub l r) = arithOp Sub l r
convertExpr (Ast.BinOp Mul l r) = arithOp Mul l r
convertExpr (Ast.BinOp Div l r) = arithOp Div l r
convertExpr (Ast.BinOp Mod l r) = arithOp Mod l r
convertExpr (Ast.BinOp Eq l r) = compOp Eq l r
convertExpr (Ast.BinOp Ne l r) = compOp Ne l r
convertExpr (Ast.BinOp Lt l r) = compOp Lt l r
convertExpr (Ast.BinOp Gt l r) = compOp Gt l r
convertExpr (Ast.BinOp Le l r) = compOp Le l r
convertExpr (Ast.BinOp Ge l r) = compOp Ge l r
convertExpr (Ast.BinOp And l r) = boolOp And l r
convertExpr (Ast.BinOp Or l r) = boolOp Or l r
convertExpr (Ast.BinOp BitAnd l r) = bitOp BitAnd l r
convertExpr (Ast.BinOp BitOr l r) = bitOp BitOr l r
convertExpr (Ast.BinOp BitXor l r) = bitOp BitXor l r
convertExpr (Ast.BinOp ShiftL l r) = shiftOp ShiftL l r
convertExpr (Ast.BinOp ShiftR l r) = shiftOp ShiftR l r
convertExpr (Ast.UnOp Neg e) = do
e' <- convertExpr e
if fst e' `elem` [IntType, UIntType, FloatType]
then return (fst e', TAst.TUnOp Neg e')
else error $ "Type mismatch: " ++ show e
convertExpr (Ast.UnOp Not e) = do
e' <- convertExpr e
if fst e' == BoolType
then return (BoolType, TAst.TUnOp Not e')
else error $ "Type mismatch: " ++ show e
convertExpr (Ast.UnOp BitNot e) = undefined
convertExpr (Ast.UnOp Deref e) = undefined
convertExpr (Ast.UnOp AddrOf e) = undefined
-- TODO type check function return
-- TODO ensure returns on all paths
-- Lower priority since LLVM checks this also
convertExpr (Ast.Call (Id f) args) = do
ctx <- get
if f == "printf"
then do
args' <- mapM convertExpr args
return (IntType, TAst.TCall "printf" args')
else case find (\(TTLFunc n a r _) -> n == f) (funcs ctx) of
Just t -> do
args' <- mapM convertExpr args
if length args' == length (TAst.funcArgs t) && all (\(t1, t2) -> t1 == t2) (zip (map fst args') (map bindType (TAst.funcArgs t)))
then return (TAst.funcRetType t, TAst.TCall f args')
else error $ "Type mismatch in call to " ++ show f
Nothing -> error $ "Function " ++ show f ++ " not in scope. Available functions: " ++ show (map TAst.funcName (funcs ctx))
convertExpr (Ast.Index arr idx) = do
arr' <- convertExpr arr
idx' <- convertExpr idx
case fst arr' of
ArrayType t -> if fst idx' == IntType
then return (t, TAst.TIndex arr' idx')
else error $ "Index must be an integer: " ++ show idx
_ -> error $ "Indexing non-array: " ++ show arr
convertExpr (Ast.Cast t e) = do
e' <- convertExpr e
return (t, TAst.TCast t e')
convertExpr (Ast.Sizeof t) = return (IntType, TAst.TSizeof t)
convertExpr (Ast.Member e (Id m)) = do
e' <- convertExpr e
case fst e' of
StructType name -> do
ctx <- get
case find (\(Struct n _) -> n == name) (structs ctx) of
Just (Struct _ binds) -> case find (\(Bind n t) -> n == m) binds of
Just (Bind _ t) -> return (t, TAst.TMember e' m)
Nothing -> error $ "Field " ++ show m ++ " not in struct " ++ show name
Nothing -> error $ "Struct " ++ show name ++ " not in scope"
_ -> error $ "Member access on non-struct " ++ show e
convertExpr (Ast.StructInit name fields) = do
ctx <- get
case find (\(Struct n _) -> n == name) (structs ctx) of
Just (Struct _ binds) -> do
fields' <- mapM (\(n, e) -> do
e' <- convertExpr e
case find (\(Bind n' t) -> n == n') binds of
Just (Bind _ t) -> if fst e' == t
then return (n, e')
else error $ "Type mismatch in struct initialization: " ++ show e
Nothing -> error $ "Field " ++ show n ++ " not in struct " ++ show name) fields
return (StructType name, TAst.TStructInit name fields')
Nothing -> error $ "Struct " ++ show name ++ " not in scope"
convertExpr e = error $ "Invalid or Unimplemented conversion for expression " ++ show e
-- Ensure that the types of the left and right expressions are the same
-- and return the type of the result
arithOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
arithOp o l r = do
l' <- convertExpr l
r' <- convertExpr r
if fst l' == fst r'
then return (fst l', TAst.TBinOp o l' r')
else error $ "Type mismatch: " ++ show l ++ " and " ++ show r
-- Ensure that the types of the left and right expressions are the same
-- and return a boolean type
compOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
compOp o l r = do
l' <- convertExpr l
r' <- convertExpr r
if fst l' == fst r'
then return (BoolType, TAst.TBinOp o l' r')
else error $ "Type mismatch: " ++ show l ++ " and " ++ show r
-- Ensure that the types of both expressions are boolean
-- and return a boolean type
boolOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
boolOp o l r = do
l' <- convertExpr l
r' <- convertExpr r
if fst l' == fst r' && fst l' == BoolType
then return (BoolType, TAst.TBinOp o l' r')
else error $ "Type mismatch: " ++ show l ++ " and " ++ show r
bitOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
bitOp o l r = do error $ "Bit operations not implemented"
shiftOp :: MonadState Ctx m => Ast.BinOp -> Ast.Expr -> Ast.Expr -> m TAst.TExpr
shiftOp o l r = do error $ "Shift operations not implemented"

View File

@@ -1,54 +0,0 @@
module Windows12.TAst where
import Data.Text (Text)
import Windows12.Ast as Ast
-- "Typed AST". A second AST that contains more type information
-- Makes verification easier, and is needed to determine type
-- of structs when accessing members in CodeGen
type TExpr = (Type, TExpr')
data TExpr'
= TVar Text
| TIntLit Int
| TUIntLit Word
| TFloatLit Double
| TStrLit Text
| TBoolLit Bool
| TCharLit Char
| TBinOp BinOp TExpr TExpr
| TUnOp UnOp TExpr
| TCall Text [TExpr]
| TIndex TExpr TExpr
| TMember TExpr Text
| TCast Type TExpr
| TSizeof Type
| TStructInit Text [(Text, TExpr)]
deriving (Show, Eq)
type TLVal = (Type, TLVal')
data TLVal'
= TDeref TExpr
| TId Text
| LTIndex TLVal TExpr
| LTMember TLVal Text
deriving (Show, Eq)
data TStmt
= TExprStmt TExpr
| TReturn TExpr
| TIf TExpr [TStmt] (Maybe [TStmt])
| TWhile TExpr [TStmt]
| TAssign AssignOp TLVal TExpr
| TBlock [TStmt]
| TDeclVar Text Type (Maybe TExpr)
deriving (Show, Eq)
data TTLFunc = TTLFunc {funcName :: Text, funcArgs :: [Bind], funcRetType :: Type, funcBody :: [TStmt]}
deriving (Show, Eq)
data TProgram = TProgram [TLStruct] [TLEnum] [TTLFunc]
deriving (Show, Eq)

View File

@@ -69,8 +69,6 @@ executable windows12
Windows12.Lexer Windows12.Lexer
Windows12.Parser Windows12.Parser
Windows12.CodeGen Windows12.CodeGen
Windows12.TAst
Windows12.Semant
-- LANGUAGE extensions used by modules in this package. -- LANGUAGE extensions used by modules in this package.
-- other-extensions: -- other-extensions: