#!/usr/bin/env python

# SAWBONES: Cat's Eye Technologies' reference implementation of SICKBAY.
# This work is in the public domain; see UNLICENSE for more information.

from optparse import OptionParser
import random
import re
import sys


DEBUG = False


class AST(object):
    def __init__(self, type, children=None, value=None):
        self.type = type
        self.value = value
        if children is not None:
            self.children = children
        else:
            self.children = []

    def add_child(self, item):
        self.children.append(item)

    def __repr__(self):
        if self.value is None:
            return 'AST(%r,%r)' % (self.type, self.children)
        if not self.children:
            return 'AST(%r,value=%r)' % (self.type, self.value)
        return 'AST(%r,%r,value=%r)' % (self.type, self.children, self.value)


class Scanner(object):
    """
    >>> a = Scanner("  (++ )  FOO  ")
    >>> a.token
    '('
    >>> a.type
    'goose egg'
    >>> a.scan()
    >>> a.on("+")
    True
    >>> a.on_type('operator')
    True
    >>> a.check_type('identifier')
    Traceback (most recent call last):
    ...
    SyntaxError: Expected identifier, but found operator ('+')
    >>> a.scan()
    >>> a.scan()
    >>> a.consume(".")
    False
    >>> a.consume(")")
    True
    >>> a.expect("FOO")
    >>> a.type
    'EOF'
    >>> a.expect("bar")
    Traceback (most recent call last):
    ...
    SyntaxError: Expected 'bar', but found 'None'

    """
    def __init__(self, text):
        self.text = text
        self.token = None
        self.type = None
        self.scan()

    def scan_pattern(self, pattern, type, token_group=1, rest_group=2):
        pattern = r'^(' + pattern + r')(.*?)$'
        match = re.match(pattern, self.text, re.DOTALL)
        if not match:
            return False
        else:
            self.type = type
            self.token = match.group(token_group)
            self.text = match.group(rest_group)
            # print >>sys.stderr, "(%r/%s->%r)" % (self.token, self.type, self.text)
            return True

    def scan(self):
        self.scan_pattern(r'[ \t]*', 'whitespace')
        if not self.text:
            self.token = None
            self.type = 'EOF'
            return
        if self.scan_pattern(r'REM[^\r\n]*', 'remark'):
            return
        if self.scan_pattern(r'[\r\n]+', 'newline'):
            return
        if self.scan_pattern(r'\*|\/|\+|\-', 'operator'):
            return
        if self.scan_pattern(r':|\;|\(|\)|\=', 'goose egg'):
            return
        if self.scan_pattern(r'\d+', 'integer literal'):
            return
        if self.scan_pattern(r'\"(.*?)\"', 'string literal',
                             token_group=2, rest_group=3):
            return
        if self.scan_pattern(r'[A-Z][A-Z0-9]?\%', 'identifier'):
            return
        if self.scan_pattern(r'[A-Z]+\$?\%?', 'command'):
            return
        if self.scan_pattern(r'.', 'unknown character'):
            return
        else:
            raise ValueError("this should never happen, self.text=(%s)" % self.text)

    def expect(self, token):
        if self.token == token:
            self.scan()
        else:
            raise SyntaxError("Expected '%s', but found '%s'" %
                              (token, self.token))

    def expect_type(self, type):
        self.check_type(type)
        self.scan()

    def on(self, token):
        return self.token == token

    def on_type(self, type):
        return self.type == type

    def check_type(self, type):
        if not self.type == type:
            raise SyntaxError("Expected %s, but found %s ('%s')" %
                              (type, self.type, self.token))

    def consume(self, token):
        if self.token == token:
            self.scan()
            return True
        else:
            return False


class Parser(object):
    """
    >>> a = Parser("123")
    >>> a.expr()
    AST('IntLit',value=123)

    >>> a = Parser("A% ( 7 )")
    >>> a.intvar()
    AST('IntVar',[AST('IntLit',value=7)],value='A%')

    >>> a = Parser("(A%/(7+B%(4)))")
    >>> a.expr()
    AST('IntOp',[AST('IntVar',[AST('IntLit',value=0)],value='A%'), AST('IntOp',[AST('IntLit',value=7), AST('IntVar',[AST('IntLit',value=4)],value='B%')],value='+')],value='/')

    >>> a = Parser("(0-100) PRINT 4;:REM YES:PRINT 5")
    >>> a.line()
    AST('Line',[AST('IntOp',[AST('IntLit',value=0), AST('IntLit',value=100)],value='-'), AST('PrintIntExpr',[AST('IntLit',value=4)]), AST('Remark',value='REM YES:PRINT 5')])

    """
    def __init__(self, text):
        self.scanner = Scanner(text)

    def program(self):
        lines = []
        count = 0
        while not self.scanner.on_type('EOF'):
            lines.append((count, self.line()))
            count += 1
        return lines

    def line(self):
        expr = self.expr()
        stmts = []
        stmts.append(self.stmt())
        while self.scanner.consume(':'):
            stmts.append(self.stmt())
        if not self.scanner.on_type('EOF'):
            self.scanner.expect_type('newline')
        return AST('Line', [expr] + stmts)

    def stmt(self):
        if self.scanner.on_type('remark'):
            s = AST('Remark', value=self.scanner.token)
            self.scanner.scan()
            return s
        if self.scanner.consume("PRINT"):
            s = None
            if self.scanner.on_type('string literal'):
                st = self.scanner.token
                self.scanner.scan()
                s = AST('PrintString', value=st)
            elif self.scanner.on("CHR$"):
                self.scanner.scan()
                subj = self.expr()
                s = AST('PrintChar', [subj])
            else:
                subj = self.expr()
                s = AST('PrintIntExpr', [subj])
            if self.scanner.on(';'):
                self.scanner.scan()
            else:
                s = AST('PrintNewline', [s])
            return s
        if self.scanner.consume("INPUT"):
            s = None
            if self.scanner.on("CHR$"):
                self.scanner.scan()
                return AST('InputChar', [self.intvar()])
            else:
                return AST('InputInt', [self.intvar()])
        elif self.scanner.consume("GOTO"):
            dest = self.intlit()
            return AST('Goto', [dest])
        elif self.scanner.consume("GOSUB"):
            dest = self.intlit()
            return AST('Gosub', [dest])
        elif self.scanner.consume("RETURN") or self.scanner.consume("END"):
            return AST('Return')
        elif self.scanner.consume("PROLONG"):
            dest = self.intlit()
            return AST('Prolong', [dest])
        elif self.scanner.consume("CUTSHORT"):
            return AST('Cutshort')
        elif self.scanner.consume("DIM"):
            self.scanner.expect("RING")
            self.scanner.expect("(")
            e = self.expr()
            self.scanner.expect(")")
            return AST('DimRing', [e])
        elif self.scanner.consume("LET"):
            var = self.intvar()
            self.scanner.expect('=')
            subj = self.expr()
            return AST('Assign', [var, subj])
        else:
            raise ValueError("stmt? %s" % self.scanner.token)

    def expr(self):
        if self.scanner.on_type('identifier'):
            return self.intvar()
        elif self.scanner.on_type('integer literal'):
            return self.intlit()
        elif self.scanner.consume('RND%'):
            self.scanner.expect('(')
            e = self.expr()
            self.scanner.expect(')')
            return AST('Random', [e])
        else:
            self.scanner.expect('(')
            e1 = self.expr()
            self.scanner.check_type('operator')
            op = self.scanner.token
            self.scanner.scan()
            e2 = self.expr()
            self.scanner.expect(')')
            return AST('IntOp', [e1, e2], value=op)

    def intlit(self):
        self.scanner.check_type('integer literal')
        s = AST('IntLit', value=int(self.scanner.token))
        self.scanner.scan()
        return s

    def intvar(self):
        self.scanner.check_type('identifier')
        id = self.scanner.token
        self.scanner.scan()
        index = AST('IntLit', value=0)
        if self.scanner.consume('('):
            index = self.expr()
            self.scanner.expect(')')
        return AST('IntVar', [index], value=id)


def eval_expr(e, store):
    """
    >>> eval_expr(Parser("123").expr(), {})
    123

    >>> eval_expr(Parser("(((10-3)*7)+1)").expr(), {})
    50

    >>> eval_expr(Parser("A%").expr(), {})
    0

    >>> eval_expr(Parser("A%").expr(), {'A%': {99: 99}})
    0

    >>> eval_expr(Parser("A%").expr(), {'A%': {0: 77}})
    77

    >>> eval_expr(Parser("A% (7)").expr(), {'A%': {7: 88}})
    88

    """
    if e.type == 'IntLit':
        return e.value
    elif e.type == 'IntVar':
        index = eval_expr(e.children[0], store)
        return store.get(e.value, {}).get(index, 0)
    elif e.type == 'Random':
        rg = eval_expr(e.children[0], store)
        return random.randint(0, rg-1)
    elif e.type == 'IntOp':
        lhs = eval_expr(e.children[0], store)
        rhs = eval_expr(e.children[1], store)
        if e.value == '+':
            return lhs + rhs
        elif e.value == '-':
            return lhs - rhs
        elif e.value == '*':
            return lhs * rhs
        elif e.value == '/':
            if rhs == 0:
                return 0
            else:
                return lhs // rhs
        else:
            raise NotImplementedError(e.value)


def eval_ref(e, store):
    """
    >>> eval_ref(Parser("A%(B%)").expr(), {'B%': {0: 88}})
    ('A%', 88)

    """
    if e.type == 'IntVar':
        index = eval_expr(e.children[0], store)
        return (e.value, index)
    else:
        raise NotImplementedError(e.value)


def read_int(infile):
    s = ''
    c = infile.read(1)
    while c in (' ', '\t', '\r', '\n'):
        c = infile.read(1)
    if c == '-':
        s += '-'
        c = infile.read(1)
    while c in ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9'):
        s += c
        c = infile.read(1)
    if c not in (' ', '\t', '\r', '\n'):
        raise ValueError('bad integer')
    return int(s)


