Support for random optimizers, including the random-greedy path.

import functools
import heapq
import math
import numbers
import random
import time
from collections import deque

from . import helpers, paths

__all__ = ["RandomGreedy", "random_greedy"]

class RandomOptimizer(paths.PathOptimizer):
    """Base class for running any random path finder that benefits
    from repeated calling, possibly in a parallel fashion. Custom random
    optimizers should subclass this, and the ``setup`` method should be
    implemented with the following signature::

        def setup(self, inputs, output, size_dict):
            # custom preparation here ...
            return trial_fn, trial_args

    Where ``trial_fn`` itself should have the signature::

        def trial_fn(r, *trial_args):
            # custom computation of path here
            return ssa_path, cost, size

    Where ``r`` is the run number and could for example be used to seed a
    random number generator. See ``RandomGreedy`` for an example.

    max_repeats : int, optional
        The maximum number of repeat trials to have.
    max_time : float, optional
        The maximum amount of time to run the algorithm for.
    minimize : {'flops', 'size'}, optional
        Whether to favour paths that minimize the total estimated flop-count or
        the size of the largest intermediate created.
    parallel : {bool, int, or executor-pool like}, optional
        Whether to parallelize the random trials, by default ``False``. If
        ``True``, use a ``concurrent.futures.ProcessPoolExecutor`` with the same
        number of processes as cores. If an integer is specified, use that many
        processes instead. Finally, you can supply a custom executor-pool which
        should have an API matching that of the python 3 standard library
        module ``concurrent.futures``. Namely, a ``submit`` method that returns
        ``Future`` objects, themselves with ``result`` and ``cancel`` methods.
    pre_dispatch : int, optional
        If running in parallel, how many jobs to pre-dispatch so as to avoid
        submitting all jobs at once. Should also be more than twice the number
        of workers to avoid under-subscription. Default: 128.

    path : list[tuple[int]]
        The best path found so far.
    costs : list[int]
        The list of each trial's costs found so far.
    sizes : list[int]
        The list of each trial's largest intermediate size so far.

    See Also

    def __init__(self, max_repeats=32, max_time=None, minimize='flops', parallel=False, pre_dispatch=128):

        if minimize not in ('flops', 'size'):
            raise ValueError("`minimize` should be one of {'flops', 'size'}.")

        self.max_repeats = max_repeats
        self.max_time = max_time
        self.minimize = minimize
        self.better = paths.get_better_fn(minimize)
        self.parallel = parallel
        self.pre_dispatch = pre_dispatch

        self.costs = []
        self.sizes = [] = {'flops': float('inf'), 'size': float('inf')}

        self._repeats_start = 0

    def path(self):
        """The best path found so far.
        return paths.ssa_to_linear(['ssa_path'])

    def parallel(self):
        return self._parallel

    def parallel(self, parallel):
        # shutdown any previous executor if we are managing it
        if getattr(self, '_managing_executor', False):

        self._parallel = parallel
        self._managing_executor = False

        if parallel is False:
            self._executor = None

        if parallel is True:
            from concurrent.futures import ProcessPoolExecutor
            self._executor = ProcessPoolExecutor()
            self._managing_executor = True

        if isinstance(parallel, numbers.Number):
            from concurrent.futures import ProcessPoolExecutor
            self._executor = ProcessPoolExecutor(parallel)
            self._managing_executor = True

        # assume a pool-executor has been supplied
        self._executor = parallel

    def _gen_results_parallel(self, repeats, trial_fn, args):
        """Lazily generate results from an executor without submitting all jobs at once.
        self._futures = deque()

        # the idea here is to submit at least ``pre_dispatch`` jobs *before* we
        # yield any results, then do both in tandem, before draining the queue
        for r in repeats:
            if len(self._futures) < self.pre_dispatch:
                self._futures.append(self._executor.submit(trial_fn, r, *args))
            yield self._futures.popleft().result()

        while self._futures:
            yield self._futures.popleft().result()

    def _cancel_futures(self):
        if self._executor is not None:
            for f in self._futures:

    def setup(self, inputs, output, size_dict):
        raise NotImplementedError

    def __call__(self, inputs, output, size_dict, memory_limit):
        # start a timer?
        if self.max_time is not None:
            t0 = time.time()

        trial_fn, trial_args = self.setup(inputs, output, size_dict)

        r_start = self._repeats_start + len(self.costs)
        r_stop = r_start + self.max_repeats
        repeats = range(r_start, r_stop)

        # create the trials lazily
        if self._executor is not None:
            trials = self._gen_results_parallel(repeats, trial_fn, trial_args)
            trials = (trial_fn(r, *trial_args) for r in repeats)

        # assess the trials
        for ssa_path, cost, size in trials:

            # keep track of all costs and sizes

            # check if we have found a new best
            found_new_best = self.better(cost, size,['flops'],['size'])

            if found_new_best:
      ['flops'] = cost
      ['size'] = size
      ['ssa_path'] = ssa_path

            # check if we have run out of time
            if (self.max_time is not None) and (time.time() > t0 + self.max_time):

        return self.path

    def __del__(self):
        # if we created the parallel pool-executor, shut it down
        if getattr(self, '_managing_executor', False):

def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=True):
    """A contraction 'chooser' that weights possible contractions using a
    Boltzmann distribution. Explicitly, given costs ``c_i`` (with ``c_0`` the
    smallest), the relative weights, ``w_i``, are computed as:

        w_i = exp( -(c_i - c_0) / temperature)

    Additionally, if ``rel_temperature`` is set, scale ``temperature`` by
    ``abs(c_0)`` to account for likely fluctuating cost magnitudes during the
    course of a contraction.

    queue : list
        The heapified list of candidate contractions.
    remaining : dict[str, int]
        Mapping of remaining inputs' indices to the ssa id.
    temperature : float, optional
        When choosing a possible contraction, its relative probability will be
        proportional to ``exp(-cost / temperature)``. Thus the larger
        ``temperature`` is, the further random paths will stray from the normal
        'greedy' path. Conversely, if set to zero, only paths with exactly the
        same cost as the best at each step will be explored.
    rel_temperature : bool, optional
        Whether to normalize the ``temperature`` at each step to the scale of
        the best cost. This is generally beneficial as the magnitude of costs
        can vary significantly throughout a contraction.
    nbranch : int, optional
        How many potential paths to calculate probability for and choose from
        at each step.

    cost, k1, k2, k12
    n = 0
    choices = []
    while queue and n < nbranch:
        cost, k1, k2, k12 = heapq.heappop(queue)
        if k1 not in remaining or k2 not in remaining:
            continue  # candidate is obsolete
        choices.append((cost, k1, k2, k12))
        n += 1

    if n == 0:
        return None
    if n == 1:
        return choices[0]

    costs = [choice[0][0] for choice in choices]
    cmin = costs[0]

    # adjust by the overall scale to account for fluctuating absolute costs
    if rel_temperature:
        temperature *= max(1, abs(cmin))

    # compute relative probability for each potential contraction
    if temperature == 0.0:
        energies = [1 if c == cmin else 0 for c in costs]
        # shift by cmin for numerical reasons
        energies = [math.exp(-(c - cmin) / temperature) for c in costs]

    # randomly choose a contraction based on energies
    chosen, = random.choices(range(n), weights=energies)
    cost, k1, k2, k12 = choices.pop(chosen)

    # put the other choise back in the heap
    for other in choices:
        heapq.heappush(queue, other)

    return cost, k1, k2, k12

