git @ Cat's Eye Technologies relwrite / master src / relwrite / engine.py
master

Tree @master (Download .tar.gz)

engine.py @masterraw · history · blame

def generate(rules, working_utterances, max_matches=None):
    """Note that an "utterance" can be any mix of terminals and non-terminals.
    The "final utterances" will consist of only one or the other.
    """
    new_working_utterances = set()
    final_utterances = set()
    for utterance in working_utterances:
        num_rewrites_of_this_utterance = 0
        for (pattern, replacement) in rules:
            indices = get_match_indices(utterance, pattern, max_matches=max_matches)
            for index in indices:
                new_utterance = replace_at_index(
                    utterance, pattern, replacement, index
                )
                new_working_utterances.add(new_utterance)
                num_rewrites_of_this_utterance += 1
        if num_rewrites_of_this_utterance == 0:
            final_utterances.add(utterance)

    return new_working_utterances, final_utterances


def get_match_indices(utterance, pattern, max_matches=None):
    length = len(pattern)
    matches = []
    for index, _ in enumerate(utterance):
        if pattern == utterance[index:index + length]:
            matches.append(index)
        if max_matches and len(matches) >= max_matches:
            break
    return matches


def replace_at_index(utterance, pattern, replacement, index):
    length = len(pattern)
    new_utterance = list(utterance)
    new_utterance[index:index + length] = replacement
    return tuple(new_utterance)


def derive(
    rules,
    working_utterances,
    strategy,
    max_derivations=None,
    max_matches=None,
    verbose=False,
    save_snapshots_every=None,
    expand_until=None,
    beam_width=10
):
    final_utterances = None
    collected_utterances = []
    num_derivations = 0
    iter = 0

    scoring_functions = {
        'complete': None,
        'expand': lambda u: 0 - len(u),
        'contract': lambda u: len(u),
        'minimize-nonterminals': lambda u: sum(map(lambda s: s.startswith('<'), u)),
    }

    while working_utterances:
        iter += 1
        if save_snapshots_every and iter % save_snapshots_every == 0:
            import json
            snapshot_filename = 'snapshot-{}.json'.format(iter)
            if verbose:
                print('Saving snapshot to {}'.format(snapshot_filename))
            with open(snapshot_filename, 'w') as f:
                f.write(json.dumps(working_utterances, indent=4))
        length = len(working_utterances)
        lengths = [len(u) for u in working_utterances]
        min_length = min(lengths)
        if verbose:
            print('{} working utterances, min length = {}'.format(
                length, min_length
            ))
        if strategy == 'expand' and min_length >= (expand_until or 0):
            if verbose:
                print('Reached {} threshold'.format(expand_until))
            # TODO: make it configurable, which strategy to switch to here?
            strategy = 'minimize-nonterminals'

        working_utterances, final_utterances = generate(rules, working_utterances, max_matches=max_matches)

        # beam search: sort by score and trim before continuing
        scoring_function = scoring_functions[strategy]
        if scoring_function:
            working_utterances = sorted(working_utterances, key=scoring_function)[:beam_width]

        for utterance in final_utterances:
            collected_utterances.append(utterance)
            num_derivations += 1
            if max_derivations and num_derivations >= max_derivations:
                working_utterances = []
                break

    return collected_utterances