def exec_stmt(stmt, store, ring, infile=sys.stdin, outfile=sys.stdout):
    r"""
    >>> try:
    ...     # Python 2
    ...     from StringIO import StringIO
    ... except ImportError:
    ...     # Python 3
    ...     from io import StringIO

    >>> out = StringIO()
    >>> exec_stmt(Parser('PRINT "HI"').stmt(), {'B%': {0: 88}}, None, outfile=out)
    >>> out.getvalue()
    'HI\n'

    >>> out = StringIO()
    >>> exec_stmt(Parser('PRINT (2+3);').stmt(), {'B%': {0: 88}}, None, outfile=out)
    >>> out.getvalue()
    '5'

    >>> out = StringIO()
    >>> exec_stmt(Parser('PRINT (B%/4);').stmt(), {'B%': {0: 88}}, None, outfile=out)
    >>> out.getvalue()
    '22'

    >>> out = StringIO()
    >>> exec_stmt(Parser('PRINT CHR$ 65;').stmt(), {'B%': {0: 88}}, None, outfile=out)
    >>> out.getvalue()
    'A'

    >>> store = {}
    >>> exec_stmt(Parser('LET A% = 15').stmt(), store, None)
    >>> store
    {'A%': {0: 15}}

    >>> store = {'B%': {9: 88}}
    >>> inp = StringIO('too')
    >>> exec_stmt(Parser('INPUT CHR$ B%(9)').stmt(), store, None, infile=inp)
    >>> store
    {'B%': {9: 116}}
    >>> exec_stmt(Parser('INPUT CHR$ B%(9)').stmt(), store, None, infile=inp)
    >>> store
    {'B%': {9: 111}}

    >>> store = {}
    >>> inp = StringIO('  -23  A')
    >>> exec_stmt(Parser('INPUT A%').stmt(), store, None, infile=inp)
    >>> store
    {'A%': {0: -23}}
    >>> exec_stmt(Parser('INPUT CHR$ A%').stmt(), store, None, infile=inp)
    >>> store
    {'A%': {0: 32}}
    >>> exec_stmt(Parser('INPUT CHR$ A%').stmt(), store, None, infile=inp)
    >>> store
    {'A%': {0: 65}}

    >>> store = {}
    >>> inp = StringIO('5x')
    >>> exec_stmt(Parser('INPUT A%').stmt(), store, None, infile=inp)
    Traceback (most recent call last):
    ...
    ValueError: bad integer

    """
    if stmt.type == 'PrintNewline':
        exec_stmt(stmt.children[0], store, ring, infile, outfile)
        outfile.write('\n')
    elif stmt.type == 'PrintString':
        outfile.write(stmt.value)
    elif stmt.type == 'PrintChar':
        outfile.write(chr(eval_expr(stmt.children[0], store)))
    elif stmt.type == 'PrintIntExpr':
        outfile.write(str(eval_expr(stmt.children[0], store)))
    elif stmt.type == 'InputChar':
        (name, index) = eval_ref(stmt.children[0], store)
        value = ord(infile.read(1))
        store.setdefault(name, {})[index] = value
    elif stmt.type == 'InputInt':
        (name, index) = eval_ref(stmt.children[0], store)
        value = read_int(infile)
        store.setdefault(name, {})[index] = value
    elif stmt.type == 'Assign':
        value = eval_expr(stmt.children[1], store)
        (name, index) = eval_ref(stmt.children[0], store)
        store.setdefault(name, {})[index] = value
    elif stmt.type == 'Goto':
        return ('goto', eval_expr(stmt.children[0], store))
    elif stmt.type == 'Gosub':
        return ('gosub', eval_expr(stmt.children[0], store))
    elif stmt.type == 'Prolong':
        ring.prolong(eval_expr(stmt.children[0], store))
    elif stmt.type == 'Return':
        return ('return', 0)
    elif stmt.type == 'Cutshort':
        if ring.is_empty():
            return ('return', 0)
        ring.cutshort()
    elif stmt.type == 'DimRing':
        ring.dim(eval_expr(stmt.children[0], store))
    elif stmt.type == 'Remark':
        return None
    else:
        raise NotImplementedError(stmt.type)


def exec_line(line, store, ring, infile=sys.stdin, outfile=sys.stdout):
    r"""
    >>> try:
    ...     # Python 2
    ...     from StringIO import StringIO
    ... except ImportError:
    ...     # Python 3
    ...     from io import StringIO

    >>> out = StringIO()
    >>> exec_line(Parser('10 PRINT "HI";:PRINT "LO"').line(), {}, None, outfile=out)
    >>> out.getvalue()
    'HILO\n'

    >>> out = StringIO()
    >>> exec_line(Parser('(0-50) PRINT "HI":GOTO 50').line(), {}, None, outfile=out)
    ('goto', 50)
    >>> out.getvalue()
    'HI\n'

    >>> out = StringIO()
    >>> exec_line(Parser('10 GOSUB 50:PRINT "HI"').line(), {}, None, outfile=out)
    ('gosub', 50)
    >>> out.getvalue()
    ''

    >>> out = StringIO()
    >>> exec_line(Parser('1 LET A%(23)=5:PRINT A%(23)').line(), {}, None, outfile=out)
    >>> out.getvalue()
    '5\n'

    """
    r = None
    for stmt in line.children[1:]:
        r = exec_stmt(stmt, store, ring, infile=infile, outfile=outfile)
        if r is not None:
            break
    return r


