diff options
-rw-r--r-- | crowbar_reference_compiler/ast.py | 283 | ||||
-rw-r--r-- | tests/test_ast.py | 356 | ||||
-rw-r--r-- | tests/test_parsing.py | 67 |
3 files changed, 598 insertions, 108 deletions
diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py index 4b9c21d..2a6963c 100644 --- a/crowbar_reference_compiler/ast.py +++ b/crowbar_reference_compiler/ast.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple, Union +import typing +from typing import ClassVar, List, Tuple, Union from parsimonious import NodeVisitor # type: ignore +from parsimonious.expressions import Compound, OneOf, Optional, Sequence, TokenMatcher, ZeroOrMore # type: ignore from .scanner import scan from .parser import parse_header @@ -29,6 +31,63 @@ class VariableExpression(Expression): @dataclass +class AddExpression(Expression): + term1: Expression + term2: Expression + + +@dataclass +class MultiplyExpression(Expression): + factor1: Expression + factor2: Expression + + +@dataclass +class StructPointerElementExpression(Expression): + base: Expression + element: str + + +@dataclass +class ArrayIndexExpression(Expression): + array: Expression + index: Expression + + +@dataclass +class FunctionCallExpression(Expression): + function: Expression + arguments: List[Expression] + + +@dataclass +class LogicalNotExpression(Expression): + body: Expression + + +@dataclass +class NegativeExpression(Expression): + body: Expression + + +@dataclass +class AddressOfExpression(Expression): + body: Expression + + +@dataclass +class SizeofExpression(Expression): + body: Union[Type, Expression] + + +@dataclass +class ComparisonExpression(Expression): + value1: Expression + op: str + value2: Expression + + +@dataclass class BasicType(Type): name: str @@ -89,13 +148,13 @@ class ExpressionStatement(Statement): class IfStatement(Statement): condition: Expression then: List[Statement] - els: Optional[List[Statement]] + els: typing.Optional[List[Statement]] @dataclass class SwitchStatement(Statement): expression: Expression - body: List[Union[Optional[Expression], Statement]] + body: List[Union[typing.Optional[Expression], Statement]] @dataclass @@ -122,7 +181,7 @@ class VariableDeclaration(Declaration, HeaderFileElement): @dataclass -class VariableDefinition(Declaration, ImplementationFileElement, Statement): +class VariableDefinition(Declaration, HeaderFileElement, ImplementationFileElement, Statement): """Represents the definition of a variable.""" type: Type value: Expression @@ -152,7 +211,7 @@ class BreakStatement(Statement): @dataclass class ReturnStatement(Statement): - body: Optional[Expression] + body: typing.Optional[Expression] @dataclass @@ -177,20 +236,20 @@ class CrementAssignment(AssignmentStatement): @dataclass class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of a struct type.""" - fields: Optional[List[VariableDeclaration]] + fields: typing.Optional[List[VariableDeclaration]] @dataclass class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of an enum type.""" - values: List[Tuple[str, Optional[int]]] + values: List[Tuple[str, typing.Optional[int]]] @dataclass class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of a union type.""" - tag: Optional[VariableDeclaration] - cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]] + tag: typing.Optional[VariableDeclaration] + cases: Union[List[VariableDeclaration], List[Tuple[Expression, typing.Optional[VariableDeclaration]]]] @dataclass @@ -210,6 +269,7 @@ class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileEleme @dataclass class HeaderFile: + grammar: ClassVar[str] = "HeaderFile <- IncludeStatement* HeaderFileElement+" includes: List['HeaderFile'] contents: List[HeaderFileElement] @@ -227,8 +287,6 @@ class ASTBuilder(NodeVisitor): def visit_HeaderFile(self, node, visited_children) -> HeaderFile: includes, elements = visited_children - if not isinstance(includes, list): - includes = [] return HeaderFile(includes, elements) def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile: @@ -253,36 +311,38 @@ class ASTBuilder(NodeVisitor): def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration: struct, name, lbrace, fields, rbrace = visited_children assert struct.type == 'struct' + assert name.type == 'identifier' + name = name.data assert lbrace.type == '{' assert rbrace.type == '}' - name = name.data - if not isinstance(fields, list): - fields = [fields] return StructDeclaration(name, fields) def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration: opaque, struct, name, semi = visited_children assert opaque.type == 'opaque' assert struct.type == 'struct' - assert semi.type == ';' + assert name.type == 'identifier' name = name.data + assert semi.type == ';' return StructDeclaration(name, None) def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration: enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children assert enum.type == 'enum' + assert name.type == 'identifier' + name = name.data assert lbrace.type == '{' assert rbrace.type == '}' - name = name.data values = [first_member] for _, v in extra_members: values.append(v) return EnumDeclaration(name, values) - def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]: + def visit_EnumMember(self, node, visited_children) -> Tuple[str, typing.Optional[Expression]]: name, equals_value = visited_children + assert name.type == 'identifier' name = name.data - if len(equals_value) == 0: + if equals_value is None: return name, None _, value = equals_value return name, value @@ -290,9 +350,10 @@ class ASTBuilder(NodeVisitor): def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration: union, name, lbrace, tag, body, rbrace = visited_children assert union.type == 'union' + assert name.type == 'identifier' + name = name.data assert lbrace.type == '{' assert rbrace.type == '}' - name = name.data expected_tagname, body = body if tag.name != expected_tagname: raise NameError(f"tag {tag} does not match switch argument {expected_tagname}") @@ -300,7 +361,7 @@ class ASTBuilder(NodeVisitor): body = [body] return UnionDeclaration(name, tag, body) - def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]: + def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]: switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children assert switch.type == 'switch' assert lparen.type == '(' @@ -309,7 +370,7 @@ class ASTBuilder(NodeVisitor): assert rbrace.type == '}' return tag.data, body - def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]: + def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, typing.Optional[VariableDeclaration]]: cases, var = visited_children if isinstance(cases, list): cases = cases[0] @@ -323,17 +384,16 @@ class ASTBuilder(NodeVisitor): visited_children = visited_children[0] # TODO don't explode on 'default:' case, expr, colon = visited_children - while isinstance(expr, list): - expr = expr[0] return expr def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration: fragile, union, name, lbrace, body, rbrace = visited_children assert fragile.type == 'fragile' assert union.type == 'union' + assert name.type == 'identifier' + name = name.data assert lbrace.type == '{' assert rbrace.type == '}' - name = name.data return UnionDeclaration(name, None, body) def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration: @@ -343,15 +403,17 @@ class ASTBuilder(NodeVisitor): def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition: type, name, eq, value, semi = visited_children + assert name.type == 'identifier' + name = name.data assert eq.type == '=' assert semi.type == ';' - name = name.data return VariableDefinition(name, type, value) def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration: type, name, semi = visited_children - assert semi.type == ';' + assert name.type == 'identifier' name = name.data + assert semi.type == ';' return VariableDeclaration(name, type) def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition: @@ -363,9 +425,53 @@ class ASTBuilder(NodeVisitor): assert name.type == 'identifier' name = name.data assert lparen.type == '(' + if args is None: + args = [] assert rparen.type == ')' return FunctionDeclaration(name, return_type, args) + def visit_SignatureArguments(self, node, visited_children) -> List[VariableDeclaration]: + first_type, first_name, rest, comma = visited_children + result = [VariableDeclaration(first_name.data, first_type)] + for comma, ty, name in rest: + result.append(VariableDeclaration(name.data, ty)) + return result + + def visit_IfStatement(self, node, visited_children): + kwd, lparen, condition, rparen, then, els = visited_children + assert kwd.type == 'if' + assert lparen.type == '(' + assert rparen.type == ')' + if els is not None: + kwd, els = els + assert kwd.type == 'else' + return IfStatement(condition, then, els) + + def visit_ReturnStatement(self, node, visited_children): + ret, body, semi = visited_children + assert ret.type == 'return' + assert semi.type == ';' + return ReturnStatement(body) + + def visit_DirectAssignmentBody(self, node, visited_children): + dest, eq, value = visited_children + assert eq.type == '=' + return DirectAssignment(dest, value) + + def visit_UpdateAssignmentBody(self, node, visited_children): + dest, op, value = visited_children + return UpdateAssignment(dest, op.type, value) + + def visit_AssignmentStatement(self, node, visited_children): + assignment, semi = visited_children + assert semi.type == ';' + return assignment + + def visit_ExpressionStatement(self, node, visited_children): + expression, semi = visited_children + assert semi.type == ';' + return ExpressionStatement(expression) + def visit_BasicType(self, node, visited_children) -> Type: while isinstance(visited_children, list) and len(visited_children) == 1: visited_children = visited_children[0] @@ -379,17 +485,23 @@ class ASTBuilder(NodeVisitor): else: category, name = visited_children category = category.type + assert name.type == 'identifier' name = name.data return BasicType(f"{category} {name}") return BasicType(visited_children.type) + def visit_ConstType(self, node, visited_children) -> ConstType: + const, contents = visited_children + assert const.type == 'const' + return ConstType(contents) + + def visit_FunctionType(self, node, visited_children): + raise NotImplementedError('function types') + def visit_ArrayType(self, node, visited_children) -> ArrayType: contents, lbracket, size, rbracket = visited_children assert lbracket.type == '[' assert rbracket.type == ']' - # TODO don't explode on nontrivial expression - while isinstance(size, list): - size = size[0] return ArrayType(contents, size) def visit_PointerType(self, node, visited_children) -> PointerType: @@ -397,6 +509,12 @@ class ASTBuilder(NodeVisitor): assert splat.type == '*' return PointerType(contents) + def visit_Block(self, node, visited_children) -> List[Expression]: + lbrace, body, rbrace = visited_children + assert lbrace.type == '{' + assert rbrace.type == '}' + return body + def visit_AtomicExpression(self, node, visited_children) -> Expression: if isinstance(visited_children, list) and len(visited_children) == 3: lparen, body, rparen = visited_children @@ -414,17 +532,106 @@ class ASTBuilder(NodeVisitor): return ConstantExpression(body.type) raise NotImplementedError() + def visit_StructPointerElementSuffix(self, node, visited_children): + separator, element = visited_children + assert separator.type == '->' + return lambda base: StructPointerElementExpression(base, element.data) + + def visit_CommasExpressionList(self, node, visited_children): + first, rest, comma = visited_children + result = [first] + for comma, next in rest: + result.append(next) + return result + + def visit_FunctionCallSuffix(self, node, visited_children): + lparen, args, rparen = visited_children + assert lparen.type == '(' + assert rparen.type == ')' + if args is None: + args = [] + return lambda base: FunctionCallExpression(base, args) + + def visit_ArrayIndexSuffix(self, node, visited_children): + lbracket, index, rbracket = visited_children + assert lbracket.type == '[' + assert rbracket.type == ']' + return lambda base: ArrayIndexExpression(base, index) + + def visit_ObjectExpression(self, node, visited_children) -> Expression: + if isinstance(visited_children, list): + base, suffix = visited_children[0] + if len(suffix) > 0: + for suffix in suffix: + base = suffix(base) + return base + raise NotImplementedError('array/struct literals') + + def visit_NegativeExpression(self, node, visited_children): + minus, body = visited_children + assert minus.type == '-' + return NegativeExpression(body) + + def visit_AddressOfExpression(self, node, visited_children): + ampersand, body = visited_children + assert ampersand.type == '&' + return AddressOfExpression(body) + + def visit_LogicalNotExpression(self, node, visited_children): + bang, body = visited_children + assert bang.type == '!' + return LogicalNotExpression(body) + + def visit_SizeofExpression(self, node, visited_children): + sizeof, argument = visited_children[0] + assert sizeof.type == 'sizeof' + return SizeofExpression(argument) + + def visit_TermExpression(self, node, visited_children) -> Expression: + base, suffix = visited_children + if suffix is not None: + for op, factor in suffix: + if op.type == '*': + base = MultiplyExpression(base, factor) + else: + raise NotImplementedError('term suffix ' + op) + return base + + def visit_ArithmeticExpression(self, node, visited_children) -> Expression: + base, suffix = visited_children + if suffix is not None: + for op, term in suffix: + if op.type == '+': + base = AddExpression(base, term) + else: + raise NotImplementedError('arithmetic suffix ' + op) + return base + + def visit_GreaterEqExpression(self, node, visited_children): + value1, op, value2 = visited_children + assert op.type == '>=' + return ComparisonExpression(value1, '>=', value2) + + def visit_LessEqExpression(self, node, visited_children): + value1, op, value2 = visited_children + assert op.type == '<=' + return ComparisonExpression(value1, '<=', value2) + def generic_visit(self, node, visited_children): - """ The generic visit method. """ - if not visited_children: - if len(node.text) == 0: - return [] - if len(node.text) == 1: - return node.text[0] - raise ValueError('just a node: ' + str(node)) - if len(visited_children) == 1: + if isinstance(node.expr, TokenMatcher): + return node.text[0] + if isinstance(node.expr, OneOf): + return visited_children[0] + if isinstance(node.expr, Optional): + if len(visited_children) == 0: + return None return visited_children[0] - return visited_children + if isinstance(node.expr, Sequence) and node.expr.name != '': + raise NotImplementedError('visit for sequence ' + str(node.expr)) + if isinstance(node.expr, Compound): + return visited_children + print(node.expr) + return super(ASTBuilder, self).generic_visit(node, visited_children) def build_ast(parse_tree, include_dirs): diff --git a/tests/test_ast.py b/tests/test_ast.py index b2c592d..b2c33c9 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,12 +1,43 @@ +import dataclasses import unittest -from crowbar_reference_compiler import build_ast, parse_header, scan -from crowbar_reference_compiler.ast import ArrayType, BasicType, ConstantExpression, EnumDeclaration, HeaderFile, \ - PointerType, StructDeclaration, UnionDeclaration, VariableDeclaration, VariableExpression +from crowbar_reference_compiler import build_ast, parse_header, parse_implementation, scan +from crowbar_reference_compiler.ast import ( + AddExpression, + AddressOfExpression, + ArrayIndexExpression, + ArrayType, + DirectAssignment, + BasicType, + ComparisonExpression, + ConstType, + ConstantExpression, + EnumDeclaration, + ExpressionStatement, + FunctionCallExpression, + FunctionDeclaration, + FunctionDefinition, + HeaderFile, + IfStatement, + ImplementationFile, + LogicalNotExpression, + MultiplyExpression, + NegativeExpression, + PointerType, + ReturnStatement, + SizeofExpression, + StructDeclaration, + StructPointerElementExpression, + UnionDeclaration, + UpdateAssignment, + VariableDeclaration, + VariableDefinition, + VariableExpression +) class TestAST(unittest.TestCase): - def test_kitchen_sink(self): + def test_type_kitchen_sink(self): code = r""" struct normal { bool fake; @@ -50,5 +81,322 @@ fragile union not_robust { self.assertEqual(decls, HeaderFile([], [normal, ope, sample, robust, not_robust])) +class TestRealCode(unittest.TestCase): + str_hro_code = r""" +struct str { + (uint8[size])* str; + uintsize len; + uintsize size; +} + +struct str *str_create(); +void str_free(struct str *str); +void str_reset(struct str *str); +intsize str_append_ch(struct str *str, uint32 ch); +""" + + unicode_hro_code = r""" +// Technically UTF-8 supports up to 6 byte codepoints, but Unicode itself +// doesn't really bother with more than 4. +const intsize UTF8_MAX_SIZE = 4; + +const uint8 UTF8_INVALID = 0x80; + +/** + * Grabs the next UTF-8 character and advances the string pointer + */ +uint32 utf8_decode(((const uint8)*)* str); + +/** + * Encodes a character as UTF-8 and returns the length of that character. + */ +intsize utf8_encode(uint8 *str, uint32 ch); + +/** + * Returns the size of the next UTF-8 character + */ +intsize utf8_size((const uint8)* str); + +/** + * Returns the size of a UTF-8 character + */ +intsize utf8_chsize(uint32 ch); + +/** + * Reads and returns the next character from the file. + */ +uint32 utf8_fgetch(struct FILE *f); + +/** + * Writes this character to the file and returns the number of bytes written. + */ +intsize utf8_fputch(struct FILE *f, uint32 ch); +""" + + string_cro_code = r""" +//include "stdlib.hro"; +//include "stdint.hro"; +include "str.hro"; +include "unicode.hro"; + +bool ensure_capacity(struct str *str, intsize len) { + if (len + 1 >= str->size) { + (uint8[str->size * 2])* new = realloc(str->str, str->size * 2); + if (!new) { + return false; + } + str->str = new; + str->size *= 2; + } + return true; +} + +struct str *str_create() { + struct str *str = calloc(1, sizeof(struct str)); + str->str = malloc(16); + str->size = 16; + str->len = 0; + str->str[0] = '\0'; + return str; +} + +void str_free(struct str *str) { + if (!str) { + return; + } + free(str->str); + free(str); +} + +intsize str_append_ch(struct str *str, uint32 ch) { + intsize size = utf8_chsize(ch); + if (size <= 0) { + return -1; + } + if (!ensure_capacity(str, str->len + size)) { + return -1; + } + utf8_encode(&str->str[str->len], ch); + str->len += size; + str->str[str->len] = '\0'; + return size; +} +""" + + def test_str_hro(self): + code = self.str_hro_code + tokens = scan(code) + parse_tree = parse_header(tokens) + ast = build_ast(parse_tree, []) + struct_str = StructDeclaration('str', [ + VariableDeclaration('str', PointerType(ArrayType(BasicType('uint8'), VariableExpression('size')))), + VariableDeclaration('len', BasicType('uintsize')), + VariableDeclaration('size', BasicType('uintsize')), + ]) + + pointer_to_struct_str = PointerType(BasicType('struct str')) + + str_create = FunctionDeclaration('str_create', pointer_to_struct_str, []) + str_free = FunctionDeclaration('str_free', BasicType('void'), [ + VariableDeclaration('str', pointer_to_struct_str) + ]) + str_reset = FunctionDeclaration('str_reset', BasicType('void'), [ + VariableDeclaration('str', pointer_to_struct_str) + ]) + str_append_ch = FunctionDeclaration('str_append_ch', BasicType('intsize'), [ + VariableDeclaration('str', pointer_to_struct_str), + VariableDeclaration('ch', BasicType('uint32')) + ]) + + self.assertEqual(ast, HeaderFile([], [struct_str, str_create, str_free, str_reset, str_append_ch])) + + def test_unicode_hro(self): + code = self.unicode_hro_code + tokens = scan(code) + parse_tree = parse_header(tokens) + ast = build_ast(parse_tree, []) + + utf8_max_size = VariableDefinition('UTF8_MAX_SIZE', ConstType(BasicType('intsize')), ConstantExpression('4')) + utf8_invalid = VariableDefinition('UTF8_INVALID', ConstType(BasicType('uint8')), ConstantExpression('0x80')) + utf8_decode = FunctionDeclaration('utf8_decode', BasicType('uint32'), [ + VariableDeclaration('str', PointerType(PointerType(ConstType(BasicType('uint8'))))), + ]) + utf8_encode = FunctionDeclaration('utf8_encode', BasicType('intsize'), [ + VariableDeclaration('str', PointerType(BasicType('uint8'))), + VariableDeclaration('ch', BasicType('uint32')), + ]) + utf8_size = FunctionDeclaration('utf8_size', BasicType('intsize'), [ + VariableDeclaration('str', PointerType(ConstType(BasicType('uint8')))), + ]) + utf8_chsize = FunctionDeclaration('utf8_chsize', BasicType('intsize'), [ + VariableDeclaration('ch', BasicType('uint32')), + ]) + utf8_fgetch = FunctionDeclaration('utf8_fgetch', BasicType('uint32'), [ + VariableDeclaration('f', PointerType(BasicType('struct FILE'))), + ]) + utf8_fputch = FunctionDeclaration('utf8_fputch', BasicType('intsize'), [ + VariableDeclaration('f', PointerType(BasicType('struct FILE'))), + VariableDeclaration('ch', BasicType('uint32')), + ]) + + self.assertEqual(ast, HeaderFile([], [utf8_max_size, utf8_invalid, utf8_decode, utf8_encode, utf8_size, + utf8_chsize, utf8_fgetch, utf8_fputch])) + + def test_string_cro(self): + import tempfile + + code = self.string_cro_code + tokens = scan(code) + parse_tree = parse_implementation(tokens) + with tempfile.TemporaryDirectory() as include_dir: + with open(f"{include_dir}/str.hro", 'w', encoding='utf-8') as f: + f.write(self.str_hro_code) + with open(f"{include_dir}/unicode.hro", 'w', encoding='utf-8') as f: + f.write(self.unicode_hro_code) + ast = build_ast(parse_tree, [include_dir]) + + included_str_hro = build_ast(parse_header(scan(self.str_hro_code)), []) + included_unicode_hro = build_ast(parse_header(scan(self.unicode_hro_code)), []) + + expected = ImplementationFile([included_str_hro, included_unicode_hro], [ + FunctionDefinition('ensure_capacity', BasicType('bool'), [ + VariableDeclaration('str', PointerType(BasicType('struct str'))), + VariableDeclaration('len', BasicType('intsize')), + ], [ + IfStatement(ComparisonExpression( + AddExpression(VariableExpression('len'), ConstantExpression('1')), + '>=', + StructPointerElementExpression(VariableExpression('str'), 'size') + ), [ + VariableDefinition( + 'new', + PointerType(ArrayType( + BasicType('uint8'), + MultiplyExpression( + StructPointerElementExpression(VariableExpression('str'), 'size'), + ConstantExpression('2') + ) + )), + FunctionCallExpression(VariableExpression('realloc'), [ + StructPointerElementExpression(VariableExpression('str'), 'str'), + MultiplyExpression( + StructPointerElementExpression(VariableExpression('str'), 'size'), + ConstantExpression('2') + ) + ]) + ), + IfStatement(LogicalNotExpression(VariableExpression('new')), [ + ReturnStatement(ConstantExpression('false')) + ], None), + DirectAssignment( + StructPointerElementExpression(VariableExpression('str'), 'str'), + VariableExpression('new'), + ), + UpdateAssignment( + StructPointerElementExpression(VariableExpression('str'), 'size'), + '*=', + ConstantExpression('2'), + ) + ], None), + ReturnStatement(ConstantExpression('true')), + ]), + FunctionDefinition('str_create', PointerType(BasicType('struct str')), [], [ + VariableDefinition( + 'str', + PointerType(BasicType('struct str')), + FunctionCallExpression( + VariableExpression('calloc'), + [ + ConstantExpression('1'), + SizeofExpression(BasicType('struct str')), + ] + ) + ), + DirectAssignment( + StructPointerElementExpression(VariableExpression('str'), 'str'), + FunctionCallExpression(VariableExpression('malloc'), [ConstantExpression('16')]), + ), + DirectAssignment( + StructPointerElementExpression(VariableExpression('str'), 'size'), + ConstantExpression('16'), + ), + DirectAssignment( + StructPointerElementExpression(VariableExpression('str'), 'len'), + ConstantExpression('0'), + ), + DirectAssignment( + ArrayIndexExpression( + StructPointerElementExpression(VariableExpression('str'), 'str'), + ConstantExpression('0'), + ), + ConstantExpression(r"'\0'"), + ), + ReturnStatement(VariableExpression('str')), + ]), + FunctionDefinition('str_free', BasicType('void'), [ + VariableDeclaration('str', PointerType(BasicType('struct str'))) + ], [ + IfStatement(LogicalNotExpression(VariableExpression('str')), [ + ReturnStatement(None) + ], None), + ExpressionStatement(FunctionCallExpression( + VariableExpression('free'), + [StructPointerElementExpression(VariableExpression('str'), 'str')] + )), + ExpressionStatement(FunctionCallExpression( + VariableExpression('free'), + [VariableExpression('str')] + )) + ]), + FunctionDefinition('str_append_ch', BasicType('intsize'), [ + VariableDeclaration('str', PointerType(BasicType('struct str'))), + VariableDeclaration('ch', BasicType('uint32')), + ], [ + VariableDefinition('size', BasicType('intsize'), FunctionCallExpression( + VariableExpression('utf8_chsize'), + [VariableExpression('ch')] + )), + IfStatement(ComparisonExpression(VariableExpression('size'), '<=', ConstantExpression('0')), [ + ReturnStatement(NegativeExpression(ConstantExpression('1'))) + ], None), + IfStatement(LogicalNotExpression(FunctionCallExpression( + VariableExpression('ensure_capacity'), + [VariableExpression('str'), AddExpression( + StructPointerElementExpression( + VariableExpression('str'), + 'len' + ), + VariableExpression('size'), + )] + )), [ + ReturnStatement(NegativeExpression(ConstantExpression('1'))) + ], None), + ExpressionStatement(FunctionCallExpression(VariableExpression('utf8_encode'), [ + AddressOfExpression(ArrayIndexExpression( + StructPointerElementExpression(VariableExpression('str'), 'str'), + StructPointerElementExpression(VariableExpression('str'), 'len'), + )), + VariableExpression('ch'), + ])), + UpdateAssignment( + StructPointerElementExpression(VariableExpression('str'), 'len'), + '+=', + VariableExpression('size') + ), + DirectAssignment( + ArrayIndexExpression( + StructPointerElementExpression(VariableExpression('str'), 'str'), + StructPointerElementExpression(VariableExpression('str'), 'len'), + ), + ConstantExpression(r"'\0'"), + ), + ReturnStatement(VariableExpression('size')) + ]) + ]) + + self.assertDictEqual(dataclasses.asdict(ast), dataclasses.asdict(expected)) + self.assertEqual(ast, expected) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_parsing.py b/tests/test_parsing.py index 7463fe7..9b151ae 100644 --- a/tests/test_parsing.py +++ b/tests/test_parsing.py @@ -1,73 +1,8 @@ import unittest -from crowbar_reference_compiler import parse_header, parse_implementation, scan +from crowbar_reference_compiler import parse_header, scan class TestParsing(unittest.TestCase): def test_basic(self): print(parse_header(scan("int8 x();"))) - - def test_scdoc_str(self): - # adapted from https://git.sr.ht/~sircmpwn/scdoc/tree/master/include/str.h - print(parse_header(scan(r""" -struct str { - (uint8[size])* str; - uintsize len; - uintsize size; -} - -struct str *str_create(); -void str_free(struct str *str); -void str_reset(struct str *str); -intsize str_append_ch(struct str *str, uint32 ch); -"""))) - # adapted from https://git.sr.ht/~sircmpwn/scdoc/tree/master/src/string.c - print(parse_implementation(scan(r""" -include "stdlib.hro"; -include "stdint.hro"; -include "str.hro"; -include "unicode.hro"; - -bool ensure_capacity(struct str *str, intsize len) { - if (len + 1 >= str->size) { - (uint8[str->size * 2])* new = realloc(str->str, str->size * 2); - if (!new) { - return false; - } - str->str = new; - str->size *= 2; - } - return true; -} - -struct str *str_create() { - struct str *str = calloc(1, sizeof(struct str)); - str->str = malloc(16); - str->size = 16; - str->len = 0; - str->str[0] = '\0'; - return str; -} - -void str_free(struct str *str) { - if (!str) { - return; - } - free(str->str); - free(str); -} - -intsize str_append_ch(struct str *str, uint32 ch) { - intsize size = utf8_chsize(ch); - if (size <= 0) { - return -1; - } - if (!ensure_capacity(str, str->len + size)) { - return -1; - } - utf8_encode(&str->str[str->len], ch); - str->len += size; - str->str[str->len] = '\0'; - return size; -} -"""))) |