git @ Cat's Eye Technologies Castile / master src / castile / backends / stackmac.py
master

Tree @master (Download .tar.gz)

stackmac.py @masterraw · history · blame

from castile.types import Void, Union, Struct

# Compile to some hypothetical stack-based machine.
# Not yet in a good way.

# A big difference between this and the higher-level backends is that
# values of type void have size zero, i.e. nothing is pushed onto the
# stack for them.  (Giving them a nominal value would possibly make
# this easier.)

# In a function like this:
#
# fun(x,y,z) {
#   var a = 0;
#   var b = 0;
#   ...
# }
#
# x is at baseptr - 3
# y is at baseptr - 2
# z is at baseptr - 1
# old baseptr is saved at baseptr + 0
# a is at baseptr + 1
# b is at baseptr + 2

# callee is responsible for popping its locals and the given arguments
# off the stack, and pushing its return value(S) in the space that the
# first argument(S) were occupying

OPS = {
    '+': 'add',
    '-': 'sub',
    '*': 'mul',
    '/': 'div',
    '==': 'eq',
    '!=': 'ne',
    '>': 'gt',
    '>=': 'gte',
    '<': 'lt',
    '<=': 'lte',
}


class Compiler(object):
    def __init__(self, out):
        self.out = out
        self.labels = {}
        self.loop_end = None
        self.fun_lit = None     # string id'ing current fun being gend
        self.fun_argsize = 0    # number of stack slots taken by args
        self.global_pos = 0     # globals at the bottom of the stack
        self.local_pos = 0      # locals after the passed arguments
        self.tags = {}          # numeric tags established so far
        self.tag_count = 0      # next tag to generate

    def size_of(self, type):
        if type == Void():
            return 0
        elif isinstance(type, Struct):
            if self.struct_size(type) == 0:
                return 0
            else:
                return 1
        elif isinstance(type, Union):
            # TODO might be unboxed, all on stack, in future
            return 1
        else:
            return 1

    def struct_size(self, type):
        assert isinstance(type, Struct)
        size = 0
        for t in type.defn.content_types:
            size += self.size_of(t)
        return size

    def get_label(self, pref):
        count = self.labels.get(pref, 0)
        label = '%s_%d' % (pref, count)
        self.labels[pref] = count + 1
        return label

    def get_tag(self, value):
        if value not in self.tags:
            self.tags[value] = self.tag_count
            self.tag_count += 1
        return self.tags[value]

    def compile(self, ast):
        if ast.tag == 'Program':
            self.out.write("""\
; AUTOMATICALLY GENERATED -- EDIT AT YOUR OWN RISK

""")
            for child in ast.children:
                self.compile(child)
            self.out.write("""\
; ...
global_pos=%d
; call main
get_global main_index
call
""" % self.global_pos)
        elif ast.tag == 'Defn':
            self.out.write('%s_index=%d\n' % (ast.value, self.global_pos))
            self.global_pos += 1
            self.compile(ast.children[0])
        elif ast.tag in ('StructDefn', 'Forward'):
            pass
        elif ast.tag == 'FunLit':
            past_fun = self.get_label('past_fun')
            self.out.write('jmp %s\n' % past_fun)
            save_fun = self.fun_lit
            save_argsize = self.fun_argsize
            self.fun_lit = self.get_label('fun_lit')
            self.local_pos = 1
            self.out.write('%s:\n' % self.fun_lit)
            # also saves the old baseptr right here
            self.out.write('set_baseptr\n')
            self.compile(ast.children[0])
            self.compile(ast.children[1])
            # TODO copy the result value(S) to the first arg position
            # (for now the opcode handles that)
            self.out.write('exeunt_%s:\n' % self.fun_lit)
            returnsize = self.size_of(ast.type.return_type)
            self.out.write('set_returnsize %d\n' % returnsize)
            self.out.write('clear_baseptr %d\n' % (0 - self.fun_argsize))
            self.out.write('rts\n')
            self.out.write('%s:\n' % past_fun)
            self.out.write('push %s\n' % self.fun_lit)
            self.fun_argsize = save_argsize
            self.fun_lit = save_fun
        elif ast.tag == 'Args':
            argsize = 0
            for child in ast.children:
                assert child.tag == 'Arg'
                argsize += self.size_of(child.type)
            self.fun_argsize = argsize
            # first arg passed is DEEPEST, so go backwards.
            pos = 0 - self.fun_argsize
            for child in ast.children:
                self.out.write(
                    '%s_local_%s=%d\n' % (self.fun_lit, child.value, pos)
                )
                pos += self.size_of(ast.type)
        elif ast.tag == 'Body':
            self.compile(ast.children[0])
            self.compile(ast.children[1])
        elif ast.tag == 'VarDecls':
            for child in ast.children:
                self.compile(child)
        elif ast.tag == 'VarDecl':
            self.out.write('push 0\n')
            self.out.write(
                '%s_local_%s=%s\n' % (self.fun_lit, ast.value, self.local_pos)
            )
            self.local_pos += 1
        elif ast.tag == 'Block':
            for child in ast.children:
                self.compile(child)
        elif ast.tag == 'While':
            start = self.get_label('loop_start')
            end = self.get_label('loop_end')
            self.loop_end = end
            self.out.write('%s:\n' % start)
            self.compile(ast.children[0])
            self.out.write('bzero %s\n' % end)
            self.compile(ast.children[1])
            self.out.write('jmp %s\n' % start)
            self.out.write('%s:\n' % end)
            self.loop_end = self.loop_end
        elif ast.tag == 'Op':
            self.compile(ast.children[0])
            self.compile(ast.children[1])
            self.out.write('%s\n' % OPS.get(ast.value, ast.value))
        elif ast.tag == 'VarRef':
            if ast.aux == 'global':
                self.out.write('builtin_%s\n' % ast.value)
            elif ast.aux == 'toplevel':
                self.out.write('get_global %s_index\n' % ast.value)
            else:
                self.out.write('get_local %s_local_%s\n' % (self.fun_lit, ast.value))
        elif ast.tag == 'FunCall':
            for child in ast.children[1:]:
                self.out.write('; push argument\n')
                self.compile(child)
            self.out.write('; push function\n')
            self.compile(ast.children[0])
            self.out.write('call\n')
        elif ast.tag == 'If':
            else_part = self.get_label('else_part')
            end_if = self.get_label('end_if')
            self.compile(ast.children[0])
            self.out.write('bzero %s\n' % else_part)
            self.compile(ast.children[1])
            self.out.write('jmp %s\n' % end_if)
            self.out.write('%s:\n' % else_part)
            if len(ast.children) == 3:
                self.compile(ast.children[2])
            self.out.write('%s:\n' % end_if)
        elif ast.tag == 'Return':
            self.compile(ast.children[0])
            self.out.write('jmp exeunt_%s\n' % self.fun_lit)
        elif ast.tag == 'Break':
            self.out.write('jmp %s\n' % self.loop_end)
        elif ast.tag == 'Not':
            self.compile(ast.children[0])
            self.out.write('not\n')
        elif ast.tag == 'None':
            pass  # sizeof(void) == 0
        elif ast.tag == 'BoolLit':
            if ast.value:
                self.out.write("push -1\n")
            else:
                self.out.write("push 0\n")
        elif ast.tag == 'IntLit':
            self.out.write('push %s\n' % ast.value)
        elif ast.tag == 'StrLit':
            self.out.write('push %r\n' % ast.value)
        elif ast.tag == 'Assignment':
            self.compile(ast.children[1])
            self.out.write('; assign to...\n')
            assert ast.children[0].tag == 'VarRef'
            self.out.write('set_local %s_local_%s\n' % (self.fun_lit, ast.children[0].value))
        elif ast.tag == 'Make':
            # TODO store in the order defined in the struct?
            fields = {}
            for child in ast.children[1:]:
                fields[child.aux] = child   # FieldInit.aux = position in struct
            for position in sorted(fields):
                self.compile(fields[position])
            self.out.write('make_struct %d\n' % self.struct_size(ast.type))
        elif ast.tag == 'FieldInit':
            self.compile(ast.children[0])
        elif ast.tag == 'Index':
            self.compile(ast.children[0])
            self.out.write('get_field %d\n' % ast.aux)
        elif ast.tag == 'TypeCast':
            self.compile(ast.children[0])
            t = str(ast.children[0].type)
            # If the LHS is not already a union type, promote it to a tagged value.
            if not isinstance(ast.children[0].type, Union):
                self.out.write('; tag with "%s"\n' % t)
                if self.size_of(ast.children[0].type) == 0:
                    # special case.  there is nothing on the stack
                    self.out.write('push 0\n')
                tag = self.get_tag(t)
                self.out.write('tag %d\n' % tag)
        elif ast.tag == 'TypeCase':
            end_typecase = self.get_label('end_typecase')
            self.compile(ast.children[0])
            self.out.write('dup\n')
            self.out.write('get_tag\n')
            tag = self.get_tag(str(ast.children[1].type))
            self.out.write('push %d\n' % tag)
            self.out.write('eq\n')
            self.out.write('bzero %s\n' % end_typecase)
            # set the value to the untagged value of the value
            self.out.write('dup\n')
            self.out.write('get_value\n')
            assert ast.children[0].tag == 'VarRef'
            self.out.write('set_local %s_local_%s\n' % (self.fun_lit, ast.children[0].value))

            self.compile(ast.children[2])
            # now restore the value, with what was saved on the stack
            self.out.write('dup\n')
            self.out.write('set_local %s_local_%s\n' % (self.fun_lit, ast.children[0].value))

            self.out.write('%s:\n' % end_typecase)
            self.out.write('pop 1\n')
        else:
            raise NotImplementedError(repr(ast))