Adding type coercion and deduction mechanism.

This commit is contained in:
Stefano Sanfilippo 2015-03-06 16:10:35 +01:00
parent cac89bb476
commit 4321fa5dbf

View File

@ -31,6 +31,7 @@
#include <vector> #include <vector>
#include <unordered_set> #include <unordered_set>
#include <initializer_list> #include <initializer_list>
#include <unordered_map>
//TODO remove //TODO remove
#include <iostream> #include <iostream>
@ -93,6 +94,78 @@ bool reportError(std::initializer_list<std::string> const& what) {
return false; return false;
} }
#define I64 llvm::Type::getInt64Ty(getGlobalContext())
#define I8 llvm::Type::getInt8Ty(getGlobalContext())
#define I1 llvm::Type::getInt1Ty(getGlobalContext())
#define F llvm::Type::getFloatTy(getGlobalContext())
#define D llvm::Type::getDoubleTy(getGlobalContext())
static const std::unordered_map<llvm::Type*, std::unordered_map<llvm::Type*, llvm::Type*>> TYPE_MAP = {
{I64, { {I8, I64}, {I1, I64}, { F, D}, {D, D}}},
{ I8, {{I64, I64}, {I1, I8}, { F, F}, {D, D}}},
{ I1, {{I64, I64}, {I8, I8}, { F, F}, {D, D}}},
{ F, {{I64, D}, {I8, F}, {I1, F}, {D, D}}},
{ D, {{I64, D}, {I8, D}, {I1, D}, { F, D} }}
};
static
llvm::Type* deduceResultType(llvm::Value *left, llvm::Value *right) {
llvm::Type *lt = left->getType();
llvm::Type *rt = right->getType();
if (lt == rt) return rt;
auto subTable = TYPE_MAP.find(lt);
if (subTable != TYPE_MAP.end()) {
auto resultType = subTable->second.find(rt);
if (resultType != subTable->second.end()) return resultType->second;
}
return nullptr;
}
#undef I64
#undef I8
#undef I1
#undef F
#undef D
static inline
bool isFP(llvm::Type *type) {
return type->isFloatTy() || type->isDoubleTy();
}
static inline
bool isInt(llvm::Type *type) {
return type->isIntegerTy();
}
static
llvm::Value* coerce(BitcodeEmitter::Private *d, llvm::Value *val, llvm::Type *toType) {
llvm::Type *fromType = val->getType();
if (fromType == toType) return val;
if (isInt(toType)) {
if (isFP(fromType)) {
return d->builder.CreateFPToSI(val, toType);
} else if (isInt(fromType)) {
return d->builder.CreateSExtOrBitCast(val, toType);
}
}
else if (isFP(toType) && isInt(fromType)) {
return d->builder.CreateFPToSI(val, toType);
}
else if (fromType->isFloatTy() && toType->isDoubleTy()) {
return d->builder.CreateFPExt(val, toType);
}
else if (fromType->isDoubleTy() && toType->isFloatTy()) {
return d->builder.CreateFPTrunc(val, toType);
}
return nullptr;
}
BitcodeEmitter::BitcodeEmitter() { BitcodeEmitter::BitcodeEmitter() {
module = std::unique_ptr<llvm::Module>( module = std::unique_ptr<llvm::Module>(
new llvm::Module("monicelli", getGlobalContext()) new llvm::Module("monicelli", getGlobalContext())
@ -145,15 +218,27 @@ bool BitcodeEmitter::emit(Loop const& node) {
return true; return true;
} }
static
bool convertAndStore(BitcodeEmitter::Private *d, llvm::AllocaInst *dest, llvm::Value *expression) {
llvm::Type *varType = dest->getAllocatedType();
expression = coerce(d, expression, varType);
if (expression == nullptr) return false;
d->builder.CreateStore(expression, dest);
return true;
}
bool BitcodeEmitter::emit(VarDeclaration const& node) { bool BitcodeEmitter::emit(VarDeclaration const& node) {
llvm::Function *father = d->builder.GetInsertBlock()->getParent(); llvm::Function *father = d->builder.GetInsertBlock()->getParent();
llvm::AllocaInst *alloc = allocateVar( llvm::Type *varType = LLVMType(node.getType());
father, node.getId(), LLVMType(node.getType()) llvm::AllocaInst *alloc = allocateVar(father, node.getId(), varType);
);
if (node.getInitializer()) { if (node.getInitializer()) {
GUARDED(node.getInitializer()->emit(this)); GUARDED(node.getInitializer()->emit(this));
d->builder.CreateStore(d->retval, alloc); if (!convertAndStore(d, alloc, d->retval)) {
return reportError({
"Invalid inizializer for variable", node.getId().getValue()
});
}
} }
// TODO pointers // TODO pointers
@ -164,7 +249,6 @@ bool BitcodeEmitter::emit(VarDeclaration const& node) {
} }
bool BitcodeEmitter::emit(Assignment const& node) { bool BitcodeEmitter::emit(Assignment const& node) {
GUARDED(node.getValue().emit(this));
auto var = d->scope.lookup(node.getName().getValue()); auto var = d->scope.lookup(node.getName().getValue());
if (!var) { if (!var) {
@ -173,7 +257,12 @@ bool BitcodeEmitter::emit(Assignment const& node) {
}); });
} }
d->builder.CreateStore(d->retval, *var); GUARDED(node.getValue().emit(this));
if (!convertAndStore(d, *var, d->retval)) {
return reportError({
"Invalid assignment to variable", node.getName().getValue()
});
}
return true; return true;
} }
@ -360,6 +449,10 @@ bool BitcodeEmitter::emit(Module const& node) {
} }
bool BitcodeEmitter::emit(Program const& program) { bool BitcodeEmitter::emit(Program const& program) {
// for (Module const& module: program.getModules()) {
// GUARDED(module.emit(this));
// }
for (Function const* function: program.getFunctions()) { for (Function const* function: program.getFunctions()) {
GUARDED(function->emit(this)); GUARDED(function->emit(this));
} }
@ -368,10 +461,6 @@ bool BitcodeEmitter::emit(Program const& program) {
GUARDED(program.getMain()->emit(this)); GUARDED(program.getMain()->emit(this));
} }
// for (Module const& module: program.getModules()) {
// GUARDED(module.emit(this));
// }
return true; return true;
} }
@ -405,16 +494,6 @@ bool BitcodeEmitter::emit(Float const& node) {
return true; return true;
} }
static
llvm::Type* deduceResultType(llvm::Value *left, llvm::Value *right) {
return llvm::Type::getInt64Ty(getGlobalContext());
}
static
llvm::Value* coerce(llvm::Value *val, llvm::Type *type) {
return val;
}
#define HANDLE(intop, fpop) \ #define HANDLE(intop, fpop) \
if (fp) { \ if (fp) { \
d->retval = d->builder.Create##fpop(left, right); \ d->retval = d->builder.Create##fpop(left, right); \
@ -432,10 +511,19 @@ llvm::Value* coerce(llvm::Value *val, llvm::Type *type) {
static static
bool createOp(BitcodeEmitter::Private *d, llvm::Value *left, Operator op, llvm::Value *right) { bool createOp(BitcodeEmitter::Private *d, llvm::Value *left, Operator op, llvm::Value *right) {
llvm::Type *retType = deduceResultType(left, right); llvm::Type *retType = deduceResultType(left, right);
bool fp = retType->isFloatTy() || retType->isDoubleTy();
left = coerce(left, retType); if (retType == nullptr) {
right = coerce(right, retType); return reportError({"Cannot combine operators."});
}
bool fp = isFP(retType);
left = coerce(d, left, retType);
right = coerce(d, right, retType);
if (left == nullptr || right == nullptr) {
return reportError({"Cannot convert operators to result type."});
}
switch (op) { switch (op) {
case Operator::PLUS: case Operator::PLUS: