git @ Cat's Eye Technologies Tamsin / master src / tamsin / desugarer.py
master

Tree @master (Download .tar.gz)

desugarer.py @masterraw · history · blame

# encoding: UTF-8

# Copyright (c)2014 Chris Pressey, Cat's Eye Technologies.
# Distributed under a BSD-style license; see LICENSE for more information.

from tamsin.ast import (
    Program, Module, Production, ProdBranch,
    And, Or, Not, While, Call, Send, Set,
    Using, On, Concat, Fold, Prodref,
    TermNode, VariableNode, PatternVariableNode, AtomNode, ConstructorNode
)
from tamsin.event import EventProducer


class Desugarer(EventProducer):
    """The Desugarer takes an AST, walks it, and returns a new AST.
    It is responsible for:

    * Desugaring Fold() nodes.
    * Turning the list of Production() nodes into a linked list.
    * Turning VariableNode() nodes into PatternVariableNodes in a pattern.

    """
    def __init__(self, program, listeners=None):
        self.listeners = listeners
        self.program = program
        self.pattern = False
        self.index = 0

    def desugar(self, ast):
        if isinstance(ast, Program):
            return Program(
                [self.desugar(m) for m in ast.modlist]
            )
        elif isinstance(ast, Module):
            prodlist = []
            
            def find_prod_pos(name):
                i = 0
                for prod in prodlist:
                    if prod.name == name:
                        return i
                    i += 1
                return None

            for prod in ast.prodlist:
                prod = self.desugar(prod)
                pos = find_prod_pos(prod.name)
                if pos is None:
                    prodlist.append(prod)
                else:
                    prodlist[pos].branches.extend(prod.branches)
            
            return Module(ast.name, prodlist)
        elif isinstance(ast, Production):
            return Production(ast.name, [self.desugar(x) for x in ast.branches])
        elif isinstance(ast, ProdBranch):
            self.pattern = True
            self.index = 0
            formals = [self.desugar(f) for f in ast.formals]
            self.pattern = False
            return ProdBranch(formals, [], self.desugar(ast.body))
        elif isinstance(ast, Or):
            return Or(self.desugar(ast.lhs), self.desugar(ast.rhs))
        elif isinstance(ast, And):
            return And(self.desugar(ast.lhs), self.desugar(ast.rhs))
        elif isinstance(ast, Using):
            return Using(self.desugar(ast.rule), ast.prodref)
        elif isinstance(ast, On):
            return On(self.desugar(ast.rule), self.desugar(ast.texpr))
        elif isinstance(ast, Call):
            return ast
        elif isinstance(ast, Send):
            self.pattern = True
            pattern = self.desugar(ast.pattern)
            self.pattern = False
            return Send(self.desugar(ast.rule), pattern)
        elif isinstance(ast, Set):
            return Set(ast.variable, self.desugar(ast.texpr))
        elif isinstance(ast, Not):
            return Not(self.desugar(ast.rule))
        elif isinstance(ast, While):
            return While(self.desugar(ast.rule))
        elif isinstance(ast, Concat):
            return Concat(self.desugar(ast.lhs), self.desugar(ast.rhs))
        elif isinstance(ast, AtomNode):
            return ast
        elif isinstance(ast, ConstructorNode):
            return ConstructorNode(ast.text,
                                   [self.desugar(x) for x in ast.contents])
        elif isinstance(ast, VariableNode):
            if self.pattern:
                index = self.index
                self.index += 1
                return PatternVariableNode(ast.name, index)
            return ast
        elif isinstance(ast, Fold):
            under1 = VariableNode('_1')
            under2 = VariableNode('_2')
            set_ = Set(under1, ast.initial)
            send_ = Send(self.desugar(ast.rule), under2)
            acc_ = Set(under1, Concat(under1, under2))
            if ast.tag is not None:
                assert isinstance(ast.tag, AtomNode)
                acc_ = Set(under1,
                           ConstructorNode(ast.tag.text,
                                           [under2, under1]))
            return_ = Call(Prodref('$', 'return'), [under1])
            return And(And(set_, While(And(send_, acc_))), return_)
        else:
            raise NotImplementedError(repr(ast))