git @ Cat's Eye Technologies Chainscape / master src / transition-matrix.py
master

Tree @master (Download .tar.gz)

transition-matrix.py @masterraw · history · blame

# Copyright (c) 2024 Chris Pressey, Cat's Eye Technologies
# This file is distributed under an MIT license.  See LICENSES directory.
# SPDX-License-Identifier: LicenseRef-MIT-X-Chainscape

from argparse import ArgumentParser

import random


class Census(dict):

    def accum(self, key, amount):
        if amount > 0:
            self.setdefault(key, 0)
            self[key] += amount

    def intersect(self, other):
        all_keys = set(self.keys()) | set(other.keys())
        result = self.__class__()
        for key in all_keys:
            value = min(self.get(key, 0), other.get(key, 0))
            if value > 0:
                result[key] = value
        return result

    def __and__(self, other):
        return self.intersect(other)

    def union(self, other):
        all_keys = set(self.keys()) | set(other.keys())
        result = self.__class__()
        for key in all_keys:
            value = max(self.get(key, 0), other.get(key, 0))
            if value > 0:
                result[key] = value
        return result

    def __or__(self, other):
        return self.union(other)

    def dump(self, prefix=""):
        for member, count in sorted(self.items()):
            print('{}{:<20} {}'.format(prefix, member, count))


class TransitionMatrix(dict):

    def associate(self, a, b):
        if not a:
            return
        self.setdefault(a, Census()).accum(b, 1)

    def intersect(self, other):
        all_keys = set(self.keys()) | set(other.keys())
        result = self.__class__()
        for key in all_keys:
            census = self.get(key, Census()) & other.get(key, Census())
            if len(census) > 0:
                result[key] = census
        return result

    def __and__(self, other):
        return self.intersect(other)

    def union(self, other):
        all_keys = set(self.keys()) | set(other.keys())
        result = self.__class__()
        for key in all_keys:
            census = self.get(key, Census()) | other.get(key, Census())
            if len(census) > 0:
                result[key] = census
        return result

    def __or__(self, other):
        return self.union(other)

    def dump(self, words=None):
        for a, census in sorted(self.items()):
            if words and a not in words:
                continue
            print(a)
            census.dump(" " * 20)


class Distribution(dict):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._total = None

    @classmethod
    def from_census(_cls, d):
        result = _cls()
        for k, v in d.items():
            result._set(k, v)
        return result

    @property
    def total(self):
        if self._total is None:
            self._total = sum(self.values())
        return self._total

    def _set(self, key, amount):
        if amount > 0.0:
            self[key] = amount
            self._total = None

    def get_normalized(self, key):
        return self.get(key, 0) / self.total if self.total > 0.0 else 0.0

    def intersect(self, other):
        all_keys = set(self.keys()) | set(other.keys())
        result = self.__class__()
        for key in all_keys:
            p1 = self.get_normalized(key)
            p2 = other.get_normalized(key)
            probability_intersection = p1 * p2
            result._set(key, probability_intersection)
        return result

    def __and__(self, other):
        return self.intersect(other)

    def union(self, other):
        all_keys = set(self.keys()) | set(other.keys())
        result = self.__class__()
        for key in all_keys:
            p1 = self.get_normalized(key)
            p2 = other.get_normalized(key)
            probability_union = p1 + p2 - (p1 * p2)
            result._set(key, probability_union)
        return result

    def __or__(self, other):
        return self.union(other)

    def select(self):
        total = random.random() * self.total
        for member, value in sorted(self.items()):
            total -= value
            if total <= 0:
                return member

    def dump(self, prefix=""):
        for member, count in sorted(self.items()):
            print('{}{:<20} {:.2%}'.format(prefix, member, count / self.total))


class MarkovChain(dict):

    @classmethod
    def from_matrix(cls, matrix):
        return cls(
            [(k, Distribution.from_census(v)) for k, v in matrix.items()]
        )

    def intersect(self, other):
        all_keys = set(self.keys()) | set(other.keys())
        result = self.__class__()
        for key in all_keys:
            dist = self.get(key, Distribution()) & other.get(key, Distribution())
            if len(dist) > 0:
                result[key] = dist
        return result

    def __and__(self, other):
        return self.intersect(other)

    def union(self, other):
        all_keys = set(self.keys()) | set(other.keys())
        result = self.__class__()
        for key in all_keys:
            dist = self.get(key, Distribution()) | other.get(key, Distribution())
            if len(dist) > 0:
                result[key] = dist
        return result

    def __or__(self, other):
        return self.union(other)

    def dump(self, words=None):
        for a, dist in sorted(self.items()):
            if words and a not in words:
                continue
            print(a)
            dist.dump(" " * 20)

    def pick_any(self):
        return random.choice(sorted(self.keys()))

    def walk(self, word, count=20):
        words = []
        while count > 0:
            words.append("\n\n" if word == "¶" else word)
            if word == "...":
                word = self.pick_any()
            elif word not in self:
                word = "..."
            else:
                word = self[word].select()
            count -= 1
        return words


def load_matrix(filenames):
    matrix = TransitionMatrix()
    for filename in filenames:
        prev = None
        with open(filename, "r") as f:
            for line in f:
                word = line.strip()
                matrix.associate(prev, word)
                prev = word
    return matrix


def load_chain(filenames):
    return MarkovChain.from_matrix(load_matrix(filenames))


def main(args):
    argparser = ArgumentParser()
    argparser.add_argument(
        '--count', type=int, default=150,
        help="How many steps to take when walking the chain"
    )
    argparser.add_argument(
        '--dump', type=str, default=None,
        help="Instead of walking the chain, dump its structure"
    )
    argparser.add_argument(
        '--seed', type=int, default=9001,
        help="Random seed to use"
    )
    argparser.add_argument('--version', action='version', version="%(prog)s 0.0")

    (options, args) = argparser.parse_known_args(args)
    random.seed(options.seed)

    if args[0] == "intersect":
        c = load_chain([args[1]])
        for arg in args[2:]:
            c = c & load_chain([arg])
        if options.dump is not None:
            c.dump(options.dump.split())
        else:
            print(' '.join(c.walk("¶", count=options.count)))
    elif args[0] == "concat":
        c1 = load_chain(args[1:])
        if options.dump is not None:
            c1.dump(options.dump.split())
        else:
            print(' '.join(c1.walk("¶", count=options.count)))
    elif args[0] == "union":
        c = load_chain([args[1]])
        for arg in args[2:]:
            c = c | load_chain([arg])
        if options.dump is not None:
            c.dump(options.dump.split())
        else:
            print(' '.join(c.walk("¶", count=options.count)))
    elif args[0] == "min":
        m = load_matrix([args[1]])
        for arg in args[2:]:
            m = m & load_matrix([arg])
        if options.dump is not None:
            m.dump(options.dump.split())
        else:
            c =  MarkovChain.from_matrix(m)
            print(' '.join(c.walk("¶", count=options.count)))
    elif args[0] == "max":
        m = load_matrix([args[1]])
        for arg in args[2:]:
            m = m | load_matrix([arg])
        if options.dump is not None:
            m.dump(options.dump.split())
        else:
            c =  MarkovChain.from_matrix(m)
            print(' '.join(c.walk("¶", count=options.count)))
    else:
        raise NotImplementedError("intersect, concat, union, min, max is all I know")


if __name__ == '__main__':
    import sys
    main(sys.argv[1:])