{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecursiveDo #-} {-# LANGUAGE MultiParamTypeClasses #-} module Windows12.CodeGen where import Windows12.Ast (BinOp(..), UnOp(..), AssignOp(..), Type(..), Bind(..), TLStruct(..), TLEnum(..)) import Windows12.TAst import LLVM.AST hiding (ArrayType, VoidType, Call, function) import LLVM.AST.Type (i32, i1, i8, double, ptr, void) import qualified LLVM.AST.Constant as C import LLVM.IRBuilder hiding (double, IRBuilder, ModuleBuilder) import LLVM.AST.Typed (typeOf) import LLVM.Prelude (ShortByteString) import qualified LLVM.AST.IntegerPredicate as IP import qualified LLVM.AST.FloatingPointPredicate as FP import Control.Monad.State hiding (void) import Data.Text (Text, unpack) import Data.String.Conversions import Data.String -- Global program context, used to keep track of operands data Ctx = Ctx { operands :: [(Text, Operand)], structs :: [TLStruct], enums :: [TLEnum], strings :: [(Text, Operand)] } deriving (Eq, Show) type ModuleBuilder = ModuleBuilderT (State Ctx) type IRBuilder = IRBuilderT ModuleBuilder -- Allow easy string conversion instance ConvertibleStrings Text ShortByteString where convertString = Data.String.fromString . Data.Text.unpack -- Put an operand into the context with a name createOperand :: MonadState Ctx m => Text -> Operand -> m () createOperand name op = do ctx <- get put $ ctx { operands = (name, op) : operands ctx } -- Take in a source file name, the AST, and return the LLVM IR module codegen :: Text -> TProgram -> Module codegen filename (TProgram structs enums funcs) = flip evalState (Ctx [] [] [] []) $ buildModuleT (cs filename) $ do printf <- externVarArgs (mkName "printf") [ptr i8] i32 createOperand "printf" printf mapM_ emitTypeDef structs mapM_ codegenFunc funcs -- Given a struct name, search the context for the struct and return its fields getStructFields :: MonadState Ctx m => Text -> m [Bind] getStructFields name = do ctx <- get case filter (\(Struct n _) -> n == name) (structs ctx) of [] -> error $ "Struct " ++ show name ++ " not found. Valid structs: " ++ show (map (\(Struct n _) -> n) (structs ctx)) [Struct _ fields] -> return fields _ -> error $ "Multiple structs with name " ++ show name -- Convert a Windows12 type to an LLVM type convertType :: MonadState Ctx m => Windows12.Ast.Type -> m LLVM.AST.Type convertType IntType = return i32 convertType UIntType = return i32 convertType FloatType = return double convertType StrType = convertType (PtrType CharType) convertType BoolType = return i1 convertType CharType = return i8 convertType (PtrType t) = ptr <$> convertType t convertType (ArrayType t) = convertType (PtrType t) convertType (StructType name) = do fields <- getStructFields name types <- mapM (convertType . bindType) fields return $ StructureType True types -- True indicates packed convertType (EnumType name) = return i32 convertType VoidType = return void -- Get the size of a type in bytes size :: MonadState Ctx m => Windows12.Ast.Type -> m Int size IntType = return 4 size UIntType = return 4 size FloatType = return 8 size StrType = size (PtrType CharType) size BoolType = return 1 size CharType = return 1 size (PtrType _) = return 4 size (ArrayType t) = size (PtrType t) size (StructType name) = do fields <- getStructFields name sizes <- mapM (size . bindType) fields return $ sum sizes size (EnumType _) = return 8 size VoidType = return 0 -- CodeGen for LValues codegenLVal :: TLVal -> IRBuilder Operand codegenLVal (t, (TId name)) = do ctx <- get case lookup name (operands ctx) of Just op -> return op Nothing -> error $ "Variable " ++ show name ++ " not found" -- TODO support members of members codegenLVal ((StructType t), (LTMember ((_, TId sName)) field)) = do ctx <- get case lookup sName (operands ctx) of Just struct -> do fields <- getStructFields t offset <- structFieldOffset (Struct sName fields) field gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))] Nothing -> error $ "Struct " ++ show sName ++ " not found" codeGenLVal (t, (TDeref e)) = codegenExpr e codeGenLVal (t, _) = error "Unimplemented or invalid LValue" -- Given a struct and a field name, return the offset of the field in the struct. -- In LLVM each field is actually size 1 structFieldOffset :: MonadState Ctx m => TLStruct -> Text -> m Int structFieldOffset (Struct name fields) field = do return $ length $ takeWhile (\(Bind n _) -> n /= field) fields -- CodeGen for expressions codegenExpr :: TExpr -> IRBuilder Operand codegenExpr (t, (TVar name)) = flip load 0 =<< codegenLVal (t, (TId name)) codegenExpr (t, (TIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i)) codegenExpr (t, (TUIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i)) codegenExpr (t, (TFloatLit f)) = undefined -- TODO floats codegenExpr (t, (TStrLit s)) = do strs <- gets strings case lookup s strs of -- If the string is already in the context, return it Just str -> return str -- Otherwise, create a new global string and add it to the context Nothing -> do let str_name = mkName ("str." <> show (length strs)) op <- globalStringPtr (cs s) str_name modify $ \ctx -> ctx { strings = (s, (ConstantOperand op)) : strs } return (ConstantOperand op) codegenExpr (t, (TBoolLit b)) = return $ ConstantOperand (C.Int 1 (if b then 1 else 0)) codegenExpr (t, (TCharLit c)) = return $ ConstantOperand (C.Int 8 (fromIntegral (fromEnum c))) codegenExpr (t, (TBinOp op lhs rhs)) = do lhs' <- codegenExpr lhs rhs' <- codegenExpr rhs -- TODO pointers, floating points case op of Windows12.Ast.Add -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> add lhs' rhs' _ -> error "Invalid types for add" Windows12.Ast.Sub -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> sub lhs' rhs' _ -> error "Invalid types for sub" Windows12.Ast.Mul -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> mul lhs' rhs' _ -> error "Invalid types for mul" Windows12.Ast.Div -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> sdiv lhs' rhs' _ -> error "Invalid types for div" Windows12.Ast.Mod -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> srem lhs' rhs' _ -> error "Invalid types for mod" Windows12.Ast.Eq -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> icmp IP.EQ lhs' rhs' _ -> error "Invalid types for eq" Windows12.Ast.Ne -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> icmp IP.NE lhs' rhs' _ -> error "Invalid types for ne" Windows12.Ast.Lt -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> icmp IP.SLT lhs' rhs' _ -> error "Invalid types for lt" Windows12.Ast.Gt -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> icmp IP.SGT lhs' rhs' _ -> error "Invalid types for gt" Windows12.Ast.Le -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> icmp IP.SLE lhs' rhs' _ -> error "Invalid types for le" Windows12.Ast.Ge -> case (typeOf lhs', typeOf rhs') of (IntegerType 32, IntegerType 32) -> icmp IP.SGE lhs' rhs' _ -> error "Invalid types for ge" other -> error $ "Operator " ++ show other ++ " not implemented" codegenExpr (t, (TUnOp op e)) = undefined -- TODO handle unary operators -- Function calls: look up the function in operands, then call it with the args codegenExpr (t, (TCall f args)) = do ctx <- get f <- case lookup f (operands ctx) of Just f -> return f Nothing -> error $ "Function " ++ show f ++ " not found" args <- mapM (fmap (, []) . codegenExpr) args call f args codegenExpr (t, (TIndex arr idx)) = undefined -- TODO arrays -- Get the address of the struct field and load it codegenExpr (t, (TMember ((StructType sName), (TVar sVarName)) m)) = do ctx <- get case lookup sVarName (operands ctx) of Just struct -> do fields <- getStructFields sName offset <- structFieldOffset (Struct sVarName fields) m addr <- gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))] load addr 0 Nothing -> error $ "Struct operand " ++ show sVarName ++ " not found" codegenExpr (_, (TCast t e)) = undefined -- TODO casts codegenExpr (_, (TSizeof t)) = ConstantOperand . C.Int 32 . fromIntegral <$> size t mkTerminator :: IRBuilder () -> IRBuilder () mkTerminator instr = do check <- hasTerminator unless check instr -- Codegen for statements codegenStmt :: TStmt -> IRBuilder () -- For expression statements, just evaluate the expression and discard the result codegenStmt (TExprStmt e) = do _expr <- codegenExpr e return () codegenStmt (TReturn e) = ret =<< codegenExpr e -- Generate if statements, with a merge block at the end codegenStmt (TIf cond t f) = mdo cond' <- codegenExpr cond condBr cond' then' else' then' <- block `named` "then" codegenStmt (TBlock t) mkTerminator $ br merge else' <- block `named` "else" codegenStmt (case f of Just f' -> TBlock f' Nothing -> TBlock []) mkTerminator $ br merge merge <- block `named` "merge" return () -- Generate while loops, with a merge block at the end codegenStmt (TWhile cond body) = mdo br condBlock condBlock <- block `named` "cond" cond' <- codegenExpr cond condBr cond' loop end loop <- block `named` "loop" codegenStmt (TBlock body) mkTerminator $ br condBlock end <- block `named` "end" return () codegenStmt (TAssign BaseAssign l@(t, (TId name)) e) = do op <- codegenExpr e var <- codegenLVal l store var 0 op codegenStmt (TAssign BaseAssign l@((StructType tName), (LTMember ((_, TId sName)) field)) e) = do op <- codegenExpr e struct <- codegenLVal l store struct 0 op codegenStmt (TAssign AddAssign l@(t, (TId name)) e) = do op <- codegenExpr e var <- codegenLVal l val <- load var 0 store var 0 =<< add val op codegenStmt (TAssign SubAssign l@(t, (TId name)) e) = do op <- codegenExpr e var <- codegenLVal l val <- load var 0 store var 0 =<< sub val op -- A block is just a list of statements codegenStmt (TBlock stmts) = mapM_ codegenStmt stmts -- Since the vars are already allocated by genBody, we just need to assign the value codegenStmt (TDeclVar name t (Just e)) = codegenStmt (TAssign BaseAssign (t, (TId name)) e) -- Do nothing with variable declaration if no expression is given -- This is because allocation is done already codegenStmt (TDeclVar name _ Nothing) = return () codegenStmt s = error $ "Unimplemented or invalid statement " ++ show s -- Generate code for a function -- First create the function, then allocate space for the arguments and locals codegenFunc :: TTLFunc -> ModuleBuilder () codegenFunc func@(TTLFunc name args retType body) = mdo createOperand name f (f, strs) <- do params' <- mapM mkParam args retType' <- convertType retType f <- function (mkName (cs name)) params' retType' genBody strs <- gets strings return (f, strs) modify $ \ctx -> ctx { strings = strs } where mkParam (Bind name t) = (,) <$> convertType t <*> pure (ParameterName (cs name)) genBody :: [Operand] -> IRBuilder () genBody ops = do forM_ (zip ops args) $ \(op, (Bind name t)) -> do addr <- alloca (typeOf op) Nothing 0 store addr 0 op createOperand name addr forM_ (getLocals func) $ \(Bind name t) -> do ltype <- convertType t addr <- alloca ltype Nothing 0 createOperand name addr codegenStmt (TBlock body) -- Given a function, get all the local variables -- Used so allocation can be done before the function body getLocals :: TTLFunc -> [Bind] getLocals (TTLFunc _ args _ body) = blockGetLocals body blockGetLocals :: [TStmt] -> [Bind] blockGetLocals = concatMap stmtGetLocals stmtGetLocals :: TStmt -> [Bind] stmtGetLocals (TDeclVar n t _) = [Bind n t] stmtGetLocals (TBlock stmts) = blockGetLocals stmts stmtGetLocals (TIf _ t f) = blockGetLocals t ++ maybe [] blockGetLocals f stmtGetLocals (TWhile _ body) = blockGetLocals body stmtGetLocals _ = [] -- Create structs emitTypeDef :: TLStruct -> ModuleBuilder LLVM.AST.Type emitTypeDef (Struct name fields) = do modify $ \ctx -> ctx { structs = Struct name fields : structs ctx } sType <- convertType (StructType name) typedef (mkName (cs ("struct." <> name))) (Just sType)