import dataclasses import unittest from crowbar_reference_compiler import build_ast, parse_header, parse_implementation, scan from crowbar_reference_compiler.ast import ( AddExpression, AddressOfExpression, ArrayIndexExpression, ArrayType, DirectAssignment, BasicType, ComparisonExpression, ConstType, ConstantExpression, EnumDeclaration, ExpressionStatement, FunctionCallExpression, FunctionDeclaration, FunctionDefinition, HeaderFile, IfStatement, ImplementationFile, LogicalNotExpression, MultiplyExpression, NegativeExpression, PointerType, ReturnStatement, SizeofExpression, StructDeclaration, StructPointerElementExpression, UnionDeclaration, UpdateAssignment, VariableDeclaration, VariableDefinition, VariableExpression ) class TestAST(unittest.TestCase): def test_type_kitchen_sink(self): code = r""" struct normal { bool fake; (uint8[3])* data; } opaque struct ope; enum sample { Testing, } union robust { enum sample tag; switch (tag) { case Testing: bool testPassed; } } fragile union not_robust { int8 sample; bool nope; } """ tokens = scan(code) parse_tree = parse_header(tokens) decls = build_ast(parse_tree, []) normal = StructDeclaration('normal', [ VariableDeclaration('fake', BasicType('bool')), VariableDeclaration('data', PointerType(ArrayType(BasicType('uint8'), ConstantExpression('3')))), ]) ope = StructDeclaration('ope', None) sample = EnumDeclaration('sample', [('Testing', None)]) robust = UnionDeclaration('robust', VariableDeclaration('tag', BasicType('enum sample')), [(VariableExpression('Testing'), VariableDeclaration('testPassed', BasicType('bool')))]) not_robust = UnionDeclaration('not_robust', None, [ VariableDeclaration('sample', BasicType('int8')), VariableDeclaration('nope', BasicType('bool')), ]) self.assertEqual(decls, HeaderFile([], [normal, ope, sample, robust, not_robust])) class TestRealCode(unittest.TestCase): str_hro_code = r""" struct str { (uint8[size])* str; uintsize len; uintsize size; } struct str *str_create(); void str_free(struct str *str); void str_reset(struct str *str); intsize str_append_ch(struct str *str, uint32 ch); """ unicode_hro_code = r""" // Technically UTF-8 supports up to 6 byte codepoints, but Unicode itself // doesn't really bother with more than 4. const intsize UTF8_MAX_SIZE = 4; const uint8 UTF8_INVALID = 0x80; /** * Grabs the next UTF-8 character and advances the string pointer */ uint32 utf8_decode(((const uint8)*)* str); /** * Encodes a character as UTF-8 and returns the length of that character. */ intsize utf8_encode(uint8 *str, uint32 ch); /** * Returns the size of the next UTF-8 character */ intsize utf8_size((const uint8)* str); /** * Returns the size of a UTF-8 character */ intsize utf8_chsize(uint32 ch); /** * Reads and returns the next character from the file. */ uint32 utf8_fgetch(struct FILE *f); /** * Writes this character to the file and returns the number of bytes written. */ intsize utf8_fputch(struct FILE *f, uint32 ch); """ string_cro_code = r""" //include "stdlib.hro"; //include "stdint.hro"; include "str.hro"; include "unicode.hro"; bool ensure_capacity(struct str *str, intsize len) { if (len + 1 >= str->size) { (uint8[str->size * 2])* new = realloc(str->str, str->size * 2); if (!new) { return false; } str->str = new; str->size *= 2; } return true; } struct str *str_create() { struct str *str = calloc(1, sizeof(struct str)); str->str = malloc(16); str->size = 16; str->len = 0; str->str[0] = '\0'; return str; } void str_free(struct str *str) { if (!str) { return; } free(str->str); free(str); } intsize str_append_ch(struct str *str, uint32 ch) { intsize size = utf8_chsize(ch); if (size <= 0) { return -1; } if (!ensure_capacity(str, str->len + size)) { return -1; } utf8_encode(&str->str[str->len], ch); str->len += size; str->str[str->len] = '\0'; return size; } """ def test_str_hro(self): code = self.str_hro_code tokens = scan(code) parse_tree = parse_header(tokens) ast = build_ast(parse_tree, []) struct_str = StructDeclaration('str', [ VariableDeclaration('str', PointerType(ArrayType(BasicType('uint8'), VariableExpression('size')))), VariableDeclaration('len', BasicType('uintsize')), VariableDeclaration('size', BasicType('uintsize')), ]) pointer_to_struct_str = PointerType(BasicType('struct str')) str_create = FunctionDeclaration('str_create', pointer_to_struct_str, []) str_free = FunctionDeclaration('str_free', BasicType('void'), [ VariableDeclaration('str', pointer_to_struct_str) ]) str_reset = FunctionDeclaration('str_reset', BasicType('void'), [ VariableDeclaration('str', pointer_to_struct_str) ]) str_append_ch = FunctionDeclaration('str_append_ch', BasicType('intsize'), [ VariableDeclaration('str', pointer_to_struct_str), VariableDeclaration('ch', BasicType('uint32')) ]) self.assertEqual(ast, HeaderFile([], [struct_str, str_create, str_free, str_reset, str_append_ch])) def test_unicode_hro(self): code = self.unicode_hro_code tokens = scan(code) parse_tree = parse_header(tokens) ast = build_ast(parse_tree, []) utf8_max_size = VariableDefinition('UTF8_MAX_SIZE', ConstType(BasicType('intsize')), ConstantExpression('4')) utf8_invalid = VariableDefinition('UTF8_INVALID', ConstType(BasicType('uint8')), ConstantExpression('0x80')) utf8_decode = FunctionDeclaration('utf8_decode', BasicType('uint32'), [ VariableDeclaration('str', PointerType(PointerType(ConstType(BasicType('uint8'))))), ]) utf8_encode = FunctionDeclaration('utf8_encode', BasicType('intsize'), [ VariableDeclaration('str', PointerType(BasicType('uint8'))), VariableDeclaration('ch', BasicType('uint32')), ]) utf8_size = FunctionDeclaration('utf8_size', BasicType('intsize'), [ VariableDeclaration('str', PointerType(ConstType(BasicType('uint8')))), ]) utf8_chsize = FunctionDeclaration('utf8_chsize', BasicType('intsize'), [ VariableDeclaration('ch', BasicType('uint32')), ]) utf8_fgetch = FunctionDeclaration('utf8_fgetch', BasicType('uint32'), [ VariableDeclaration('f', PointerType(BasicType('struct FILE'))), ]) utf8_fputch = FunctionDeclaration('utf8_fputch', BasicType('intsize'), [ VariableDeclaration('f', PointerType(BasicType('struct FILE'))), VariableDeclaration('ch', BasicType('uint32')), ]) self.assertEqual(ast, HeaderFile([], [utf8_max_size, utf8_invalid, utf8_decode, utf8_encode, utf8_size, utf8_chsize, utf8_fgetch, utf8_fputch])) def test_string_cro(self): import tempfile code = self.string_cro_code tokens = scan(code) parse_tree = parse_implementation(tokens) with tempfile.TemporaryDirectory() as include_dir: with open(f"{include_dir}/str.hro", 'w', encoding='utf-8') as f: f.write(self.str_hro_code) with open(f"{include_dir}/unicode.hro", 'w', encoding='utf-8') as f: f.write(self.unicode_hro_code) ast = build_ast(parse_tree, [include_dir]) included_str_hro = build_ast(parse_header(scan(self.str_hro_code)), []) included_unicode_hro = build_ast(parse_header(scan(self.unicode_hro_code)), []) expected = ImplementationFile([included_str_hro, included_unicode_hro], [ FunctionDefinition('ensure_capacity', BasicType('bool'), [ VariableDeclaration('str', PointerType(BasicType('struct str'))), VariableDeclaration('len', BasicType('intsize')), ], [ IfStatement(ComparisonExpression( AddExpression(VariableExpression('len'), ConstantExpression('1')), '>=', StructPointerElementExpression(VariableExpression('str'), 'size') ), [ VariableDefinition( 'new', PointerType(ArrayType( BasicType('uint8'), MultiplyExpression( StructPointerElementExpression(VariableExpression('str'), 'size'), ConstantExpression('2') ) )), FunctionCallExpression(VariableExpression('realloc'), [ StructPointerElementExpression(VariableExpression('str'), 'str'), MultiplyExpression( StructPointerElementExpression(VariableExpression('str'), 'size'), ConstantExpression('2') ) ]) ), IfStatement(LogicalNotExpression(VariableExpression('new')), [ ReturnStatement(ConstantExpression('false')) ], None), DirectAssignment( StructPointerElementExpression(VariableExpression('str'), 'str'), VariableExpression('new'), ), UpdateAssignment( StructPointerElementExpression(VariableExpression('str'), 'size'), '*=', ConstantExpression('2'), ) ], None), ReturnStatement(ConstantExpression('true')), ]), FunctionDefinition('str_create', PointerType(BasicType('struct str')), [], [ VariableDefinition( 'str', PointerType(BasicType('struct str')), FunctionCallExpression( VariableExpression('calloc'), [ ConstantExpression('1'), SizeofExpression(BasicType('struct str')), ] ) ), DirectAssignment( StructPointerElementExpression(VariableExpression('str'), 'str'), FunctionCallExpression(VariableExpression('malloc'), [ConstantExpression('16')]), ), DirectAssignment( StructPointerElementExpression(VariableExpression('str'), 'size'), ConstantExpression('16'), ), DirectAssignment( StructPointerElementExpression(VariableExpression('str'), 'len'), ConstantExpression('0'), ), DirectAssignment( ArrayIndexExpression( StructPointerElementExpression(VariableExpression('str'), 'str'), ConstantExpression('0'), ), ConstantExpression(r"'\0'"), ), ReturnStatement(VariableExpression('str')), ]), FunctionDefinition('str_free', BasicType('void'), [ VariableDeclaration('str', PointerType(BasicType('struct str'))) ], [ IfStatement(LogicalNotExpression(VariableExpression('str')), [ ReturnStatement(None) ], None), ExpressionStatement(FunctionCallExpression( VariableExpression('free'), [StructPointerElementExpression(VariableExpression('str'), 'str')] )), ExpressionStatement(FunctionCallExpression( VariableExpression('free'), [VariableExpression('str')] )) ]), FunctionDefinition('str_append_ch', BasicType('intsize'), [ VariableDeclaration('str', PointerType(BasicType('struct str'))), VariableDeclaration('ch', BasicType('uint32')), ], [ VariableDefinition('size', BasicType('intsize'), FunctionCallExpression( VariableExpression('utf8_chsize'), [VariableExpression('ch')] )), IfStatement(ComparisonExpression(VariableExpression('size'), '<=', ConstantExpression('0')), [ ReturnStatement(NegativeExpression(ConstantExpression('1'))) ], None), IfStatement(LogicalNotExpression(FunctionCallExpression( VariableExpression('ensure_capacity'), [VariableExpression('str'), AddExpression( StructPointerElementExpression( VariableExpression('str'), 'len' ), VariableExpression('size'), )] )), [ ReturnStatement(NegativeExpression(ConstantExpression('1'))) ], None), ExpressionStatement(FunctionCallExpression(VariableExpression('utf8_encode'), [ AddressOfExpression(ArrayIndexExpression( StructPointerElementExpression(VariableExpression('str'), 'str'), StructPointerElementExpression(VariableExpression('str'), 'len'), )), VariableExpression('ch'), ])), UpdateAssignment( StructPointerElementExpression(VariableExpression('str'), 'len'), '+=', VariableExpression('size') ), DirectAssignment( ArrayIndexExpression( StructPointerElementExpression(VariableExpression('str'), 'str'), StructPointerElementExpression(VariableExpression('str'), 'len'), ), ConstantExpression(r"'\0'"), ), ReturnStatement(VariableExpression('size')) ]) ]) self.assertDictEqual(dataclasses.asdict(ast), dataclasses.asdict(expected)) self.assertEqual(ast, expected) if __name__ == '__main__': unittest.main()