aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crowbar_reference_compiler/__init__.py2
-rw-r--r--crowbar_reference_compiler/ast.py90
-rw-r--r--crowbar_reference_compiler/ssagen.py221
3 files changed, 295 insertions, 18 deletions
diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py
index 1410bf7..46d0115 100644
--- a/crowbar_reference_compiler/__init__.py
+++ b/crowbar_reference_compiler/__init__.py
@@ -62,7 +62,7 @@ def main():
if args.out is None:
args.out = args.input.replace('.cro', '.o')
extra_gcc_flags.append('-c')
- gcc_result = subprocess.run(['gcc', '-x', 'assembler', '-o', args.out, '-'], input=asm, text=True)
+ gcc_result = subprocess.run(['gcc', '-x', 'assembler', '-o', args.out, *extra_gcc_flags, '-'], input=asm, text=True)
sys.exit(gcc_result.returncode)
diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py
index 86e64fe..37ce0da 100644
--- a/crowbar_reference_compiler/ast.py
+++ b/crowbar_reference_compiler/ast.py
@@ -12,23 +12,48 @@ from .parser import parse_header
@dataclass
class Type:
- pass
+ def size_bytes(self, declarations: List['Declaration']) -> int:
+ raise NotImplementedError('type.size_bytes() on ' + str(type(self)) + ' not implemented')
@dataclass
class Expression:
- pass
+ def type(self, declarations: List['Declaration']) -> Type:
+ raise NotImplementedError('expression.type() on ' + str(type(self)) + ' not implemented')
@dataclass
class ConstantExpression(Expression):
value: str
+ def type(self, _: List['Declaration']) -> Type:
+ if self.value.startswith('"'):
+ return PointerType(ConstType(BasicType('char')))
+ elif self.value.startswith("'"):
+ return BasicType('char')
+ elif self.value in ['true', 'false']:
+ return BasicType('bool')
+ elif '.' in self.value:
+ return BasicType('float?') # TODO infer size
+ else:
+ return BasicType('int?') # TODO infer size and signedness
+
@dataclass
class VariableExpression(Expression):
name: str
+ def type(self, declarations: List['Declaration']) -> Type:
+ for decl in declarations:
+ if decl.name == self.name:
+ if isinstance(decl, VariableDeclaration):
+ return decl.type
+ elif isinstance(decl, VariableDefinition):
+ return decl.type
+ elif isinstance(decl, FunctionDeclaration) or isinstance(decl, FunctionDefinition):
+ return FunctionType(decl.return_type, [arg.type for arg in decl.args])
+ raise KeyError('unknown variable ' + self.name)
+
@dataclass
class AddExpression(Expression):
@@ -37,6 +62,12 @@ class AddExpression(Expression):
@dataclass
+class SubtractExpression(Expression):
+ term1: Expression
+ term2: Expression
+
+
+@dataclass
class MultiplyExpression(Expression):
factor1: Expression
factor2: Expression
@@ -47,6 +78,22 @@ class StructPointerElementExpression(Expression):
base: Expression
element: str
+ def type(self, declarations: List['Declaration']) -> Type:
+ base_type = self.base.type(declarations)
+ assert isinstance(base_type, PointerType)
+ assert isinstance(base_type.target, BasicType)
+ hopefully_struct, struct_name = base_type.target.name.split(' ')
+ assert hopefully_struct == 'struct'
+ for decl in declarations:
+ if isinstance(decl, StructDeclaration) and decl.name == struct_name:
+ if decl.fields is None:
+ raise KeyError('struct ' + struct_name + ' is opaque')
+ for elem in decl.fields:
+ if elem.name == self.element:
+ return elem.type
+ raise KeyError('element ' + self.element + ' not found in struct ' + struct_name)
+ raise KeyError('struct ' + struct_name + ' not found')
+
@dataclass
class ArrayIndexExpression(Expression):
@@ -91,6 +138,20 @@ class ComparisonExpression(Expression):
class BasicType(Type):
name: str
+ def size_bytes(self, declarations: List['Declaration']) -> int:
+ if self.name == 'uint8':
+ return 1
+ elif self.name == 'uintsize':
+ return 8
+ elif self.name.startswith('struct'):
+ _, struct_name = self.name.split(' ')
+ for decl in declarations:
+ if isinstance(decl, StructDeclaration) and decl.name == struct_name:
+ if decl.fields is None:
+ raise KeyError('struct ' + struct_name + ' is opaque')
+ return sum(field.type.size_bytes(declarations) for field in decl.fields)
+ raise NotImplementedError('size of ' + str(self) + ' not yet found')
+
@dataclass
class ConstType(Type):
@@ -101,6 +162,9 @@ class ConstType(Type):
class PointerType(Type):
target: Type
+ def size_bytes(self, declarations: List['Declaration']) -> int:
+ return 8 # TODO figure out 32 bit vs 64 bit
+
@dataclass
class ArrayType(Type):
@@ -226,6 +290,14 @@ class UpdateAssignment(AssignmentStatement):
operation: str
value: Expression
+ def deconstruct(self) -> DirectAssignment:
+ if self.operation == '+=':
+ return DirectAssignment(self.destination, AddExpression(self.destination, self.value))
+ elif self.operation == '*=':
+ return DirectAssignment(self.destination, MultiplyExpression(self.destination, self.value))
+ else:
+ raise NotImplementedError('UpdateAssignment deconstruct with ' + self.operation)
+
@dataclass
class CrementAssignment(AssignmentStatement):
@@ -273,12 +345,24 @@ class HeaderFile:
includes: List['HeaderFile']
contents: List[HeaderFileElement]
+ def get_declarations(self) -> List[Declaration]:
+ included_declarations = [x.get_declarations() for x in self.includes]
+ own_declarations = [x for x in self.contents if isinstance(x, Declaration)]
+ all_declarations = included_declarations + [own_declarations]
+ return [x for l in all_declarations for x in l]
+
@dataclass
class ImplementationFile:
includes: List[HeaderFile]
contents: List[ImplementationFileElement]
+ def get_declarations(self) -> List[Declaration]:
+ included_declarations = [x.get_declarations() for x in self.includes]
+ own_declarations = [x for x in self.contents if isinstance(x, Declaration)]
+ all_declarations = included_declarations + [own_declarations]
+ return [x for l in all_declarations for x in l]
+
# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class ASTBuilder(NodeVisitor):
@@ -605,6 +689,8 @@ class ASTBuilder(NodeVisitor):
for op, term in suffix:
if op.type == '+':
base = AddExpression(base, term)
+ elif op.type == '-':
+ base = SubtractExpression(base, term)
else:
raise NotImplementedError('arithmetic suffix ' + op)
return base
diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py
index 508025e..b326239 100644
--- a/crowbar_reference_compiler/ssagen.py
+++ b/crowbar_reference_compiler/ssagen.py
@@ -1,9 +1,13 @@
+import dataclasses
from dataclasses import dataclass
from functools import singledispatch
from typing import List
from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \
- VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression
+ VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, \
+ AddExpression, StructPointerElementExpression, Declaration, PointerType, StructDeclaration, VariableDefinition, \
+ MultiplyExpression, LogicalNotExpression, DirectAssignment, UpdateAssignment, SizeofExpression, Expression, \
+ ConstType, ArrayIndexExpression, ArrayType, NegativeExpression, SubtractExpression, AddressOfExpression
@dataclass
@@ -25,13 +29,14 @@ class SsaResult:
@dataclass
class CompileContext:
+ declarations: List[Declaration]
next_data: int = 0
next_temp: int = 0
next_label: int = 0
def build_ssa(file: ImplementationFile) -> str:
- result = compile_to_ssa(file, CompileContext())
+ result = compile_to_ssa(file, CompileContext(file.get_declarations()))
data = '\n'.join(result.data)
code = '\n'.join(result.code)
return data + '\n\n' + code
@@ -53,12 +58,20 @@ def _(target: ImplementationFile, context: CompileContext):
@compile_to_ssa.register
def _(target: FunctionDefinition, context: CompileContext) -> SsaResult:
result = SsaResult([], [])
+ context = dataclasses.replace(context, declarations=target.args + context.declarations)
for statement in target.body:
result += compile_to_ssa(statement, context)
+ if isinstance(statement, Declaration):
+ context = dataclasses.replace(context, declarations=[statement]+context.declarations)
+ if not result.code[-1].startswith('ret'):
+ result.code.append('ret')
code = [' ' + instr for instr in result.code]
- assert len(target.args) == 0
- assert target.return_type == BasicType('int32')
- code = [f"export function w ${target.name}() {{", "@start", *code, "}"]
+ # TODO types
+ args = ','.join(f"l %{x.name}" for x in target.args)
+ ret_type = ''
+ if target.return_type != BasicType('void'):
+ ret_type = 'l'
+ code = [f"export function {ret_type} ${target.name}({args}) {{", "@start", *code, "}"]
return SsaResult(result.data, code)
@@ -82,14 +95,28 @@ def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:
@compile_to_ssa.register
def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
- if target.value.startswith('"'):
+ if target.type(context.declarations) == PointerType(ConstType(BasicType('char'))):
data_dest = context.next_data
context.next_data += 1
data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"]
temp = context.next_temp
context.next_temp += 1
code = [f"%t{temp} =l copy $data{data_dest}"]
- else:
+ elif target.type(context.declarations) == BasicType('char'):
+ data = []
+ temp = context.next_temp
+ context.next_temp += 1
+ code = [f"%t{temp} =l copy {ord(target.value[1])}"] # TODO handle escape sequences
+ elif target.type(context.declarations) == BasicType('bool'):
+ data = []
+ temp = context.next_temp
+ context.next_temp += 1
+ if target.value == 'true':
+ value = 1
+ else:
+ value = 0
+ code = [f"%t{temp} =l copy {value}"]
+ elif target.type(context.declarations) == BasicType('int?'):
assert not target.value.startswith('0b')
assert not target.value.startswith('0B')
assert not target.value.startswith('0o')
@@ -102,7 +129,9 @@ def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
data = []
temp = context.next_temp
context.next_temp += 1
- code = [f"%t{temp} =w copy {target.value}"]
+ code = [f"%t{temp} =l copy {target.value}"]
+ else:
+ raise NotImplementedError('compiling ' + str(target))
return SsaResult(data, code)
@@ -130,11 +159,14 @@ def _(target: IfStatement, context: CompileContext) -> SsaResult:
result.code.append(f"@l{true_label}")
for statement in target.then:
result += compile_to_ssa(statement, context)
- result.code.append(f"jmp @l{after_label}")
+ if not result.code[-1].startswith('ret'):
+ result.code.append(f"jmp @l{after_label}")
result.code.append(f"@l{false_label}")
- for statement in target.els:
- result += compile_to_ssa(statement, context)
- result.code.append(f"jmp @l{after_label}")
+ if target.els is not None:
+ for statement in target.els:
+ result += compile_to_ssa(statement, context)
+ if not result.code[-1].startswith('ret'):
+ result.code.append(f"jmp @l{after_label}")
result.code.append(f"@l{after_label}")
return result
@@ -147,10 +179,16 @@ def _(target: ComparisonExpression, context: CompileContext) -> SsaResult:
value2_dest = context.next_temp - 1
result_dest = context.next_temp
context.next_temp += 1
+ # TODO types, and signedness
if target.op == '==':
- result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}")
+ op = "ceqw"
+ elif target.op == '>=':
+ op = "cugew"
+ elif target.op == '<=':
+ op = "culew"
else:
raise NotImplementedError('comparison ' + target.op)
+ result.code.append(f"%t{result_dest} =l {op} %t{value1_dest}, %t{value2_dest}")
return result
@@ -163,7 +201,33 @@ def _(target: AddExpression, context: CompileContext) -> SsaResult:
result_reg = context.next_temp
context.next_temp += 1
# TODO make sure the types are correct
- result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}")
+ result.code.append(f"%t{result_reg} =l add %t{value1_dest}, %t{value2_dest}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: SubtractExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.term1, context)
+ value1_dest = context.next_temp - 1
+ result += compile_to_ssa(target.term2, context)
+ value2_dest = context.next_temp - 1
+ result_reg = context.next_temp
+ context.next_temp += 1
+ # TODO make sure the types are correct
+ result.code.append(f"%t{result_reg} =l sub %t{value1_dest}, %t{value2_dest}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: MultiplyExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.factor1, context)
+ value1_dest = context.next_temp - 1
+ result += compile_to_ssa(target.factor2, context)
+ value2_dest = context.next_temp - 1
+ result_reg = context.next_temp
+ context.next_temp += 1
+ # TODO make sure the types are correct
+ result.code.append(f"%t{result_reg} =l mul %t{value1_dest}, %t{value2_dest}")
return result
@@ -172,4 +236,131 @@ def _(target: VariableExpression, context: CompileContext) -> SsaResult:
# TODO make sure any of this is reasonable
result = context.next_temp
context.next_temp += 1
- return SsaResult([], [f"%t{result} =w copy %{target.name}"])
+ return SsaResult([], [f"%t{result} =l copy %{target.name}"])
+
+
+@compile_to_ssa.register
+def _(target: VariableDefinition, context: CompileContext) -> SsaResult:
+ # TODO figure some shit out
+ result = compile_to_ssa(target.value, context)
+ result_dest = context.next_temp - 1
+ result.code.append(f"%{target.name} =l copy %t{result_dest}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: LogicalNotExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.body, context)
+ inner_result_dest = context.next_temp - 1
+ result_dest = context.next_temp
+ context.next_temp += 1
+ result.code.append(f"%t{result_dest} =l ceqw %t{inner_result_dest}, 0")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: NegativeExpression, context: CompileContext) -> SsaResult:
+ return compile_to_ssa(SubtractExpression(ConstantExpression('0'), target.body), context)
+
+
+@compile_to_ssa.register
+def _(target: ArrayIndexExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.array, context)
+ base = context.next_temp - 1
+ result += compile_to_ssa(target.index, context)
+ index = context.next_temp - 1
+ array_type = target.array.type(context.declarations)
+ if isinstance(array_type, PointerType):
+ array_type = array_type.target
+ assert isinstance(array_type, ArrayType)
+ content_type = array_type.contents
+ scale = content_type.size_bytes(context.declarations)
+ offset = context.next_temp
+ context.next_temp += 1
+ address = context.next_temp
+ context.next_temp += 1
+ dest = context.next_temp
+ context.next_temp += 1
+ # TODO types
+ result.code.append(f"%t{offset} =l mul %t{index}, {scale}")
+ result.code.append(f"%t{address} =l add %t{base}, %t{offset}")
+ result.code.append(f"%t{dest} =l loadsw %t{address}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: StructPointerElementExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.base, context)
+ base_dest = context.next_temp - 1
+ # hoooo boy.
+ base_type = target.base.type(context.declarations)
+ assert isinstance(base_type, PointerType)
+ assert isinstance(base_type.target, BasicType)
+ hopefully_struct, struct_name = base_type.target.name.split(' ')
+ assert hopefully_struct == 'struct'
+ target_struct = None
+ for decl in context.declarations:
+ if isinstance(decl, StructDeclaration) and decl.name == struct_name:
+ if decl.fields is None:
+ raise KeyError('struct ' + struct_name + ' is opaque')
+ target_struct = decl
+ break
+ if target_struct is None:
+ raise KeyError('struct ' + struct_name + ' not found')
+ offset = 0
+ for field in target_struct.fields:
+ if field.name == target.element:
+ break
+ else:
+ offset += field.type.size_bytes(context.declarations)
+ temp = context.next_temp
+ context.next_temp += 1
+ result_dest = context.next_temp
+ context.next_temp += 1
+ # TODO types
+ result.code.append(f"%t{temp} =l add %t{base_dest}, {offset}")
+ result.code.append(f"%t{result_dest} =l loadsw %t{temp}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: AddressOfExpression, context: CompileContext) -> SsaResult:
+ if isinstance(target.body, StructPointerElementExpression) or isinstance(target.body, ArrayIndexExpression):
+ result = compile_to_ssa(target.body, context)
+ result.code.pop()
+ context.next_temp -= 1
+ else:
+ raise NotImplementedError('address of ' + str(type(target.body)))
+ return result
+
+
+@compile_to_ssa.register
+def _(target: DirectAssignment, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.value, context)
+ result_dest = context.next_temp - 1
+ if isinstance(target.destination, VariableExpression):
+ raise NotImplementedError('assign directly to variable')
+ elif isinstance(target.destination, StructPointerElementExpression) or isinstance(target.destination, ArrayIndexExpression):
+ sub_result = compile_to_ssa(target.destination, context)
+ last_instr = sub_result.code.pop()
+ _, _, _, location = last_instr.split(' ')
+ # TODO type
+ sub_result.code.append(f"storew %t{result_dest}, {location}")
+ result += sub_result
+ else:
+ raise NotImplementedError('assign to ' + str(type(target.destination)))
+ return result
+
+
+@compile_to_ssa.register
+def _(target: UpdateAssignment, context: CompileContext) -> SsaResult:
+ return compile_to_ssa(target.deconstruct(), context)
+
+
+@compile_to_ssa.register
+def _(target: SizeofExpression, context: CompileContext) -> SsaResult:
+ target = target.body
+ if isinstance(target, Expression):
+ target = target.type(context.declarations)
+ size = target.size_bytes(context.declarations)
+ return compile_to_ssa(ConstantExpression(str(size)), context)