git @ Cat's Eye Technologies Cleandown / master src / marko / helpers.py
master

Tree @master (Download .tar.gz)

helpers.py @masterraw · history · blame

# Copyright (c) 2019 Frost Ming
#
# SPDX-License-Identifier: LicenseRef-MIT-X-Marko

"""
Helper functions and data structures
"""

from __future__ import annotations

import dataclasses
import re
from functools import partial
from importlib import import_module
from typing import TYPE_CHECKING, overload

from marko.renderer import Renderer

if TYPE_CHECKING:
    from typing import Any, Callable, Container, Iterable, TypeVar

    from .element import Element

    RendererFunc = Callable[[Any, Element], Any]
    TRenderer = TypeVar("TRenderer", bound=RendererFunc)
    D = TypeVar("D", bound="_RendererDispatcher")


def camel_to_snake_case(name: str) -> str:
    """Takes a camelCased string and converts to snake_case."""
    pattern = r"[A-Z][a-z]+|[A-Z]+(?![a-z])"
    return "_".join(map(str.lower, re.findall(pattern, name)))


def is_paired(text: Iterable[str], open: str = "(", close: str = ")") -> bool:
    """Check if the text only contains:
    1. blackslash escaped parentheses, or
    2. parentheses paired.
    """
    count = 0
    escape = False
    for c in text:
        if escape:
            escape = False
        elif c == "\\":
            escape = True
        elif c == open:
            count += 1
        elif c == close:
            if count == 0:
                return False
            count -= 1
    return count == 0


def find_next(
    text: str,
    target: Container[str],
    start: int = 0,
    end: int | None = None,
    disallowed: Container[str] = (),
) -> int:
    """Find the next occurrence of target in text, and return the index
    Characters are escaped by backslash.
    Optional disallowed characters can be specified, if found, the search
    will fail with -2 returned. Otherwise, -1 is returned if not found.
    """
    if end is None:
        end = len(text)
    i = start
    escaped = False
    while i < end:
        c = text[i]
        if escaped:
            escaped = False
        elif c in target:
            return i
        elif c in disallowed:
            return -2
        elif c == "\\":
            escaped = True
        i += 1
    return -1


def partition_by_spaces(text: str, spaces: str = " \t") -> tuple[str, str, str]:
    """Split the given text by spaces or tabs, and return a tuple of
    (start, delimiter, remaining). If spaces are not found, the latter
    two elements will be empty.
    """
    start = end = -1
    for i, c in enumerate(text):
        if c in spaces:
            if start >= 0:
                continue
            start = i
        elif start >= 0:
            end = i
            break
    if start < 0:
        return text, "", ""
    if end < 0:
        return text[:start], text[start:], ""
    return text[:start], text[start:end], text[end:]


@dataclasses.dataclass(frozen=True)
class MarkoExtension:
    parser_mixins: list[type] = dataclasses.field(default_factory=list)
    renderer_mixins: list[type] = dataclasses.field(default_factory=list)
    elements: list[type[Element]] = dataclasses.field(default_factory=list)


def load_extension(name: str, **kwargs: Any) -> MarkoExtension:
    """Load extension object from a string.
    First try `marko.ext.<name>` if possible
    """
    module = None
    if "." not in name:
        try:
            module = import_module(f"marko.ext.{name}")
        except ImportError:
            pass
    if module is None:
        try:
            module = import_module(name)
        except ImportError as e:
            raise ImportError(f"Extension {name} cannot be imported") from e

    try:
        return module.make_extension(**kwargs)
    except AttributeError:
        raise AttributeError(
            f"Module {name} does not have 'make_extension' attributte."
        ) from None


class _RendererDispatcher:
    name: str

    def __init__(
        self, types: type[Renderer] | tuple[type[Renderer], ...], func: RendererFunc
    ) -> None:
        from marko.ast_renderer import ASTRenderer, XMLRenderer

        self._mapping = {types: func}
        self._mapping.setdefault((ASTRenderer, XMLRenderer), self.render_ast)

    def dispatch(
        self: D, types: type[Renderer] | tuple[type[Renderer], ...]
    ) -> Callable[[RendererFunc], D]:
        def decorator(func: RendererFunc) -> D:
            self._mapping[types] = func
            return self

        return decorator

    def __set_name__(self, owner: type, name: str) -> None:
        self.name = name

    @staticmethod
    def render_ast(self, element: Element) -> Any:
        return self.render_children(element)

    def super_render(self, r: Any, element: Element) -> Any:
        try:
            return getattr(super(type(r), r), self.name)(element)
        except AttributeError:
            raise NotImplementedError(f"Unsupported renderer {type(r)}") from None

    @overload
    def __get__(self: D, obj: None, owner: type) -> D: ...

    @overload
    def __get__(self: D, obj: Renderer, owner: type) -> RendererFunc: ...

    def __get__(self: D, obj: Renderer | None, owner: type) -> RendererFunc | D:
        if obj is None:
            return self
        for types, func in self._mapping.items():
            if isinstance(obj, types):
                return partial(func, obj)
        return partial(self.super_render, obj)


def render_dispatch(
    types: type[Renderer] | tuple[type[Renderer], ...]
) -> Callable[[RendererFunc], _RendererDispatcher]:
    def decorator(func: RendererFunc) -> _RendererDispatcher:
        return _RendererDispatcher(types, func)

    return decorator