# Copyright (c) 2024 Chris Pressey, Cat's Eye Technologies
# This file is distributed under an MIT license. See LICENSES directory.
# SPDX-License-Identifier: LicenseRef-MIT-X-Chainscape
from argparse import ArgumentParser
import random
class Census(dict):
def accum(self, key, amount):
if amount > 0:
self.setdefault(key, 0)
self[key] += amount
def intersect(self, other):
all_keys = set(self.keys()) | set(other.keys())
result = self.__class__()
for key in all_keys:
value = min(self.get(key, 0), other.get(key, 0))
if value > 0:
result[key] = value
return result
def __and__(self, other):
return self.intersect(other)
def union(self, other):
all_keys = set(self.keys()) | set(other.keys())
result = self.__class__()
for key in all_keys:
value = max(self.get(key, 0), other.get(key, 0))
if value > 0:
result[key] = value
return result
def __or__(self, other):
return self.union(other)
def dump(self, prefix=""):
for member, count in sorted(self.items()):
print('{}{:<20} {}'.format(prefix, member, count))
class TransitionMatrix(dict):
def associate(self, a, b):
if not a:
return
self.setdefault(a, Census()).accum(b, 1)
def intersect(self, other):
all_keys = set(self.keys()) | set(other.keys())
result = self.__class__()
for key in all_keys:
census = self.get(key, Census()) & other.get(key, Census())
if len(census) > 0:
result[key] = census
return result
def __and__(self, other):
return self.intersect(other)
def union(self, other):
all_keys = set(self.keys()) | set(other.keys())
result = self.__class__()
for key in all_keys:
census = self.get(key, Census()) | other.get(key, Census())
if len(census) > 0:
result[key] = census
return result
def __or__(self, other):
return self.union(other)
def dump(self, words=None):
for a, census in sorted(self.items()):
if words and a not in words:
continue
print(a)
census.dump(" " * 20)
class Distribution(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._total = None
@classmethod
def from_census(_cls, d):
result = _cls()
for k, v in d.items():
result._set(k, v)
return result
@property
def total(self):
if self._total is None:
self._total = sum(self.values())
return self._total
def _set(self, key, amount):
if amount > 0.0:
self[key] = amount
self._total = None
def get_normalized(self, key):
return self.get(key, 0) / self.total if self.total > 0.0 else 0.0
def intersect(self, other):
all_keys = set(self.keys()) | set(other.keys())
result = self.__class__()
for key in all_keys:
p1 = self.get_normalized(key)
p2 = other.get_normalized(key)
probability_intersection = p1 * p2
result._set(key, probability_intersection)
return result
def __and__(self, other):
return self.intersect(other)
def union(self, other):
all_keys = set(self.keys()) | set(other.keys())
result = self.__class__()
for key in all_keys:
p1 = self.get_normalized(key)
p2 = other.get_normalized(key)
probability_union = p1 + p2 - (p1 * p2)
result._set(key, probability_union)
return result
def __or__(self, other):
return self.union(other)
def select(self):
total = random.random() * self.total
for member, value in sorted(self.items()):
total -= value
if total <= 0:
return member
def dump(self, prefix=""):
for member, count in sorted(self.items()):
print('{}{:<20} {:.2%}'.format(prefix, member, count / self.total))
class MarkovChain(dict):
@classmethod
def from_matrix(cls, matrix):
return cls(
[(k, Distribution.from_census(v)) for k, v in matrix.items()]
)
def intersect(self, other):
all_keys = set(self.keys()) | set(other.keys())
result = self.__class__()
for key in all_keys:
dist = self.get(key, Distribution()) & other.get(key, Distribution())
if len(dist) > 0:
result[key] = dist
return result
def __and__(self, other):
return self.intersect(other)
def union(self, other):
all_keys = set(self.keys()) | set(other.keys())
result = self.__class__()
for key in all_keys:
dist = self.get(key, Distribution()) | other.get(key, Distribution())
if len(dist) > 0:
result[key] = dist
return result
def __or__(self, other):
return self.union(other)
def dump(self, words=None):
for a, dist in sorted(self.items()):
if words and a not in words:
continue
print(a)
dist.dump(" " * 20)
def pick_any(self):
return random.choice(sorted(self.keys()))
def walk(self, word, count=20):
words = []
while count > 0:
words.append("\n\n" if word == "¶" else word)
if word == "...":
word = self.pick_any()
elif word not in self:
word = "..."
else:
word = self[word].select()
count -= 1
return words
def load_matrix(filenames):
matrix = TransitionMatrix()
for filename in filenames:
prev = None
with open(filename, "r") as f:
for line in f:
word = line.strip()
matrix.associate(prev, word)
prev = word
return matrix
def load_chain(filenames):
return MarkovChain.from_matrix(load_matrix(filenames))
def main(args):
argparser = ArgumentParser()
argparser.add_argument(
'--count', type=int, default=150,
help="How many steps to take when walking the chain"
)
argparser.add_argument(
'--dump', type=str, default=None,
help="Instead of walking the chain, dump its structure"
)
argparser.add_argument(
'--seed', type=int, default=9001,
help="Random seed to use"
)
argparser.add_argument('--version', action='version', version="%(prog)s 0.0")
(options, args) = argparser.parse_known_args(args)
random.seed(options.seed)
if args[0] == "intersect":
c = load_chain([args[1]])
for arg in args[2:]:
c = c & load_chain([arg])
if options.dump is not None:
c.dump(options.dump.split())
else:
print(' '.join(c.walk("¶", count=options.count)))
elif args[0] == "concat":
c1 = load_chain(args[1:])
if options.dump is not None:
c1.dump(options.dump.split())
else:
print(' '.join(c1.walk("¶", count=options.count)))
elif args[0] == "union":
c = load_chain([args[1]])
for arg in args[2:]:
c = c | load_chain([arg])
if options.dump is not None:
c.dump(options.dump.split())
else:
print(' '.join(c.walk("¶", count=options.count)))
elif args[0] == "min":
m = load_matrix([args[1]])
for arg in args[2:]:
m = m & load_matrix([arg])
if options.dump is not None:
m.dump(options.dump.split())
else:
c = MarkovChain.from_matrix(m)
print(' '.join(c.walk("¶", count=options.count)))
elif args[0] == "max":
m = load_matrix([args[1]])
for arg in args[2:]:
m = m | load_matrix([arg])
if options.dump is not None:
m.dump(options.dump.split())
else:
c = MarkovChain.from_matrix(m)
print(' '.join(c.walk("¶", count=options.count)))
else:
raise NotImplementedError("intersect, concat, union, min, max is all I know")
if __name__ == '__main__':
import sys
main(sys.argv[1:])