diff --git a/src/codegen.cpp b/src/codegen.cpp index 3d685fd..7a13f71 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -136,6 +136,7 @@ private: const char* getSourceBaseType(llvm::Type* type); std::string getSourceType(llvm::Type* type); + llvm::Value* evalBooleanCondition(const Expression* condition_expression); llvm::Value* evalTruthiness(llvm::Value* val); llvm::Function* current_function() { return builder_.GetInsertBlock()->getParent(); } @@ -329,6 +330,17 @@ llvm::Value* IRGenerator::visitAssignStatement(const AssignStatement* a) { return nullptr; } +llvm::Value* IRGenerator::evalBooleanCondition(const Expression* condition_expression) { + llvm::Value* condition = visit(condition_expression); + auto condition_type = condition->getType(); + condition = evalTruthiness(condition); + if (!condition) { + error(condition_expression, "cannot convert expression of type", + getSourceType(condition_type), "to boolean."); + } + return condition; +} + llvm::Value* IRGenerator::evalTruthiness(llvm::Value* val) { if (llvm::isa(val)) return val; auto val_type = val->getType(); @@ -354,13 +366,8 @@ llvm::Value* IRGenerator::visitBranchStatement(const BranchStatement* b) { builder_.SetInsertPoint(case_cond_bb); for (const BranchCase& branch_case : b->cases()) { - llvm::Value* condition = visit(branch_case.getExpression()); - auto condition_type = condition->getType(); - condition = evalTruthiness(condition); - if (!condition) { - error(branch_case.getExpression(), "cannot convert expression of type", - getSourceType(condition_type), "to boolean."); - } + llvm::Value* condition = evalBooleanCondition(branch_case.getExpression()); + case_cond_bb = llvm::BasicBlock::Create(context_, "branch.case.cond"); llvm::BasicBlock* case_body_bb = llvm::BasicBlock::Create(context_, "branch.case.body", current_function()); @@ -411,13 +418,7 @@ llvm::Value* IRGenerator::visitLoopStatement(const LoopStatement* l) { builder_.CreateBr(condition_bb); builder_.SetInsertPoint(condition_bb); - auto condition = visit(l->getCondition()); - auto condition_type = condition->getType(); - condition = evalTruthiness(condition); - if (!condition) { - error(l->getCondition(), "cannot convert expression of type", getSourceType(condition_type), - "to boolean"); - } + llvm::Value* condition = evalBooleanCondition(l->getCondition()); builder_.CreateCondBr(condition, body_bb, after_bb); current_function()->getBasicBlockList().push_back(after_bb);