git @ Cat's Eye Technologies Castile / master src / castile / checker.py
master

Tree @master (Download .tar.gz)

checker.py @masterraw · history · blame

from castile.builtins import BUILTINS
from castile.context import ScopedContext
from castile.types import (
    Integer, String, Void, Boolean, Function, Union, Struct
)


class CastileTypeError(ValueError):
    def __init__(self, ast, message, *args, **kwargs):
        message = 'line {}: {}'.format(ast.line, message)
        super(CastileTypeError, self).__init__(message, *args, **kwargs)


class StructDefinition(object):
    def __init__(self, name, field_names, content_types, scope_idents):
        self.name = name
        self.field_names = field_names  # dict of name -> position
        self.content_types = content_types  # list of types in order
        self.scope_idents = scope_idents  # list of identifiers, or None

    def field_names_in_order(self):
        m = {}
        for k, v in self.field_names.items():
            m[v] = k
        names = []
        for i in range(len(m)):
            names.append(m[i])
        return names


class TypeChecker(object):
    def __init__(self):
        global_context = {}
        for (name, (value, type)) in BUILTINS.items():
            global_context[name] = type
        self.context = ScopedContext(global_context, level='global')
        self.toplevel_context = ScopedContext({}, self.context, level='toplevel')
        self.context = self.toplevel_context
        self.current_defn = None

        self.forwards = {}
        self.structs = {}  # struct name -> StructDefinition
        self.return_type = None
        self.within_control = False

        self.verbose = False

    def set(self, name, type):
        self.context[name] = type
        if self.verbose:
            print('%s: %s' % (name, type))
        return type

    def assert_eq(self, ast, t1, t2):
        if t1 == t2:
            return
        raise CastileTypeError(ast, "type mismatch: %s != %s" % (t1, t2))

    def collect_structs(self, ast):
        for child in ast.children:
            if child.tag == 'StructDefn':
                self.collect_struct(child)

    def collect_struct(self, ast):
        name = ast.value
        if name in self.structs:
            raise CastileTypeError(ast, 'duplicate struct %s' % name)
        struct_fields = {}
        type_exprs = []
        i = 0
        field_defns = ast.children[0].children
        scope_idents = None
        if len(ast.children) > 1:
            scope_idents = [a.value for a in ast.children[1].children]
        for child in field_defns:
            assert child.tag == 'FieldDefn', child.tag
            field_name = child.value
            if field_name in struct_fields:
                raise CastileTypeError(child, 'already-defined field %s' % field_name)
            struct_fields[field_name] = i
            i += 1
            type_exprs.append(self.type_of(child.children[0]))
        self.structs[name] = StructDefinition(ast.value, struct_fields, type_exprs, scope_idents)

    def resolve_structs(self, ast):
        if isinstance(ast.type, Struct):
            if ast.type.name not in self.structs:
                raise CastileTypeError(ast, 'undefined struct %s' % ast.type.name)
            ast.type.defn = self.structs[ast.type.name]
        for child in ast.children:
            self.resolve_structs(child)

    # context is modified as side-effect of traversal
    def type_of(self, ast):
        if ast.tag == 'Op':
            if ast.value in ('and', 'or'):
                self.assert_eq(ast, self.type_of(ast.children[0]), Boolean())
                self.assert_eq(ast, self.type_of(ast.children[1]), Boolean())
                ast.type = Boolean()
            elif ast.value in ('+', '-', '*', '/'):
                type1 = self.type_of(ast.children[0])
                type2 = self.type_of(ast.children[1])
                self.assert_eq(ast, type1, type2)
                self.assert_eq(ast, type1, Integer())
                ast.type = Integer()
            elif ast.value in ('==', '!=', '>', '>=', '<', '<='):
                type1 = self.type_of(ast.children[0])
                type2 = self.type_of(ast.children[1])
                self.assert_eq(ast, type1, type2)
                if isinstance(type1, Struct):
                    raise CastileTypeError(ast, "structs cannot be compared")
                if isinstance(type1, Union) and type1.contains_instance_of(Struct):
                    raise CastileTypeError(ast, "unions containing structs cannot be compared")
                ast.type = Boolean()
        elif ast.tag == 'Not':
            type1 = self.type_of(ast.children[0])
            self.assert_eq(ast, type1, Boolean())
            ast.type = Boolean()
        elif ast.tag == 'IntLit':
            ast.type = Integer()
        elif ast.tag == 'StrLit':
            ast.type = String()
        elif ast.tag == 'BoolLit':
            ast.type = Boolean()
        elif ast.tag == 'FunLit':
            save_context = self.context
            self.context = ScopedContext(
                {}, self.toplevel_context, level='argument'
            )
            self.return_type = None
            arg_types = self.type_of(ast.children[0])  # args
            t = self.type_of(ast.children[1])  # body
            self.assert_eq(ast, t, Void())
            self.context = save_context
            return_type = self.return_type
            self.return_type = None
            if return_type is None:
                return_type = Void()
            ast.type = Function(arg_types, return_type)
        elif ast.tag == 'Args':
            types = []
            for child in ast.children:
                types.append(self.type_of(child))
            return types  # NOT A TYPE
        elif ast.tag == 'Arg':
            ast.type = self.set(ast.value, self.type_of(ast.children[0]))
        elif ast.tag == 'Type':
            map = {
                'integer': Integer(),
                'boolean': Boolean(),
                'string': String(),
                'void': Void(),
            }
            ast.type = map[ast.value]
        elif ast.tag == 'Body':
            self.context = ScopedContext({}, self.context,
                                         level='local')
            self.assert_eq(ast, self.type_of(ast.children[1]), Void())
            self.context = self.context.parent
            ast.type = Void()
        elif ast.tag == 'FunType':
            return_type = self.type_of(ast.children[0])
            ast.type = Function(
                [self.type_of(c) for c in ast.children[1:]], return_type
            )
        elif ast.tag == 'UnionType':
            types = []
            for c in ast.children:
                type_ = self.type_of(c)
                if type_ in types:
                    raise CastileTypeError(c, "bad union type")
                types.append(type_)
            ast.type = Union(types)
        elif ast.tag == 'StructType':
            ast.type = Struct(ast.value)
        elif ast.tag == 'VarRef':
            ast.type = self.context[ast.value]
            ast.aux = self.context.level(ast.value)
        elif ast.tag == 'None':
            ast.type = Void()
        elif ast.tag == 'FunCall':
            t1 = self.type_of(ast.children[0])
            assert isinstance(t1, Function), \
                '%r is not a function' % t1
            if len(t1.arg_types) != len(ast.children) - 1:
                raise CastileTypeError(ast, "argument mismatch")
            i = 0
            for child in ast.children[1:]:
                self.assert_eq(ast, self.type_of(child), t1.arg_types[i])
                i += 1
            ast.type = t1.return_type
        elif ast.tag == 'Return':
            t1 = self.type_of(ast.children[0])
            if self.return_type is None:
                self.return_type = t1
            else:
                self.assert_eq(ast, t1, self.return_type)
            ast.type = Void()
        elif ast.tag == 'Break':
            ast.type = Void()
        elif ast.tag == 'If':
            within_control = self.within_control
            self.within_control = True
            t1 = self.type_of(ast.children[0])
            assert t1 == Boolean()
            t2 = self.type_of(ast.children[1])
            if len(ast.children) == 3:
                # TODO useless!  is void.
                t3 = self.type_of(ast.children[2])
                self.assert_eq(ast, t2, t3)
                ast.type = t2
            else:
                ast.type = Void()
            self.within_control = within_control
        elif ast.tag == 'While':
            within_control = self.within_control
            self.within_control = True
            t1 = self.type_of(ast.children[0])
            self.assert_eq(ast, t1, Boolean())
            t2 = self.type_of(ast.children[1])
            ast.type = Void()
            self.within_control = within_control
        elif ast.tag == 'Block':
            for child in ast.children:
                self.assert_eq(ast, self.type_of(child), Void())
            ast.type = Void()
        elif ast.tag == 'Assignment':
            t2 = self.type_of(ast.children[1])
            t1 = None
            name = ast.children[0].value
            if ast.aux == 'defining instance':
                if self.within_control:
                    raise CastileTypeError(ast, 'definition of %s within control block' % name)
                if name in self.context:
                    raise CastileTypeError(ast, 'definition of %s shadows previous' % name)
                self.set(name, t2)
                t1 = t2
            else:
                if name not in self.context:
                    raise CastileTypeError(ast, 'variable %s used before definition' % name)
                t1 = self.type_of(ast.children[0])
            self.assert_eq(ast, t1, t2)
            # not quite useless now (typecase still likes this)
            if self.context.level(ast.children[0].value) != 'local':
                raise CastileTypeError(ast, 'cannot assign to non-local')
            ast.type = Void()
        elif ast.tag == 'Make':
            t = self.type_of(ast.children[0])
            if t.name not in self.structs:
                raise CastileTypeError(ast, "undefined struct %s" % t.name)
            struct_defn = self.structs[t.name]
            if struct_defn.scope_idents is not None:
                if self.current_defn not in struct_defn.scope_idents:
                    raise CastileTypeError(ast, "inaccessible struct %s for make: %s not in %s" %
                        (t.name, self.current_defn, struct_defn.scope_idents)
                    )
            if len(struct_defn.content_types) != len(ast.children) - 1:
                raise CastileTypeError(ast, "argument mismatch; expected {}, got {} in {}".format(
                    len(struct_defn.content_types), len(ast.children) - 1, ast
                ))
            i = 0
            for defn in ast.children[1:]:
                name = defn.value
                t1 = self.type_of(defn)
                pos = struct_defn.field_names[name]
                defn.aux = pos
                self.assert_eq(ast, t1, struct_defn.content_types[pos])
                i += 1
            ast.type = t
        elif ast.tag == 'FieldInit':
            ast.type = self.type_of(ast.children[0])
        elif ast.tag == 'Index':
            t = self.type_of(ast.children[0])
            struct_defn = self.structs[t.name]
            if struct_defn.scope_idents is not None:
                if self.current_defn not in struct_defn.scope_idents:
                    raise CastileTypeError(ast, "inaccessible struct %s for access: %s not in %s" %
                        (t.name, self.current_defn, struct_defn.scope_idents)
                    )
            field_name = ast.value
            struct_fields = struct_defn.field_names
            if field_name not in struct_fields:
                raise CastileTypeError(ast, "undefined field")
            index = struct_fields[field_name]
            # we make this value available to compiler backends
            ast.aux = index
            # we look up the type from the StructDefinition
            ast.type = struct_defn.content_types[index]
        elif ast.tag == 'TypeCase':
            t1 = self.type_of(ast.children[0])
            t2 = self.type_of(ast.children[1])
            if not isinstance(t1, Union):
                raise CastileTypeError(ast, 'bad typecase, %s not a union' % t1)
            if not t1.contains(t2):
                raise CastileTypeError(ast, 'bad typecase, %s not in %s' % (t2, t1))
            # typecheck t3 with variable in children[0] having type t2
            assert ast.children[0].tag == 'VarRef'
            within_control = self.within_control
            self.within_control = True
            self.context = ScopedContext({}, self.context, level='typecase')
            self.context[ast.children[0].value] = t2
            ast.type = self.type_of(ast.children[2])
            self.context = self.context.parent
            self.within_control = within_control
        elif ast.tag == 'Program':
            for defn in ast.children:
                self.assert_eq(ast, self.type_of(defn), Void())
            ast.type = Void()
            self.resolve_structs(ast)
        elif ast.tag == 'Defn':
            self.current_defn = ast.value
            t = self.type_of(ast.children[0])
            self.current_defn = None
            if ast.value in self.forwards:
                self.assert_eq(ast, self.forwards[ast.value], t)
                del self.forwards[ast.value]
            else:
                self.set(ast.value, t)
            if ast.value == 'main':
                # any return type is fine, for now, so,
                # we compare it against itself
                rt = t.return_type
                self.assert_eq(ast, t, Function([], rt))
            ast.type = Void()
        elif ast.tag == 'Forward':
            t = self.type_of(ast.children[0])
            self.forwards[ast.value] = t
            self.set(ast.value, t)
            ast.type = Void()
        elif ast.tag == 'StructDefn':
            ast.type = Void()
        elif ast.tag == 'TypeCast':
            value_t = self.type_of(ast.children[0])
            union_t = self.type_of(ast.children[1])
            if not isinstance(union_t, Union):
                raise CastileTypeError(ast, 'bad cast, not a union: %s' % union_t)
            if not union_t.contains(value_t):
                raise CastileTypeError(
                    ast, 'bad cast, %s does not include %s' % (union_t, value_t)
                )
            ast.type = union_t
        else:
            raise NotImplementedError(repr(ast))
        return ast.type