aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crowbar_reference_compiler/ast.py283
-rw-r--r--tests/test_ast.py356
-rw-r--r--tests/test_parsing.py67
3 files changed, 598 insertions, 108 deletions
diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py
index 4b9c21d..2a6963c 100644
--- a/crowbar_reference_compiler/ast.py
+++ b/crowbar_reference_compiler/ast.py
@@ -1,8 +1,10 @@
from dataclasses import dataclass
from pathlib import Path
-from typing import List, Optional, Tuple, Union
+import typing
+from typing import ClassVar, List, Tuple, Union
from parsimonious import NodeVisitor # type: ignore
+from parsimonious.expressions import Compound, OneOf, Optional, Sequence, TokenMatcher, ZeroOrMore # type: ignore
from .scanner import scan
from .parser import parse_header
@@ -29,6 +31,63 @@ class VariableExpression(Expression):
@dataclass
+class AddExpression(Expression):
+ term1: Expression
+ term2: Expression
+
+
+@dataclass
+class MultiplyExpression(Expression):
+ factor1: Expression
+ factor2: Expression
+
+
+@dataclass
+class StructPointerElementExpression(Expression):
+ base: Expression
+ element: str
+
+
+@dataclass
+class ArrayIndexExpression(Expression):
+ array: Expression
+ index: Expression
+
+
+@dataclass
+class FunctionCallExpression(Expression):
+ function: Expression
+ arguments: List[Expression]
+
+
+@dataclass
+class LogicalNotExpression(Expression):
+ body: Expression
+
+
+@dataclass
+class NegativeExpression(Expression):
+ body: Expression
+
+
+@dataclass
+class AddressOfExpression(Expression):
+ body: Expression
+
+
+@dataclass
+class SizeofExpression(Expression):
+ body: Union[Type, Expression]
+
+
+@dataclass
+class ComparisonExpression(Expression):
+ value1: Expression
+ op: str
+ value2: Expression
+
+
+@dataclass
class BasicType(Type):
name: str
@@ -89,13 +148,13 @@ class ExpressionStatement(Statement):
class IfStatement(Statement):
condition: Expression
then: List[Statement]
- els: Optional[List[Statement]]
+ els: typing.Optional[List[Statement]]
@dataclass
class SwitchStatement(Statement):
expression: Expression
- body: List[Union[Optional[Expression], Statement]]
+ body: List[Union[typing.Optional[Expression], Statement]]
@dataclass
@@ -122,7 +181,7 @@ class VariableDeclaration(Declaration, HeaderFileElement):
@dataclass
-class VariableDefinition(Declaration, ImplementationFileElement, Statement):
+class VariableDefinition(Declaration, HeaderFileElement, ImplementationFileElement, Statement):
"""Represents the definition of a variable."""
type: Type
value: Expression
@@ -152,7 +211,7 @@ class BreakStatement(Statement):
@dataclass
class ReturnStatement(Statement):
- body: Optional[Expression]
+ body: typing.Optional[Expression]
@dataclass
@@ -177,20 +236,20 @@ class CrementAssignment(AssignmentStatement):
@dataclass
class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of a struct type."""
- fields: Optional[List[VariableDeclaration]]
+ fields: typing.Optional[List[VariableDeclaration]]
@dataclass
class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of an enum type."""
- values: List[Tuple[str, Optional[int]]]
+ values: List[Tuple[str, typing.Optional[int]]]
@dataclass
class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of a union type."""
- tag: Optional[VariableDeclaration]
- cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]]
+ tag: typing.Optional[VariableDeclaration]
+ cases: Union[List[VariableDeclaration], List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]
@dataclass
@@ -210,6 +269,7 @@ class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileEleme
@dataclass
class HeaderFile:
+ grammar: ClassVar[str] = "HeaderFile <- IncludeStatement* HeaderFileElement+"
includes: List['HeaderFile']
contents: List[HeaderFileElement]
@@ -227,8 +287,6 @@ class ASTBuilder(NodeVisitor):
def visit_HeaderFile(self, node, visited_children) -> HeaderFile:
includes, elements = visited_children
- if not isinstance(includes, list):
- includes = []
return HeaderFile(includes, elements)
def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile:
@@ -253,36 +311,38 @@ class ASTBuilder(NodeVisitor):
def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration:
struct, name, lbrace, fields, rbrace = visited_children
assert struct.type == 'struct'
+ assert name.type == 'identifier'
+ name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
- name = name.data
- if not isinstance(fields, list):
- fields = [fields]
return StructDeclaration(name, fields)
def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration:
opaque, struct, name, semi = visited_children
assert opaque.type == 'opaque'
assert struct.type == 'struct'
- assert semi.type == ';'
+ assert name.type == 'identifier'
name = name.data
+ assert semi.type == ';'
return StructDeclaration(name, None)
def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration:
enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
assert enum.type == 'enum'
+ assert name.type == 'identifier'
+ name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
- name = name.data
values = [first_member]
for _, v in extra_members:
values.append(v)
return EnumDeclaration(name, values)
- def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]:
+ def visit_EnumMember(self, node, visited_children) -> Tuple[str, typing.Optional[Expression]]:
name, equals_value = visited_children
+ assert name.type == 'identifier'
name = name.data
- if len(equals_value) == 0:
+ if equals_value is None:
return name, None
_, value = equals_value
return name, value
@@ -290,9 +350,10 @@ class ASTBuilder(NodeVisitor):
def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration:
union, name, lbrace, tag, body, rbrace = visited_children
assert union.type == 'union'
+ assert name.type == 'identifier'
+ name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
- name = name.data
expected_tagname, body = body
if tag.name != expected_tagname:
raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
@@ -300,7 +361,7 @@ class ASTBuilder(NodeVisitor):
body = [body]
return UnionDeclaration(name, tag, body)
- def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]:
+ def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]:
switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
assert switch.type == 'switch'
assert lparen.type == '('
@@ -309,7 +370,7 @@ class ASTBuilder(NodeVisitor):
assert rbrace.type == '}'
return tag.data, body
- def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]:
+ def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, typing.Optional[VariableDeclaration]]:
cases, var = visited_children
if isinstance(cases, list):
cases = cases[0]
@@ -323,17 +384,16 @@ class ASTBuilder(NodeVisitor):
visited_children = visited_children[0]
# TODO don't explode on 'default:'
case, expr, colon = visited_children
- while isinstance(expr, list):
- expr = expr[0]
return expr
def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration:
fragile, union, name, lbrace, body, rbrace = visited_children
assert fragile.type == 'fragile'
assert union.type == 'union'
+ assert name.type == 'identifier'
+ name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
- name = name.data
return UnionDeclaration(name, None, body)
def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration:
@@ -343,15 +403,17 @@ class ASTBuilder(NodeVisitor):
def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition:
type, name, eq, value, semi = visited_children
+ assert name.type == 'identifier'
+ name = name.data
assert eq.type == '='
assert semi.type == ';'
- name = name.data
return VariableDefinition(name, type, value)
def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration:
type, name, semi = visited_children
- assert semi.type == ';'
+ assert name.type == 'identifier'
name = name.data
+ assert semi.type == ';'
return VariableDeclaration(name, type)
def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition:
@@ -363,9 +425,53 @@ class ASTBuilder(NodeVisitor):
assert name.type == 'identifier'
name = name.data
assert lparen.type == '('
+ if args is None:
+ args = []
assert rparen.type == ')'
return FunctionDeclaration(name, return_type, args)
+ def visit_SignatureArguments(self, node, visited_children) -> List[VariableDeclaration]:
+ first_type, first_name, rest, comma = visited_children
+ result = [VariableDeclaration(first_name.data, first_type)]
+ for comma, ty, name in rest:
+ result.append(VariableDeclaration(name.data, ty))
+ return result
+
+ def visit_IfStatement(self, node, visited_children):
+ kwd, lparen, condition, rparen, then, els = visited_children
+ assert kwd.type == 'if'
+ assert lparen.type == '('
+ assert rparen.type == ')'
+ if els is not None:
+ kwd, els = els
+ assert kwd.type == 'else'
+ return IfStatement(condition, then, els)
+
+ def visit_ReturnStatement(self, node, visited_children):
+ ret, body, semi = visited_children
+ assert ret.type == 'return'
+ assert semi.type == ';'
+ return ReturnStatement(body)
+
+ def visit_DirectAssignmentBody(self, node, visited_children):
+ dest, eq, value = visited_children
+ assert eq.type == '='
+ return DirectAssignment(dest, value)
+
+ def visit_UpdateAssignmentBody(self, node, visited_children):
+ dest, op, value = visited_children
+ return UpdateAssignment(dest, op.type, value)
+
+ def visit_AssignmentStatement(self, node, visited_children):
+ assignment, semi = visited_children
+ assert semi.type == ';'
+ return assignment
+
+ def visit_ExpressionStatement(self, node, visited_children):
+ expression, semi = visited_children
+ assert semi.type == ';'
+ return ExpressionStatement(expression)
+
def visit_BasicType(self, node, visited_children) -> Type:
while isinstance(visited_children, list) and len(visited_children) == 1:
visited_children = visited_children[0]
@@ -379,17 +485,23 @@ class ASTBuilder(NodeVisitor):
else:
category, name = visited_children
category = category.type
+ assert name.type == 'identifier'
name = name.data
return BasicType(f"{category} {name}")
return BasicType(visited_children.type)
+ def visit_ConstType(self, node, visited_children) -> ConstType:
+ const, contents = visited_children
+ assert const.type == 'const'
+ return ConstType(contents)
+
+ def visit_FunctionType(self, node, visited_children):
+ raise NotImplementedError('function types')
+
def visit_ArrayType(self, node, visited_children) -> ArrayType:
contents, lbracket, size, rbracket = visited_children
assert lbracket.type == '['
assert rbracket.type == ']'
- # TODO don't explode on nontrivial expression
- while isinstance(size, list):
- size = size[0]
return ArrayType(contents, size)
def visit_PointerType(self, node, visited_children) -> PointerType:
@@ -397,6 +509,12 @@ class ASTBuilder(NodeVisitor):
assert splat.type == '*'
return PointerType(contents)
+ def visit_Block(self, node, visited_children) -> List[Expression]:
+ lbrace, body, rbrace = visited_children
+ assert lbrace.type == '{'
+ assert rbrace.type == '}'
+ return body
+
def visit_AtomicExpression(self, node, visited_children) -> Expression:
if isinstance(visited_children, list) and len(visited_children) == 3:
lparen, body, rparen = visited_children
@@ -414,17 +532,106 @@ class ASTBuilder(NodeVisitor):
return ConstantExpression(body.type)
raise NotImplementedError()
+ def visit_StructPointerElementSuffix(self, node, visited_children):
+ separator, element = visited_children
+ assert separator.type == '->'
+ return lambda base: StructPointerElementExpression(base, element.data)
+
+ def visit_CommasExpressionList(self, node, visited_children):
+ first, rest, comma = visited_children
+ result = [first]
+ for comma, next in rest:
+ result.append(next)
+ return result
+
+ def visit_FunctionCallSuffix(self, node, visited_children):
+ lparen, args, rparen = visited_children
+ assert lparen.type == '('
+ assert rparen.type == ')'
+ if args is None:
+ args = []
+ return lambda base: FunctionCallExpression(base, args)
+
+ def visit_ArrayIndexSuffix(self, node, visited_children):
+ lbracket, index, rbracket = visited_children
+ assert lbracket.type == '['
+ assert rbracket.type == ']'
+ return lambda base: ArrayIndexExpression(base, index)
+
+ def visit_ObjectExpression(self, node, visited_children) -> Expression:
+ if isinstance(visited_children, list):
+ base, suffix = visited_children[0]
+ if len(suffix) > 0:
+ for suffix in suffix:
+ base = suffix(base)
+ return base
+ raise NotImplementedError('array/struct literals')
+
+ def visit_NegativeExpression(self, node, visited_children):
+ minus, body = visited_children
+ assert minus.type == '-'
+ return NegativeExpression(body)
+
+ def visit_AddressOfExpression(self, node, visited_children):
+ ampersand, body = visited_children
+ assert ampersand.type == '&'
+ return AddressOfExpression(body)
+
+ def visit_LogicalNotExpression(self, node, visited_children):
+ bang, body = visited_children
+ assert bang.type == '!'
+ return LogicalNotExpression(body)
+
+ def visit_SizeofExpression(self, node, visited_children):
+ sizeof, argument = visited_children[0]
+ assert sizeof.type == 'sizeof'
+ return SizeofExpression(argument)
+
+ def visit_TermExpression(self, node, visited_children) -> Expression:
+ base, suffix = visited_children
+ if suffix is not None:
+ for op, factor in suffix:
+ if op.type == '*':
+ base = MultiplyExpression(base, factor)
+ else:
+ raise NotImplementedError('term suffix ' + op)
+ return base
+
+ def visit_ArithmeticExpression(self, node, visited_children) -> Expression:
+ base, suffix = visited_children
+ if suffix is not None:
+ for op, term in suffix:
+ if op.type == '+':
+ base = AddExpression(base, term)
+ else:
+ raise NotImplementedError('arithmetic suffix ' + op)
+ return base
+
+ def visit_GreaterEqExpression(self, node, visited_children):
+ value1, op, value2 = visited_children
+ assert op.type == '>='
+ return ComparisonExpression(value1, '>=', value2)
+
+ def visit_LessEqExpression(self, node, visited_children):
+ value1, op, value2 = visited_children
+ assert op.type == '<='
+ return ComparisonExpression(value1, '<=', value2)
+
def generic_visit(self, node, visited_children):
- """ The generic visit method. """
- if not visited_children:
- if len(node.text) == 0:
- return []
- if len(node.text) == 1:
- return node.text[0]
- raise ValueError('just a node: ' + str(node))
- if len(visited_children) == 1:
+ if isinstance(node.expr, TokenMatcher):
+ return node.text[0]
+ if isinstance(node.expr, OneOf):
+ return visited_children[0]
+ if isinstance(node.expr, Optional):
+ if len(visited_children) == 0:
+ return None
return visited_children[0]
- return visited_children
+ if isinstance(node.expr, Sequence) and node.expr.name != '':
+ raise NotImplementedError('visit for sequence ' + str(node.expr))
+ if isinstance(node.expr, Compound):
+ return visited_children
+ print(node.expr)
+ return super(ASTBuilder, self).generic_visit(node, visited_children)
def build_ast(parse_tree, include_dirs):
diff --git a/tests/test_ast.py b/tests/test_ast.py
index b2c592d..b2c33c9 100644
--- a/tests/test_ast.py
+++ b/tests/test_ast.py
@@ -1,12 +1,43 @@
+import dataclasses
import unittest
-from crowbar_reference_compiler import build_ast, parse_header, scan
-from crowbar_reference_compiler.ast import ArrayType, BasicType, ConstantExpression, EnumDeclaration, HeaderFile, \
- PointerType, StructDeclaration, UnionDeclaration, VariableDeclaration, VariableExpression
+from crowbar_reference_compiler import build_ast, parse_header, parse_implementation, scan
+from crowbar_reference_compiler.ast import (
+ AddExpression,
+ AddressOfExpression,
+ ArrayIndexExpression,
+ ArrayType,
+ DirectAssignment,
+ BasicType,
+ ComparisonExpression,
+ ConstType,
+ ConstantExpression,
+ EnumDeclaration,
+ ExpressionStatement,
+ FunctionCallExpression,
+ FunctionDeclaration,
+ FunctionDefinition,
+ HeaderFile,
+ IfStatement,
+ ImplementationFile,
+ LogicalNotExpression,
+ MultiplyExpression,
+ NegativeExpression,
+ PointerType,
+ ReturnStatement,
+ SizeofExpression,
+ StructDeclaration,
+ StructPointerElementExpression,
+ UnionDeclaration,
+ UpdateAssignment,
+ VariableDeclaration,
+ VariableDefinition,
+ VariableExpression
+)
class TestAST(unittest.TestCase):
- def test_kitchen_sink(self):
+ def test_type_kitchen_sink(self):
code = r"""
struct normal {
bool fake;
@@ -50,5 +81,322 @@ fragile union not_robust {
self.assertEqual(decls, HeaderFile([], [normal, ope, sample, robust, not_robust]))
+class TestRealCode(unittest.TestCase):
+ str_hro_code = r"""
+struct str {
+ (uint8[size])* str;
+ uintsize len;
+ uintsize size;
+}
+
+struct str *str_create();
+void str_free(struct str *str);
+void str_reset(struct str *str);
+intsize str_append_ch(struct str *str, uint32 ch);
+"""
+
+ unicode_hro_code = r"""
+// Technically UTF-8 supports up to 6 byte codepoints, but Unicode itself
+// doesn't really bother with more than 4.
+const intsize UTF8_MAX_SIZE = 4;
+
+const uint8 UTF8_INVALID = 0x80;
+
+/**
+ * Grabs the next UTF-8 character and advances the string pointer
+ */
+uint32 utf8_decode(((const uint8)*)* str);
+
+/**
+ * Encodes a character as UTF-8 and returns the length of that character.
+ */
+intsize utf8_encode(uint8 *str, uint32 ch);
+
+/**
+ * Returns the size of the next UTF-8 character
+ */
+intsize utf8_size((const uint8)* str);
+
+/**
+ * Returns the size of a UTF-8 character
+ */
+intsize utf8_chsize(uint32 ch);
+
+/**
+ * Reads and returns the next character from the file.
+ */
+uint32 utf8_fgetch(struct FILE *f);
+
+/**
+ * Writes this character to the file and returns the number of bytes written.
+ */
+intsize utf8_fputch(struct FILE *f, uint32 ch);
+"""
+
+ string_cro_code = r"""
+//include "stdlib.hro";
+//include "stdint.hro";
+include "str.hro";
+include "unicode.hro";
+
+bool ensure_capacity(struct str *str, intsize len) {
+ if (len + 1 >= str->size) {
+ (uint8[str->size * 2])* new = realloc(str->str, str->size * 2);
+ if (!new) {
+ return false;
+ }
+ str->str = new;
+ str->size *= 2;
+ }
+ return true;
+}
+
+struct str *str_create() {
+ struct str *str = calloc(1, sizeof(struct str));
+ str->str = malloc(16);
+ str->size = 16;
+ str->len = 0;
+ str->str[0] = '\0';
+ return str;
+}
+
+void str_free(struct str *str) {
+ if (!str) {
+ return;
+ }
+ free(str->str);
+ free(str);
+}
+
+intsize str_append_ch(struct str *str, uint32 ch) {
+ intsize size = utf8_chsize(ch);
+ if (size <= 0) {
+ return -1;
+ }
+ if (!ensure_capacity(str, str->len + size)) {
+ return -1;
+ }
+ utf8_encode(&str->str[str->len], ch);
+ str->len += size;
+ str->str[str->len] = '\0';
+ return size;
+}
+"""
+
+ def test_str_hro(self):
+ code = self.str_hro_code
+ tokens = scan(code)
+ parse_tree = parse_header(tokens)
+ ast = build_ast(parse_tree, [])
+ struct_str = StructDeclaration('str', [
+ VariableDeclaration('str', PointerType(ArrayType(BasicType('uint8'), VariableExpression('size')))),
+ VariableDeclaration('len', BasicType('uintsize')),
+ VariableDeclaration('size', BasicType('uintsize')),
+ ])
+
+ pointer_to_struct_str = PointerType(BasicType('struct str'))
+
+ str_create = FunctionDeclaration('str_create', pointer_to_struct_str, [])
+ str_free = FunctionDeclaration('str_free', BasicType('void'), [
+ VariableDeclaration('str', pointer_to_struct_str)
+ ])
+ str_reset = FunctionDeclaration('str_reset', BasicType('void'), [
+ VariableDeclaration('str', pointer_to_struct_str)
+ ])
+ str_append_ch = FunctionDeclaration('str_append_ch', BasicType('intsize'), [
+ VariableDeclaration('str', pointer_to_struct_str),
+ VariableDeclaration('ch', BasicType('uint32'))
+ ])
+
+ self.assertEqual(ast, HeaderFile([], [struct_str, str_create, str_free, str_reset, str_append_ch]))
+
+ def test_unicode_hro(self):
+ code = self.unicode_hro_code
+ tokens = scan(code)
+ parse_tree = parse_header(tokens)
+ ast = build_ast(parse_tree, [])
+
+ utf8_max_size = VariableDefinition('UTF8_MAX_SIZE', ConstType(BasicType('intsize')), ConstantExpression('4'))
+ utf8_invalid = VariableDefinition('UTF8_INVALID', ConstType(BasicType('uint8')), ConstantExpression('0x80'))
+ utf8_decode = FunctionDeclaration('utf8_decode', BasicType('uint32'), [
+ VariableDeclaration('str', PointerType(PointerType(ConstType(BasicType('uint8'))))),
+ ])
+ utf8_encode = FunctionDeclaration('utf8_encode', BasicType('intsize'), [
+ VariableDeclaration('str', PointerType(BasicType('uint8'))),
+ VariableDeclaration('ch', BasicType('uint32')),
+ ])
+ utf8_size = FunctionDeclaration('utf8_size', BasicType('intsize'), [
+ VariableDeclaration('str', PointerType(ConstType(BasicType('uint8')))),
+ ])
+ utf8_chsize = FunctionDeclaration('utf8_chsize', BasicType('intsize'), [
+ VariableDeclaration('ch', BasicType('uint32')),
+ ])
+ utf8_fgetch = FunctionDeclaration('utf8_fgetch', BasicType('uint32'), [
+ VariableDeclaration('f', PointerType(BasicType('struct FILE'))),
+ ])
+ utf8_fputch = FunctionDeclaration('utf8_fputch', BasicType('intsize'), [
+ VariableDeclaration('f', PointerType(BasicType('struct FILE'))),
+ VariableDeclaration('ch', BasicType('uint32')),
+ ])
+
+ self.assertEqual(ast, HeaderFile([], [utf8_max_size, utf8_invalid, utf8_decode, utf8_encode, utf8_size,
+ utf8_chsize, utf8_fgetch, utf8_fputch]))
+
+ def test_string_cro(self):
+ import tempfile
+
+ code = self.string_cro_code
+ tokens = scan(code)
+ parse_tree = parse_implementation(tokens)
+ with tempfile.TemporaryDirectory() as include_dir:
+ with open(f"{include_dir}/str.hro", 'w', encoding='utf-8') as f:
+ f.write(self.str_hro_code)
+ with open(f"{include_dir}/unicode.hro", 'w', encoding='utf-8') as f:
+ f.write(self.unicode_hro_code)
+ ast = build_ast(parse_tree, [include_dir])
+
+ included_str_hro = build_ast(parse_header(scan(self.str_hro_code)), [])
+ included_unicode_hro = build_ast(parse_header(scan(self.unicode_hro_code)), [])
+
+ expected = ImplementationFile([included_str_hro, included_unicode_hro], [
+ FunctionDefinition('ensure_capacity', BasicType('bool'), [
+ VariableDeclaration('str', PointerType(BasicType('struct str'))),
+ VariableDeclaration('len', BasicType('intsize')),
+ ], [
+ IfStatement(ComparisonExpression(
+ AddExpression(VariableExpression('len'), ConstantExpression('1')),
+ '>=',
+ StructPointerElementExpression(VariableExpression('str'), 'size')
+ ), [
+ VariableDefinition(
+ 'new',
+ PointerType(ArrayType(
+ BasicType('uint8'),
+ MultiplyExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'size'),
+ ConstantExpression('2')
+ )
+ )),
+ FunctionCallExpression(VariableExpression('realloc'), [
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ MultiplyExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'size'),
+ ConstantExpression('2')
+ )
+ ])
+ ),
+ IfStatement(LogicalNotExpression(VariableExpression('new')), [
+ ReturnStatement(ConstantExpression('false'))
+ ], None),
+ DirectAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ VariableExpression('new'),
+ ),
+ UpdateAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'size'),
+ '*=',
+ ConstantExpression('2'),
+ )
+ ], None),
+ ReturnStatement(ConstantExpression('true')),
+ ]),
+ FunctionDefinition('str_create', PointerType(BasicType('struct str')), [], [
+ VariableDefinition(
+ 'str',
+ PointerType(BasicType('struct str')),
+ FunctionCallExpression(
+ VariableExpression('calloc'),
+ [
+ ConstantExpression('1'),
+ SizeofExpression(BasicType('struct str')),
+ ]
+ )
+ ),
+ DirectAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ FunctionCallExpression(VariableExpression('malloc'), [ConstantExpression('16')]),
+ ),
+ DirectAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'size'),
+ ConstantExpression('16'),
+ ),
+ DirectAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'len'),
+ ConstantExpression('0'),
+ ),
+ DirectAssignment(
+ ArrayIndexExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ ConstantExpression('0'),
+ ),
+ ConstantExpression(r"'\0'"),
+ ),
+ ReturnStatement(VariableExpression('str')),
+ ]),
+ FunctionDefinition('str_free', BasicType('void'), [
+ VariableDeclaration('str', PointerType(BasicType('struct str')))
+ ], [
+ IfStatement(LogicalNotExpression(VariableExpression('str')), [
+ ReturnStatement(None)
+ ], None),
+ ExpressionStatement(FunctionCallExpression(
+ VariableExpression('free'),
+ [StructPointerElementExpression(VariableExpression('str'), 'str')]
+ )),
+ ExpressionStatement(FunctionCallExpression(
+ VariableExpression('free'),
+ [VariableExpression('str')]
+ ))
+ ]),
+ FunctionDefinition('str_append_ch', BasicType('intsize'), [
+ VariableDeclaration('str', PointerType(BasicType('struct str'))),
+ VariableDeclaration('ch', BasicType('uint32')),
+ ], [
+ VariableDefinition('size', BasicType('intsize'), FunctionCallExpression(
+ VariableExpression('utf8_chsize'),
+ [VariableExpression('ch')]
+ )),
+ IfStatement(ComparisonExpression(VariableExpression('size'), '<=', ConstantExpression('0')), [
+ ReturnStatement(NegativeExpression(ConstantExpression('1')))
+ ], None),
+ IfStatement(LogicalNotExpression(FunctionCallExpression(
+ VariableExpression('ensure_capacity'),
+ [VariableExpression('str'), AddExpression(
+ StructPointerElementExpression(
+ VariableExpression('str'),
+ 'len'
+ ),
+ VariableExpression('size'),
+ )]
+ )), [
+ ReturnStatement(NegativeExpression(ConstantExpression('1')))
+ ], None),
+ ExpressionStatement(FunctionCallExpression(VariableExpression('utf8_encode'), [
+ AddressOfExpression(ArrayIndexExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ StructPointerElementExpression(VariableExpression('str'), 'len'),
+ )),
+ VariableExpression('ch'),
+ ])),
+ UpdateAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'len'),
+ '+=',
+ VariableExpression('size')
+ ),
+ DirectAssignment(
+ ArrayIndexExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ StructPointerElementExpression(VariableExpression('str'), 'len'),
+ ),
+ ConstantExpression(r"'\0'"),
+ ),
+ ReturnStatement(VariableExpression('size'))
+ ])
+ ])
+
+ self.assertDictEqual(dataclasses.asdict(ast), dataclasses.asdict(expected))
+ self.assertEqual(ast, expected)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_parsing.py b/tests/test_parsing.py
index 7463fe7..9b151ae 100644
--- a/tests/test_parsing.py
+++ b/tests/test_parsing.py
@@ -1,73 +1,8 @@
import unittest
-from crowbar_reference_compiler import parse_header, parse_implementation, scan
+from crowbar_reference_compiler import parse_header, scan
class TestParsing(unittest.TestCase):
def test_basic(self):
print(parse_header(scan("int8 x();")))
-
- def test_scdoc_str(self):
- # adapted from https://git.sr.ht/~sircmpwn/scdoc/tree/master/include/str.h
- print(parse_header(scan(r"""
-struct str {
- (uint8[size])* str;
- uintsize len;
- uintsize size;
-}
-
-struct str *str_create();
-void str_free(struct str *str);
-void str_reset(struct str *str);
-intsize str_append_ch(struct str *str, uint32 ch);
-""")))
- # adapted from https://git.sr.ht/~sircmpwn/scdoc/tree/master/src/string.c
- print(parse_implementation(scan(r"""
-include "stdlib.hro";
-include "stdint.hro";
-include "str.hro";
-include "unicode.hro";
-
-bool ensure_capacity(struct str *str, intsize len) {
- if (len + 1 >= str->size) {
- (uint8[str->size * 2])* new = realloc(str->str, str->size * 2);
- if (!new) {
- return false;
- }
- str->str = new;
- str->size *= 2;
- }
- return true;
-}
-
-struct str *str_create() {
- struct str *str = calloc(1, sizeof(struct str));
- str->str = malloc(16);
- str->size = 16;
- str->len = 0;
- str->str[0] = '\0';
- return str;
-}
-
-void str_free(struct str *str) {
- if (!str) {
- return;
- }
- free(str->str);
- free(str);
-}
-
-intsize str_append_ch(struct str *str, uint32 ch) {
- intsize size = utf8_chsize(ch);
- if (size <= 0) {
- return -1;
- }
- if (!ensure_capacity(str, str->len + size)) {
- return -1;
- }
- utf8_encode(&str->str[str->len], ch);
- str->len += size;
- str->str[str->len] = '\0';
- return size;
-}
-""")))