aboutsummaryrefslogtreecommitdiff
path: root/tests/test_ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_ast.py')
-rw-r--r--tests/test_ast.py356
1 files changed, 352 insertions, 4 deletions
diff --git a/tests/test_ast.py b/tests/test_ast.py
index b2c592d..b2c33c9 100644
--- a/tests/test_ast.py
+++ b/tests/test_ast.py
@@ -1,12 +1,43 @@
+import dataclasses
import unittest
-from crowbar_reference_compiler import build_ast, parse_header, scan
-from crowbar_reference_compiler.ast import ArrayType, BasicType, ConstantExpression, EnumDeclaration, HeaderFile, \
- PointerType, StructDeclaration, UnionDeclaration, VariableDeclaration, VariableExpression
+from crowbar_reference_compiler import build_ast, parse_header, parse_implementation, scan
+from crowbar_reference_compiler.ast import (
+ AddExpression,
+ AddressOfExpression,
+ ArrayIndexExpression,
+ ArrayType,
+ DirectAssignment,
+ BasicType,
+ ComparisonExpression,
+ ConstType,
+ ConstantExpression,
+ EnumDeclaration,
+ ExpressionStatement,
+ FunctionCallExpression,
+ FunctionDeclaration,
+ FunctionDefinition,
+ HeaderFile,
+ IfStatement,
+ ImplementationFile,
+ LogicalNotExpression,
+ MultiplyExpression,
+ NegativeExpression,
+ PointerType,
+ ReturnStatement,
+ SizeofExpression,
+ StructDeclaration,
+ StructPointerElementExpression,
+ UnionDeclaration,
+ UpdateAssignment,
+ VariableDeclaration,
+ VariableDefinition,
+ VariableExpression
+)
class TestAST(unittest.TestCase):
- def test_kitchen_sink(self):
+ def test_type_kitchen_sink(self):
code = r"""
struct normal {
bool fake;
@@ -50,5 +81,322 @@ fragile union not_robust {
self.assertEqual(decls, HeaderFile([], [normal, ope, sample, robust, not_robust]))
+class TestRealCode(unittest.TestCase):
+ str_hro_code = r"""
+struct str {
+ (uint8[size])* str;
+ uintsize len;
+ uintsize size;
+}
+
+struct str *str_create();
+void str_free(struct str *str);
+void str_reset(struct str *str);
+intsize str_append_ch(struct str *str, uint32 ch);
+"""
+
+ unicode_hro_code = r"""
+// Technically UTF-8 supports up to 6 byte codepoints, but Unicode itself
+// doesn't really bother with more than 4.
+const intsize UTF8_MAX_SIZE = 4;
+
+const uint8 UTF8_INVALID = 0x80;
+
+/**
+ * Grabs the next UTF-8 character and advances the string pointer
+ */
+uint32 utf8_decode(((const uint8)*)* str);
+
+/**
+ * Encodes a character as UTF-8 and returns the length of that character.
+ */
+intsize utf8_encode(uint8 *str, uint32 ch);
+
+/**
+ * Returns the size of the next UTF-8 character
+ */
+intsize utf8_size((const uint8)* str);
+
+/**
+ * Returns the size of a UTF-8 character
+ */
+intsize utf8_chsize(uint32 ch);
+
+/**
+ * Reads and returns the next character from the file.
+ */
+uint32 utf8_fgetch(struct FILE *f);
+
+/**
+ * Writes this character to the file and returns the number of bytes written.
+ */
+intsize utf8_fputch(struct FILE *f, uint32 ch);
+"""
+
+ string_cro_code = r"""
+//include "stdlib.hro";
+//include "stdint.hro";
+include "str.hro";
+include "unicode.hro";
+
+bool ensure_capacity(struct str *str, intsize len) {
+ if (len + 1 >= str->size) {
+ (uint8[str->size * 2])* new = realloc(str->str, str->size * 2);
+ if (!new) {
+ return false;
+ }
+ str->str = new;
+ str->size *= 2;
+ }
+ return true;
+}
+
+struct str *str_create() {
+ struct str *str = calloc(1, sizeof(struct str));
+ str->str = malloc(16);
+ str->size = 16;
+ str->len = 0;
+ str->str[0] = '\0';
+ return str;
+}
+
+void str_free(struct str *str) {
+ if (!str) {
+ return;
+ }
+ free(str->str);
+ free(str);
+}
+
+intsize str_append_ch(struct str *str, uint32 ch) {
+ intsize size = utf8_chsize(ch);
+ if (size <= 0) {
+ return -1;
+ }
+ if (!ensure_capacity(str, str->len + size)) {
+ return -1;
+ }
+ utf8_encode(&str->str[str->len], ch);
+ str->len += size;
+ str->str[str->len] = '\0';
+ return size;
+}
+"""
+
+ def test_str_hro(self):
+ code = self.str_hro_code
+ tokens = scan(code)
+ parse_tree = parse_header(tokens)
+ ast = build_ast(parse_tree, [])
+ struct_str = StructDeclaration('str', [
+ VariableDeclaration('str', PointerType(ArrayType(BasicType('uint8'), VariableExpression('size')))),
+ VariableDeclaration('len', BasicType('uintsize')),
+ VariableDeclaration('size', BasicType('uintsize')),
+ ])
+
+ pointer_to_struct_str = PointerType(BasicType('struct str'))
+
+ str_create = FunctionDeclaration('str_create', pointer_to_struct_str, [])
+ str_free = FunctionDeclaration('str_free', BasicType('void'), [
+ VariableDeclaration('str', pointer_to_struct_str)
+ ])
+ str_reset = FunctionDeclaration('str_reset', BasicType('void'), [
+ VariableDeclaration('str', pointer_to_struct_str)
+ ])
+ str_append_ch = FunctionDeclaration('str_append_ch', BasicType('intsize'), [
+ VariableDeclaration('str', pointer_to_struct_str),
+ VariableDeclaration('ch', BasicType('uint32'))
+ ])
+
+ self.assertEqual(ast, HeaderFile([], [struct_str, str_create, str_free, str_reset, str_append_ch]))
+
+ def test_unicode_hro(self):
+ code = self.unicode_hro_code
+ tokens = scan(code)
+ parse_tree = parse_header(tokens)
+ ast = build_ast(parse_tree, [])
+
+ utf8_max_size = VariableDefinition('UTF8_MAX_SIZE', ConstType(BasicType('intsize')), ConstantExpression('4'))
+ utf8_invalid = VariableDefinition('UTF8_INVALID', ConstType(BasicType('uint8')), ConstantExpression('0x80'))
+ utf8_decode = FunctionDeclaration('utf8_decode', BasicType('uint32'), [
+ VariableDeclaration('str', PointerType(PointerType(ConstType(BasicType('uint8'))))),
+ ])
+ utf8_encode = FunctionDeclaration('utf8_encode', BasicType('intsize'), [
+ VariableDeclaration('str', PointerType(BasicType('uint8'))),
+ VariableDeclaration('ch', BasicType('uint32')),
+ ])
+ utf8_size = FunctionDeclaration('utf8_size', BasicType('intsize'), [
+ VariableDeclaration('str', PointerType(ConstType(BasicType('uint8')))),
+ ])
+ utf8_chsize = FunctionDeclaration('utf8_chsize', BasicType('intsize'), [
+ VariableDeclaration('ch', BasicType('uint32')),
+ ])
+ utf8_fgetch = FunctionDeclaration('utf8_fgetch', BasicType('uint32'), [
+ VariableDeclaration('f', PointerType(BasicType('struct FILE'))),
+ ])
+ utf8_fputch = FunctionDeclaration('utf8_fputch', BasicType('intsize'), [
+ VariableDeclaration('f', PointerType(BasicType('struct FILE'))),
+ VariableDeclaration('ch', BasicType('uint32')),
+ ])
+
+ self.assertEqual(ast, HeaderFile([], [utf8_max_size, utf8_invalid, utf8_decode, utf8_encode, utf8_size,
+ utf8_chsize, utf8_fgetch, utf8_fputch]))
+
+ def test_string_cro(self):
+ import tempfile
+
+ code = self.string_cro_code
+ tokens = scan(code)
+ parse_tree = parse_implementation(tokens)
+ with tempfile.TemporaryDirectory() as include_dir:
+ with open(f"{include_dir}/str.hro", 'w', encoding='utf-8') as f:
+ f.write(self.str_hro_code)
+ with open(f"{include_dir}/unicode.hro", 'w', encoding='utf-8') as f:
+ f.write(self.unicode_hro_code)
+ ast = build_ast(parse_tree, [include_dir])
+
+ included_str_hro = build_ast(parse_header(scan(self.str_hro_code)), [])
+ included_unicode_hro = build_ast(parse_header(scan(self.unicode_hro_code)), [])
+
+ expected = ImplementationFile([included_str_hro, included_unicode_hro], [
+ FunctionDefinition('ensure_capacity', BasicType('bool'), [
+ VariableDeclaration('str', PointerType(BasicType('struct str'))),
+ VariableDeclaration('len', BasicType('intsize')),
+ ], [
+ IfStatement(ComparisonExpression(
+ AddExpression(VariableExpression('len'), ConstantExpression('1')),
+ '>=',
+ StructPointerElementExpression(VariableExpression('str'), 'size')
+ ), [
+ VariableDefinition(
+ 'new',
+ PointerType(ArrayType(
+ BasicType('uint8'),
+ MultiplyExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'size'),
+ ConstantExpression('2')
+ )
+ )),
+ FunctionCallExpression(VariableExpression('realloc'), [
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ MultiplyExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'size'),
+ ConstantExpression('2')
+ )
+ ])
+ ),
+ IfStatement(LogicalNotExpression(VariableExpression('new')), [
+ ReturnStatement(ConstantExpression('false'))
+ ], None),
+ DirectAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ VariableExpression('new'),
+ ),
+ UpdateAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'size'),
+ '*=',
+ ConstantExpression('2'),
+ )
+ ], None),
+ ReturnStatement(ConstantExpression('true')),
+ ]),
+ FunctionDefinition('str_create', PointerType(BasicType('struct str')), [], [
+ VariableDefinition(
+ 'str',
+ PointerType(BasicType('struct str')),
+ FunctionCallExpression(
+ VariableExpression('calloc'),
+ [
+ ConstantExpression('1'),
+ SizeofExpression(BasicType('struct str')),
+ ]
+ )
+ ),
+ DirectAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ FunctionCallExpression(VariableExpression('malloc'), [ConstantExpression('16')]),
+ ),
+ DirectAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'size'),
+ ConstantExpression('16'),
+ ),
+ DirectAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'len'),
+ ConstantExpression('0'),
+ ),
+ DirectAssignment(
+ ArrayIndexExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ ConstantExpression('0'),
+ ),
+ ConstantExpression(r"'\0'"),
+ ),
+ ReturnStatement(VariableExpression('str')),
+ ]),
+ FunctionDefinition('str_free', BasicType('void'), [
+ VariableDeclaration('str', PointerType(BasicType('struct str')))
+ ], [
+ IfStatement(LogicalNotExpression(VariableExpression('str')), [
+ ReturnStatement(None)
+ ], None),
+ ExpressionStatement(FunctionCallExpression(
+ VariableExpression('free'),
+ [StructPointerElementExpression(VariableExpression('str'), 'str')]
+ )),
+ ExpressionStatement(FunctionCallExpression(
+ VariableExpression('free'),
+ [VariableExpression('str')]
+ ))
+ ]),
+ FunctionDefinition('str_append_ch', BasicType('intsize'), [
+ VariableDeclaration('str', PointerType(BasicType('struct str'))),
+ VariableDeclaration('ch', BasicType('uint32')),
+ ], [
+ VariableDefinition('size', BasicType('intsize'), FunctionCallExpression(
+ VariableExpression('utf8_chsize'),
+ [VariableExpression('ch')]
+ )),
+ IfStatement(ComparisonExpression(VariableExpression('size'), '<=', ConstantExpression('0')), [
+ ReturnStatement(NegativeExpression(ConstantExpression('1')))
+ ], None),
+ IfStatement(LogicalNotExpression(FunctionCallExpression(
+ VariableExpression('ensure_capacity'),
+ [VariableExpression('str'), AddExpression(
+ StructPointerElementExpression(
+ VariableExpression('str'),
+ 'len'
+ ),
+ VariableExpression('size'),
+ )]
+ )), [
+ ReturnStatement(NegativeExpression(ConstantExpression('1')))
+ ], None),
+ ExpressionStatement(FunctionCallExpression(VariableExpression('utf8_encode'), [
+ AddressOfExpression(ArrayIndexExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ StructPointerElementExpression(VariableExpression('str'), 'len'),
+ )),
+ VariableExpression('ch'),
+ ])),
+ UpdateAssignment(
+ StructPointerElementExpression(VariableExpression('str'), 'len'),
+ '+=',
+ VariableExpression('size')
+ ),
+ DirectAssignment(
+ ArrayIndexExpression(
+ StructPointerElementExpression(VariableExpression('str'), 'str'),
+ StructPointerElementExpression(VariableExpression('str'), 'len'),
+ ),
+ ConstantExpression(r"'\0'"),
+ ),
+ ReturnStatement(VariableExpression('size'))
+ ])
+ ])
+
+ self.assertDictEqual(dataclasses.asdict(ast), dataclasses.asdict(expected))
+ self.assertEqual(ast, expected)
+
+
if __name__ == '__main__':
unittest.main()