git @ Cat's Eye Technologies Tamsin / master src / tamsin / backends / c.py
master

Tree @master (Download .tar.gz)

c.py @masterraw · history · blame

# encoding: UTF-8

# Copyright (c)2014 Chris Pressey, Cat's Eye Technologies.
# Distributed under a BSD-style license; see LICENSE for more information.

"""
C-generating backend for tamsin.py.  Generated program must be linked
with -ltamsin.

"""

from tamsin.codenode import (
    CodeNode, Program, Prototype, Subroutine,
    Block, If, While, And, Not, Return, Builtin, Call, Truth, Falsity,
    DeclareLocal, GetVar, SetVar, Concat, VariableRef,
    Unifier, PatternMatch, NoMatch, GetMatchedVar,
    DeclState, SaveState, RestoreState,
    MkAtom, MkConstructor,
)
from tamsin.ast import PatternVariableNode
from tamsin.term import Atom, Constructor, Variable
import tamsin.sysmod


PRELUDE = r'''
/*
 * Generated code!  Edit at your own risk!
 * Must be linked with -ltamsin to build.
 */
#include <assert.h>
#include <tamsin.h>

/* global scanner */

struct scanner * scanner;

/* global state: result of last action */

int ok;
const struct term *result;

'''

POSTLUDE = r'''
const struct term *bufterm = NULL;

int read_file(FILE *input) {
    char *buffer = malloc(8193);

    assert(input != NULL);

    while (!feof(input)) {
        int num_read = fread(buffer, 1, 8192, input);
        if (bufterm == NULL) {
            bufterm = term_new_atom(buffer, num_read);
        } else {
            bufterm = term_concat(bufterm, term_new_atom(buffer, num_read));
        }
    }

    free(buffer);
}

int main(int argc, char **argv) {

    if (argc == 1) {
        read_file(stdin);
    } else {
        int i;

        for (i = 1; i < argc; i++) {
            FILE *input = fopen(argv[i], "r");
            read_file(input);
            fclose(input);
        }
    }

    scanner = scanner_new(bufterm->atom, bufterm->size);
    ok = 0;
    result = term_new_atom_from_cstring("nil");

    prod_main_main();

#ifdef HITS_AND_MISSES
    fprintf(stderr, "hits: %d, misses: %d\n", hits, misses);
#endif

    if (ok) {
        term_fput(result, stdout);
        fwrite("\n", 1, 1, stdout);
        exit(0);
    } else {
        term_fput(result, stderr);
        fwrite("\n", 1, 1, stderr);
        exit(1);
    }
}
'''

