Adding type coercion and deduction mechanism.
This commit is contained in:
parent
cac89bb476
commit
4321fa5dbf
|
@ -31,6 +31,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
//TODO remove
|
//TODO remove
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
@ -93,6 +94,78 @@ bool reportError(std::initializer_list<std::string> const& what) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define I64 llvm::Type::getInt64Ty(getGlobalContext())
|
||||||
|
#define I8 llvm::Type::getInt8Ty(getGlobalContext())
|
||||||
|
#define I1 llvm::Type::getInt1Ty(getGlobalContext())
|
||||||
|
#define F llvm::Type::getFloatTy(getGlobalContext())
|
||||||
|
#define D llvm::Type::getDoubleTy(getGlobalContext())
|
||||||
|
|
||||||
|
static const std::unordered_map<llvm::Type*, std::unordered_map<llvm::Type*, llvm::Type*>> TYPE_MAP = {
|
||||||
|
{I64, { {I8, I64}, {I1, I64}, { F, D}, {D, D}}},
|
||||||
|
{ I8, {{I64, I64}, {I1, I8}, { F, F}, {D, D}}},
|
||||||
|
{ I1, {{I64, I64}, {I8, I8}, { F, F}, {D, D}}},
|
||||||
|
{ F, {{I64, D}, {I8, F}, {I1, F}, {D, D}}},
|
||||||
|
{ D, {{I64, D}, {I8, D}, {I1, D}, { F, D} }}
|
||||||
|
};
|
||||||
|
|
||||||
|
static
|
||||||
|
llvm::Type* deduceResultType(llvm::Value *left, llvm::Value *right) {
|
||||||
|
llvm::Type *lt = left->getType();
|
||||||
|
llvm::Type *rt = right->getType();
|
||||||
|
|
||||||
|
if (lt == rt) return rt;
|
||||||
|
|
||||||
|
auto subTable = TYPE_MAP.find(lt);
|
||||||
|
if (subTable != TYPE_MAP.end()) {
|
||||||
|
auto resultType = subTable->second.find(rt);
|
||||||
|
if (resultType != subTable->second.end()) return resultType->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef I64
|
||||||
|
#undef I8
|
||||||
|
#undef I1
|
||||||
|
#undef F
|
||||||
|
#undef D
|
||||||
|
|
||||||
|
static inline
|
||||||
|
bool isFP(llvm::Type *type) {
|
||||||
|
return type->isFloatTy() || type->isDoubleTy();
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline
|
||||||
|
bool isInt(llvm::Type *type) {
|
||||||
|
return type->isIntegerTy();
|
||||||
|
}
|
||||||
|
|
||||||
|
static
|
||||||
|
llvm::Value* coerce(BitcodeEmitter::Private *d, llvm::Value *val, llvm::Type *toType) {
|
||||||
|
llvm::Type *fromType = val->getType();
|
||||||
|
|
||||||
|
if (fromType == toType) return val;
|
||||||
|
|
||||||
|
if (isInt(toType)) {
|
||||||
|
if (isFP(fromType)) {
|
||||||
|
return d->builder.CreateFPToSI(val, toType);
|
||||||
|
} else if (isInt(fromType)) {
|
||||||
|
return d->builder.CreateSExtOrBitCast(val, toType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (isFP(toType) && isInt(fromType)) {
|
||||||
|
return d->builder.CreateFPToSI(val, toType);
|
||||||
|
}
|
||||||
|
else if (fromType->isFloatTy() && toType->isDoubleTy()) {
|
||||||
|
return d->builder.CreateFPExt(val, toType);
|
||||||
|
}
|
||||||
|
else if (fromType->isDoubleTy() && toType->isFloatTy()) {
|
||||||
|
return d->builder.CreateFPTrunc(val, toType);
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
BitcodeEmitter::BitcodeEmitter() {
|
BitcodeEmitter::BitcodeEmitter() {
|
||||||
module = std::unique_ptr<llvm::Module>(
|
module = std::unique_ptr<llvm::Module>(
|
||||||
new llvm::Module("monicelli", getGlobalContext())
|
new llvm::Module("monicelli", getGlobalContext())
|
||||||
|
@ -145,15 +218,27 @@ bool BitcodeEmitter::emit(Loop const& node) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static
|
||||||
|
bool convertAndStore(BitcodeEmitter::Private *d, llvm::AllocaInst *dest, llvm::Value *expression) {
|
||||||
|
llvm::Type *varType = dest->getAllocatedType();
|
||||||
|
expression = coerce(d, expression, varType);
|
||||||
|
if (expression == nullptr) return false;
|
||||||
|
d->builder.CreateStore(expression, dest);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool BitcodeEmitter::emit(VarDeclaration const& node) {
|
bool BitcodeEmitter::emit(VarDeclaration const& node) {
|
||||||
llvm::Function *father = d->builder.GetInsertBlock()->getParent();
|
llvm::Function *father = d->builder.GetInsertBlock()->getParent();
|
||||||
llvm::AllocaInst *alloc = allocateVar(
|
llvm::Type *varType = LLVMType(node.getType());
|
||||||
father, node.getId(), LLVMType(node.getType())
|
llvm::AllocaInst *alloc = allocateVar(father, node.getId(), varType);
|
||||||
);
|
|
||||||
|
|
||||||
if (node.getInitializer()) {
|
if (node.getInitializer()) {
|
||||||
GUARDED(node.getInitializer()->emit(this));
|
GUARDED(node.getInitializer()->emit(this));
|
||||||
d->builder.CreateStore(d->retval, alloc);
|
if (!convertAndStore(d, alloc, d->retval)) {
|
||||||
|
return reportError({
|
||||||
|
"Invalid inizializer for variable", node.getId().getValue()
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO pointers
|
// TODO pointers
|
||||||
|
@ -164,7 +249,6 @@ bool BitcodeEmitter::emit(VarDeclaration const& node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BitcodeEmitter::emit(Assignment const& node) {
|
bool BitcodeEmitter::emit(Assignment const& node) {
|
||||||
GUARDED(node.getValue().emit(this));
|
|
||||||
auto var = d->scope.lookup(node.getName().getValue());
|
auto var = d->scope.lookup(node.getName().getValue());
|
||||||
|
|
||||||
if (!var) {
|
if (!var) {
|
||||||
|
@ -173,7 +257,12 @@ bool BitcodeEmitter::emit(Assignment const& node) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
d->builder.CreateStore(d->retval, *var);
|
GUARDED(node.getValue().emit(this));
|
||||||
|
if (!convertAndStore(d, *var, d->retval)) {
|
||||||
|
return reportError({
|
||||||
|
"Invalid assignment to variable", node.getName().getValue()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -360,6 +449,10 @@ bool BitcodeEmitter::emit(Module const& node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BitcodeEmitter::emit(Program const& program) {
|
bool BitcodeEmitter::emit(Program const& program) {
|
||||||
|
// for (Module const& module: program.getModules()) {
|
||||||
|
// GUARDED(module.emit(this));
|
||||||
|
// }
|
||||||
|
|
||||||
for (Function const* function: program.getFunctions()) {
|
for (Function const* function: program.getFunctions()) {
|
||||||
GUARDED(function->emit(this));
|
GUARDED(function->emit(this));
|
||||||
}
|
}
|
||||||
|
@ -368,10 +461,6 @@ bool BitcodeEmitter::emit(Program const& program) {
|
||||||
GUARDED(program.getMain()->emit(this));
|
GUARDED(program.getMain()->emit(this));
|
||||||
}
|
}
|
||||||
|
|
||||||
// for (Module const& module: program.getModules()) {
|
|
||||||
// GUARDED(module.emit(this));
|
|
||||||
// }
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -405,16 +494,6 @@ bool BitcodeEmitter::emit(Float const& node) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static
|
|
||||||
llvm::Type* deduceResultType(llvm::Value *left, llvm::Value *right) {
|
|
||||||
return llvm::Type::getInt64Ty(getGlobalContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
static
|
|
||||||
llvm::Value* coerce(llvm::Value *val, llvm::Type *type) {
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define HANDLE(intop, fpop) \
|
#define HANDLE(intop, fpop) \
|
||||||
if (fp) { \
|
if (fp) { \
|
||||||
d->retval = d->builder.Create##fpop(left, right); \
|
d->retval = d->builder.Create##fpop(left, right); \
|
||||||
|
@ -432,10 +511,19 @@ llvm::Value* coerce(llvm::Value *val, llvm::Type *type) {
|
||||||
static
|
static
|
||||||
bool createOp(BitcodeEmitter::Private *d, llvm::Value *left, Operator op, llvm::Value *right) {
|
bool createOp(BitcodeEmitter::Private *d, llvm::Value *left, Operator op, llvm::Value *right) {
|
||||||
llvm::Type *retType = deduceResultType(left, right);
|
llvm::Type *retType = deduceResultType(left, right);
|
||||||
bool fp = retType->isFloatTy() || retType->isDoubleTy();
|
|
||||||
|
|
||||||
left = coerce(left, retType);
|
if (retType == nullptr) {
|
||||||
right = coerce(right, retType);
|
return reportError({"Cannot combine operators."});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool fp = isFP(retType);
|
||||||
|
|
||||||
|
left = coerce(d, left, retType);
|
||||||
|
right = coerce(d, right, retType);
|
||||||
|
|
||||||
|
if (left == nullptr || right == nullptr) {
|
||||||
|
return reportError({"Cannot convert operators to result type."});
|
||||||
|
}
|
||||||
|
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case Operator::PLUS:
|
case Operator::PLUS:
|
||||||
|
|
Reference in New Issue
Block a user