# Copyright (c) 2022-2024 Chris Pressey, Cat's Eye Technologies
# This file is distributed under a 2-clause BSD license. See LICENSES directory:
# SPDX-License-Identifier: LicenseRef-BSD-2-Clause-X-Eqthy
from copy import copy
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
class Term:
pass
@dataclass(frozen=True)
class Ctor(Term):
ctor: str
subterms: List[Term]
@dataclass(frozen=True)
class Variable(Term):
name: str
@dataclass(frozen=True)
class Eqn:
lhs: Term
rhs: Term
@dataclass(frozen=True)
class RewriteRule:
pattern: Term
substitution: Term
@dataclass(frozen=True)
class Unifier:
success: bool
bindings: Dict[str, Term]
TermIndex = List[int]
unify_fail = Unifier(success=False, bindings={})
def render(t: Union[Term, Eqn, RewriteRule, Unifier, str]) -> str:
if isinstance(t, Ctor):
if t.subterms:
return "{}({})".format(t.ctor, ', '.join([render(st) for st in t.subterms]))
else:
return t.ctor
elif isinstance(t, Variable):
return t.name
elif isinstance(t, Eqn):
return "{} = {}".format(render(t.lhs), render(t.rhs))
elif isinstance(t, RewriteRule):
return "{} => {}".format(render(t.pattern), render(t.substitution))
elif isinstance(t, Unifier):
if not t.success:
return '#F'
else:
return str(dict([(k, render(v)) for k, v in t.bindings.items()]))
else:
return str(t)
def merge_unifiers(first: Unifier, next: Unifier) -> Unifier:
if not first.success or not next.success:
return unify_fail
bindings = copy(first.bindings)
for key, value in next.bindings.items():
if key in bindings and bindings[key] != value:
return unify_fail
bindings[key] = value
return Unifier(success=True, bindings=bindings)
def match(pattern: Term, term: Term) -> Unifier:
if isinstance(pattern, Variable):
return Unifier(success=True, bindings={
pattern.name: term
})
else:
assert isinstance(pattern, Ctor)
if not isinstance(term, Ctor) or term.ctor != pattern.ctor or len(term.subterms) != len(pattern.subterms):
return unify_fail
unifier = Unifier(success=True, bindings={})
for (subpattern, subterm) in zip(pattern.subterms, term.subterms):
subunifier = match(subpattern, subterm)
unifier = merge_unifiers(unifier, subunifier)
return unifier
def all_matches(pattern: Term, term: Term, index: Union[TermIndex, None] = None) -> List[Tuple[TermIndex, Unifier]]:
if index is None:
index = []
matches = []
unifier = match(pattern, term)
if unifier.success:
matches.append((index, unifier))
if isinstance(term, Ctor):
for n, subterm in enumerate(term.subterms):
matches += all_matches(pattern, subterm, index + [n])
return matches
def expand(term: Term, unifier: Unifier) -> Term:
if not unifier.success:
return term
elif isinstance(term, Variable):
if term.name in unifier.bindings:
return unifier.bindings[term.name]
else:
return term
elif isinstance(term, Ctor):
return Ctor(term.ctor, [expand(st, unifier) for st in term.subterms])
else:
raise NotImplementedError(str(term))
def subterm_at_index(term: Term, index: TermIndex) -> Term:
if not index:
return term
elif isinstance(term, Ctor):
position = index[0]
return subterm_at_index(term.subterms[position], index[1:])
else:
raise KeyError('{} at {}'.format(str(term), index))
def update_at_index(term: Term, subterm: Term, index: TermIndex) -> Term:
if not index:
return subterm
elif isinstance(term, Ctor):
position = index[0]
replaced_subterm = update_at_index(term.subterms[position], subterm, index[1:])
new_subterms = copy(term.subterms)
new_subterms[position] = replaced_subterm
return Ctor(term.ctor, new_subterms)
else:
raise KeyError('{} at {}'.format(str(term), index))
def replace(term: Term, target: Term, replacement: Term) -> Term:
if term == target:
return replacement
elif isinstance(term, Ctor):
return Ctor(term.ctor, [replace(st, target, replacement) for st in term.subterms])
else:
return term
def all_rewrites(pattern: Term, substitution: Term, term: Term) -> List[Term]:
"""Given a rule (a pattern and a substitution) and a term, return
a list of the terms that would result from rewriting the term
in all the possible ways by the rule."""
# First, obtain all the unifiers where the pattern of the rule matches any subterm of the term
matches = all_matches(pattern, term)
# Now, collect all the rewritten terms -- a subterm replaced by the expanded rhs of the rule
rewrites = []
for (index, unifier) in matches:
rewritten_subterm = expand(substitution, unifier)
result = update_at_index(term, rewritten_subterm, index)
rewrites.append(result)
return rewrites
def apply_substs_to_rule(rule: RewriteRule, substs: List[Eqn]) -> RewriteRule:
if not substs:
return rule
bindings = {}
for subst in substs:
assert isinstance(subst.lhs, Variable)
bindings[subst.lhs.name] = subst.rhs
unifier = Unifier(success=True, bindings=bindings)
return RewriteRule(
pattern=expand(rule.pattern, unifier),
substitution=expand(rule.substitution, unifier),
)