class Emitter(object):
    def __init__(self, codenode, outfile):
        self.codenode = codenode
        self.outfile = outfile
        self.indent_ = 0
        self.current_prod = None
        self.current_branch = None
        self.currmod = None
        self.name_index = 0

    def new_name(self):
        name = "temp%s" % self.name_index
        self.name_index += 1
        return name

    def indent(self):
        self.indent_ += 1

    def outdent(self):
        self.indent_ -= 1

    def emit(self, *args):
        s = "    " * self.indent_ + ''.join(args)
        self.outfile.write(s)

    def emitln(self, *args):
        s = "    " * self.indent_ + ''.join(args) + "\n"
        self.outfile.write(s)

    # kontinue the line
    def emitk(self, *args):
        self.outfile.write(''.join(args))

    def emitkln(self, *args):
        self.outfile.write(''.join(args) + "\n")

    def go(self):
        self.outfile.write(PRELUDE)
        self.traverse(self.codenode)
        self.outfile.write(POSTLUDE)

    def traverse(self, codenode):
        if isinstance(codenode, Program):
            for arg in codenode.args:
                self.traverse(arg)
        elif isinstance(codenode, Prototype):
            self.emitln("void prod_%s_%s(%s);" % (
                codenode['module'].name, codenode['prod'].name,
                ', '.join(["const struct term *"
                           for f in codenode['formals']])
            ))
        elif isinstance(codenode, Subroutine):
            fmls = []
            for (i, f) in enumerate(codenode.formals):
                fmls.append("const struct term *i%s" % i)
            fmls = ', '.join(fmls)

            self.emitln("void prod_%s_%s(%s) {" %
                (codenode.module.name, codenode.prod.name, fmls)
            )
            self.indent()
            for children in codenode.children:
                self.traverse(children)  
            self.outdent()
            self.emitln("}")
        elif isinstance(codenode, Unifier):
            self.emitln("/* %r */" % codenode)
        elif isinstance(codenode, If):
            self.emit("if (")
            self.traverse(codenode[0])
            self.emitkln(") {")
            self.indent()
            self.traverse(codenode[1])
            self.outdent()
            if len(codenode.args) == 3:
                self.emitln("} else {")
                self.indent()
                self.traverse(codenode[2])
                self.outdent()
            self.emitln("}")
        elif isinstance(codenode, While):
            self.emit("while (")
            self.traverse(codenode[0])
            self.emitkln(") {")
            self.indent()
            self.traverse(codenode[1])
            self.outdent()
            self.emitln("}")
        elif isinstance(codenode, Not):
            self.emitk("(!(")
            self.traverse(codenode[0])
            self.emitk("))")
        elif isinstance(codenode, And):
            self.emitk("(")
            self.traverse(codenode[0])
            self.emitk(" && ")
            self.traverse(codenode[1])
            self.emitk(")")
        elif isinstance(codenode, PatternMatch):
            self.emitk("PATTERNMATCH")
        elif isinstance(codenode, Truth):
            self.emitk("1")
        elif isinstance(codenode, Falsity):
            self.emitk("0")
        elif isinstance(codenode, Block):
            for arg in codenode.args:
                self.traverse(arg)
        elif isinstance(codenode, DeclareLocal):
            self.emit("const struct term *%s" % codenode[0])
            if len(codenode.args) == 2:
                self.emitk(' = ');
                self.traverse(codenode[1])
            self.emitkln(';')
        elif isinstance(codenode, Call):
            self.emitln("prod_%s_%s(%s);" %
                (codenode['module'], codenode['name'], '')
            )
        elif isinstance(codenode, GetVar):
            self.emitk(codenode.name)
        elif isinstance(codenode, SetVar):
            self.emitln("/* %r */" % codenode)
            self.emit('')
            self.traverse(codenode.ref)
            self.emitk(' = ')
            self.traverse(codenode.expr)
            self.emitkln(';')

            #name = self.compile_r(ast.texpr)
            #lname = self.emit_lvalue(ast.variable)
            #self.emit("%s = %s;" % (lname, name))
            self.emit("result = ")
            self.traverse(codenode.ref)
            self.emitkln(";")
            self.emitln("ok = 1;")

        elif isinstance(codenode, Concat):
            self.emit('const struct term *%s = term_concat(term_flatten(', codenode.name)
            self.traverse(codenode.lhs)
            self.emitk('), term_flatten(')
            self.traverse(codenode.rhs)
            self.emitkln('));')
        elif isinstance(codenode, Builtin):
            if codenode['name'] == 'print':
                self.emit("result = ")
                self.traverse(codenode[0])
                self.emitkln(';')
                self.emitln("term_fput(result, stdout);")
                self.emitln(r'fwrite("\n", 1, 1, stdout);')
                self.emitln("ok = 1;")
            elif codenode['name'] == 'return':
                self.emit("result = ")
                self.traverse(codenode[0])
                self.emitkln(';')
            elif codenode['name'] == 'expect':
                self.emit('tamsin_expect(scanner, ')
                self.traverse(codenode[0])
                self.emitkln(');')
            elif codenode['name'] == 'any':
                self.emitln('tamsin_any(scanner);')
            else:
                raise NotImplementedError(repr(codenode))
        elif isinstance(codenode, MkAtom):
            self.emitk('term_new_atom_from_cstring("%s")' % codenode[0])
        elif isinstance(codenode, VariableRef):
            self.emitk(codenode[0])
        elif isinstance(codenode, MkConstructor):
            #self.emitk(codenode.text)  # FIXME
            termlist_name = self.new_name()
            self.emitln('struct termlist *%s = NULL;' % termlist_name);
            #for c in reversed(codenode.contents):
            #    subname = self.compile_r(c)
            #    self.emit('termlist_add_term(&%s, %s);' % (termlist_name, subname))
            name = self.new_name()
            self.emitln('const struct term *%s = term_new_constructor("%s", %s, %s);' %
                (name, escaped(codenode.text), len(codenode.text), termlist_name)
            )
        elif isinstance(codenode, Return):
            self.emitln("return;")
        elif isinstance(codenode, NoMatch):
            self.emitln('result = term_new_atom_from_cstring'
                      '("No \'%s\' production matched arguments ");' %
                      codenode['prod'].name)
            for i in xrange(0, len(codenode['formals'])):
                self.emitln('result = term_concat(result, term_flatten(i%d));' % i)
                self.emitln('result = term_concat(result, term_new_atom_from_cstring(", "));')
            self.emitln("ok = 0;")
        elif isinstance(codenode, DeclState):
            for local in []:  # self.current_branch.locals_:
                self.emitln("const struct term *save_%s;" % local)
            self.emitln("int position;")
            self.emitln("int reset_position;")
            self.emitln("const char *buffer;")
            self.emitln("int buffer_size;")
            self.emitln("")
        elif isinstance(codenode, SaveState):
            for local in []:  # self.current_branch.locals_:
                self.emitln("save_%s = %s;" % (local, local))
            self.emitln("position = scanner->position;")
            self.emitln("reset_position = scanner->reset_position;")
            self.emitln("buffer = scanner->buffer;")
            self.emitln("buffer_size = scanner->size;")
        elif isinstance(codenode, RestoreState):
            self.emitln("scanner->position = position;")
            self.emitln("scanner->reset_position = reset_position;")
            self.emitln("scanner->buffer = buffer;")
            self.emitln("scanner->size = buffer_size;")
            for local in []:  # self.current_branch.locals_:
                self.emitln("%s = save_%s;" % (local, local))
        elif isinstance(codenode, GetMatchedVar):
            ref = codenode[0]
            assert isinstance(ref, PatternVariableNode)
            self.emitln(
                "const struct term *%s = unifier[%s];" % (ref.name, ref.index)
            )
        else:
            raise NotImplementedError(repr(codenode))


def escaped(s):
    a = ''
    i = 0
    while i < len(s):
        c = s[i]
        # gcc appears to have some issues with \xXX... perhaps it
        # consumes greater or fewer than two hex digits...?
        if ord(c) < 32 or ord(c) > 126:
            a += "\\%03o" % ord(c)
        elif c == "\\":
            a += r"\\"
        elif c == '"':
            a += r'\"'
        else:
            a += c
        i += 1
    return a