def ssa_path_compute_cost(ssa_path, inputs, output, size_dict):
    """Compute the flops and max size of an ssa path.
    inputs = list(map(frozenset, inputs))
    output = frozenset(output)
    remaining = set(range(len(inputs)))
    total_cost = 0
    max_size = 0

    for i, j in ssa_path:
        k12, flops12 = paths.calc_k12_flops(inputs, output, remaining, i, j, size_dict)
        total_cost += flops12
        max_size = max(max_size, helpers.compute_size_by_dict(k12, size_dict))

    return total_cost, max_size

def _trial_greedy_ssa_path_and_cost(r, inputs, output, size_dict, choose_fn, cost_fn):
    """A single, repeatable, greedy trial run. Returns ``ssa_path`` and cost.
    if r == 0:
        # always start with the standard greedy approach
        choose_fn = None

    ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn)
    cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict)

    return ssa_path, cost, size

class RandomGreedy(RandomOptimizer):

    cost_fn : callable, optional
        A function that returns a heuristic 'cost' of a potential contraction
        with which to sort candidates. Should have signature
        ``cost_fn(size12, size1, size2, k12, k1, k2)``.
    temperature : float, optional
        When choosing a possible contraction, its relative probability will be
        proportional to ``exp(-cost / temperature)``. Thus the larger
        ``temperature`` is, the further random paths will stray from the normal
        'greedy' path. Conversely, if set to zero, only paths with exactly the
        same cost as the best at each step will be explored.
    rel_temperature : bool, optional
        Whether to normalize the ``temperature`` at each step to the scale of
        the best cost. This is generally beneficial as the magnitude of costs
        can vary significantly throughout a contraction. If False, the
        algorithm will end up branching when the absolute cost is low, but
        stick to the 'greedy' path when the cost is high - this can also be
    nbranch : int, optional
        How many potential paths to calculate probability for and choose from
        at each step.
        Supplied to RandomOptimizer.

    See Also

    def __init__(self, cost_fn='memory-removed-jitter', temperature=1.0,
                 rel_temperature=True, nbranch=8, **kwargs):
        self.cost_fn = cost_fn
        self.temperature = temperature
        self.rel_temperature = rel_temperature
        self.nbranch = nbranch

    def choose_fn(self):
        """The function that chooses which contraction to take - make this a
        property so that ``temperature`` and ``nbranch`` etc. can be updated
        between runs.
        if self.nbranch == 1:
            return None

        return functools.partial(thermal_chooser, temperature=self.temperature,
                                 nbranch=self.nbranch, rel_temperature=self.rel_temperature)

    def setup(self, inputs, output, size_dict):
        fn = _trial_greedy_ssa_path_and_cost
        args = (inputs, output, size_dict, self.choose_fn, self.cost_fn)
        return fn, args

def random_greedy(inputs, output, idx_dict, memory_limit=None, **optimizer_kwargs):
    optimizer = RandomGreedy(**optimizer_kwargs)
    return optimizer(inputs, output, idx_dict, memory_limit)