Skip to content

Commit

Permalink
Add == for booleans (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli committed Dec 26, 2023
1 parent b37d0ff commit d62daf8
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 25 deletions.
10 changes: 7 additions & 3 deletions self_hosted/create_llvm_ir.jou
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,13 @@ class AstToIR:
if rhs_type->kind == TypeKind::Enum:
rhs_type = int_type

got_numbers = lhs_type->is_number_type() and rhs_type->is_number_type()
got_pointers = lhs_type->is_pointer_type() and rhs_type->is_pointer_type()
assert got_numbers or got_pointers
if lhs_type == &bool_type and rhs_type == &bool_type:
# bools are 1-bit integers in llvm
if op == AstExpressionKind::Eq:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, lhs, rhs, "eq")
if op == AstExpressionKind::Ne:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::NE, lhs, rhs, "ne")
assert False

if lhs_type->kind == TypeKind::FloatingPoint and rhs_type->kind == TypeKind::FloatingPoint:
if op == AstExpressionKind::Add:
Expand Down
7 changes: 5 additions & 2 deletions self_hosted/typecheck.jou
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def check_binop(
do_what = "compare"
result_is_bool = True

got_bools = lhs_types->original_type == &bool_type and rhs_types->original_type == &bool_type
got_integers = lhs_types->original_type->is_integer_type() and rhs_types->original_type->is_integer_type()
got_numbers = lhs_types->original_type->is_number_type() and rhs_types->original_type->is_number_type()
got_enums = lhs_types->original_type->kind == TypeKind::Enum and rhs_types->original_type->kind == TypeKind::Enum
Expand All @@ -703,7 +704,7 @@ def check_binop(
)

if (
(not got_numbers and not got_enums and not got_pointers)
(not got_bools and not got_numbers and not got_enums and not got_pointers)
or (op != AstExpressionKind::Eq and op != AstExpressionKind::Ne and not got_numbers)
):
message: byte[500]
Expand All @@ -714,7 +715,9 @@ def check_binop(
)
fail(location, message)

if got_integers:
if got_bools:
cast_type = &bool_type
elif got_integers:
size = max(lhs_types->original_type->size_in_bits, rhs_types->original_type->size_in_bits)
if (
lhs_types->original_type->kind == TypeKind::SignedInteger
Expand Down
82 changes: 64 additions & 18 deletions src/build_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,41 @@ static const LocalVariable *build_cast(
assert(0);
}

static const LocalVariable *build_bool_eq(struct State *st, Location location, const LocalVariable *a, const LocalVariable *b)
{
assert(a->type == boolType);
assert(b->type == boolType);

/*
Pseudo code:
if a:
result = b
else:
result = not b
*/
const LocalVariable *result = add_local_var(st, boolType);

CfBlock *atrue = add_block(st);
CfBlock *afalse = add_block(st);
CfBlock *done = add_block(st);

// if a:
add_jump(st, a, atrue, afalse, atrue);

// result = b
add_unary_op(st, location, CF_VARCPY, b, result);

// else:
add_jump(st, NULL, done, done, afalse);

// result = not b
add_unary_op(st, location, CF_BOOL_NEGATE, b, result);

add_jump(st, NULL, done, done, done);
return result;
}

static const LocalVariable *build_binop(
struct State *st,
enum AstExpressionKind op,
Expand All @@ -175,32 +210,43 @@ static const LocalVariable *build_binop(
const LocalVariable *rhs,
const Type *result_type)
{
bool got_bools = lhs->type == boolType && rhs->type == boolType;
bool got_numbers = is_number_type(lhs->type) && is_number_type(rhs->type);
bool got_pointers = is_pointer_type(lhs->type) && is_pointer_type(rhs->type);
assert(got_numbers || got_pointers);
assert(got_bools || got_numbers || got_pointers);

enum CfInstructionKind k;
bool negate = false;
bool swap = false;

switch(op) {
case AST_EXPR_ADD: k = CF_NUM_ADD; break;
case AST_EXPR_SUB: k = CF_NUM_SUB; break;
case AST_EXPR_MUL: k = CF_NUM_MUL; break;
case AST_EXPR_DIV: k = CF_NUM_DIV; break;
case AST_EXPR_MOD: k = CF_NUM_MOD; break;
case AST_EXPR_EQ: k = CF_NUM_EQ; break;
case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break;
case AST_EXPR_LT: k = CF_NUM_LT; break;
case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break;
case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break;
case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break;
default: assert(0);
const LocalVariable *destvar;
if (got_bools) {
assert(result_type == boolType);
destvar = build_bool_eq(st, location, lhs, rhs);
switch(op) {
case AST_EXPR_EQ: break;
case AST_EXPR_NE: negate=true; break;
default: assert(0); break;
}
} else {
destvar = add_local_var(st, result_type);
enum CfInstructionKind k;
switch(op) {
case AST_EXPR_ADD: k = CF_NUM_ADD; break;
case AST_EXPR_SUB: k = CF_NUM_SUB; break;
case AST_EXPR_MUL: k = CF_NUM_MUL; break;
case AST_EXPR_DIV: k = CF_NUM_DIV; break;
case AST_EXPR_MOD: k = CF_NUM_MOD; break;
case AST_EXPR_EQ: k = CF_NUM_EQ; break;
case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break;
case AST_EXPR_LT: k = CF_NUM_LT; break;
case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break;
case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break;
case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break;
default: assert(0);
}
add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar);
}

const LocalVariable *destvar = add_local_var(st, result_type);
add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar);

if (!negate)
return destvar;

Expand Down
7 changes: 5 additions & 2 deletions src/typecheck.c
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ static const Type *check_binop(
assert(0);
}

bool got_bools = lhstypes->type == boolType && rhstypes->type == boolType;
bool got_integers = is_integer_type(lhstypes->type) && is_integer_type(rhstypes->type);
bool got_numbers = is_number_type(lhstypes->type) && is_number_type(rhstypes->type);
bool got_enums = lhstypes->type->kind == TYPE_ENUM && rhstypes->type->kind == TYPE_ENUM;
Expand All @@ -711,13 +712,15 @@ static const Type *check_binop(
);

if (
(!got_numbers && !got_enums && !got_pointers)
|| (got_enums && op != AST_EXPR_EQ && op != AST_EXPR_NE)
(!got_bools && !got_numbers && !got_enums && !got_pointers)
|| ((got_bools || got_enums) && op != AST_EXPR_EQ && op != AST_EXPR_NE)
|| (got_pointers && op != AST_EXPR_EQ && op != AST_EXPR_NE && op != AST_EXPR_GT && op != AST_EXPR_GE && op != AST_EXPR_LT && op != AST_EXPR_LE)
)
fail(location, "wrong types: cannot %s %s and %s", do_what, lhstypes->type->name, rhstypes->type->name);

const Type *cast_type = NULL;
if (got_bools)
cast_type = boolType;
if (got_integers) {
cast_type = get_integer_type(
max(lhstypes->type->data.width_in_bits, rhstypes->type->data.width_in_bits),
Expand Down
19 changes: 19 additions & 0 deletions tests/should_succeed/bool_eq_bool.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import "stdlib/io.jou"

def do_stuff(a: bool, b: bool) -> None:
if a == b:
puts("Hi")

def main() -> int:
do_stuff(False, False) # Output: Hi

# Output: 1001
printf(
"%d%d%d%d\n",
True == True, # Warning: this code will never run
True == False, # Warning: this code will never run
False == True, # Warning: this code will never run
False == False, # Warning: this code will never run
)

return 0
3 changes: 3 additions & 0 deletions tests/wrong_type/bool_plus_bool.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def main() -> int:
x = True + False # Error: wrong types: cannot add bool and bool
return 0

0 comments on commit d62daf8

Please sign in to comment.