class RingBuffer(object):
    """
    >>> r = RingBuffer()
    >>> r.push(100)
    >>> r.push(200)
    >>> r.pop()
    200
    >>> r.push(300)
    >>> r.cutshort()
    100
    >>> r.prolong(90)
    >>> r.pop()
    300
    >>> r.pop()
    90
    >>> r.pop()
    Traceback (most recent call last):
    ...
    ValueError: ring buffer empty

    >>> r = RingBuffer()
    >>> r.push(100)
    >>> r.capacity
    10
    >>> r.capacity = 2
    >>> r.push(200)
    >>> r.push(300)
    Traceback (most recent call last):
    ...
    ValueError: ring buffer full

    """
    def __init__(self):
        # THIS IS SOOOOOO CHEATING
        self.buffer = []
        self.capacity = None

    def is_empty(self):
        return not self.buffer

    def dim(self, capacity):
        if self.capacity is not None:
            raise ValueError("already dim'ed")
        self.capacity = capacity

    def pop(self):
        if not self.capacity:
            self.dim(10)
        if not self.buffer:
            raise ValueError("ring buffer empty")
        x = self.buffer.pop()
        return x

    def push(self, x):
        if not self.capacity:
            self.dim(10)
        self.buffer.append(x)
        if len(self.buffer) > self.capacity:
            raise ValueError("ring buffer full")

    def cutshort(self):
        if not self.capacity:
            self.dim(10)
        if not self.buffer:
            raise ValueError("ring buffer empty")
        return self.buffer.pop(0)

    def prolong(self, x):
        if not self.capacity:
            self.dim(10)
        self.buffer.insert(0, x)
        if len(self.buffer) > self.capacity:
            raise ValueError("ring buffer full")


class Lines(object):
    def __init__(self, prog, store):
        tmp = []
        self.smallest = None
        self.largest = None
        for (text_line, line) in prog:
            line_no = eval_expr(line.children[0], store)
            if self.smallest is None or line_no < self.smallest:
                self.smallest = line_no
            if self.largest is None or line_no > self.largest:
                self.largest = line_no
            tmp.append((line_no, text_line, line))
        self.lines = sorted(tmp)
        if DEBUG:
            for (line_no, text_line, line) in self.lines:
                print(line_no, text_line, str(line.children[1])[:50] + '...')

    def seek_line_number(self, line_number):
        i = 0
        while i < len(self.lines):
            if self.lines[i][0] > line_number:
                return self.lines[i][0]
            i += 1
        return None

    def __getitem__(self, line_number):
        for line in self.lines:
            if line[0] == line_number:
                return line[2]
        raise ValueError("line number %d not present" % line_number)


def exec_program(prog, store, ring, infile=sys.stdin, outfile=sys.stdout):
    def re_tu_rn():
        if ring.is_empty():
            return None
        dest = ring.pop()
        line_number = lines.seek_line_number(dest)
        if line_number is None:
            raise ValueError("line number %d beyond end of program" % dest)
        return line_number

    lines = Lines(prog, store)
    line_number = lines.smallest
    done = False
    while not done:
        line = lines[line_number]
        r = exec_line(line, store, ring, infile=infile, outfile=outfile)
        lines = Lines(prog, store)
        if r is None:
            line_number = lines.seek_line_number(line_number)
            if line_number is None:
                line_number = re_tu_rn()
                if line_number is None:
                    return
        elif r[0] == 'goto':
            line_number = r[1]
        elif r[0] == 'gosub':
            ring.push(line_number + 1)
            line_number = r[1]
        elif r[0] == 'return':
            line_number = re_tu_rn()
            if line_number is None:
                return
        else:
            raise NotImplementedError
        if DEBUG:
            print("AT LINE %d" % line_number)


def main(argv):
    optparser = OptionParser(__doc__)
    optparser.add_option("-a", "--show-ast",
                         action="store_true", dest="show_ast", default=False,
                         help="show parsed AST instead of evaluating")
    optparser.add_option("-t", "--test",
                         action="store_true", dest="test", default=False,
                         help="run test cases and exit")
    (options, args) = optparser.parse_args(argv[1:])
    if options.test:
        import doctest
        (fails, something) = doctest.testmod(verbose=True)
        if fails == 0:
            print("All tests passed.")
            sys.exit(0)
        else:
            sys.exit(1)
    file = open(args[0])
    text = file.read()
    file.close()
    p = Parser(text)
    prog = p.program()
    if options.show_ast:
        from pprint import pprint
        pprint(prog)
        sys.exit(0)
    ring = RingBuffer()
    exec_program(prog, {}, ring)
    sys.exit(0)


if __name__ == "__main__":
    main(sys.argv)
