git @ Cat's Eye Technologies Eqthy / master src / eqthy / terms.py
master

Tree @master (Download .tar.gz)

terms.py @masterraw · history · blame

# Copyright (c) 2022-2024 Chris Pressey, Cat's Eye Technologies
# This file is distributed under a 2-clause BSD license.  See LICENSES directory:
# SPDX-License-Identifier: LicenseRef-BSD-2-Clause-X-Eqthy

from copy import copy
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union


class Term:
    pass


@dataclass(frozen=True)
class Ctor(Term):
    ctor: str
    subterms: List[Term]


@dataclass(frozen=True)
class Variable(Term):
    name: str


@dataclass(frozen=True)
class Eqn:
    lhs: Term
    rhs: Term


@dataclass(frozen=True)
class RewriteRule:
    pattern: Term
    substitution: Term


@dataclass(frozen=True)
class Unifier:
    success: bool
    bindings: Dict[str, Term]


TermIndex = List[int]


unify_fail = Unifier(success=False, bindings={})


def render(t: Union[Term, Eqn, RewriteRule, Unifier, str]) -> str:
    if isinstance(t, Ctor):
        if t.subterms:
            return "{}({})".format(t.ctor, ', '.join([render(st) for st in t.subterms]))
        else:
            return t.ctor
    elif isinstance(t, Variable):
        return t.name
    elif isinstance(t, Eqn):
        return "{} = {}".format(render(t.lhs), render(t.rhs))
    elif isinstance(t, RewriteRule):
        return "{} => {}".format(render(t.pattern), render(t.substitution))
    elif isinstance(t, Unifier):
        if not t.success:
            return '#F'
        else:
            return str(dict([(k, render(v)) for k, v in t.bindings.items()]))
    else:
        return str(t)


def merge_unifiers(first: Unifier, next: Unifier) -> Unifier:
    if not first.success or not next.success:
        return unify_fail
    bindings = copy(first.bindings)
    for key, value in next.bindings.items():
        if key in bindings and bindings[key] != value:
            return unify_fail
        bindings[key] = value
    return Unifier(success=True, bindings=bindings)


def match(pattern: Term, term: Term) -> Unifier:
    if isinstance(pattern, Variable):
        return Unifier(success=True, bindings={
            pattern.name: term
        })
    else:
        assert isinstance(pattern, Ctor)
        if not isinstance(term, Ctor) or term.ctor != pattern.ctor or len(term.subterms) != len(pattern.subterms):
            return unify_fail
        unifier = Unifier(success=True, bindings={})
        for (subpattern, subterm) in zip(pattern.subterms, term.subterms):
            subunifier = match(subpattern, subterm)
            unifier = merge_unifiers(unifier, subunifier)
        return unifier


def all_matches(pattern: Term, term: Term, index: Union[TermIndex, None] = None) -> List[Tuple[TermIndex, Unifier]]:
    if index is None:
        index = []

    matches = []

    unifier = match(pattern, term)
    if unifier.success:
        matches.append((index, unifier))

    if isinstance(term, Ctor):
        for n, subterm in enumerate(term.subterms):
            matches += all_matches(pattern, subterm, index + [n])

    return matches


def expand(term: Term, unifier: Unifier) -> Term:
    if not unifier.success:
        return term
    elif isinstance(term, Variable):
        if term.name in unifier.bindings:
            return unifier.bindings[term.name]
        else:
            return term
    elif isinstance(term, Ctor):
        return Ctor(term.ctor, [expand(st, unifier) for st in term.subterms])
    else:
        raise NotImplementedError(str(term))


def subterm_at_index(term: Term, index: TermIndex) -> Term:
    if not index:
        return term
    elif isinstance(term, Ctor):
        position = index[0]
        return subterm_at_index(term.subterms[position], index[1:])
    else:
        raise KeyError('{} at {}'.format(str(term), index))


def update_at_index(term: Term, subterm: Term, index: TermIndex) -> Term:
    if not index:
        return subterm
    elif isinstance(term, Ctor):
        position = index[0]
        replaced_subterm = update_at_index(term.subterms[position], subterm, index[1:])
        new_subterms = copy(term.subterms)
        new_subterms[position] = replaced_subterm
        return Ctor(term.ctor, new_subterms)
    else:
        raise KeyError('{} at {}'.format(str(term), index))


def replace(term: Term, target: Term, replacement: Term) -> Term:
    if term == target:
        return replacement
    elif isinstance(term, Ctor):
        return Ctor(term.ctor, [replace(st, target, replacement) for st in term.subterms])
    else:
        return term


def all_rewrites(pattern: Term, substitution: Term, term: Term) -> List[Term]:
    """Given a rule (a pattern and a substitution) and a term, return
    a list of the terms that would result from rewriting the term
    in all the possible ways by the rule."""

    # First, obtain all the unifiers where the pattern of the rule matches any subterm of the term
    matches = all_matches(pattern, term)

    # Now, collect all the rewritten terms -- a subterm replaced by the expanded rhs of the rule
    rewrites = []
    for (index, unifier) in matches:
        rewritten_subterm = expand(substitution, unifier)
        result = update_at_index(term, rewritten_subterm, index)
        rewrites.append(result)

    return rewrites


def apply_substs_to_rule(rule: RewriteRule, substs: List[Eqn]) -> RewriteRule:
    if not substs:
        return rule
    bindings = {}
    for subst in substs:
        assert isinstance(subst.lhs, Variable)
        bindings[subst.lhs.name] = subst.rhs
    unifier = Unifier(success=True, bindings=bindings)
    return RewriteRule(
        pattern=expand(rule.pattern, unifier),
        substitution=expand(rule.substitution, unifier),
    )