from typing import Callable, List, Optional, Tuple, TypeVar, Sequence __all__ = [ 'ParseResult', 'Parser', 'as_predicate', 'alt', 'tag', 'itag', 'take_till1', 'take_while0', 'take_while1', 'take_n', 'any_char', 'all_consuming', 'map_parser', 'and_then', 'opt', 'verify', 'many0', 'many1', 'many_m_n', 'delimited', 'pair', 'triple', 'preceded', 'followed', 'separated_pair', 'separated_triple', 'separated_many0', 'separated_many1', 'string_concat', ] T = TypeVar('T') T1 = TypeVar('T1') T2 = TypeVar('T2') T3 = TypeVar('T3') ParseResult = Optional[Tuple[T, bytes]] Parser = Callable[[bytes], ParseResult[T]] def as_predicate(parser: Parser[T]) -> Callable[[int], bool]: def check(text: int) -> bool: return parser(bytes([text])) is not None return check def alt(*parsers: Parser[T]) -> Parser[T]: def parse(text: bytes) -> ParseResult[T]: for parser in parsers[:-1]: result = parser(text) if result is not None: return result return parsers[-1](text) return parse def tag(tag_text: bytes) -> Parser[bytes]: def parse(text: bytes) -> ParseResult[bytes]: if text.startswith(tag_text): return tag_text, text[len(tag_text):] return None return parse # case-insensitive tag def itag(tag_text: bytes) -> Parser[bytes]: def parse(text: bytes) -> ParseResult[bytes]: tag_str = tag_text.decode() text_str = text.decode() if text_str.casefold().startswith(tag_str.casefold()): return tag_text, text[len(tag_text):] return None return parse def take_while0(predicate: Callable[[int], bool]) -> Parser[bytes]: def parse(text: bytes) -> ParseResult[bytes]: for i in range(len(text)): if not predicate(text[i]): return text[:i], text[i:] return text, b"" return parse def take_while1(predicate: Callable[[int], bool]) -> Parser[bytes]: def parse(text: bytes) -> ParseResult[bytes]: if len(text) == 0 or not predicate(text[0]): return None for i in range(1, len(text)): if not predicate(text[i]): return text[:i], text[i:] return text, b"" return parse def take_till1(predicate: Callable[[int], bool]) -> Parser[bytes]: return take_while1(lambda x: not predicate(x)) def take_n(n: int) -> Parser[bytes]: def parse(text: bytes) -> ParseResult[bytes]: if len(text) < n: return None return text[:n], text[n:] return parse def any_char(text: bytes) -> ParseResult[bytes]: if len(text) > 0: return text[0:1], text[1:] return None def all_consuming(parser: Parser[T], *, debug=False) -> Parser[T]: def parse(text: bytes) -> ParseResult[T]: parsed_result = parser(text) if parsed_result is None: if debug: print('all_consuming: parser failed') return None result, extra = parsed_result if len(extra) > 0: if debug: print('all_consuming: leftover text {}', repr(extra)) return None return result, b'' return parse def map_parser(parser: Parser[T1], mapper: Callable[[T1], T2]) -> Parser[T2]: def parse(text: bytes) -> ParseResult[T2]: parsed_result = parser(text) if parsed_result is None: return None result, extra = parsed_result return mapper(result), extra return parse def and_then(first_parser: Parser[T1], get_second_parser: Callable[[T1], Parser[T2]]) -> Parser[T2]: def parse(text: bytes) -> ParseResult[T2]: parsed_result = first_parser(text) if parsed_result is None: return None result, _ = parsed_result return get_second_parser(result)(text) return parse def opt(parser: Parser[T]) -> Parser[Optional[T]]: def parse(text: bytes) -> ParseResult[Optional[T]]: result = parser(text) if result is None: return None, text return result return parse def verify(parser: Parser[T], predicate: Callable[[T], bool]) -> Parser[T]: def parse(text: bytes) -> ParseResult[T]: parsed_result = parser(text) if parsed_result is None: return None result, extra = parsed_result if predicate(result): return result, extra return None return parse def many0(parser: Parser[T]) -> Parser[List[T]]: def parse(text: bytes) -> ParseResult[List[T]]: result = [] parser_result = parser(text) while parser_result is not None: this_result, text = parser_result result.append(this_result) parser_result = parser(text) return result, text return parse def many1(parser: Parser[T]) -> Parser[List[T]]: def parse(text: bytes) -> ParseResult[List[T]]: parser_result = parser(text) if parser_result is None: return None this_result, extra = parser_result result = [this_result] parser_result = parser(extra) while parser_result is not None: this_result, extra = parser_result result.append(this_result) parser_result = parser(extra) return result, extra return parse def many_m_n(parser: Parser[T], min_inclusive: int, max_inclusive: int) -> Parser[List[T]]: def parse(text: bytes) -> ParseResult[List[T]]: result: List[T] = [] while len(result) < min_inclusive: parser_result = parser(text) if parser_result is None: return None this_result, text = parser_result result.append(this_result) while len(result) < max_inclusive: parser_result = parser(text) if parser_result is None: break this_result, text = parser_result result.append(this_result) return result, text return parse def separated_many0(parser: Parser[T], separator_parser: Parser) -> Parser[List[T]]: def parse(text: bytes) -> ParseResult[List[T]]: result = [] while True: parser_result = parser(text) if parser_result is None: break this_result, text = parser_result result.append(this_result) separator_result = separator_parser(text) if separator_result is None: break _, text = separator_result return result, text return parse def separated_many1(parser: Parser[T], separator_parser: Parser) -> Parser[List[T]]: return verify(separated_many0(parser, separator_parser), lambda result: len(result) > 0) def delimited(before_parser: Parser[T1], parser: Parser[T], after_parser: Parser[T2]) -> Parser[T]: def parse(text: bytes) -> ParseResult[T]: before_result = before_parser(text) if before_result is None: return None _, extra = before_result parsed_result = parser(extra) if parsed_result is None: return None result, extra = parsed_result after_result = after_parser(extra) if after_result is None: return None _, extra = after_result return result, extra return parse def pair(first_parser: Parser[T1], second_parser: Parser[T2]) -> Parser[Tuple[T1, T2]]: def parse(text: bytes) -> ParseResult[Tuple[T1, T2]]: first_parsed_result = first_parser(text) if first_parsed_result is None: return None first_result, extra = first_parsed_result second_parsed_result = second_parser(extra) if second_parsed_result is None: return None second_result, extra = second_parsed_result return (first_result, second_result), extra return parse def triple(first_parser: Parser[T1], second_parser: Parser[T2], third_parser: Parser[T3]) -> Parser[Tuple[T1, T2, T3]]: def parse(text: bytes) -> ParseResult[Tuple[T1, T2, T3]]: first_parsed_result = first_parser(text) if first_parsed_result is None: return None first_result, extra = first_parsed_result second_parsed_result = second_parser(extra) if second_parsed_result is None: return None second_result, extra = second_parsed_result third_parsed_result = third_parser(extra) if third_parsed_result is None: return None third_result, extra = third_parsed_result return (first_result, second_result, third_result), extra return parse def preceded(before_parser: Parser[T1], parser: Parser[T]) -> Parser[T]: def second(x: Tuple[T1, T]) -> T: return x[1] return map_parser(pair(before_parser, parser), second) def followed(parser: Parser[T], after_parser: Parser[T1]) -> Parser[T]: def first(x: Tuple[T, T1]) -> T: return x[0] return map_parser(pair(parser, after_parser), first) def separated_pair(first_parser: Parser[T1], between_parser: Parser[T], second_parser: Parser[T2]) -> Parser[Tuple[T1, T2]]: def parse(text: bytes) -> ParseResult[Tuple[T1, T2]]: first_parsed_result = first_parser(text) if first_parsed_result is None: return None first_result, extra = first_parsed_result between_result = between_parser(extra) if between_result is None: return None _, extra = between_result second_parsed_result = second_parser(extra) if second_parsed_result is None: return None second_result, extra = second_parsed_result return (first_result, second_result), extra return parse def separated_triple(first_parser: Parser[T1], between12_parser: Parser, second_parser: Parser[T2], between23_parser: Parser, third_parser: Parser[T3]) -> Parser[Tuple[T1, T2, T3]]: def parse(text: bytes) -> ParseResult[Tuple[T1, T2, T3]]: first_parsed_result = first_parser(text) if first_parsed_result is None: return None first_result, extra = first_parsed_result between_result = between12_parser(extra) if between_result is None: return None _, extra = between_result second_parsed_result = second_parser(extra) if second_parsed_result is None: return None second_result, extra = second_parsed_result between_result = between23_parser(extra) if between_result is None: return None _, extra = between_result third_parsed_result = third_parser(extra) if third_parsed_result is None: return None third_result, extra = third_parsed_result return (first_result, second_result, third_result), extra return parse def string_concat(parser: Parser[Sequence[bytes]]) -> Parser[bytes]: return map_parser(parser, b''.join)