pyGSTi 0.9.12__cp38-cp38-win_amd64.whl → 0.9.13__cp38-cp38-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pyGSTi-0.9.13.dist-info/METADATA +185 -0
- {pyGSTi-0.9.12.dist-info → pyGSTi-0.9.13.dist-info}/RECORD +211 -220
- {pyGSTi-0.9.12.dist-info → pyGSTi-0.9.13.dist-info}/WHEEL +1 -1
- pygsti/_version.py +2 -2
- pygsti/algorithms/contract.py +1 -1
- pygsti/algorithms/core.py +62 -35
- pygsti/algorithms/fiducialpairreduction.py +95 -110
- pygsti/algorithms/fiducialselection.py +17 -8
- pygsti/algorithms/gaugeopt.py +2 -2
- pygsti/algorithms/germselection.py +87 -77
- pygsti/algorithms/mirroring.py +0 -388
- pygsti/algorithms/randomcircuit.py +165 -1333
- pygsti/algorithms/rbfit.py +0 -234
- pygsti/baseobjs/basis.py +94 -396
- pygsti/baseobjs/errorgenbasis.py +0 -132
- pygsti/baseobjs/errorgenspace.py +0 -10
- pygsti/baseobjs/label.py +52 -168
- pygsti/baseobjs/opcalc/fastopcalc.cp38-win_amd64.pyd +0 -0
- pygsti/baseobjs/opcalc/fastopcalc.pyx +2 -2
- pygsti/baseobjs/polynomial.py +13 -595
- pygsti/baseobjs/protectedarray.py +72 -132
- pygsti/baseobjs/statespace.py +1 -0
- pygsti/circuits/__init__.py +1 -1
- pygsti/circuits/circuit.py +753 -504
- pygsti/circuits/circuitconstruction.py +0 -4
- pygsti/circuits/circuitlist.py +47 -5
- pygsti/circuits/circuitparser/__init__.py +8 -8
- pygsti/circuits/circuitparser/fastcircuitparser.cp38-win_amd64.pyd +0 -0
- pygsti/circuits/circuitstructure.py +3 -3
- pygsti/circuits/cloudcircuitconstruction.py +27 -14
- pygsti/data/datacomparator.py +4 -9
- pygsti/data/dataset.py +51 -46
- pygsti/data/hypothesistest.py +0 -7
- pygsti/drivers/bootstrap.py +0 -49
- pygsti/drivers/longsequence.py +46 -10
- pygsti/evotypes/basereps_cython.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/chp/opreps.py +0 -61
- pygsti/evotypes/chp/statereps.py +0 -32
- pygsti/evotypes/densitymx/effectcreps.cpp +9 -10
- pygsti/evotypes/densitymx/effectreps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/densitymx/effectreps.pyx +1 -1
- pygsti/evotypes/densitymx/opreps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/densitymx/opreps.pyx +2 -2
- pygsti/evotypes/densitymx/statereps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/densitymx/statereps.pyx +1 -1
- pygsti/evotypes/densitymx_slow/effectreps.py +7 -23
- pygsti/evotypes/densitymx_slow/opreps.py +16 -23
- pygsti/evotypes/densitymx_slow/statereps.py +10 -3
- pygsti/evotypes/evotype.py +39 -2
- pygsti/evotypes/stabilizer/effectreps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/stabilizer/effectreps.pyx +0 -4
- pygsti/evotypes/stabilizer/opreps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/stabilizer/opreps.pyx +0 -4
- pygsti/evotypes/stabilizer/statereps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/stabilizer/statereps.pyx +1 -5
- pygsti/evotypes/stabilizer/termreps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/stabilizer/termreps.pyx +0 -7
- pygsti/evotypes/stabilizer_slow/effectreps.py +0 -22
- pygsti/evotypes/stabilizer_slow/opreps.py +0 -4
- pygsti/evotypes/stabilizer_slow/statereps.py +0 -4
- pygsti/evotypes/statevec/effectreps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/statevec/effectreps.pyx +1 -1
- pygsti/evotypes/statevec/opreps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/statevec/opreps.pyx +2 -2
- pygsti/evotypes/statevec/statereps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/statevec/statereps.pyx +1 -1
- pygsti/evotypes/statevec/termreps.cp38-win_amd64.pyd +0 -0
- pygsti/evotypes/statevec/termreps.pyx +0 -7
- pygsti/evotypes/statevec_slow/effectreps.py +0 -3
- pygsti/evotypes/statevec_slow/opreps.py +0 -5
- pygsti/extras/__init__.py +0 -1
- pygsti/extras/drift/signal.py +1 -1
- pygsti/extras/drift/stabilityanalyzer.py +3 -1
- pygsti/extras/interpygate/__init__.py +12 -0
- pygsti/extras/interpygate/core.py +0 -36
- pygsti/extras/interpygate/process_tomography.py +44 -10
- pygsti/extras/rpe/rpeconstruction.py +0 -2
- pygsti/forwardsims/__init__.py +1 -0
- pygsti/forwardsims/forwardsim.py +50 -93
- pygsti/forwardsims/mapforwardsim.py +78 -20
- pygsti/forwardsims/mapforwardsim_calc_densitymx.cp38-win_amd64.pyd +0 -0
- pygsti/forwardsims/mapforwardsim_calc_densitymx.pyx +65 -66
- pygsti/forwardsims/mapforwardsim_calc_generic.py +91 -13
- pygsti/forwardsims/matrixforwardsim.py +72 -17
- pygsti/forwardsims/termforwardsim.py +9 -111
- pygsti/forwardsims/termforwardsim_calc_stabilizer.cp38-win_amd64.pyd +0 -0
- pygsti/forwardsims/termforwardsim_calc_statevec.cp38-win_amd64.pyd +0 -0
- pygsti/forwardsims/termforwardsim_calc_statevec.pyx +0 -651
- pygsti/forwardsims/torchfwdsim.py +265 -0
- pygsti/forwardsims/weakforwardsim.py +2 -2
- pygsti/io/__init__.py +1 -2
- pygsti/io/mongodb.py +0 -2
- pygsti/io/stdinput.py +6 -22
- pygsti/layouts/copalayout.py +10 -12
- pygsti/layouts/distlayout.py +0 -40
- pygsti/layouts/maplayout.py +103 -25
- pygsti/layouts/matrixlayout.py +99 -60
- pygsti/layouts/prefixtable.py +1534 -52
- pygsti/layouts/termlayout.py +1 -1
- pygsti/modelmembers/instruments/instrument.py +3 -3
- pygsti/modelmembers/instruments/tpinstrument.py +2 -2
- pygsti/modelmembers/modelmember.py +0 -17
- pygsti/modelmembers/operations/__init__.py +3 -4
- pygsti/modelmembers/operations/affineshiftop.py +206 -0
- pygsti/modelmembers/operations/composederrorgen.py +1 -1
- pygsti/modelmembers/operations/composedop.py +1 -24
- pygsti/modelmembers/operations/denseop.py +5 -5
- pygsti/modelmembers/operations/eigpdenseop.py +2 -2
- pygsti/modelmembers/operations/embeddederrorgen.py +1 -1
- pygsti/modelmembers/operations/embeddedop.py +0 -1
- pygsti/modelmembers/operations/experrorgenop.py +5 -2
- pygsti/modelmembers/operations/fullarbitraryop.py +1 -0
- pygsti/modelmembers/operations/fullcptpop.py +2 -2
- pygsti/modelmembers/operations/fulltpop.py +28 -6
- pygsti/modelmembers/operations/fullunitaryop.py +5 -4
- pygsti/modelmembers/operations/lindbladcoefficients.py +93 -78
- pygsti/modelmembers/operations/lindbladerrorgen.py +268 -441
- pygsti/modelmembers/operations/linearop.py +7 -27
- pygsti/modelmembers/operations/opfactory.py +1 -1
- pygsti/modelmembers/operations/repeatedop.py +1 -24
- pygsti/modelmembers/operations/staticstdop.py +1 -1
- pygsti/modelmembers/povms/__init__.py +3 -3
- pygsti/modelmembers/povms/basepovm.py +7 -36
- pygsti/modelmembers/povms/complementeffect.py +4 -9
- pygsti/modelmembers/povms/composedeffect.py +0 -320
- pygsti/modelmembers/povms/computationaleffect.py +1 -1
- pygsti/modelmembers/povms/computationalpovm.py +3 -1
- pygsti/modelmembers/povms/effect.py +3 -5
- pygsti/modelmembers/povms/marginalizedpovm.py +3 -81
- pygsti/modelmembers/povms/tppovm.py +74 -2
- pygsti/modelmembers/states/__init__.py +2 -5
- pygsti/modelmembers/states/composedstate.py +0 -317
- pygsti/modelmembers/states/computationalstate.py +3 -3
- pygsti/modelmembers/states/cptpstate.py +4 -4
- pygsti/modelmembers/states/densestate.py +10 -8
- pygsti/modelmembers/states/fullpurestate.py +0 -24
- pygsti/modelmembers/states/purestate.py +1 -1
- pygsti/modelmembers/states/state.py +5 -6
- pygsti/modelmembers/states/tpstate.py +28 -10
- pygsti/modelmembers/term.py +3 -6
- pygsti/modelmembers/torchable.py +50 -0
- pygsti/modelpacks/_modelpack.py +1 -1
- pygsti/modelpacks/smq1Q_ZN.py +3 -1
- pygsti/modelpacks/smq2Q_XXYYII.py +2 -1
- pygsti/modelpacks/smq2Q_XY.py +3 -3
- pygsti/modelpacks/smq2Q_XYI.py +2 -2
- pygsti/modelpacks/smq2Q_XYICNOT.py +3 -3
- pygsti/modelpacks/smq2Q_XYICPHASE.py +3 -3
- pygsti/modelpacks/smq2Q_XYXX.py +1 -1
- pygsti/modelpacks/smq2Q_XYZICNOT.py +3 -3
- pygsti/modelpacks/smq2Q_XYZZ.py +1 -1
- pygsti/modelpacks/stdtarget.py +0 -121
- pygsti/models/cloudnoisemodel.py +1 -2
- pygsti/models/explicitcalc.py +3 -3
- pygsti/models/explicitmodel.py +3 -13
- pygsti/models/fogistore.py +5 -3
- pygsti/models/localnoisemodel.py +1 -2
- pygsti/models/memberdict.py +0 -12
- pygsti/models/model.py +801 -68
- pygsti/models/modelconstruction.py +4 -4
- pygsti/models/modelnoise.py +2 -2
- pygsti/models/modelparaminterposer.py +1 -1
- pygsti/models/oplessmodel.py +1 -1
- pygsti/models/qutrit.py +15 -14
- pygsti/objectivefns/objectivefns.py +75 -140
- pygsti/objectivefns/wildcardbudget.py +2 -7
- pygsti/optimize/__init__.py +1 -0
- pygsti/optimize/arraysinterface.py +28 -0
- pygsti/optimize/customcg.py +0 -12
- pygsti/optimize/customlm.py +129 -323
- pygsti/optimize/customsolve.py +2 -2
- pygsti/optimize/optimize.py +0 -84
- pygsti/optimize/simplerlm.py +841 -0
- pygsti/optimize/wildcardopt.py +19 -598
- pygsti/protocols/confidenceregionfactory.py +28 -14
- pygsti/protocols/estimate.py +31 -14
- pygsti/protocols/gst.py +238 -142
- pygsti/protocols/modeltest.py +19 -12
- pygsti/protocols/protocol.py +9 -37
- pygsti/protocols/rb.py +450 -79
- pygsti/protocols/treenode.py +8 -2
- pygsti/protocols/vb.py +108 -206
- pygsti/protocols/vbdataframe.py +1 -1
- pygsti/report/factory.py +0 -15
- pygsti/report/fogidiagram.py +1 -17
- pygsti/report/modelfunction.py +12 -3
- pygsti/report/mpl_colormaps.py +1 -1
- pygsti/report/plothelpers.py +11 -3
- pygsti/report/report.py +16 -0
- pygsti/report/reportables.py +41 -37
- pygsti/report/templates/offline/pygsti_dashboard.css +6 -0
- pygsti/report/templates/offline/pygsti_dashboard.js +12 -0
- pygsti/report/workspace.py +2 -14
- pygsti/report/workspaceplots.py +328 -505
- pygsti/tools/basistools.py +9 -36
- pygsti/tools/edesigntools.py +124 -96
- pygsti/tools/fastcalc.cp38-win_amd64.pyd +0 -0
- pygsti/tools/fastcalc.pyx +35 -81
- pygsti/tools/internalgates.py +151 -15
- pygsti/tools/jamiolkowski.py +5 -5
- pygsti/tools/lindbladtools.py +19 -11
- pygsti/tools/listtools.py +0 -114
- pygsti/tools/matrixmod2.py +1 -1
- pygsti/tools/matrixtools.py +173 -339
- pygsti/tools/nameddict.py +1 -1
- pygsti/tools/optools.py +154 -88
- pygsti/tools/pdftools.py +0 -25
- pygsti/tools/rbtheory.py +3 -320
- pygsti/tools/slicetools.py +64 -12
- pyGSTi-0.9.12.dist-info/METADATA +0 -157
- pygsti/algorithms/directx.py +0 -711
- pygsti/evotypes/qibo/__init__.py +0 -33
- pygsti/evotypes/qibo/effectreps.py +0 -78
- pygsti/evotypes/qibo/opreps.py +0 -376
- pygsti/evotypes/qibo/povmreps.py +0 -98
- pygsti/evotypes/qibo/statereps.py +0 -174
- pygsti/extras/rb/__init__.py +0 -13
- pygsti/extras/rb/benchmarker.py +0 -957
- pygsti/extras/rb/dataset.py +0 -378
- pygsti/extras/rb/io.py +0 -814
- pygsti/extras/rb/simulate.py +0 -1020
- pygsti/io/legacyio.py +0 -385
- pygsti/modelmembers/povms/denseeffect.py +0 -142
- {pyGSTi-0.9.12.dist-info → pyGSTi-0.9.13.dist-info}/LICENSE +0 -0
- {pyGSTi-0.9.12.dist-info → pyGSTi-0.9.13.dist-info}/top_level.txt +0 -0
pygsti/layouts/prefixtable.py
CHANGED
@@ -11,7 +11,10 @@ Defines the PrefixTable class.
|
|
11
11
|
#***************************************************************************************************
|
12
12
|
|
13
13
|
import collections as _collections
|
14
|
-
|
14
|
+
import networkx as _nx
|
15
|
+
import matplotlib.pyplot as plt
|
16
|
+
from math import ceil
|
17
|
+
from pygsti.baseobjs import Label as _Label
|
15
18
|
from pygsti.circuits.circuit import SeparatePOVMCircuit as _SeparatePOVMCircuit
|
16
19
|
|
17
20
|
|
@@ -38,6 +41,14 @@ class PrefixTable(object):
|
|
38
41
|
`iDest` is always in the range [0,len(circuits_to_evaluate)-1], and
|
39
42
|
indexes the result computed for each of the circuits.
|
40
43
|
|
44
|
+
Parameters
|
45
|
+
----------
|
46
|
+
|
47
|
+
|
48
|
+
circuit_parameter_sensitivities :
|
49
|
+
A map between the circuits in circuits_to_evaluate and the indices of the model parameters
|
50
|
+
to which these circuits depend.
|
51
|
+
|
41
52
|
Returns
|
42
53
|
-------
|
43
54
|
tuple
|
@@ -45,19 +56,23 @@ class PrefixTable(object):
|
|
45
56
|
of tuples as given above and `cache_size` is the total size of the state
|
46
57
|
cache used to hold intermediate results.
|
47
58
|
"""
|
59
|
+
|
48
60
|
#Sort the operation sequences "alphabetically", so that it's trivial to find common prefixes
|
49
|
-
circuits_to_evaluate_fastlookup = {i: cir for i, cir in enumerate(circuits_to_evaluate)}
|
50
61
|
circuits_to_sort_by = [cir.circuit_without_povm if isinstance(cir, _SeparatePOVMCircuit) else cir
|
51
62
|
for cir in circuits_to_evaluate] # always Circuits - not SeparatePOVMCircuits
|
52
|
-
|
53
|
-
|
63
|
+
#with the current logic in _build_table a candidate circuit is only treated as a possible prefix if
|
64
|
+
#it is shorter than the one it is being evaluated as a prefix for. So it should work to sort these
|
65
|
+
#circuits by length for the purposes of the current logic.
|
66
|
+
sorted_circuits_to_sort_by = sorted(list(enumerate(circuits_to_sort_by)), key=lambda x: len(x[1]))
|
67
|
+
orig_indices, sorted_circuits_to_evaluate = zip(*[(i, circuits_to_evaluate[i]) for i, _ in sorted_circuits_to_sort_by])
|
68
|
+
|
69
|
+
self.sorted_circuits_to_evaluate = sorted_circuits_to_evaluate
|
70
|
+
self.orig_indices = orig_indices
|
71
|
+
|
72
|
+
#get the circuits in a form readily usable for comparisons
|
73
|
+
circuit_reps, circuit_lens = _circuits_to_compare(sorted_circuits_to_evaluate)
|
74
|
+
self.circuit_reps = circuit_reps
|
54
75
|
|
55
|
-
distinct_line_labels = set([cir.line_labels for cir in circuits_to_sort_by])
|
56
|
-
if len(distinct_line_labels) == 1: # if all circuits have the *same* line labels, we can just compare tuples
|
57
|
-
circuit_reps_to_compare_and_lengths = {i: (cir.layertup, len(cir))
|
58
|
-
for i, cir in enumerate(circuits_to_sort_by)}
|
59
|
-
else:
|
60
|
-
circuit_reps_to_compare_and_lengths = {i: (cir, len(cir)) for i, cir in enumerate(circuits_to_sort_by)}
|
61
76
|
|
62
77
|
if max_cache_size is None or max_cache_size > 0:
|
63
78
|
#CACHE assessment pass: figure out what's worth keeping in the cache.
|
@@ -66,48 +81,13 @@ class PrefixTable(object):
|
|
66
81
|
# Not: this logic could be much better, e.g. computing a cost savings for each
|
67
82
|
# potentially-cached item and choosing the best ones, and proper accounting
|
68
83
|
# for chains of cached items.
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
for i, _ in sorted_circuits_to_evaluate:
|
73
|
-
circuit, L = circuit_reps_to_compare_and_lengths[i] # can be a Circuit or a label tuple
|
74
|
-
for cached_index in reversed(cacheIndices):
|
75
|
-
candidate, Lc = circuit_reps_to_compare_and_lengths[cached_index]
|
76
|
-
if L >= Lc > 0 and circuit[0:Lc] == candidate: # a cache hit!
|
77
|
-
cache_hits[cached_index] += 1
|
78
|
-
break # stop looking through cache
|
79
|
-
cacheIndices.append(i) # cache *everything* in this pass
|
80
|
-
|
81
|
-
# Build prefix table: construct list, only caching items with hits > 0 (up to max_cache_size)
|
82
|
-
cacheIndices = [] # indices into circuits_to_evaluate of the results to cache
|
83
|
-
table_contents = []
|
84
|
-
curCacheSize = 0
|
85
|
-
|
86
|
-
for i, circuit in sorted_circuits_to_evaluate:
|
87
|
-
circuit_rep, L = circuit_reps_to_compare_and_lengths[i]
|
88
|
-
|
89
|
-
#find longest existing prefix for circuit by working backwards
|
90
|
-
# and finding the first string that *is* a prefix of this string
|
91
|
-
# (this will necessarily be the longest prefix, given the sorting)
|
92
|
-
for i_in_cache in range(curCacheSize - 1, -1, -1): # from curCacheSize-1 -> 0
|
93
|
-
candidate, Lc = circuit_reps_to_compare_and_lengths[cacheIndices[i_in_cache]]
|
94
|
-
if L >= Lc > 0 and circuit_rep[0:Lc] == candidate: # ">=" allows for duplicates
|
95
|
-
iStart = i_in_cache # an index into the *cache*, not into circuits_to_evaluate
|
96
|
-
remaining = circuit[Lc:] # *always* a SeparatePOVMCircuit or Circuit
|
97
|
-
break
|
98
|
-
else: # no break => no prefix
|
99
|
-
iStart = None
|
100
|
-
remaining = circuit[:]
|
101
|
-
|
102
|
-
# if/where this string should get stored in the cache
|
103
|
-
if (max_cache_size is None or curCacheSize < max_cache_size) and cache_hits.get(i, 0) > 0:
|
104
|
-
iCache = len(cacheIndices)
|
105
|
-
cacheIndices.append(i); curCacheSize += 1
|
106
|
-
else: # don't store in the cache
|
107
|
-
iCache = None
|
84
|
+
cache_hits = _cache_hits(self.circuit_reps, circuit_lens)
|
85
|
+
else:
|
86
|
+
cache_hits = [None]*len(self.circuit_reps)
|
108
87
|
|
109
|
-
|
110
|
-
|
88
|
+
table_contents, curCacheSize = _build_table(sorted_circuits_to_evaluate, cache_hits,
|
89
|
+
max_cache_size, self.circuit_reps, circuit_lens,
|
90
|
+
orig_indices)
|
111
91
|
|
112
92
|
#FUTURE: could perform a second pass, and if there is
|
113
93
|
# some threshold number of elements which share the
|
@@ -118,9 +98,196 @@ class PrefixTable(object):
|
|
118
98
|
# order.
|
119
99
|
self.contents = table_contents
|
120
100
|
self.cache_size = curCacheSize
|
101
|
+
self.circuits_evaluated = circuits_to_sort_by
|
102
|
+
|
121
103
|
|
122
104
|
def __len__(self):
|
123
105
|
return len(self.contents)
|
106
|
+
|
107
|
+
def num_state_propagations(self):
|
108
|
+
"""
|
109
|
+
Return the number of state propagation operations (excluding the action of POVM effects)
|
110
|
+
required for the evaluation strategy given by this PrefixTable.
|
111
|
+
"""
|
112
|
+
return sum(self.num_state_propagations_by_circuit().values())
|
113
|
+
|
114
|
+
def num_state_propagations_by_circuit(self):
|
115
|
+
"""
|
116
|
+
Return the number of state propagation operations per-circuit
|
117
|
+
(excluding the action of POVM effects) required for the evaluation strategy
|
118
|
+
given by this PrefixTable, returned as a dictionary with keys corresponding to
|
119
|
+
circuits and values corresponding to the number of state propagations
|
120
|
+
required for that circuit.
|
121
|
+
"""
|
122
|
+
state_props_by_circuit = {}
|
123
|
+
for i, istart, remainder, _ in self.contents:
|
124
|
+
if len(self.circuits_evaluated[i][0])>0 and self.circuits_evaluated[i][0] == _Label('rho0') and istart is None:
|
125
|
+
state_props_by_circuit[self.circuits_evaluated[i]] = len(remainder)-1
|
126
|
+
else:
|
127
|
+
state_props_by_circuit[self.circuits_evaluated[i]] = len(remainder)
|
128
|
+
|
129
|
+
return state_props_by_circuit
|
130
|
+
|
131
|
+
def num_state_propagations_by_circuit_no_caching(self):
|
132
|
+
"""
|
133
|
+
Return the number of state propagation operations per-circuit
|
134
|
+
(excluding the action of POVM effects) required for an evaluation strategy
|
135
|
+
without caching, returned as a dictionary with keys corresponding to
|
136
|
+
circuits and values corresponding to the number of state propagations
|
137
|
+
required for that circuit.
|
138
|
+
"""
|
139
|
+
state_props_by_circuit = {}
|
140
|
+
for circuit in self.circuits_evaluated:
|
141
|
+
if len(circuit)>0 and circuit[0] == _Label('rho0'):
|
142
|
+
state_props_by_circuit[circuit] = len(circuit[1:])
|
143
|
+
else:
|
144
|
+
state_props_by_circuit[circuit] = len(circuit)
|
145
|
+
return state_props_by_circuit
|
146
|
+
|
147
|
+
def num_state_propagations_no_caching(self):
|
148
|
+
"""
|
149
|
+
Return the total number of state propagation operations
|
150
|
+
(excluding the action of POVM effects) required for an evaluation strategy
|
151
|
+
without caching.
|
152
|
+
"""
|
153
|
+
return sum(self.num_state_propagations_by_circuit_no_caching().values())
|
154
|
+
|
155
|
+
def find_splitting_new(self, max_sub_table_size=None, num_sub_tables=None, initial_cost_metric='size',
|
156
|
+
rebalancing_cost_metric='propagations', imbalance_threshold=1.2, minimum_improvement_threshold=.1,
|
157
|
+
verbosity=0):
|
158
|
+
"""
|
159
|
+
Find a partition of the indices of this table to define a set of sub-tables with the desire properties.
|
160
|
+
|
161
|
+
This is done in order to reduce the maximum size of any tree (useful for
|
162
|
+
limiting memory consumption or for using multiple cores). Must specify
|
163
|
+
either max_sub_tree_size or num_sub_trees.
|
164
|
+
|
165
|
+
Parameters
|
166
|
+
----------
|
167
|
+
max_sub_table_size : int, optional
|
168
|
+
The maximum size (i.e. list length) of each sub-table. If the
|
169
|
+
original table is smaller than this size, no splitting will occur.
|
170
|
+
If None, then there is no limit.
|
171
|
+
|
172
|
+
num_sub_tables : int, optional
|
173
|
+
The maximum size (i.e. list length) of each sub-table. If the
|
174
|
+
original table is smaller than this size, no splitting will occur.
|
175
|
+
|
176
|
+
imbalance_threshold : float, optional (default 1.2)
|
177
|
+
This number serves as a tolerance parameter for a final load balancing refinement
|
178
|
+
to the splitting. The value coresponds to a threshold value of the ratio of the heaviest
|
179
|
+
to the lightest subtree such that ratios below this value are considered sufficiently
|
180
|
+
balanced and processing stops.
|
181
|
+
|
182
|
+
minimum_improvement_threshold : float, optional (default .1)
|
183
|
+
A parameter for the final load balancing refinement process that sets a minimum balance
|
184
|
+
improvement (improvement to the ratio of the sizes of two subtrees) such that a rebalancing
|
185
|
+
step is considered worth performing (even if it would otherwise bring the imbalance parameter
|
186
|
+
described above in `imbalance_threshold` below the target value) .
|
187
|
+
|
188
|
+
verbosity : int, optional (default 0)
|
189
|
+
How much detail to send to stdout.
|
190
|
+
|
191
|
+
Returns
|
192
|
+
-------
|
193
|
+
list
|
194
|
+
A list of sets of elements to place in sub-tables.
|
195
|
+
"""
|
196
|
+
|
197
|
+
table_contents = self.contents
|
198
|
+
if max_sub_table_size is None and num_sub_tables is None:
|
199
|
+
return [set(range(len(table_contents)))] # no splitting needed
|
200
|
+
|
201
|
+
if max_sub_table_size is not None and num_sub_tables is not None:
|
202
|
+
raise ValueError("Cannot specify both max_sub_table_size and num_sub_tables")
|
203
|
+
if num_sub_tables is not None and num_sub_tables <= 0:
|
204
|
+
raise ValueError("Error: num_sub_tables must be > 0!")
|
205
|
+
|
206
|
+
#Don't split at all if it's unnecessary
|
207
|
+
if max_sub_table_size is None or len(table_contents) < max_sub_table_size:
|
208
|
+
if num_sub_tables is None or num_sub_tables == 1:
|
209
|
+
return [set(range(len(table_contents)))]
|
210
|
+
|
211
|
+
#construct a tree structure describing the prefix strucure of the circuit set.
|
212
|
+
circuit_tree = _build_prefix_tree(self.sorted_circuits_to_evaluate, self.circuit_reps, self.orig_indices)
|
213
|
+
circuit_tree_nx = circuit_tree.to_networkx_graph()
|
214
|
+
|
215
|
+
if num_sub_tables is not None:
|
216
|
+
max_max_sub_table_size = len(self.sorted_circuits_to_evaluate)
|
217
|
+
initial_max_sub_table_size = ceil(len(self.sorted_circuits_to_evaluate)/num_sub_tables)
|
218
|
+
cut_edges, new_roots, tree_levels, subtree_weights = tree_partition_kundu_misra(circuit_tree_nx, max_weight=initial_max_sub_table_size,
|
219
|
+
weight_key= 'cost' if initial_cost_metric=='size' else 'prop_cost',
|
220
|
+
return_levels_and_weights=True)
|
221
|
+
|
222
|
+
if len(new_roots) > num_sub_tables: #iteratively row the maximum subtree size until we either hit or are less than the target.
|
223
|
+
last_seen_sub_max_sub_table_size_val = None
|
224
|
+
feasible_range = [initial_max_sub_table_size+1, max_max_sub_table_size-1]
|
225
|
+
#bisect on max_sub_table_size until we find the smallest value for which len(new_roots) <= num_sub_tables
|
226
|
+
while feasible_range[0] < feasible_range[1]:
|
227
|
+
current_max_sub_table_size = (feasible_range[0] + feasible_range[1])//2
|
228
|
+
cut_edges, new_roots = tree_partition_kundu_misra(circuit_tree_nx, max_weight=current_max_sub_table_size,
|
229
|
+
weight_key='cost' if initial_cost_metric=='size' else 'prop_cost',
|
230
|
+
test_leaves=False, precomp_levels=tree_levels, precomp_weights=subtree_weights)
|
231
|
+
if len(new_roots) > num_sub_tables:
|
232
|
+
feasible_range[0] = current_max_sub_table_size+1
|
233
|
+
else:
|
234
|
+
last_seen_sub_max_sub_table_size_val = (cut_edges, new_roots) #In the multiple root setting I am seeing some strange
|
235
|
+
#non-monotonicity, so add this as a fall back in case the final result anomalously has len(roots)>num_sub_tables
|
236
|
+
feasible_range[1] = current_max_sub_table_size
|
237
|
+
if len(new_roots)>num_sub_tables and last_seen_sub_max_sub_table_size_val is not None: #fallback
|
238
|
+
cut_edges, new_roots = last_seen_sub_max_sub_table_size_val
|
239
|
+
|
240
|
+
#only apply the cuts now that we have found our starting point.
|
241
|
+
partitioned_tree = _copy_networkx_graph(circuit_tree_nx)
|
242
|
+
#update the propagation cost attribute of the promoted nodes.
|
243
|
+
#only do this at this point to reduce the need for copying
|
244
|
+
for edge in cut_edges:
|
245
|
+
partitioned_tree.nodes[edge[1]]['prop_cost'] += partitioned_tree.edges[edge[0], edge[1]]['promotion_cost']
|
246
|
+
partitioned_tree.remove_edges_from(cut_edges)
|
247
|
+
|
248
|
+
#if we have hit the number of partitions, great, we're done!
|
249
|
+
if len(new_roots) == num_sub_tables:
|
250
|
+
#only apply the cuts now that we have found our starting point.
|
251
|
+
partitioned_tree = _copy_networkx_graph(circuit_tree_nx)
|
252
|
+
#update the propagation cost attribute of the promoted nodes.
|
253
|
+
#only do this at this point to reduce the need for copying
|
254
|
+
for edge in cut_edges:
|
255
|
+
partitioned_tree.nodes[edge[1]]['prop_cost'] += partitioned_tree.edges[edge[0], edge[1]]['promotion_cost']
|
256
|
+
partitioned_tree.remove_edges_from(cut_edges)
|
257
|
+
pass
|
258
|
+
#if we have fewer subtables then we need to look whether or not we should strictly
|
259
|
+
#hit the number of partitions, or whether we allow for fewer than the requested number to be returned.
|
260
|
+
if len(new_roots) < num_sub_tables:
|
261
|
+
#Perform bisection operations on the heaviest subtrees until we hit the target number.
|
262
|
+
while len(new_roots) < num_sub_tables:
|
263
|
+
partitioned_tree, new_roots, cut_edges = _bisection_pass(partitioned_tree, cut_edges, new_roots, num_sub_tables,
|
264
|
+
weight_key='cost' if rebalancing_cost_metric=='size' else 'prop_cost')
|
265
|
+
#add in a final refinement pass to improve the balancing across subtrees.
|
266
|
+
partitioned_tree, new_roots, addl_cut_edges = _refinement_pass(partitioned_tree, new_roots,
|
267
|
+
weight_key='cost' if rebalancing_cost_metric=='size' else 'prop_cost',
|
268
|
+
imbalance_threshold= imbalance_threshold,
|
269
|
+
minimum_improvement_threshold= minimum_improvement_threshold)
|
270
|
+
else:
|
271
|
+
cut_edges, new_roots = tree_partition_kundu_misra(circuit_tree_nx, max_weight = max_sub_table_size,
|
272
|
+
weight_key='cost' if initial_cost_metric=='size' else 'prop_cost')
|
273
|
+
partitioned_tree = _copy_networkx_graph(circuit_tree_nx)
|
274
|
+
for edge in cut_edges:
|
275
|
+
partitioned_tree.nodes[edge[1]]['prop_cost'] += partitioned_tree.edges[edge[0], edge[1]]['promotion_cost']
|
276
|
+
partitioned_tree.remove_edges_from(cut_edges)
|
277
|
+
|
278
|
+
#Collect the original circuit indices for each of the parititioned subtrees.
|
279
|
+
orig_index_groups = []
|
280
|
+
for root in new_roots:
|
281
|
+
if isinstance(root,tuple):
|
282
|
+
ckts = []
|
283
|
+
for elem in root:
|
284
|
+
ckts.extend(_collect_orig_indices(partitioned_tree, elem))
|
285
|
+
orig_index_groups.append(ckts)
|
286
|
+
else:
|
287
|
+
orig_index_groups.append(_collect_orig_indices(partitioned_tree, root))
|
288
|
+
|
289
|
+
return orig_index_groups
|
290
|
+
|
124
291
|
|
125
292
|
def find_splitting(self, max_sub_table_size=None, num_sub_tables=None, cost_metric="size", verbosity=0):
|
126
293
|
"""
|
@@ -182,7 +349,7 @@ class PrefixTable(object):
|
|
182
349
|
over the course of the iteration.
|
183
350
|
"""
|
184
351
|
|
185
|
-
if cost_metric == "
|
352
|
+
if cost_metric == "applies":
|
186
353
|
def cost_fn(rem): return len(rem) # length of remainder = #-apply ops needed
|
187
354
|
elif cost_metric == "size":
|
188
355
|
def cost_fn(rem): return 1 # everything costs 1 in size of table
|
@@ -333,3 +500,1318 @@ class PrefixTable(object):
|
|
333
500
|
|
334
501
|
assert(sum(map(len, subTableSetList)) == len(self)), "sub-table sets are not disjoint!"
|
335
502
|
return subTableSetList
|
503
|
+
|
504
|
+
|
505
|
+
class PrefixTableJacobian(object):
|
506
|
+
"""
|
507
|
+
An ordered list ("table") of circuits to evaluate, where common prefixes can be cached.
|
508
|
+
Specialized for purposes of jacobian calculations.
|
509
|
+
|
510
|
+
"""
|
511
|
+
|
512
|
+
def __init__(self, circuits_to_evaluate, max_cache_size, parameter_circuit_dependencies=None):
|
513
|
+
"""
|
514
|
+
Creates a "prefix table" for evaluating a set of circuits.
|
515
|
+
|
516
|
+
The table is list of tuples, where each element contains
|
517
|
+
instructions for evaluating a particular operation sequence:
|
518
|
+
|
519
|
+
(iDest, iStart, tuple_of_following_items, iCache)
|
520
|
+
|
521
|
+
Means that circuit[iDest] = cached_circuit[iStart] + tuple_of_following_items,
|
522
|
+
and that the resulting state should be stored at cache index iCache (for
|
523
|
+
later reference as an iStart value). The ordering of the returned list
|
524
|
+
specifies the evaluation order.
|
525
|
+
|
526
|
+
`iDest` is always in the range [0,len(circuits_to_evaluate)-1], and
|
527
|
+
indexes the result computed for each of the circuits.
|
528
|
+
|
529
|
+
Parameters
|
530
|
+
----------
|
531
|
+
|
532
|
+
|
533
|
+
circuit_parameter_sensitivities :
|
534
|
+
A map between the circuits in circuits_to_evaluate and the indices of the model parameters
|
535
|
+
to which these circuits depend.
|
536
|
+
|
537
|
+
Returns
|
538
|
+
-------
|
539
|
+
tuple
|
540
|
+
A tuple of `(table_contents, cache_size)` where `table_contents` is a list
|
541
|
+
of tuples as given above and `cache_size` is the total size of the state
|
542
|
+
cache used to hold intermediate results.
|
543
|
+
"""
|
544
|
+
#Sort the operation sequences "alphabetically", so that it's trivial to find common prefixes
|
545
|
+
circuits_to_sort_by = [cir.circuit_without_povm if isinstance(cir, _SeparatePOVMCircuit) else cir
|
546
|
+
for cir in circuits_to_evaluate] # always Circuits - not SeparatePOVMCircuits
|
547
|
+
sorted_circuits_to_sort_by = sorted(list(enumerate(circuits_to_sort_by)), key=lambda x: len(x[1]))
|
548
|
+
sorted_circuits_to_evaluate = [(i, circuits_to_evaluate[i]) for i, _ in sorted_circuits_to_sort_by]
|
549
|
+
#create a map from sorted_circuits_to_sort_by by can be used to quickly sort each of the parameter
|
550
|
+
#dependency lists.
|
551
|
+
fast_sorting_map = {circuits_to_evaluate[i]:j for j, (i, _) in enumerate(sorted_circuits_to_sort_by)}
|
552
|
+
|
553
|
+
#also need a map from circuits to their original indices in circuits_to_evaluate
|
554
|
+
#for the purpose of setting the correct destination indices in the evaluation instructions.
|
555
|
+
circuit_to_orig_index_map = {circuit: i for i,circuit in enumerate(circuits_to_evaluate)}
|
556
|
+
|
557
|
+
#use this map to sort the parameter_circuit_dependencies sublists.
|
558
|
+
sorted_parameter_circuit_dependencies = []
|
559
|
+
sorted_parameter_circuit_dependencies_orig_indices = []
|
560
|
+
for sublist in parameter_circuit_dependencies:
|
561
|
+
sorted_sublist = [None]*len(sorted_circuits_to_evaluate)
|
562
|
+
for ckt in sublist:
|
563
|
+
sorted_sublist[fast_sorting_map[ckt]] = ckt
|
564
|
+
|
565
|
+
#filter out instances of None to get the correctly sized and sorted
|
566
|
+
#sublist.
|
567
|
+
filtered_sorted_sublist = [val for val in sorted_sublist if val is not None]
|
568
|
+
orig_index_sublist = [circuit_to_orig_index_map[ckt] for ckt in filtered_sorted_sublist]
|
569
|
+
|
570
|
+
sorted_parameter_circuit_dependencies.append(filtered_sorted_sublist)
|
571
|
+
sorted_parameter_circuit_dependencies_orig_indices.append(orig_index_sublist)
|
572
|
+
|
573
|
+
sorted_circuit_reps = []
|
574
|
+
sorted_circuit_lengths = []
|
575
|
+
for sublist in sorted_parameter_circuit_dependencies:
|
576
|
+
circuit_reps, circuit_lengths = _circuits_to_compare(sublist)
|
577
|
+
sorted_circuit_reps.append(circuit_reps)
|
578
|
+
sorted_circuit_lengths.append(circuit_lengths)
|
579
|
+
|
580
|
+
#Intuition: The sorted circuit lists should likely break into equivalence classes, wherein multiple
|
581
|
+
#parameters will have the same dependent circuits. This is because in typical models parameters
|
582
|
+
#appear in blocks corresponding to a particular gate label, and so most of the time it should be the
|
583
|
+
#case that the list fractures into all those circuits containing a particular label.
|
584
|
+
#This intuition probably breaks down for ImplicitOpModels with complicated layer rules for which
|
585
|
+
#the breaking into equivalence classes may have limited savings.
|
586
|
+
unique_parameter_circuit_dependency_classes = {}
|
587
|
+
for i, sublist in enumerate(sorted_circuit_reps):
|
588
|
+
if unique_parameter_circuit_dependency_classes.get(sublist, None) is None:
|
589
|
+
unique_parameter_circuit_dependency_classes[sublist] = [i]
|
590
|
+
else:
|
591
|
+
unique_parameter_circuit_dependency_classes[sublist].append(i)
|
592
|
+
|
593
|
+
self.unique_parameter_circuit_dependency_classes = unique_parameter_circuit_dependency_classes
|
594
|
+
|
595
|
+
#the keys of the dictionary already give the needed circuit rep lists for
|
596
|
+
#each class, also grab the appropriate list of length for each class.
|
597
|
+
sorted_circuit_lengths_by_class = [sorted_circuit_lengths[class_indices[0]]
|
598
|
+
for class_indices in unique_parameter_circuit_dependency_classes.values()]
|
599
|
+
|
600
|
+
#also need representatives fo the entries in sorted_parameter_circuit_dependencies for each class,
|
601
|
+
#and for sorted_parameter_circuit_dependencies_orig_indices
|
602
|
+
sorted_parameter_circuit_dependencies_by_class = [sorted_parameter_circuit_dependencies[class_indices[0]]
|
603
|
+
for class_indices in unique_parameter_circuit_dependency_classes.values()]
|
604
|
+
sorted_parameter_circuit_dependencies_orig_indices_by_class = [sorted_parameter_circuit_dependencies_orig_indices[class_indices[0]]
|
605
|
+
for class_indices in unique_parameter_circuit_dependency_classes.values()]
|
606
|
+
|
607
|
+
#now we can just do the calculation for each of these equivalence classes.
|
608
|
+
|
609
|
+
#get the cache hits for all of the parameter circuit dependency sublists
|
610
|
+
if max_cache_size is None or max_cache_size > 0:
|
611
|
+
cache_hits_by_class = []
|
612
|
+
#CACHE assessment pass: figure out what's worth keeping in the cache.
|
613
|
+
# In this pass, we cache *everything* and keep track of how many times each
|
614
|
+
# original index (after it's cached) is utilized as a prefix for another circuit.
|
615
|
+
# Not: this logic could be much better, e.g. computing a cost savings for each
|
616
|
+
# potentially-cached item and choosing the best ones, and proper accounting
|
617
|
+
# for chains of cached items.
|
618
|
+
for circuit_reps, circuit_lengths in zip(unique_parameter_circuit_dependency_classes.keys(),
|
619
|
+
sorted_circuit_lengths_by_class):
|
620
|
+
cache_hits_by_class.append(_cache_hits(circuit_reps, circuit_lengths))
|
621
|
+
else:
|
622
|
+
cache_hits_by_class = [None]*len(unique_parameter_circuit_dependency_classes)
|
623
|
+
|
624
|
+
#next construct a prefix table for each sublist.
|
625
|
+
table_contents_by_class = []
|
626
|
+
cache_size_by_class = []
|
627
|
+
for sublist, cache_hits, circuit_reps, circuit_lengths, orig_indices in zip(sorted_parameter_circuit_dependencies_by_class,
|
628
|
+
cache_hits_by_class,
|
629
|
+
unique_parameter_circuit_dependency_classes.keys(),
|
630
|
+
sorted_circuit_lengths_by_class,
|
631
|
+
sorted_parameter_circuit_dependencies_orig_indices_by_class):
|
632
|
+
table_contents, curCacheSize = _build_table(sublist, cache_hits,
|
633
|
+
max_cache_size, circuit_reps, circuit_lengths,
|
634
|
+
orig_indices)
|
635
|
+
table_contents_by_class.append(table_contents)
|
636
|
+
cache_size_by_class.append(curCacheSize)
|
637
|
+
|
638
|
+
#FUTURE: could perform a second pass, and if there is
|
639
|
+
# some threshold number of elements which share the
|
640
|
+
# *same* iStart and the same beginning of the
|
641
|
+
# 'remaining' part then add a new "extra" element
|
642
|
+
# (beyond the #circuits index) which computes
|
643
|
+
# the shared prefix and insert this into the eval
|
644
|
+
# order.
|
645
|
+
|
646
|
+
#map back from equivalence classes to by parameter.
|
647
|
+
table_contents_by_parameter = [None]*len(parameter_circuit_dependencies)
|
648
|
+
cache_size_by_parameter = [None]*len(parameter_circuit_dependencies)
|
649
|
+
for table_contents, cache_size, param_class in zip(table_contents_by_class, cache_size_by_class,
|
650
|
+
unique_parameter_circuit_dependency_classes.values()):
|
651
|
+
for idx in param_class:
|
652
|
+
table_contents_by_parameter[idx] = table_contents
|
653
|
+
cache_size_by_parameter[idx] = cache_size
|
654
|
+
|
655
|
+
self.contents_by_parameter = table_contents_by_parameter
|
656
|
+
self.cache_size_by_parameter = cache_size_by_parameter
|
657
|
+
self.parameter_circuit_dependencies = sorted_parameter_circuit_dependencies
|
658
|
+
|
659
|
+
|
660
|
+
#---------Helper Functions------------#
|
661
|
+
|
662
|
+
def _circuits_to_compare(sorted_circuits_to_evaluate):
|
663
|
+
|
664
|
+
bare_circuits = [cir.circuit_without_povm if isinstance(cir, _SeparatePOVMCircuit) else cir
|
665
|
+
for cir in sorted_circuits_to_evaluate]
|
666
|
+
distinct_line_labels = set([cir.line_labels for cir in bare_circuits])
|
667
|
+
|
668
|
+
circuit_lens = [None]*len(sorted_circuits_to_evaluate)
|
669
|
+
if len(distinct_line_labels) == 1:
|
670
|
+
circuit_reps = [None]*len(sorted_circuits_to_evaluate)
|
671
|
+
for i, cir in enumerate(bare_circuits):
|
672
|
+
circuit_reps[i] = cir.layertup
|
673
|
+
circuit_lens[i] = len(circuit_reps[i])
|
674
|
+
else:
|
675
|
+
circuit_reps = bare_circuits
|
676
|
+
for i, cir in enumerate(sorted_circuits_to_evaluate):
|
677
|
+
circuit_lens[i] = len(circuit_reps[i])
|
678
|
+
|
679
|
+
return tuple(circuit_reps), tuple(circuit_lens)
|
680
|
+
|
681
|
+
def _cache_hits(circuit_reps, circuit_lengths):
|
682
|
+
|
683
|
+
#CACHE assessment pass: figure out what's worth keeping in the cache.
|
684
|
+
# In this pass, we cache *everything* and keep track of how many times each
|
685
|
+
# original index (after it's cached) is utilized as a prefix for another circuit.
|
686
|
+
# Not: this logic could be much better, e.g. computing a cost savings for each
|
687
|
+
# potentially-cached item and choosing the best ones, and proper accounting
|
688
|
+
# for chains of cached items.
|
689
|
+
|
690
|
+
cacheIndices = [] # indices into circuits_to_evaluate of the results to cache
|
691
|
+
cache_hits = [0]*len(circuit_reps)
|
692
|
+
|
693
|
+
for i in range(len(circuit_reps)):
|
694
|
+
circuit = circuit_reps[i]
|
695
|
+
L = circuit_lengths[i] # can be a Circuit or a label tuple
|
696
|
+
for cached_index in reversed(cacheIndices):
|
697
|
+
candidate = circuit_reps[cached_index]
|
698
|
+
Lc = circuit_lengths[cached_index]
|
699
|
+
if L >= Lc > 0 and circuit[0:Lc] == candidate: # a cache hit!
|
700
|
+
cache_hits[cached_index] += 1
|
701
|
+
break # stop looking through cache
|
702
|
+
cacheIndices.append(i) # cache *everything* in this pass
|
703
|
+
|
704
|
+
return cache_hits
|
705
|
+
|
706
|
+
def _build_table(sorted_circuits_to_evaluate, cache_hits, max_cache_size, circuit_reps, circuit_lengths,
|
707
|
+
orig_indices):
|
708
|
+
|
709
|
+
# Build prefix table: construct list, only caching items with hits > 0 (up to max_cache_size)
|
710
|
+
cacheIndices = [] # indices into circuits_to_evaluate of the results to cache
|
711
|
+
table_contents = [None]*len(sorted_circuits_to_evaluate)
|
712
|
+
curCacheSize = 0
|
713
|
+
for j, (i, _) in zip(orig_indices,enumerate(sorted_circuits_to_evaluate)):
|
714
|
+
|
715
|
+
circuit_rep = circuit_reps[i]
|
716
|
+
L = circuit_lengths[i]
|
717
|
+
|
718
|
+
#find longest existing prefix for circuit by working backwards
|
719
|
+
# and finding the first string that *is* a prefix of this string
|
720
|
+
# (this will necessarily be the longest prefix, given the sorting)
|
721
|
+
for i_in_cache in range(curCacheSize - 1, -1, -1): # from curCacheSize-1 -> 0
|
722
|
+
candidate = circuit_reps[cacheIndices[i_in_cache]]
|
723
|
+
Lc = circuit_lengths[cacheIndices[i_in_cache]]
|
724
|
+
if L >= Lc > 0 and circuit_rep[0:Lc] == candidate: # ">=" allows for duplicates
|
725
|
+
iStart = i_in_cache # an index into the *cache*, not into circuits_to_evaluate
|
726
|
+
remaining = circuit_rep[Lc:] # *always* a SeparatePOVMCircuit or Circuit
|
727
|
+
break
|
728
|
+
else: # no break => no prefix
|
729
|
+
iStart = None
|
730
|
+
remaining = circuit_rep
|
731
|
+
|
732
|
+
# if/where this string should get stored in the cache
|
733
|
+
if (max_cache_size is None or curCacheSize < max_cache_size) and cache_hits[i]:
|
734
|
+
iCache = len(cacheIndices)
|
735
|
+
cacheIndices.append(i); curCacheSize += 1
|
736
|
+
else: # don't store in the cache
|
737
|
+
iCache = None
|
738
|
+
|
739
|
+
#Add instruction for computing this circuit
|
740
|
+
table_contents[i] = (j, iStart, remaining, iCache)
|
741
|
+
|
742
|
+
return table_contents, curCacheSize
|
743
|
+
|
744
|
+
#helper method for building a tree showing the connections between different circuits
|
745
|
+
#for the purposes of prefix-based evaluation.
|
746
|
+
def _build_prefix_tree(sorted_circuits_to_evaluate, circuit_reps, orig_indices):
|
747
|
+
#assume the input circuits have already been sorted by length.
|
748
|
+
circuit_tree = Tree()
|
749
|
+
for j, (i, _) in zip(orig_indices,enumerate(sorted_circuits_to_evaluate)):
|
750
|
+
circuit_rep = circuit_reps[i]
|
751
|
+
#the first layer should be a state preparation. If this isn't in a root in the
|
752
|
+
#tree add it.
|
753
|
+
root_node = circuit_tree.get_root_node(circuit_rep[0])
|
754
|
+
if root_node is None and len(circuit_rep)>0:
|
755
|
+
#cost is the number of propagations, so exclude the initial state prep
|
756
|
+
root_node = RootNode(circuit_rep[0], cost=0)
|
757
|
+
circuit_tree.add_root(root_node)
|
758
|
+
|
759
|
+
current_node = root_node
|
760
|
+
for layerlbl in circuit_reps[i][1:]:
|
761
|
+
child_node = current_node.get_child_node(layerlbl)
|
762
|
+
if child_node is None:
|
763
|
+
child_node = ChildNode(layerlbl, parent=current_node)
|
764
|
+
current_node = child_node
|
765
|
+
#when we get to the end of the circuit add a pointer on the
|
766
|
+
#final node to the original index of this circuit in the
|
767
|
+
#circuit list.
|
768
|
+
current_node.add_orig_index(j)
|
769
|
+
|
770
|
+
return circuit_tree
|
771
|
+
|
772
|
+
|
773
|
+
#----------------------Helper classes for managing circuit evaluation tree. --------------#
|
774
|
+
class TreeNode:
|
775
|
+
def __init__(self, value, children=None, orig_indices=None):
|
776
|
+
"""
|
777
|
+
Parameters
|
778
|
+
----------
|
779
|
+
value : any
|
780
|
+
The value to be stored in the node.
|
781
|
+
|
782
|
+
children : list, optional (default is None)
|
783
|
+
A list of child nodes. If None, initializes an empty list.
|
784
|
+
|
785
|
+
orig_indices : list, optional (default is None)
|
786
|
+
A list of original indices. If None, initializes an empty list.
|
787
|
+
"""
|
788
|
+
self.value = value
|
789
|
+
self.children = [] if children is None else children
|
790
|
+
self.orig_indices = [] if orig_indices is None else orig_indices #make this a list to allow for duplicates
|
791
|
+
|
792
|
+
def add_child(self, child_node):
|
793
|
+
"""
|
794
|
+
Add a child node to the current node.
|
795
|
+
|
796
|
+
Parameters
|
797
|
+
----------
|
798
|
+
child_node : TreeNode
|
799
|
+
The child node to be added.
|
800
|
+
"""
|
801
|
+
|
802
|
+
self.children.append(child_node)
|
803
|
+
|
804
|
+
def remove_child(self, child_node):
|
805
|
+
"""
|
806
|
+
Remove a child node from the current node.
|
807
|
+
|
808
|
+
Parameters
|
809
|
+
----------
|
810
|
+
child_node : TreeNode
|
811
|
+
The child node to be removed.
|
812
|
+
"""
|
813
|
+
self.children = [child for child in self.children if child is not child_node]
|
814
|
+
|
815
|
+
def get_child_node(self, value):
|
816
|
+
"""
|
817
|
+
Get the child node associated with the input value. If that node is not present, return None.
|
818
|
+
|
819
|
+
Parameters
|
820
|
+
----------
|
821
|
+
value : any
|
822
|
+
The value to search for in the child nodes.
|
823
|
+
|
824
|
+
Returns
|
825
|
+
-------
|
826
|
+
TreeNode or None
|
827
|
+
The child node with the specified value, or None if not found.
|
828
|
+
"""
|
829
|
+
|
830
|
+
for node in self.children:
|
831
|
+
if node.value == value:
|
832
|
+
return node
|
833
|
+
#if we haven't returned already it is because there wasn't a corresponding root,
|
834
|
+
#so return None
|
835
|
+
return None
|
836
|
+
|
837
|
+
def add_orig_index(self, value):
|
838
|
+
"""
|
839
|
+
Add an original index to the node.
|
840
|
+
|
841
|
+
Parameters
|
842
|
+
----------
|
843
|
+
value : int
|
844
|
+
The original index to be added.
|
845
|
+
"""
|
846
|
+
self.orig_indices.append(value)
|
847
|
+
|
848
|
+
def traverse(self):
|
849
|
+
"""
|
850
|
+
Traverse the tree in pre-order and return a list of node values.
|
851
|
+
|
852
|
+
Returns
|
853
|
+
-------
|
854
|
+
list
|
855
|
+
A list of node values in pre-order traversal.
|
856
|
+
"""
|
857
|
+
|
858
|
+
nodes = []
|
859
|
+
stack = [self]
|
860
|
+
while stack:
|
861
|
+
node = stack.pop()
|
862
|
+
nodes.append(node.value)
|
863
|
+
stack.extend(reversed(node.children)) # Add children to stack in reverse order for pre-order traversal
|
864
|
+
return nodes
|
865
|
+
|
866
|
+
def get_descendants(self):
|
867
|
+
"""
|
868
|
+
Get all descendant node values of the current node in pre-order traversal.
|
869
|
+
|
870
|
+
Returns
|
871
|
+
-------
|
872
|
+
list
|
873
|
+
A list of descendant node values.
|
874
|
+
"""
|
875
|
+
descendants = []
|
876
|
+
stack = self.children[:]
|
877
|
+
while stack:
|
878
|
+
node = stack.pop()
|
879
|
+
descendants.append(node.value)
|
880
|
+
stack.extend(reversed(node.children)) # Add children to stack in reverse order for pre-order traversal
|
881
|
+
return descendants
|
882
|
+
|
883
|
+
def total_orig_indices(self):
|
884
|
+
"""
|
885
|
+
Calculate the total number of orig_indices values for this node and all of its descendants.
|
886
|
+
"""
|
887
|
+
total = len(self.orig_indices)
|
888
|
+
for child in self.get_descendants():
|
889
|
+
total += len(child.orig_indices)
|
890
|
+
return total
|
891
|
+
|
892
|
+
def print_tree(self, level=0, prefix=""):
|
893
|
+
"""
|
894
|
+
Print the tree structure starting from the current node.
|
895
|
+
|
896
|
+
Parameters
|
897
|
+
----------
|
898
|
+
level : int, optional (default 0)
|
899
|
+
The current level in the tree.
|
900
|
+
prefix : str, optional (default "")
|
901
|
+
The prefix for the current level.
|
902
|
+
"""
|
903
|
+
connector = "├── " if level > 0 else ""
|
904
|
+
print(prefix + connector + str(self.value) +', ' + str(self.orig_indices))
|
905
|
+
for i, child in enumerate(self.children):
|
906
|
+
if i == len(self.children) - 1:
|
907
|
+
child.print_tree(level + 1, prefix + (" " if level > 0 else ""))
|
908
|
+
else:
|
909
|
+
child.print_tree(level + 1, prefix + ("│ " if level > 0 else ""))
|
910
|
+
|
911
|
+
#create a class for RootNodes that includes additional initial cost information.
|
912
|
+
class RootNode(TreeNode):
|
913
|
+
"""
|
914
|
+
Class for representing a root node for a tree, along with the corresponding metadata
|
915
|
+
specific to root nodes.
|
916
|
+
"""
|
917
|
+
|
918
|
+
def __init__(self, value, cost=0, tree=None, children=None, orig_indices=None):
|
919
|
+
"""
|
920
|
+
Initialize a RootNode with a value, optional cost, optional tree, optional children, and optional original indices.
|
921
|
+
|
922
|
+
Parameters
|
923
|
+
----------
|
924
|
+
value : any
|
925
|
+
The value to be stored in the node.
|
926
|
+
cost : int, optional (default is 0)
|
927
|
+
The initial cost associated with the root node.
|
928
|
+
tree : Tree, optional (default is None)
|
929
|
+
The tree to which this root node belongs.
|
930
|
+
children : list, optional (default is None)
|
931
|
+
A list of child nodes. If None, initializes an empty list.
|
932
|
+
orig_indices : list, optional (default is None)
|
933
|
+
A list of original indices. If None, initializes an empty list.
|
934
|
+
"""
|
935
|
+
super().__init__(value, children, orig_indices)
|
936
|
+
self.cost = cost
|
937
|
+
self.tree = tree
|
938
|
+
|
939
|
+
class ChildNode(TreeNode):
|
940
|
+
"""
|
941
|
+
Class for representing a child node for a tree, along with the corresponding metadata
|
942
|
+
specific to child nodes.
|
943
|
+
"""
|
944
|
+
def __init__(self, value, parent=None, children=None, orig_indices=None):
|
945
|
+
"""
|
946
|
+
Parameters
|
947
|
+
----------
|
948
|
+
value : any
|
949
|
+
The value to be stored in the node.
|
950
|
+
parent : TreeNode, optional (default is None)
|
951
|
+
The parent node.
|
952
|
+
children : list, optional (default is None)
|
953
|
+
A list of child nodes. If None, initializes an empty list.
|
954
|
+
orig_indices : list, optional (default is None)
|
955
|
+
A list of original indices. If None, initializes an empty list.
|
956
|
+
"""
|
957
|
+
super().__init__(value, children, orig_indices)
|
958
|
+
self.parent = parent
|
959
|
+
if parent is not None:
|
960
|
+
parent.add_child(self)
|
961
|
+
|
962
|
+
def get_ancestors(self):
|
963
|
+
"""
|
964
|
+
Get all ancestor nodes of the current node up to the root node.
|
965
|
+
|
966
|
+
Returns
|
967
|
+
-------
|
968
|
+
list
|
969
|
+
A list of ancestor nodes.
|
970
|
+
"""
|
971
|
+
ancestors = []
|
972
|
+
node = self
|
973
|
+
while node:
|
974
|
+
ancestors.append(node)
|
975
|
+
if isinstance(node, RootNode):
|
976
|
+
break
|
977
|
+
node = node.parent
|
978
|
+
return ancestors
|
979
|
+
|
980
|
+
def calculate_promotion_cost(self):
|
981
|
+
"""
|
982
|
+
Calculate the cost of promoting this child node to a root node. This
|
983
|
+
corresponds to the sum of the cost of this node's current root, plus
|
984
|
+
the total number of ancestors (less the root).
|
985
|
+
"""
|
986
|
+
ancestors = self.get_ancestors()
|
987
|
+
ancestor_count = len(ancestors) - 1
|
988
|
+
current_root = self.get_root()
|
989
|
+
current_root_cost = current_root.cost
|
990
|
+
return ancestor_count + current_root_cost
|
991
|
+
|
992
|
+
def promote_to_root(self):
|
993
|
+
"""
|
994
|
+
Promote this child node to a root node, updating the tree structure accordingly.
|
995
|
+
"""
|
996
|
+
# Calculate the cost (I know this is code duplication, but in this case
|
997
|
+
#we need the intermediate values as well).
|
998
|
+
ancestors = self.get_ancestors()
|
999
|
+
ancestor_count = len(ancestors) - 1
|
1000
|
+
current_root = self.get_root()
|
1001
|
+
current_root_cost = current_root.cost
|
1002
|
+
new_root_cost = ancestor_count + current_root_cost
|
1003
|
+
|
1004
|
+
# Remove this node from its parent's children
|
1005
|
+
if self.parent:
|
1006
|
+
self.parent.remove_child(self)
|
1007
|
+
|
1008
|
+
# Create a new RootNode
|
1009
|
+
ancestor_values = [ancestor.value for ancestor in reversed(ancestors)]
|
1010
|
+
if isinstance(ancestor_values[0], tuple):
|
1011
|
+
ancestor_values = list(ancestor_values[0]) + ancestor_values[1:]
|
1012
|
+
new_root_value = tuple(ancestor_values)
|
1013
|
+
new_root = RootNode(new_root_value, cost=new_root_cost, tree=current_root.tree, children=self.children,
|
1014
|
+
orig_indices=self.orig_indices)
|
1015
|
+
|
1016
|
+
# Update the children of the new RootNode
|
1017
|
+
for child in new_root.children:
|
1018
|
+
child.parent = new_root
|
1019
|
+
|
1020
|
+
# Add the new RootNode to the tree
|
1021
|
+
if new_root.tree:
|
1022
|
+
new_root.tree.add_root(new_root)
|
1023
|
+
|
1024
|
+
# Delete this ChildNode
|
1025
|
+
del self
|
1026
|
+
|
1027
|
+
def get_root(self):
|
1028
|
+
"""
|
1029
|
+
Get the root node of the current node.
|
1030
|
+
|
1031
|
+
Returns
|
1032
|
+
-------
|
1033
|
+
RootNode
|
1034
|
+
The root node of the current node.
|
1035
|
+
"""
|
1036
|
+
node = self
|
1037
|
+
while node.parent and not isinstance(node.parent, RootNode):
|
1038
|
+
node = node.parent
|
1039
|
+
return node.parent
|
1040
|
+
|
1041
|
+
class Tree:
|
1042
|
+
"""
|
1043
|
+
Container class for storing a tree structure (technically a forest, as there
|
1044
|
+
can be multiple roots).
|
1045
|
+
"""
|
1046
|
+
def __init__(self, roots=None):
|
1047
|
+
"""
|
1048
|
+
Parameters
|
1049
|
+
----------
|
1050
|
+
roots: list of RootNode, optional (default None)
|
1051
|
+
List of roots for this tree structure.
|
1052
|
+
"""
|
1053
|
+
self.roots = []
|
1054
|
+
self.root_set = set(self.roots)
|
1055
|
+
|
1056
|
+
def get_root_node(self, value):
|
1057
|
+
"""
|
1058
|
+
Get the root node associated with the input value. If that node is not present, return None.
|
1059
|
+
|
1060
|
+
Parameters
|
1061
|
+
----------
|
1062
|
+
value : any
|
1063
|
+
The value to search for in the root nodes.
|
1064
|
+
|
1065
|
+
Returns
|
1066
|
+
-------
|
1067
|
+
RootNode or None
|
1068
|
+
The root node with the specified value, or None if not found.
|
1069
|
+
"""
|
1070
|
+
|
1071
|
+
for node in self.roots:
|
1072
|
+
if node.value == value:
|
1073
|
+
return node
|
1074
|
+
#if we haven't returned already it is because there wasn't a corresponding root,
|
1075
|
+
#so return None
|
1076
|
+
return None
|
1077
|
+
|
1078
|
+
def add_root(self, root_node):
|
1079
|
+
"""
|
1080
|
+
Add a root node to the tree.
|
1081
|
+
|
1082
|
+
Parameters
|
1083
|
+
----------
|
1084
|
+
root_node : RootNode
|
1085
|
+
The root node to be added.
|
1086
|
+
"""
|
1087
|
+
|
1088
|
+
root_node.tree = self
|
1089
|
+
self.roots.append(root_node)
|
1090
|
+
self.root_set.add(root_node)
|
1091
|
+
|
1092
|
+
def remove_root(self, root_node):
|
1093
|
+
"""
|
1094
|
+
Remove a root node from the tree.
|
1095
|
+
|
1096
|
+
Parameters
|
1097
|
+
----------
|
1098
|
+
root_node : RootNode
|
1099
|
+
The root node to be removed.
|
1100
|
+
"""
|
1101
|
+
|
1102
|
+
root_node.tree = None
|
1103
|
+
self.roots = [root for root in self.roots if root is not root_node]
|
1104
|
+
|
1105
|
+
def total_orig_indices(self):
|
1106
|
+
"""
|
1107
|
+
Calculate the total number of original indices for all root nodes and their descendants.
|
1108
|
+
"""
|
1109
|
+
return sum([root.total_orig_indices() for root in self.roots])
|
1110
|
+
|
1111
|
+
def traverse(self):
|
1112
|
+
"""
|
1113
|
+
Traverse the entire tree in pre-order and return a list of node values.
|
1114
|
+
|
1115
|
+
Returns
|
1116
|
+
-------
|
1117
|
+
list
|
1118
|
+
A list of node values in pre-order traversal.
|
1119
|
+
"""
|
1120
|
+
nodes = []
|
1121
|
+
for root in self.roots:
|
1122
|
+
nodes.extend(root.traverse())
|
1123
|
+
return nodes
|
1124
|
+
|
1125
|
+
def count_nodes(self):
|
1126
|
+
"""
|
1127
|
+
Count the total number of nodes in the tree.
|
1128
|
+
"""
|
1129
|
+
count = 0
|
1130
|
+
stack = self.roots[:]
|
1131
|
+
while stack:
|
1132
|
+
node = stack.pop()
|
1133
|
+
count += 1
|
1134
|
+
stack.extend(node.children)
|
1135
|
+
return count
|
1136
|
+
|
1137
|
+
def print_tree(self):
|
1138
|
+
"""
|
1139
|
+
Print the entire tree structure.
|
1140
|
+
"""
|
1141
|
+
for root in self.roots:
|
1142
|
+
root.print_tree()
|
1143
|
+
|
1144
|
+
def calculate_cost(self):
|
1145
|
+
"""
|
1146
|
+
Calculate the total cost of the tree, including root costs and promotion costs for child nodes.
|
1147
|
+
See `RootNode` and `ChildNode`.
|
1148
|
+
"""
|
1149
|
+
total_cost = sum([root.cost for root in self.roots])
|
1150
|
+
total_nodes = self.count_nodes()
|
1151
|
+
total_child_nodes = total_nodes - len(self.roots)
|
1152
|
+
return total_cost + total_child_nodes
|
1153
|
+
|
1154
|
+
def to_networkx_graph(self):
|
1155
|
+
"""
|
1156
|
+
Convert the tree to a NetworkX directed graph with node and edge attributes.
|
1157
|
+
|
1158
|
+
Returns
|
1159
|
+
-------
|
1160
|
+
networkx.DiGraph
|
1161
|
+
The NetworkX directed graph representation of the tree.
|
1162
|
+
"""
|
1163
|
+
G = _nx.DiGraph()
|
1164
|
+
stack = [(None, root) for root in self.roots]
|
1165
|
+
insertion_order = 0
|
1166
|
+
while stack:
|
1167
|
+
parent, node = stack.pop()
|
1168
|
+
node_id = id(node)
|
1169
|
+
prop_cost = node.cost if isinstance(node, RootNode) else 1
|
1170
|
+
G.add_node(node_id, cost=len(node.orig_indices), orig_indices=tuple(node.orig_indices),
|
1171
|
+
label=node.value, prop_cost = prop_cost,
|
1172
|
+
insertion_order=insertion_order)
|
1173
|
+
insertion_order+=1
|
1174
|
+
if parent is not None:
|
1175
|
+
parent_id = id(parent)
|
1176
|
+
edge_cost = node.calculate_promotion_cost()
|
1177
|
+
G.add_edge(parent_id, node_id, promotion_cost=edge_cost)
|
1178
|
+
for child in node.children:
|
1179
|
+
stack.append((node, child))
|
1180
|
+
|
1181
|
+
#if there are multiple roots then add an additional virtual root node as the
|
1182
|
+
#parent for all of these roots to enable partitioning with later algorithms.
|
1183
|
+
if len(self.roots)>1:
|
1184
|
+
G.add_node('virtual_root', cost = 0, orig_indices=(), label = (), prop_cost=0, insertion_order=-1)
|
1185
|
+
for root in self.roots:
|
1186
|
+
G.add_edge('virtual_root', id(root), promotion_cost=0)
|
1187
|
+
|
1188
|
+
return G
|
1189
|
+
|
1190
|
+
#--------------- Tree Partitioning Algorithm Helpers (+NetworkX Utilities)-----------------#
|
1191
|
+
|
1192
|
+
def _draw_graph(G, node_label_key='label', edge_label_key='promotion_cost', figure_size=(10,10)):
|
1193
|
+
"""
|
1194
|
+
Draw the NetworkX graph with node labels.
|
1195
|
+
|
1196
|
+
Parameters
|
1197
|
+
----------
|
1198
|
+
G : networkx.Graph
|
1199
|
+
The networkx Graph object to draw.
|
1200
|
+
|
1201
|
+
node_label_key : str, optional (default 'label')
|
1202
|
+
Optional key for the node attribute to use for the node labels.
|
1203
|
+
|
1204
|
+
edge_label_key : str, optional (default 'cost')
|
1205
|
+
Optional key for the edge attribute to use for the edge labels.
|
1206
|
+
|
1207
|
+
figure_size : tuple of floats, optional (default (10,10))
|
1208
|
+
An optional size specifier passed into the matplotlib figure
|
1209
|
+
constructor to set the plot size.
|
1210
|
+
"""
|
1211
|
+
plt.figure(figsize=figure_size)
|
1212
|
+
pos = _nx.nx_agraph.graphviz_layout(G, prog="dot", args="-Granksep=5 -Gnodesep=10")
|
1213
|
+
labels = _nx.get_node_attributes(G, node_label_key)
|
1214
|
+
_nx.draw(G, pos, labels=labels, with_labels=True, node_size=500, node_color='lightblue', font_size=6, font_weight='bold')
|
1215
|
+
edge_labels = _nx.get_edge_attributes(G, edge_label_key)
|
1216
|
+
_nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
|
1217
|
+
plt.show()
|
1218
|
+
|
1219
|
+
def _copy_networkx_graph(G):
|
1220
|
+
"""
|
1221
|
+
Create a new independent copy of a NetworkX directed graph with node and edge attributes that
|
1222
|
+
match the original graph. Specialized to copying graphs with the known attributes set by the
|
1223
|
+
`to_networkx_graph` method of the `Tree` class.
|
1224
|
+
|
1225
|
+
Parameters
|
1226
|
+
----------
|
1227
|
+
G : networkx.DiGraph
|
1228
|
+
The original NetworkX directed graph.
|
1229
|
+
|
1230
|
+
Returns
|
1231
|
+
-------
|
1232
|
+
networkx.DiGraph
|
1233
|
+
The new independent copy of the NetworkX directed graph.
|
1234
|
+
"""
|
1235
|
+
new_G = _nx.DiGraph()
|
1236
|
+
|
1237
|
+
# Copy nodes with attributes
|
1238
|
+
for node, data in G.nodes(data=True):
|
1239
|
+
new_G.add_node(node, cost = data['cost'], orig_indices=data['orig_indices'],
|
1240
|
+
label= data['label'] , prop_cost = data['prop_cost'],
|
1241
|
+
insertion_order=data['insertion_order'])
|
1242
|
+
|
1243
|
+
# Copy edges with attributes
|
1244
|
+
for u, v, data in G.edges(data=True):
|
1245
|
+
new_G.add_edge(u, v, promotion_cost = data['promotion_cost'])
|
1246
|
+
|
1247
|
+
return new_G
|
1248
|
+
|
1249
|
+
def _find_root(tree):
|
1250
|
+
"""
|
1251
|
+
Find the root node of a directed tree.
|
1252
|
+
|
1253
|
+
Parameters
|
1254
|
+
----------
|
1255
|
+
tree : networkx.DiGraph
|
1256
|
+
The directed tree.
|
1257
|
+
|
1258
|
+
Returns
|
1259
|
+
-------
|
1260
|
+
networkx node corresponding to the root.
|
1261
|
+
"""
|
1262
|
+
|
1263
|
+
# The root node will have no incoming edges
|
1264
|
+
for node in tree.nodes():
|
1265
|
+
if tree.in_degree(node) == 0:
|
1266
|
+
return node
|
1267
|
+
raise ValueError("The input graph is not a valid tree (no root found).")
|
1268
|
+
|
1269
|
+
def _compute_subtree_weights(tree, root, weight_key):
|
1270
|
+
"""
|
1271
|
+
This function computes the total weight of each subtree in a directed tree.
|
1272
|
+
The weight of a subtree is defined as the sum of the weights of all nodes
|
1273
|
+
in that subtree, including the root of the subtree.
|
1274
|
+
|
1275
|
+
Parameters
|
1276
|
+
----------
|
1277
|
+
tree : networkx.DiGraph
|
1278
|
+
The directed tree.
|
1279
|
+
|
1280
|
+
root: networkx node
|
1281
|
+
The root node of the tree.
|
1282
|
+
|
1283
|
+
weight_key : str
|
1284
|
+
A string corresponding to the node attribute to use as the weights.
|
1285
|
+
|
1286
|
+
Returns
|
1287
|
+
-------
|
1288
|
+
A dictionary where keys are nodes and values are the total weights of the subtrees rooted at those nodes.
|
1289
|
+
"""
|
1290
|
+
|
1291
|
+
subtree_weights = {} # {node: 0 for node in tree.nodes()}
|
1292
|
+
stack = [root]
|
1293
|
+
visited = set()
|
1294
|
+
|
1295
|
+
# First pass: calculate the subtree weights in a bottom-up manner
|
1296
|
+
while stack:
|
1297
|
+
node = stack.pop()
|
1298
|
+
if node in visited:
|
1299
|
+
# All children have been processed, now process the node itself
|
1300
|
+
subtree_weight = tree.nodes[node][weight_key]
|
1301
|
+
for child in tree.successors(node):
|
1302
|
+
subtree_weight += subtree_weights[child]
|
1303
|
+
subtree_weights[node] = subtree_weight
|
1304
|
+
else:
|
1305
|
+
# Process the node after its children
|
1306
|
+
visited.add(node)
|
1307
|
+
stack.append(node)
|
1308
|
+
for child in tree.successors(node):
|
1309
|
+
if child not in visited:
|
1310
|
+
stack.append(child)
|
1311
|
+
|
1312
|
+
return subtree_weights
|
1313
|
+
|
1314
|
+
def _partition_levels(tree, root):
|
1315
|
+
"""
|
1316
|
+
Partition the nodes of a rooted directed tree into levels based on their distance from the root.
|
1317
|
+
|
1318
|
+
Parameters
|
1319
|
+
----------
|
1320
|
+
tree : networkx.DiGraph
|
1321
|
+
The directed tree.
|
1322
|
+
root : networkx node
|
1323
|
+
The root node of the tree.
|
1324
|
+
|
1325
|
+
Returns
|
1326
|
+
-------
|
1327
|
+
list of sets:
|
1328
|
+
A list where each set contains nodes that are equidistant from the root.
|
1329
|
+
"""
|
1330
|
+
# Initialize a dictionary to store the level of each node
|
1331
|
+
levels = {}
|
1332
|
+
# Initialize a queue for BFS
|
1333
|
+
queue = _collections.deque([(root, 0)])
|
1334
|
+
|
1335
|
+
while queue:
|
1336
|
+
node, level = queue.popleft()
|
1337
|
+
if level not in levels:
|
1338
|
+
levels[level] = set()
|
1339
|
+
levels[level].add(node)
|
1340
|
+
|
1341
|
+
for child in tree.successors(node):
|
1342
|
+
queue.append((child, level + 1))
|
1343
|
+
|
1344
|
+
tree_nodes = tree.nodes
|
1345
|
+
# Convert the levels dictionary to a list of sets ordered by level
|
1346
|
+
sorted_levels = []
|
1347
|
+
for level in sorted(levels.keys()):
|
1348
|
+
# Sort nodes at each level by 'insertion_order' attribute
|
1349
|
+
sorted_nodes = sorted(levels[level], key=lambda node: tree_nodes[node]['insertion_order'])
|
1350
|
+
sorted_levels.append(sorted_nodes)
|
1351
|
+
|
1352
|
+
return sorted_levels
|
1353
|
+
|
1354
|
+
|
1355
|
+
def _partition_levels_and_compute_subtree_weights(tree, root, weight_key):
|
1356
|
+
"""
|
1357
|
+
Partition the nodes of a rooted directed tree into levels based on their distance from the root
|
1358
|
+
and compute the total weight of each subtree.
|
1359
|
+
|
1360
|
+
Parameters
|
1361
|
+
----------
|
1362
|
+
tree : networkx.DiGraph
|
1363
|
+
The directed tree.
|
1364
|
+
root : networkx node
|
1365
|
+
The root node of the tree.
|
1366
|
+
weight_key : str
|
1367
|
+
A string corresponding to the node attribute to use as the weights.
|
1368
|
+
|
1369
|
+
Returns
|
1370
|
+
-------
|
1371
|
+
tuple:
|
1372
|
+
- list of sets: A list where each set contains nodes that are equidistant from the root.
|
1373
|
+
- dict: A dictionary where keys are nodes and values are the total weights of the subtrees rooted at those nodes.
|
1374
|
+
"""
|
1375
|
+
# Initialize a dictionary to store the level of each node
|
1376
|
+
levels = {}
|
1377
|
+
# Initialize a dictionary to store the subtree weights
|
1378
|
+
subtree_weights = {}
|
1379
|
+
# Initialize a queue for BFS
|
1380
|
+
queue = _collections.deque([(root, 0)])
|
1381
|
+
# Initialize a stack for DFS to compute subtree weights
|
1382
|
+
stack = []
|
1383
|
+
visited = set()
|
1384
|
+
|
1385
|
+
#I think this returns a view, so grab this ahead of time in case
|
1386
|
+
#there is overhead with that.
|
1387
|
+
tree_nodes = tree.nodes
|
1388
|
+
|
1389
|
+
#successors is kind of an expensive call and we use at least twice
|
1390
|
+
#per node, so let's just compute it once and cache in a dict.
|
1391
|
+
node_successors = {node: list(tree.successors(node)) for node in tree_nodes}
|
1392
|
+
|
1393
|
+
while queue:
|
1394
|
+
node, level = queue.popleft()
|
1395
|
+
if node not in visited:
|
1396
|
+
visited.add(node)
|
1397
|
+
if level not in levels:
|
1398
|
+
levels[level] = set()
|
1399
|
+
levels[level].add(node)
|
1400
|
+
stack.append(node)
|
1401
|
+
for child in node_successors[node]:
|
1402
|
+
queue.append((child, level + 1))
|
1403
|
+
|
1404
|
+
# Compute subtree weights in a bottom-up manner
|
1405
|
+
while stack:
|
1406
|
+
node = stack.pop()
|
1407
|
+
subtree_weight = tree_nodes[node][weight_key]
|
1408
|
+
for child in node_successors[node]:
|
1409
|
+
subtree_weight += subtree_weights[child]
|
1410
|
+
subtree_weights[node] = subtree_weight
|
1411
|
+
|
1412
|
+
# Convert the levels dictionary to a list of sets ordered by level
|
1413
|
+
sorted_levels = []
|
1414
|
+
for level in sorted(levels.keys()):
|
1415
|
+
# Sort nodes at each level by 'insertion_order' attribute
|
1416
|
+
sorted_nodes = sorted(levels[level], key=lambda node: tree_nodes[node]['insertion_order'])
|
1417
|
+
sorted_levels.append(sorted_nodes)
|
1418
|
+
|
1419
|
+
return sorted_levels, subtree_weights
|
1420
|
+
|
1421
|
+
|
1422
|
+
def _find_leaves(tree):
|
1423
|
+
"""
|
1424
|
+
Find all leaf nodes in a directed tree.
|
1425
|
+
|
1426
|
+
Parameters
|
1427
|
+
----------
|
1428
|
+
tree : networkx.DiGraph
|
1429
|
+
The directed tree.
|
1430
|
+
|
1431
|
+
Returns
|
1432
|
+
-------
|
1433
|
+
A list of leaf nodes.
|
1434
|
+
"""
|
1435
|
+
leaf_nodes = set([node for node in tree.nodes() if tree.out_degree(node) == 0])
|
1436
|
+
return leaf_nodes
|
1437
|
+
|
1438
|
+
def _path_to_root(tree, node, root):
|
1439
|
+
"""
|
1440
|
+
Return a list of nodes along the path from the given node to the root.
|
1441
|
+
|
1442
|
+
Parameters
|
1443
|
+
----------
|
1444
|
+
tree : networkx.DiGraph
|
1445
|
+
The directed tree.
|
1446
|
+
node : networkx node
|
1447
|
+
The starting node.
|
1448
|
+
root : networkx node
|
1449
|
+
The root node of the tree.
|
1450
|
+
|
1451
|
+
Returns
|
1452
|
+
-------
|
1453
|
+
A list of nodes along the path from the given node to the root.
|
1454
|
+
"""
|
1455
|
+
path = []
|
1456
|
+
current_node = node
|
1457
|
+
|
1458
|
+
while current_node != root:
|
1459
|
+
path.append(current_node)
|
1460
|
+
#note: for a tree structure there should be just one predecessor
|
1461
|
+
#so not worried about nondeterminism, if we every apply this to another
|
1462
|
+
#graph structure this needs to be reevaluated.
|
1463
|
+
predecessors = list(tree.predecessors(current_node))
|
1464
|
+
current_node = predecessors[0]
|
1465
|
+
path.append(root)
|
1466
|
+
|
1467
|
+
return path
|
1468
|
+
|
1469
|
+
def _get_subtree(tree, root):
|
1470
|
+
"""
|
1471
|
+
Return a new graph corresponding to the subtree rooted at the given node.
|
1472
|
+
|
1473
|
+
Parameters
|
1474
|
+
----------
|
1475
|
+
tree : networkx.DiGraph
|
1476
|
+
The directed tree.
|
1477
|
+
|
1478
|
+
root : networkx node
|
1479
|
+
The root node of the subtree.
|
1480
|
+
|
1481
|
+
Returns
|
1482
|
+
-------
|
1483
|
+
subtree : networkx.DiGraph
|
1484
|
+
A new directed graph corresponding to the subtree rooted at the given node.
|
1485
|
+
"""
|
1486
|
+
# Create a new directed graph for the subtree
|
1487
|
+
subtree = _nx.DiGraph()
|
1488
|
+
|
1489
|
+
# Use a queue to perform BFS and add nodes and edges to the subtree
|
1490
|
+
queue = [root]
|
1491
|
+
while queue:
|
1492
|
+
node = queue.pop(0)
|
1493
|
+
subtree.add_node(node, **tree.nodes[node])
|
1494
|
+
for child in tree.successors(node):
|
1495
|
+
subtree.add_edge(node, child, **tree.edges[node, child])
|
1496
|
+
queue.append(child)
|
1497
|
+
|
1498
|
+
return subtree
|
1499
|
+
|
1500
|
+
def _collect_orig_indices(tree, root):
|
1501
|
+
"""
|
1502
|
+
Collect all values of the 'orig_indices' node attributes in the subtree rooted at the given node.
|
1503
|
+
The 'orig_indices' values are tuples, and the function flattens these tuples into a single list.
|
1504
|
+
|
1505
|
+
Parameters
|
1506
|
+
----------
|
1507
|
+
tree : networkx.DiGraph
|
1508
|
+
The directed tree.
|
1509
|
+
|
1510
|
+
root : networkx node
|
1511
|
+
The root node of the subtree.
|
1512
|
+
|
1513
|
+
Returns
|
1514
|
+
-------
|
1515
|
+
list
|
1516
|
+
A flattened list of all values of the 'orig_indices' node attributes in the subtree.
|
1517
|
+
"""
|
1518
|
+
orig_indices_list = []
|
1519
|
+
queue = [root]
|
1520
|
+
|
1521
|
+
#TODO: See if this would be any faster with one of the dfs/bfs iterators in networkx
|
1522
|
+
while queue:
|
1523
|
+
node = queue.pop()
|
1524
|
+
orig_indices_list.extend(tree.nodes[node]['orig_indices'])
|
1525
|
+
for child in tree.successors(node):
|
1526
|
+
queue.append(child)
|
1527
|
+
|
1528
|
+
return sorted(orig_indices_list) #sort it to account for any nondeterministic traversal order.
|
1529
|
+
|
1530
|
+
def _process_node_km(node, tree, subtree_weights, cut_edges, max_weight, root, new_roots):
|
1531
|
+
"""
|
1532
|
+
Helper function for Kundu-Misra algorithm. This function processes each node
|
1533
|
+
by cutting edges with the highest weight children until the node's subtree weight
|
1534
|
+
is below the maximum weight threshold, updating the subtree weights of any ancestors
|
1535
|
+
as needed.
|
1536
|
+
"""
|
1537
|
+
|
1538
|
+
#if the subtree weight of this node is less than max weight we can stop right away
|
1539
|
+
#and avoid the sorting of the child weights.
|
1540
|
+
if subtree_weights[node]<=max_weight:
|
1541
|
+
return
|
1542
|
+
|
1543
|
+
tree_nodes = tree.nodes
|
1544
|
+
#otherwise we will sort the weights of the child nodes to get the heaviest weight ones.
|
1545
|
+
#sorting by insertion order to ensure determinism.
|
1546
|
+
weighted_children = [(child, subtree_weights[child]) for child in
|
1547
|
+
sorted(tree.successors(node), key=lambda node: tree_nodes[node]['insertion_order']) ]
|
1548
|
+
sorted_weighted_children = sorted(weighted_children, key = lambda x: x[1], reverse=True)
|
1549
|
+
|
1550
|
+
#get the path of nodes up to the root which need to have their weights updated upon edge removal.
|
1551
|
+
nodes_to_update = _path_to_root(tree, node, root)
|
1552
|
+
|
1553
|
+
#remove the weightiest children until the weight is below the maximum weight.
|
1554
|
+
removed_child_index = 0 #track the index of the child being removed.
|
1555
|
+
while subtree_weights[node]>max_weight:
|
1556
|
+
removed_child = sorted_weighted_children[removed_child_index][0]
|
1557
|
+
#add the edge to this child to the list of those cut.
|
1558
|
+
cut_edges.append((node, removed_child))
|
1559
|
+
new_roots.append(removed_child)
|
1560
|
+
removed_child_weight = subtree_weights[removed_child]
|
1561
|
+
#update the subtree weight of the current node and all parents up to the root.
|
1562
|
+
for node_to_update in nodes_to_update:
|
1563
|
+
subtree_weights[node_to_update]-= removed_child_weight
|
1564
|
+
#update index:
|
1565
|
+
removed_child_index+=1
|
1566
|
+
|
1567
|
+
def tree_partition_kundu_misra(tree, max_weight, weight_key='cost', test_leaves = True,
|
1568
|
+
return_levels_and_weights=False, precomp_levels = None,
|
1569
|
+
precomp_weights = None):
|
1570
|
+
"""
|
1571
|
+
Algorithm for optimal minimum cardinality k-partition of tree (a partition
|
1572
|
+
of a tree into cluster of size at most k) based on a slightly less sophisticated
|
1573
|
+
implementation of the algorithm from "A Linear Tree Partitioning Algorithm"
|
1574
|
+
by Kundu and Misra (SIAM J. Comput. Vol. 6, No. 1, March 1977). Less sophisiticated
|
1575
|
+
because the strictly linear time implementation uses linear-time median estimation
|
1576
|
+
routine, while this implementation uses sorting (n log(n)-time), in practice it is
|
1577
|
+
likely that the highly-optimized C implementation of sorting would beat an uglier
|
1578
|
+
python implementation of median finding for most problem instances of interest anyhow.
|
1579
|
+
|
1580
|
+
Parameters
|
1581
|
+
----------
|
1582
|
+
tree : networkx.DiGraph
|
1583
|
+
An input graph representing the directed tree to perform partitioning on.
|
1584
|
+
|
1585
|
+
max_weight : int
|
1586
|
+
Maximum node weight allowed for each partition.
|
1587
|
+
|
1588
|
+
weight_key : str, optional (default 'cost')
|
1589
|
+
An optional string denoting the node attribute label to use for node weights
|
1590
|
+
in partitioning.
|
1591
|
+
|
1592
|
+
test_leaves : bool, optional (default True)
|
1593
|
+
When True an initial test is performed to ensure that the weight of the leaves are all
|
1594
|
+
less than the maximum weight. Only turn off if you know for certain this is true.
|
1595
|
+
|
1596
|
+
return_levels_and_weights : bool, optional (default False)
|
1597
|
+
If True return the constructed tree level structure (the lists of nodes partitioned
|
1598
|
+
by distance from the root) and subtree weights.
|
1599
|
+
|
1600
|
+
precomp_levels : list of sets, optional (default None)
|
1601
|
+
A list where each set contains nodes that are equidistant from the root.
|
1602
|
+
|
1603
|
+
precomp_weights : dict, optional (default None)
|
1604
|
+
A dictionary where keys are nodes and values are the total weights of the subtrees rooted at those nodes.
|
1605
|
+
|
1606
|
+
Returns
|
1607
|
+
-------
|
1608
|
+
partitioned_tree : networkx.DiGraph
|
1609
|
+
A new DiGraph corresponding to the partitioned tree. I.e. a copy of the original
|
1610
|
+
tree with the requisite edge cuts performed.
|
1611
|
+
|
1612
|
+
cut_edges : list of tuples
|
1613
|
+
A list of the parent-child node pairs whose edges were cut in partitioning the tree.
|
1614
|
+
|
1615
|
+
|
1616
|
+
"""
|
1617
|
+
#create a copy of the input tree:
|
1618
|
+
#tree = _copy_networkx_graph(tree)
|
1619
|
+
|
1620
|
+
cut_edges = [] #list of cut edges.
|
1621
|
+
new_roots = [] #list of the subtree root node in the partitioned tree
|
1622
|
+
|
1623
|
+
#find the root node of tree:
|
1624
|
+
root = _find_root(tree)
|
1625
|
+
new_roots.append(root)
|
1626
|
+
|
1627
|
+
tree_nodes = tree.nodes
|
1628
|
+
|
1629
|
+
if test_leaves:
|
1630
|
+
#find the leaves:
|
1631
|
+
leaves = _find_leaves(tree)
|
1632
|
+
#make sure that the weights of the leaves are all less than the maximum weight.
|
1633
|
+
msg = 'The maximum node weight for at least one leaf is greater than the maximum weight, no partition possible.'
|
1634
|
+
assert all([tree_nodes[leaf][weight_key]<=max_weight for leaf in leaves]), msg
|
1635
|
+
|
1636
|
+
#precompute a list of subtree weights which will be dynamically updated as we make cuts. Also
|
1637
|
+
#parition tree into levels based on distance from root.
|
1638
|
+
if precomp_levels is None and precomp_weights is None:
|
1639
|
+
tree_levels, subtree_weights = _partition_levels_and_compute_subtree_weights(tree, root, weight_key)
|
1640
|
+
else:
|
1641
|
+
tree_levels = precomp_levels if precomp_levels is not None else _partition_levels(tree, root)
|
1642
|
+
subtree_weights = precomp_weights.copy() if precomp_weights is not None else _compute_subtree_weights(tree, root, weight_key)
|
1643
|
+
|
1644
|
+
#the subtree_weights get modified in-place by _process_node_km, so create a copy for the return value.
|
1645
|
+
if return_levels_and_weights:
|
1646
|
+
subtree_weights_orig = subtree_weights.copy()
|
1647
|
+
|
1648
|
+
#begin processing the nodes level-by-level.
|
1649
|
+
for level in reversed(tree_levels):
|
1650
|
+
for node in level:
|
1651
|
+
_process_node_km(node, tree, subtree_weights, cut_edges, max_weight, root, new_roots)
|
1652
|
+
|
1653
|
+
#sort the new root nodes in case there are determinism issues
|
1654
|
+
new_roots = sorted(new_roots, key=lambda node: tree_nodes[node]['insertion_order'])
|
1655
|
+
|
1656
|
+
if return_levels_and_weights:
|
1657
|
+
return cut_edges, new_roots, tree_levels, subtree_weights_orig
|
1658
|
+
else:
|
1659
|
+
return cut_edges, new_roots
|
1660
|
+
|
1661
|
+
def _bisect_tree(tree, subtree_root, subtree_weights, weight_key, root_cost = 0, target_proportion = .5):
|
1662
|
+
#perform a bisection on the subtree. Loop through the tree beginning at the root,
|
1663
|
+
#and find as cheap as possible of an edge which when cut approximately bisects the tree based on cost.
|
1664
|
+
|
1665
|
+
heaviest_subtree_levels = _partition_levels(tree, subtree_root)
|
1666
|
+
new_subtree_cost = {}
|
1667
|
+
|
1668
|
+
new_subtree_cost[subtree_root] = subtree_weights[subtree_root]
|
1669
|
+
for i, level in enumerate(heaviest_subtree_levels[1:]): #skip the root.
|
1670
|
+
for node in level:
|
1671
|
+
#calculate the cost of a new subtree rooted at this node. This is the current cost
|
1672
|
+
#plus the current level plus the propagation cost of the current root.
|
1673
|
+
new_subtree_cost[node] = subtree_weights[node] + i + root_cost if weight_key == 'prop_cost' else subtree_weights[node]
|
1674
|
+
|
1675
|
+
#find the node that results in as close as possible to a bisection of the subtree
|
1676
|
+
#in terms of propagation cost.
|
1677
|
+
target_prop_cost = new_subtree_cost[subtree_root] * target_proportion
|
1678
|
+
closest_node = subtree_root
|
1679
|
+
closest_distance = new_subtree_cost[subtree_root]
|
1680
|
+
for node, cost in new_subtree_cost.items(): #since the nodes in each level are sorted this should be alright for determinism.
|
1681
|
+
current_distance = abs(cost - target_prop_cost)
|
1682
|
+
if current_distance < closest_distance:
|
1683
|
+
closest_distance = current_distance
|
1684
|
+
closest_node = node
|
1685
|
+
#we now have the node which when promoted to a root produces the tree closest to a bisection in terms of propagation
|
1686
|
+
#cost possible. Let's perform that bisection now.
|
1687
|
+
if closest_node is not subtree_root:
|
1688
|
+
#since a tree should only be one predecessor, so don't need to worry about determinism.
|
1689
|
+
cut_edge = (list(tree.predecessors(closest_node))[0], closest_node)
|
1690
|
+
return cut_edge, (new_subtree_cost[closest_node], subtree_weights[subtree_root] - subtree_weights[closest_node])
|
1691
|
+
else:
|
1692
|
+
return None, None
|
1693
|
+
|
1694
|
+
def _bisection_pass(partitioned_tree, cut_edges, new_roots, num_sub_tables, weight_key):
|
1695
|
+
partitioned_tree = _copy_networkx_graph(partitioned_tree)
|
1696
|
+
subtree_weights = [(root, _compute_subtree_weights(partitioned_tree, root, weight_key)) for root in new_roots]
|
1697
|
+
sorted_subtree_weights = sorted(subtree_weights, key=lambda x: x[1][x[0]], reverse=True)
|
1698
|
+
|
1699
|
+
#perform a bisection on the heaviest subtree. Loop through the tree beginning at the root,
|
1700
|
+
#and find as cheap as possible of an edge which when cut approximately bisects the tree based on cost.
|
1701
|
+
for i in range(len(sorted_subtree_weights)):
|
1702
|
+
heaviest_subtree_root = sorted_subtree_weights[i][0]
|
1703
|
+
heaviest_subtree_weights = sorted_subtree_weights[i][1]
|
1704
|
+
root_cost = partitioned_tree.nodes[heaviest_subtree_root][weight_key] if weight_key == 'prop_cost' else 0
|
1705
|
+
cut_edge, new_subtree_costs = _bisect_tree(partitioned_tree, heaviest_subtree_root, heaviest_subtree_weights, weight_key, root_cost)
|
1706
|
+
if cut_edge is not None:
|
1707
|
+
cut_edges.append(cut_edge)
|
1708
|
+
new_roots.append(cut_edge[1])
|
1709
|
+
#cut the prescribed edge.
|
1710
|
+
partitioned_tree.remove_edge(cut_edge[0], cut_edge[1])
|
1711
|
+
#check whether we need to continue paritioning subtrees.
|
1712
|
+
if len(new_roots) == num_sub_tables:
|
1713
|
+
break
|
1714
|
+
#sort the new root nodes in case there are determinism issues
|
1715
|
+
new_roots = sorted(new_roots, key=lambda node: partitioned_tree.nodes[node]['insertion_order'])
|
1716
|
+
|
1717
|
+
return partitioned_tree, new_roots, cut_edges
|
1718
|
+
|
1719
|
+
def _refinement_pass(partitioned_tree, roots, weight_key, imbalance_threshold=1.2, minimum_improvement_threshold = .1):
|
1720
|
+
#refine the partitioning to improve the balancing of the specified weights across the
|
1721
|
+
#subtrees.
|
1722
|
+
#start by recomputing the latest subtree weights and ranking them from heaviest to lightest.
|
1723
|
+
partitioned_tree = _copy_networkx_graph(partitioned_tree)
|
1724
|
+
subtree_weights = [(root, _compute_subtree_weights(partitioned_tree, root, weight_key)) for root in roots]
|
1725
|
+
sorted_subtree_weights = sorted(subtree_weights, key=lambda x: x[1][x[0]], reverse=True)
|
1726
|
+
|
1727
|
+
partitioned_tree_nodes = partitioned_tree.nodes
|
1728
|
+
|
1729
|
+
#Strategy: pair heaviest and lightest subtrees and identify the subtree in the heaviest that could be
|
1730
|
+
#snipped out and added to the lightest to bring their weights as close as possible.
|
1731
|
+
#Next do this for the second heaviest and second lightest, etc.
|
1732
|
+
#Only do so while the imbalance threshold, the ratio between the heaviest and lightest subtrees, is
|
1733
|
+
#above a specified threshold.
|
1734
|
+
heavy_light_pairs = _pair_elements(sorted_subtree_weights)
|
1735
|
+
heavy_light_pair_indices = _pair_elements(list(range(len(sorted_subtree_weights))))
|
1736
|
+
heavy_light_weights = [(sorted_subtree_weights[i][1][sorted_subtree_weights[i][0]], sorted_subtree_weights[j][1][sorted_subtree_weights[j][0]])
|
1737
|
+
for i,j in heavy_light_pair_indices]
|
1738
|
+
heavy_light_ratios = [weight_1/weight_2 for weight_1,weight_2 in heavy_light_weights]
|
1739
|
+
|
1740
|
+
heavy_light_pairs_to_balance = heavy_light_pairs if len(sorted_subtree_weights)%2==0 else heavy_light_pairs[0:-1]
|
1741
|
+
new_roots = []
|
1742
|
+
addl_cut_edges = []
|
1743
|
+
pair_iter = iter(range(len(heavy_light_pairs_to_balance)))
|
1744
|
+
for i in pair_iter:
|
1745
|
+
#if the ratio is above the threshold then try a rebalancing
|
1746
|
+
#step.
|
1747
|
+
if heavy_light_ratios[i] > imbalance_threshold:
|
1748
|
+
#calculate the fraction of the heavy tree that would be needed to bring the weight of the
|
1749
|
+
#lighter tree in line.
|
1750
|
+
root_cost = partitioned_tree_nodes[heavy_light_pairs[i][0][0]][weight_key] if weight_key == 'prop_cost' else 0
|
1751
|
+
|
1752
|
+
rebalancing_target_fraction = (.5*(heavy_light_weights[i][0] - heavy_light_weights[i][1]))/heavy_light_weights[i][0]
|
1753
|
+
cut_edge, new_subtree_weights =_bisect_tree(partitioned_tree, heavy_light_pairs[i][0][0], heavy_light_pairs[i][0][1],
|
1754
|
+
weight_key, root_cost = root_cost,
|
1755
|
+
target_proportion = rebalancing_target_fraction)
|
1756
|
+
#before applying the edge cut check whether the edge we found was close enough
|
1757
|
+
# to bring us below the threshold.
|
1758
|
+
if cut_edge is not None:
|
1759
|
+
new_light_tree_weight = new_subtree_weights[0] + heavy_light_weights[i][1]
|
1760
|
+
new_heavy_tree_weight = new_subtree_weights[1]
|
1761
|
+
new_heavy_light_ratio = new_heavy_tree_weight/new_light_tree_weight
|
1762
|
+
if new_heavy_light_ratio > imbalance_threshold and \
|
1763
|
+
(heavy_light_ratios[i] - new_heavy_light_ratio)<minimum_improvement_threshold:
|
1764
|
+
#We're only as good as the worst balancing, so if we are unable to
|
1765
|
+
#balance below the threshold and the improvement is below some minimum threshold
|
1766
|
+
#then we won't make an update and will terminate.
|
1767
|
+
#Maybe we should throw a warning too?
|
1768
|
+
#but it isn't clear whether that would just be confusing to end-users who wouldn't
|
1769
|
+
#know what was meant or if it was important.
|
1770
|
+
#also add the roots of any of the pairs we haven't yet processed.
|
1771
|
+
remaining_indices = [i] + [j for j in pair_iter]
|
1772
|
+
for idx in remaining_indices:
|
1773
|
+
new_roots.extend((heavy_light_pairs[idx][0][0], heavy_light_pairs[idx][1][0]))
|
1774
|
+
break
|
1775
|
+
|
1776
|
+
else:
|
1777
|
+
#append the original root of the heavy tree, and a tuple of roots for the light plus the
|
1778
|
+
#bisected part of the heavy.
|
1779
|
+
new_roots.append(heavy_light_pairs[i][0][0])
|
1780
|
+
new_roots.append((heavy_light_pairs[i][1][0], cut_edge[1]))
|
1781
|
+
addl_cut_edges.append(cut_edge)
|
1782
|
+
#apply the cut
|
1783
|
+
partitioned_tree.remove_edge(cut_edge[0],cut_edge[1])
|
1784
|
+
else:
|
1785
|
+
#if the cut edge is None append the original heavy and light roots.
|
1786
|
+
new_roots.extend((heavy_light_pairs[i][0][0], heavy_light_pairs[i][1][0]))
|
1787
|
+
#since we're pairing up subsequent pairs of heavy and light
|
1788
|
+
#elements, once we see one which is sufficiently balanced we
|
1789
|
+
#know the rest must be.
|
1790
|
+
else:
|
1791
|
+
remaining_indices = [i] + [j for j in pair_iter]
|
1792
|
+
for idx in remaining_indices:
|
1793
|
+
new_roots.extend((heavy_light_pairs[idx][0][0], heavy_light_pairs[idx][1][0]))
|
1794
|
+
break
|
1795
|
+
|
1796
|
+
#if the number of subtrees was odd to start we need to append on the median weight element which hasn't
|
1797
|
+
#been processed.
|
1798
|
+
if len(sorted_subtree_weights)%2!=0:
|
1799
|
+
new_roots.append(heavy_light_pairs[-1][0][0])
|
1800
|
+
|
1801
|
+
return partitioned_tree, new_roots, addl_cut_edges
|
1802
|
+
|
1803
|
+
|
1804
|
+
#helper function for pairing up heavy and light subtrees.
|
1805
|
+
def _pair_elements(lst):
|
1806
|
+
paired_list = []
|
1807
|
+
length = len(lst)
|
1808
|
+
|
1809
|
+
for i in range((length + 1) // 2):
|
1810
|
+
if i == length - i - 1:
|
1811
|
+
paired_list.append((lst[i], lst[i]))
|
1812
|
+
else:
|
1813
|
+
paired_list.append((lst[i], lst[length - i - 1]))
|
1814
|
+
|
1815
|
+
return paired_list
|
1816
|
+
|
1817
|
+
|