# Copyright (c) 2025-2026, Chris Pressey, Cat's Eye Technologies.
# This file is distributed under a 2-clause BSD license. See LICENSES/ dir.
# SPDX-License-Identifier: LicenseRef-BSD-2-Clause-X-Lome
import logging
from typing import Dict
from .ast import Rule
from .term import Ctor, Term, is_variable, is_decoration, render_term
from .pattern import match_term, substitute_term
logger = logging.getLogger(__name__)
class ReductionError(ValueError):
pass
class ProofSystem:
def __init__(self) -> None:
self.rules: Dict[str, Rule] = {}
def evaluate(self, decoration: Term) -> Term:
assert is_decoration(decoration)
assert isinstance(decoration, Ctor)
def_name = decoration.symbol[1:] # Remove the '*' prefix
if def_name in self.rules:
return self.apply_rule(decoration, self.rules[def_name])
else:
raise ReductionError(f"Unknown transformer: {def_name}")
def apply_rule(self, decoration: Ctor, rule: Rule) -> Term:
logger.debug(f"applying rule: {render_term(rule.lhs)} => {render_term(rule.rhs)}")
logger.debug(f"applying to decoration: {render_term(decoration)}")
if not decoration.subterms:
raise ReductionError(f"Decoration {decoration.symbol} has no subterms")
target_term = decoration.subterms[0]
# decoration_args = decoration.subterms[1:]
pattern = rule.lhs.subterms[0]
match_result = match_term(pattern, target_term)
if not match_result.success:
raise ReductionError(
f"Cannot apply {decoration.symbol[1:]} to "
f"{render_term(target_term)}: pattern "
f"{render_term(rule.lhs)} does not match it"
)
result = substitute_term(rule.rhs, match_result.bindings)
logger.debug(f"result of applying: {render_term(result)}")
if len(rule.lhs.subterms) == 2:
parameter = rule.lhs.subterms[1]
assert isinstance(parameter, Ctor)
if not is_variable(parameter):
raise ReductionError(
f"Proof transformer {render_term(rule.lhs)} "
"has a parameter that is not a variable"
)
if len(decoration.subterms) != 2:
raise ReductionError(
f"Proof transformer {render_term(rule.lhs)} "
"has a parameter but no parameter given in "
f"term decoration {render_term(decoration)}"
)
result = substitute_term(result, {
parameter.symbol: decoration.subterms[1]
})
logger.debug(f"result of substituting parameter: {render_term(result)}")
return result
def add_rule(self, rule: Rule) -> None:
"""
Add a user-defined transformer to the proof system.
The name is extracted from the LHS symbol.
"""
name = rule.lhs.symbol
self.rules[name] = rule
def reduce_term(term: Term, proof_system: ProofSystem) -> Term:
if not isinstance(term, Ctor):
return term
if is_decoration(term):
return proof_system.evaluate(term)
elif term.subterms:
return Ctor(term.symbol, [reduce_term(subterm, proof_system) for subterm in term.subterms])
else:
return term