pyGSTi 0.9.12.1__cp310-cp310-win_amd64.whl → 0.9.13__cp310-cp310-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (221) hide show
  1. pyGSTi-0.9.13.dist-info/METADATA +197 -0
  2. {pyGSTi-0.9.12.1.dist-info → pyGSTi-0.9.13.dist-info}/RECORD +207 -217
  3. {pyGSTi-0.9.12.1.dist-info → pyGSTi-0.9.13.dist-info}/WHEEL +1 -1
  4. pygsti/_version.py +2 -2
  5. pygsti/algorithms/contract.py +1 -1
  6. pygsti/algorithms/core.py +42 -28
  7. pygsti/algorithms/fiducialselection.py +17 -8
  8. pygsti/algorithms/gaugeopt.py +2 -2
  9. pygsti/algorithms/germselection.py +87 -77
  10. pygsti/algorithms/mirroring.py +0 -388
  11. pygsti/algorithms/randomcircuit.py +165 -1333
  12. pygsti/algorithms/rbfit.py +0 -234
  13. pygsti/baseobjs/basis.py +94 -396
  14. pygsti/baseobjs/errorgenbasis.py +0 -132
  15. pygsti/baseobjs/errorgenspace.py +0 -10
  16. pygsti/baseobjs/label.py +52 -168
  17. pygsti/baseobjs/opcalc/fastopcalc.cp310-win_amd64.pyd +0 -0
  18. pygsti/baseobjs/opcalc/fastopcalc.pyx +2 -2
  19. pygsti/baseobjs/polynomial.py +13 -595
  20. pygsti/baseobjs/statespace.py +1 -0
  21. pygsti/circuits/__init__.py +1 -1
  22. pygsti/circuits/circuit.py +682 -505
  23. pygsti/circuits/circuitconstruction.py +0 -4
  24. pygsti/circuits/circuitlist.py +47 -5
  25. pygsti/circuits/circuitparser/__init__.py +8 -8
  26. pygsti/circuits/circuitparser/fastcircuitparser.cp310-win_amd64.pyd +0 -0
  27. pygsti/circuits/circuitstructure.py +3 -3
  28. pygsti/circuits/cloudcircuitconstruction.py +1 -1
  29. pygsti/data/datacomparator.py +2 -7
  30. pygsti/data/dataset.py +46 -44
  31. pygsti/data/hypothesistest.py +0 -7
  32. pygsti/drivers/bootstrap.py +0 -49
  33. pygsti/drivers/longsequence.py +2 -1
  34. pygsti/evotypes/basereps_cython.cp310-win_amd64.pyd +0 -0
  35. pygsti/evotypes/chp/opreps.py +0 -61
  36. pygsti/evotypes/chp/statereps.py +0 -32
  37. pygsti/evotypes/densitymx/effectcreps.cpp +9 -10
  38. pygsti/evotypes/densitymx/effectreps.cp310-win_amd64.pyd +0 -0
  39. pygsti/evotypes/densitymx/effectreps.pyx +1 -1
  40. pygsti/evotypes/densitymx/opreps.cp310-win_amd64.pyd +0 -0
  41. pygsti/evotypes/densitymx/opreps.pyx +2 -2
  42. pygsti/evotypes/densitymx/statereps.cp310-win_amd64.pyd +0 -0
  43. pygsti/evotypes/densitymx/statereps.pyx +1 -1
  44. pygsti/evotypes/densitymx_slow/effectreps.py +7 -23
  45. pygsti/evotypes/densitymx_slow/opreps.py +16 -23
  46. pygsti/evotypes/densitymx_slow/statereps.py +10 -3
  47. pygsti/evotypes/evotype.py +39 -2
  48. pygsti/evotypes/stabilizer/effectreps.cp310-win_amd64.pyd +0 -0
  49. pygsti/evotypes/stabilizer/effectreps.pyx +0 -4
  50. pygsti/evotypes/stabilizer/opreps.cp310-win_amd64.pyd +0 -0
  51. pygsti/evotypes/stabilizer/opreps.pyx +0 -4
  52. pygsti/evotypes/stabilizer/statereps.cp310-win_amd64.pyd +0 -0
  53. pygsti/evotypes/stabilizer/statereps.pyx +1 -5
  54. pygsti/evotypes/stabilizer/termreps.cp310-win_amd64.pyd +0 -0
  55. pygsti/evotypes/stabilizer/termreps.pyx +0 -7
  56. pygsti/evotypes/stabilizer_slow/effectreps.py +0 -22
  57. pygsti/evotypes/stabilizer_slow/opreps.py +0 -4
  58. pygsti/evotypes/stabilizer_slow/statereps.py +0 -4
  59. pygsti/evotypes/statevec/effectreps.cp310-win_amd64.pyd +0 -0
  60. pygsti/evotypes/statevec/effectreps.pyx +1 -1
  61. pygsti/evotypes/statevec/opreps.cp310-win_amd64.pyd +0 -0
  62. pygsti/evotypes/statevec/opreps.pyx +2 -2
  63. pygsti/evotypes/statevec/statereps.cp310-win_amd64.pyd +0 -0
  64. pygsti/evotypes/statevec/statereps.pyx +1 -1
  65. pygsti/evotypes/statevec/termreps.cp310-win_amd64.pyd +0 -0
  66. pygsti/evotypes/statevec/termreps.pyx +0 -7
  67. pygsti/evotypes/statevec_slow/effectreps.py +0 -3
  68. pygsti/evotypes/statevec_slow/opreps.py +0 -5
  69. pygsti/extras/__init__.py +0 -1
  70. pygsti/extras/drift/stabilityanalyzer.py +3 -1
  71. pygsti/extras/interpygate/__init__.py +12 -0
  72. pygsti/extras/interpygate/core.py +0 -36
  73. pygsti/extras/interpygate/process_tomography.py +44 -10
  74. pygsti/extras/rpe/rpeconstruction.py +0 -2
  75. pygsti/forwardsims/__init__.py +1 -0
  76. pygsti/forwardsims/forwardsim.py +14 -55
  77. pygsti/forwardsims/mapforwardsim.py +69 -18
  78. pygsti/forwardsims/mapforwardsim_calc_densitymx.cp310-win_amd64.pyd +0 -0
  79. pygsti/forwardsims/mapforwardsim_calc_densitymx.pyx +65 -66
  80. pygsti/forwardsims/mapforwardsim_calc_generic.py +91 -13
  81. pygsti/forwardsims/matrixforwardsim.py +63 -15
  82. pygsti/forwardsims/termforwardsim.py +8 -110
  83. pygsti/forwardsims/termforwardsim_calc_stabilizer.cp310-win_amd64.pyd +0 -0
  84. pygsti/forwardsims/termforwardsim_calc_statevec.cp310-win_amd64.pyd +0 -0
  85. pygsti/forwardsims/termforwardsim_calc_statevec.pyx +0 -651
  86. pygsti/forwardsims/torchfwdsim.py +265 -0
  87. pygsti/forwardsims/weakforwardsim.py +2 -2
  88. pygsti/io/__init__.py +1 -2
  89. pygsti/io/mongodb.py +0 -2
  90. pygsti/io/stdinput.py +6 -22
  91. pygsti/layouts/copalayout.py +10 -12
  92. pygsti/layouts/distlayout.py +0 -40
  93. pygsti/layouts/maplayout.py +103 -25
  94. pygsti/layouts/matrixlayout.py +99 -60
  95. pygsti/layouts/prefixtable.py +1534 -52
  96. pygsti/layouts/termlayout.py +1 -1
  97. pygsti/modelmembers/instruments/instrument.py +3 -3
  98. pygsti/modelmembers/instruments/tpinstrument.py +2 -2
  99. pygsti/modelmembers/modelmember.py +0 -17
  100. pygsti/modelmembers/operations/__init__.py +2 -4
  101. pygsti/modelmembers/operations/affineshiftop.py +1 -0
  102. pygsti/modelmembers/operations/composederrorgen.py +1 -1
  103. pygsti/modelmembers/operations/composedop.py +1 -24
  104. pygsti/modelmembers/operations/denseop.py +5 -5
  105. pygsti/modelmembers/operations/eigpdenseop.py +2 -2
  106. pygsti/modelmembers/operations/embeddederrorgen.py +1 -1
  107. pygsti/modelmembers/operations/embeddedop.py +0 -1
  108. pygsti/modelmembers/operations/experrorgenop.py +2 -2
  109. pygsti/modelmembers/operations/fullarbitraryop.py +1 -0
  110. pygsti/modelmembers/operations/fullcptpop.py +2 -2
  111. pygsti/modelmembers/operations/fulltpop.py +28 -6
  112. pygsti/modelmembers/operations/fullunitaryop.py +5 -4
  113. pygsti/modelmembers/operations/lindbladcoefficients.py +93 -78
  114. pygsti/modelmembers/operations/lindbladerrorgen.py +268 -441
  115. pygsti/modelmembers/operations/linearop.py +7 -27
  116. pygsti/modelmembers/operations/opfactory.py +1 -1
  117. pygsti/modelmembers/operations/repeatedop.py +1 -24
  118. pygsti/modelmembers/operations/staticstdop.py +1 -1
  119. pygsti/modelmembers/povms/__init__.py +3 -3
  120. pygsti/modelmembers/povms/basepovm.py +7 -36
  121. pygsti/modelmembers/povms/complementeffect.py +4 -9
  122. pygsti/modelmembers/povms/composedeffect.py +0 -320
  123. pygsti/modelmembers/povms/computationaleffect.py +1 -1
  124. pygsti/modelmembers/povms/computationalpovm.py +3 -1
  125. pygsti/modelmembers/povms/effect.py +3 -5
  126. pygsti/modelmembers/povms/marginalizedpovm.py +0 -79
  127. pygsti/modelmembers/povms/tppovm.py +74 -2
  128. pygsti/modelmembers/states/__init__.py +2 -5
  129. pygsti/modelmembers/states/composedstate.py +0 -317
  130. pygsti/modelmembers/states/computationalstate.py +3 -3
  131. pygsti/modelmembers/states/cptpstate.py +4 -4
  132. pygsti/modelmembers/states/densestate.py +6 -4
  133. pygsti/modelmembers/states/fullpurestate.py +0 -24
  134. pygsti/modelmembers/states/purestate.py +1 -1
  135. pygsti/modelmembers/states/state.py +5 -6
  136. pygsti/modelmembers/states/tpstate.py +28 -10
  137. pygsti/modelmembers/term.py +3 -6
  138. pygsti/modelmembers/torchable.py +50 -0
  139. pygsti/modelpacks/_modelpack.py +1 -1
  140. pygsti/modelpacks/smq1Q_ZN.py +3 -1
  141. pygsti/modelpacks/smq2Q_XXYYII.py +2 -1
  142. pygsti/modelpacks/smq2Q_XY.py +3 -3
  143. pygsti/modelpacks/smq2Q_XYI.py +2 -2
  144. pygsti/modelpacks/smq2Q_XYICNOT.py +3 -3
  145. pygsti/modelpacks/smq2Q_XYICPHASE.py +3 -3
  146. pygsti/modelpacks/smq2Q_XYXX.py +1 -1
  147. pygsti/modelpacks/smq2Q_XYZICNOT.py +3 -3
  148. pygsti/modelpacks/smq2Q_XYZZ.py +1 -1
  149. pygsti/modelpacks/stdtarget.py +0 -121
  150. pygsti/models/cloudnoisemodel.py +1 -2
  151. pygsti/models/explicitcalc.py +3 -3
  152. pygsti/models/explicitmodel.py +3 -13
  153. pygsti/models/fogistore.py +5 -3
  154. pygsti/models/localnoisemodel.py +1 -2
  155. pygsti/models/memberdict.py +0 -12
  156. pygsti/models/model.py +800 -65
  157. pygsti/models/modelconstruction.py +4 -4
  158. pygsti/models/modelnoise.py +2 -2
  159. pygsti/models/modelparaminterposer.py +1 -1
  160. pygsti/models/oplessmodel.py +1 -1
  161. pygsti/models/qutrit.py +15 -14
  162. pygsti/objectivefns/objectivefns.py +73 -138
  163. pygsti/objectivefns/wildcardbudget.py +2 -7
  164. pygsti/optimize/__init__.py +1 -0
  165. pygsti/optimize/arraysinterface.py +28 -0
  166. pygsti/optimize/customcg.py +0 -12
  167. pygsti/optimize/customlm.py +129 -323
  168. pygsti/optimize/customsolve.py +2 -2
  169. pygsti/optimize/optimize.py +0 -84
  170. pygsti/optimize/simplerlm.py +841 -0
  171. pygsti/optimize/wildcardopt.py +19 -598
  172. pygsti/protocols/confidenceregionfactory.py +28 -14
  173. pygsti/protocols/estimate.py +31 -14
  174. pygsti/protocols/gst.py +142 -68
  175. pygsti/protocols/modeltest.py +6 -10
  176. pygsti/protocols/protocol.py +9 -37
  177. pygsti/protocols/rb.py +450 -79
  178. pygsti/protocols/treenode.py +8 -2
  179. pygsti/protocols/vb.py +108 -206
  180. pygsti/protocols/vbdataframe.py +1 -1
  181. pygsti/report/factory.py +0 -15
  182. pygsti/report/fogidiagram.py +1 -17
  183. pygsti/report/modelfunction.py +12 -3
  184. pygsti/report/mpl_colormaps.py +1 -1
  185. pygsti/report/plothelpers.py +8 -2
  186. pygsti/report/reportables.py +41 -37
  187. pygsti/report/templates/offline/pygsti_dashboard.css +6 -0
  188. pygsti/report/templates/offline/pygsti_dashboard.js +12 -0
  189. pygsti/report/workspace.py +2 -14
  190. pygsti/report/workspaceplots.py +326 -504
  191. pygsti/tools/basistools.py +9 -36
  192. pygsti/tools/edesigntools.py +124 -96
  193. pygsti/tools/fastcalc.cp310-win_amd64.pyd +0 -0
  194. pygsti/tools/fastcalc.pyx +35 -81
  195. pygsti/tools/internalgates.py +151 -15
  196. pygsti/tools/jamiolkowski.py +5 -5
  197. pygsti/tools/lindbladtools.py +19 -11
  198. pygsti/tools/listtools.py +0 -114
  199. pygsti/tools/matrixmod2.py +1 -1
  200. pygsti/tools/matrixtools.py +173 -339
  201. pygsti/tools/nameddict.py +1 -1
  202. pygsti/tools/optools.py +154 -88
  203. pygsti/tools/pdftools.py +0 -25
  204. pygsti/tools/rbtheory.py +3 -320
  205. pygsti/tools/slicetools.py +64 -12
  206. pyGSTi-0.9.12.1.dist-info/METADATA +0 -155
  207. pygsti/algorithms/directx.py +0 -711
  208. pygsti/evotypes/qibo/__init__.py +0 -33
  209. pygsti/evotypes/qibo/effectreps.py +0 -78
  210. pygsti/evotypes/qibo/opreps.py +0 -376
  211. pygsti/evotypes/qibo/povmreps.py +0 -98
  212. pygsti/evotypes/qibo/statereps.py +0 -174
  213. pygsti/extras/rb/__init__.py +0 -13
  214. pygsti/extras/rb/benchmarker.py +0 -957
  215. pygsti/extras/rb/dataset.py +0 -378
  216. pygsti/extras/rb/io.py +0 -814
  217. pygsti/extras/rb/simulate.py +0 -1020
  218. pygsti/io/legacyio.py +0 -385
  219. pygsti/modelmembers/povms/denseeffect.py +0 -142
  220. {pyGSTi-0.9.12.1.dist-info → pyGSTi-0.9.13.dist-info}/LICENSE +0 -0
  221. {pyGSTi-0.9.12.1.dist-info → pyGSTi-0.9.13.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,10 @@ Defines the PrefixTable class.
11
11
  #***************************************************************************************************
12
12
 
13
13
  import collections as _collections
14
-
14
+ import networkx as _nx
15
+ import matplotlib.pyplot as plt
16
+ from math import ceil
17
+ from pygsti.baseobjs import Label as _Label
15
18
  from pygsti.circuits.circuit import SeparatePOVMCircuit as _SeparatePOVMCircuit
16
19
 
17
20
 
@@ -38,6 +41,14 @@ class PrefixTable(object):
38
41
  `iDest` is always in the range [0,len(circuits_to_evaluate)-1], and
39
42
  indexes the result computed for each of the circuits.
40
43
 
44
+ Parameters
45
+ ----------
46
+
47
+
48
+ circuit_parameter_sensitivities :
49
+ A map between the circuits in circuits_to_evaluate and the indices of the model parameters
50
+ to which these circuits depend.
51
+
41
52
  Returns
42
53
  -------
43
54
  tuple
@@ -45,19 +56,23 @@ class PrefixTable(object):
45
56
  of tuples as given above and `cache_size` is the total size of the state
46
57
  cache used to hold intermediate results.
47
58
  """
59
+
48
60
  #Sort the operation sequences "alphabetically", so that it's trivial to find common prefixes
49
- circuits_to_evaluate_fastlookup = {i: cir for i, cir in enumerate(circuits_to_evaluate)}
50
61
  circuits_to_sort_by = [cir.circuit_without_povm if isinstance(cir, _SeparatePOVMCircuit) else cir
51
62
  for cir in circuits_to_evaluate] # always Circuits - not SeparatePOVMCircuits
52
- sorted_circuits_to_sort_by = sorted(list(enumerate(circuits_to_sort_by)), key=lambda x: x[1])
53
- sorted_circuits_to_evaluate = [(i, circuits_to_evaluate_fastlookup[i]) for i, _ in sorted_circuits_to_sort_by]
63
+ #with the current logic in _build_table a candidate circuit is only treated as a possible prefix if
64
+ #it is shorter than the one it is being evaluated as a prefix for. So it should work to sort these
65
+ #circuits by length for the purposes of the current logic.
66
+ sorted_circuits_to_sort_by = sorted(list(enumerate(circuits_to_sort_by)), key=lambda x: len(x[1]))
67
+ orig_indices, sorted_circuits_to_evaluate = zip(*[(i, circuits_to_evaluate[i]) for i, _ in sorted_circuits_to_sort_by])
68
+
69
+ self.sorted_circuits_to_evaluate = sorted_circuits_to_evaluate
70
+ self.orig_indices = orig_indices
71
+
72
+ #get the circuits in a form readily usable for comparisons
73
+ circuit_reps, circuit_lens = _circuits_to_compare(sorted_circuits_to_evaluate)
74
+ self.circuit_reps = circuit_reps
54
75
 
55
- distinct_line_labels = set([cir.line_labels for cir in circuits_to_sort_by])
56
- if len(distinct_line_labels) == 1: # if all circuits have the *same* line labels, we can just compare tuples
57
- circuit_reps_to_compare_and_lengths = {i: (cir.layertup, len(cir))
58
- for i, cir in enumerate(circuits_to_sort_by)}
59
- else:
60
- circuit_reps_to_compare_and_lengths = {i: (cir, len(cir)) for i, cir in enumerate(circuits_to_sort_by)}
61
76
 
62
77
  if max_cache_size is None or max_cache_size > 0:
63
78
  #CACHE assessment pass: figure out what's worth keeping in the cache.
@@ -66,48 +81,13 @@ class PrefixTable(object):
66
81
  # Not: this logic could be much better, e.g. computing a cost savings for each
67
82
  # potentially-cached item and choosing the best ones, and proper accounting
68
83
  # for chains of cached items.
69
- cacheIndices = [] # indices into circuits_to_evaluate of the results to cache
70
- cache_hits = _collections.defaultdict(lambda: 0)
71
-
72
- for i, _ in sorted_circuits_to_evaluate:
73
- circuit, L = circuit_reps_to_compare_and_lengths[i] # can be a Circuit or a label tuple
74
- for cached_index in reversed(cacheIndices):
75
- candidate, Lc = circuit_reps_to_compare_and_lengths[cached_index]
76
- if L >= Lc > 0 and circuit[0:Lc] == candidate: # a cache hit!
77
- cache_hits[cached_index] += 1
78
- break # stop looking through cache
79
- cacheIndices.append(i) # cache *everything* in this pass
80
-
81
- # Build prefix table: construct list, only caching items with hits > 0 (up to max_cache_size)
82
- cacheIndices = [] # indices into circuits_to_evaluate of the results to cache
83
- table_contents = []
84
- curCacheSize = 0
85
-
86
- for i, circuit in sorted_circuits_to_evaluate:
87
- circuit_rep, L = circuit_reps_to_compare_and_lengths[i]
88
-
89
- #find longest existing prefix for circuit by working backwards
90
- # and finding the first string that *is* a prefix of this string
91
- # (this will necessarily be the longest prefix, given the sorting)
92
- for i_in_cache in range(curCacheSize - 1, -1, -1): # from curCacheSize-1 -> 0
93
- candidate, Lc = circuit_reps_to_compare_and_lengths[cacheIndices[i_in_cache]]
94
- if L >= Lc > 0 and circuit_rep[0:Lc] == candidate: # ">=" allows for duplicates
95
- iStart = i_in_cache # an index into the *cache*, not into circuits_to_evaluate
96
- remaining = circuit[Lc:] # *always* a SeparatePOVMCircuit or Circuit
97
- break
98
- else: # no break => no prefix
99
- iStart = None
100
- remaining = circuit[:]
101
-
102
- # if/where this string should get stored in the cache
103
- if (max_cache_size is None or curCacheSize < max_cache_size) and cache_hits.get(i, 0) > 0:
104
- iCache = len(cacheIndices)
105
- cacheIndices.append(i); curCacheSize += 1
106
- else: # don't store in the cache
107
- iCache = None
84
+ cache_hits = _cache_hits(self.circuit_reps, circuit_lens)
85
+ else:
86
+ cache_hits = [None]*len(self.circuit_reps)
108
87
 
109
- #Add instruction for computing this circuit
110
- table_contents.append((i, iStart, remaining, iCache))
88
+ table_contents, curCacheSize = _build_table(sorted_circuits_to_evaluate, cache_hits,
89
+ max_cache_size, self.circuit_reps, circuit_lens,
90
+ orig_indices)
111
91
 
112
92
  #FUTURE: could perform a second pass, and if there is
113
93
  # some threshold number of elements which share the
@@ -118,9 +98,196 @@ class PrefixTable(object):
118
98
  # order.
119
99
  self.contents = table_contents
120
100
  self.cache_size = curCacheSize
101
+ self.circuits_evaluated = circuits_to_sort_by
102
+
121
103
 
122
104
  def __len__(self):
123
105
  return len(self.contents)
106
+
107
+ def num_state_propagations(self):
108
+ """
109
+ Return the number of state propagation operations (excluding the action of POVM effects)
110
+ required for the evaluation strategy given by this PrefixTable.
111
+ """
112
+ return sum(self.num_state_propagations_by_circuit().values())
113
+
114
+ def num_state_propagations_by_circuit(self):
115
+ """
116
+ Return the number of state propagation operations per-circuit
117
+ (excluding the action of POVM effects) required for the evaluation strategy
118
+ given by this PrefixTable, returned as a dictionary with keys corresponding to
119
+ circuits and values corresponding to the number of state propagations
120
+ required for that circuit.
121
+ """
122
+ state_props_by_circuit = {}
123
+ for i, istart, remainder, _ in self.contents:
124
+ if len(self.circuits_evaluated[i][0])>0 and self.circuits_evaluated[i][0] == _Label('rho0') and istart is None:
125
+ state_props_by_circuit[self.circuits_evaluated[i]] = len(remainder)-1
126
+ else:
127
+ state_props_by_circuit[self.circuits_evaluated[i]] = len(remainder)
128
+
129
+ return state_props_by_circuit
130
+
131
+ def num_state_propagations_by_circuit_no_caching(self):
132
+ """
133
+ Return the number of state propagation operations per-circuit
134
+ (excluding the action of POVM effects) required for an evaluation strategy
135
+ without caching, returned as a dictionary with keys corresponding to
136
+ circuits and values corresponding to the number of state propagations
137
+ required for that circuit.
138
+ """
139
+ state_props_by_circuit = {}
140
+ for circuit in self.circuits_evaluated:
141
+ if len(circuit)>0 and circuit[0] == _Label('rho0'):
142
+ state_props_by_circuit[circuit] = len(circuit[1:])
143
+ else:
144
+ state_props_by_circuit[circuit] = len(circuit)
145
+ return state_props_by_circuit
146
+
147
+ def num_state_propagations_no_caching(self):
148
+ """
149
+ Return the total number of state propagation operations
150
+ (excluding the action of POVM effects) required for an evaluation strategy
151
+ without caching.
152
+ """
153
+ return sum(self.num_state_propagations_by_circuit_no_caching().values())
154
+
155
+ def find_splitting_new(self, max_sub_table_size=None, num_sub_tables=None, initial_cost_metric='size',
156
+ rebalancing_cost_metric='propagations', imbalance_threshold=1.2, minimum_improvement_threshold=.1,
157
+ verbosity=0):
158
+ """
159
+ Find a partition of the indices of this table to define a set of sub-tables with the desire properties.
160
+
161
+ This is done in order to reduce the maximum size of any tree (useful for
162
+ limiting memory consumption or for using multiple cores). Must specify
163
+ either max_sub_tree_size or num_sub_trees.
164
+
165
+ Parameters
166
+ ----------
167
+ max_sub_table_size : int, optional
168
+ The maximum size (i.e. list length) of each sub-table. If the
169
+ original table is smaller than this size, no splitting will occur.
170
+ If None, then there is no limit.
171
+
172
+ num_sub_tables : int, optional
173
+ The maximum size (i.e. list length) of each sub-table. If the
174
+ original table is smaller than this size, no splitting will occur.
175
+
176
+ imbalance_threshold : float, optional (default 1.2)
177
+ This number serves as a tolerance parameter for a final load balancing refinement
178
+ to the splitting. The value coresponds to a threshold value of the ratio of the heaviest
179
+ to the lightest subtree such that ratios below this value are considered sufficiently
180
+ balanced and processing stops.
181
+
182
+ minimum_improvement_threshold : float, optional (default .1)
183
+ A parameter for the final load balancing refinement process that sets a minimum balance
184
+ improvement (improvement to the ratio of the sizes of two subtrees) such that a rebalancing
185
+ step is considered worth performing (even if it would otherwise bring the imbalance parameter
186
+ described above in `imbalance_threshold` below the target value) .
187
+
188
+ verbosity : int, optional (default 0)
189
+ How much detail to send to stdout.
190
+
191
+ Returns
192
+ -------
193
+ list
194
+ A list of sets of elements to place in sub-tables.
195
+ """
196
+
197
+ table_contents = self.contents
198
+ if max_sub_table_size is None and num_sub_tables is None:
199
+ return [set(range(len(table_contents)))] # no splitting needed
200
+
201
+ if max_sub_table_size is not None and num_sub_tables is not None:
202
+ raise ValueError("Cannot specify both max_sub_table_size and num_sub_tables")
203
+ if num_sub_tables is not None and num_sub_tables <= 0:
204
+ raise ValueError("Error: num_sub_tables must be > 0!")
205
+
206
+ #Don't split at all if it's unnecessary
207
+ if max_sub_table_size is None or len(table_contents) < max_sub_table_size:
208
+ if num_sub_tables is None or num_sub_tables == 1:
209
+ return [set(range(len(table_contents)))]
210
+
211
+ #construct a tree structure describing the prefix strucure of the circuit set.
212
+ circuit_tree = _build_prefix_tree(self.sorted_circuits_to_evaluate, self.circuit_reps, self.orig_indices)
213
+ circuit_tree_nx = circuit_tree.to_networkx_graph()
214
+
215
+ if num_sub_tables is not None:
216
+ max_max_sub_table_size = len(self.sorted_circuits_to_evaluate)
217
+ initial_max_sub_table_size = ceil(len(self.sorted_circuits_to_evaluate)/num_sub_tables)
218
+ cut_edges, new_roots, tree_levels, subtree_weights = tree_partition_kundu_misra(circuit_tree_nx, max_weight=initial_max_sub_table_size,
219
+ weight_key= 'cost' if initial_cost_metric=='size' else 'prop_cost',
220
+ return_levels_and_weights=True)
221
+
222
+ if len(new_roots) > num_sub_tables: #iteratively row the maximum subtree size until we either hit or are less than the target.
223
+ last_seen_sub_max_sub_table_size_val = None
224
+ feasible_range = [initial_max_sub_table_size+1, max_max_sub_table_size-1]
225
+ #bisect on max_sub_table_size until we find the smallest value for which len(new_roots) <= num_sub_tables
226
+ while feasible_range[0] < feasible_range[1]:
227
+ current_max_sub_table_size = (feasible_range[0] + feasible_range[1])//2
228
+ cut_edges, new_roots = tree_partition_kundu_misra(circuit_tree_nx, max_weight=current_max_sub_table_size,
229
+ weight_key='cost' if initial_cost_metric=='size' else 'prop_cost',
230
+ test_leaves=False, precomp_levels=tree_levels, precomp_weights=subtree_weights)
231
+ if len(new_roots) > num_sub_tables:
232
+ feasible_range[0] = current_max_sub_table_size+1
233
+ else:
234
+ last_seen_sub_max_sub_table_size_val = (cut_edges, new_roots) #In the multiple root setting I am seeing some strange
235
+ #non-monotonicity, so add this as a fall back in case the final result anomalously has len(roots)>num_sub_tables
236
+ feasible_range[1] = current_max_sub_table_size
237
+ if len(new_roots)>num_sub_tables and last_seen_sub_max_sub_table_size_val is not None: #fallback
238
+ cut_edges, new_roots = last_seen_sub_max_sub_table_size_val
239
+
240
+ #only apply the cuts now that we have found our starting point.
241
+ partitioned_tree = _copy_networkx_graph(circuit_tree_nx)
242
+ #update the propagation cost attribute of the promoted nodes.
243
+ #only do this at this point to reduce the need for copying
244
+ for edge in cut_edges:
245
+ partitioned_tree.nodes[edge[1]]['prop_cost'] += partitioned_tree.edges[edge[0], edge[1]]['promotion_cost']
246
+ partitioned_tree.remove_edges_from(cut_edges)
247
+
248
+ #if we have hit the number of partitions, great, we're done!
249
+ if len(new_roots) == num_sub_tables:
250
+ #only apply the cuts now that we have found our starting point.
251
+ partitioned_tree = _copy_networkx_graph(circuit_tree_nx)
252
+ #update the propagation cost attribute of the promoted nodes.
253
+ #only do this at this point to reduce the need for copying
254
+ for edge in cut_edges:
255
+ partitioned_tree.nodes[edge[1]]['prop_cost'] += partitioned_tree.edges[edge[0], edge[1]]['promotion_cost']
256
+ partitioned_tree.remove_edges_from(cut_edges)
257
+ pass
258
+ #if we have fewer subtables then we need to look whether or not we should strictly
259
+ #hit the number of partitions, or whether we allow for fewer than the requested number to be returned.
260
+ if len(new_roots) < num_sub_tables:
261
+ #Perform bisection operations on the heaviest subtrees until we hit the target number.
262
+ while len(new_roots) < num_sub_tables:
263
+ partitioned_tree, new_roots, cut_edges = _bisection_pass(partitioned_tree, cut_edges, new_roots, num_sub_tables,
264
+ weight_key='cost' if rebalancing_cost_metric=='size' else 'prop_cost')
265
+ #add in a final refinement pass to improve the balancing across subtrees.
266
+ partitioned_tree, new_roots, addl_cut_edges = _refinement_pass(partitioned_tree, new_roots,
267
+ weight_key='cost' if rebalancing_cost_metric=='size' else 'prop_cost',
268
+ imbalance_threshold= imbalance_threshold,
269
+ minimum_improvement_threshold= minimum_improvement_threshold)
270
+ else:
271
+ cut_edges, new_roots = tree_partition_kundu_misra(circuit_tree_nx, max_weight = max_sub_table_size,
272
+ weight_key='cost' if initial_cost_metric=='size' else 'prop_cost')
273
+ partitioned_tree = _copy_networkx_graph(circuit_tree_nx)
274
+ for edge in cut_edges:
275
+ partitioned_tree.nodes[edge[1]]['prop_cost'] += partitioned_tree.edges[edge[0], edge[1]]['promotion_cost']
276
+ partitioned_tree.remove_edges_from(cut_edges)
277
+
278
+ #Collect the original circuit indices for each of the parititioned subtrees.
279
+ orig_index_groups = []
280
+ for root in new_roots:
281
+ if isinstance(root,tuple):
282
+ ckts = []
283
+ for elem in root:
284
+ ckts.extend(_collect_orig_indices(partitioned_tree, elem))
285
+ orig_index_groups.append(ckts)
286
+ else:
287
+ orig_index_groups.append(_collect_orig_indices(partitioned_tree, root))
288
+
289
+ return orig_index_groups
290
+
124
291
 
125
292
  def find_splitting(self, max_sub_table_size=None, num_sub_tables=None, cost_metric="size", verbosity=0):
126
293
  """
@@ -182,7 +349,7 @@ class PrefixTable(object):
182
349
  over the course of the iteration.
183
350
  """
184
351
 
185
- if cost_metric == "applys":
352
+ if cost_metric == "applies":
186
353
  def cost_fn(rem): return len(rem) # length of remainder = #-apply ops needed
187
354
  elif cost_metric == "size":
188
355
  def cost_fn(rem): return 1 # everything costs 1 in size of table
@@ -333,3 +500,1318 @@ class PrefixTable(object):
333
500
 
334
501
  assert(sum(map(len, subTableSetList)) == len(self)), "sub-table sets are not disjoint!"
335
502
  return subTableSetList
503
+
504
+
505
+ class PrefixTableJacobian(object):
506
+ """
507
+ An ordered list ("table") of circuits to evaluate, where common prefixes can be cached.
508
+ Specialized for purposes of jacobian calculations.
509
+
510
+ """
511
+
512
+ def __init__(self, circuits_to_evaluate, max_cache_size, parameter_circuit_dependencies=None):
513
+ """
514
+ Creates a "prefix table" for evaluating a set of circuits.
515
+
516
+ The table is list of tuples, where each element contains
517
+ instructions for evaluating a particular operation sequence:
518
+
519
+ (iDest, iStart, tuple_of_following_items, iCache)
520
+
521
+ Means that circuit[iDest] = cached_circuit[iStart] + tuple_of_following_items,
522
+ and that the resulting state should be stored at cache index iCache (for
523
+ later reference as an iStart value). The ordering of the returned list
524
+ specifies the evaluation order.
525
+
526
+ `iDest` is always in the range [0,len(circuits_to_evaluate)-1], and
527
+ indexes the result computed for each of the circuits.
528
+
529
+ Parameters
530
+ ----------
531
+
532
+
533
+ circuit_parameter_sensitivities :
534
+ A map between the circuits in circuits_to_evaluate and the indices of the model parameters
535
+ to which these circuits depend.
536
+
537
+ Returns
538
+ -------
539
+ tuple
540
+ A tuple of `(table_contents, cache_size)` where `table_contents` is a list
541
+ of tuples as given above and `cache_size` is the total size of the state
542
+ cache used to hold intermediate results.
543
+ """
544
+ #Sort the operation sequences "alphabetically", so that it's trivial to find common prefixes
545
+ circuits_to_sort_by = [cir.circuit_without_povm if isinstance(cir, _SeparatePOVMCircuit) else cir
546
+ for cir in circuits_to_evaluate] # always Circuits - not SeparatePOVMCircuits
547
+ sorted_circuits_to_sort_by = sorted(list(enumerate(circuits_to_sort_by)), key=lambda x: len(x[1]))
548
+ sorted_circuits_to_evaluate = [(i, circuits_to_evaluate[i]) for i, _ in sorted_circuits_to_sort_by]
549
+ #create a map from sorted_circuits_to_sort_by by can be used to quickly sort each of the parameter
550
+ #dependency lists.
551
+ fast_sorting_map = {circuits_to_evaluate[i]:j for j, (i, _) in enumerate(sorted_circuits_to_sort_by)}
552
+
553
+ #also need a map from circuits to their original indices in circuits_to_evaluate
554
+ #for the purpose of setting the correct destination indices in the evaluation instructions.
555
+ circuit_to_orig_index_map = {circuit: i for i,circuit in enumerate(circuits_to_evaluate)}
556
+
557
+ #use this map to sort the parameter_circuit_dependencies sublists.
558
+ sorted_parameter_circuit_dependencies = []
559
+ sorted_parameter_circuit_dependencies_orig_indices = []
560
+ for sublist in parameter_circuit_dependencies:
561
+ sorted_sublist = [None]*len(sorted_circuits_to_evaluate)
562
+ for ckt in sublist:
563
+ sorted_sublist[fast_sorting_map[ckt]] = ckt
564
+
565
+ #filter out instances of None to get the correctly sized and sorted
566
+ #sublist.
567
+ filtered_sorted_sublist = [val for val in sorted_sublist if val is not None]
568
+ orig_index_sublist = [circuit_to_orig_index_map[ckt] for ckt in filtered_sorted_sublist]
569
+
570
+ sorted_parameter_circuit_dependencies.append(filtered_sorted_sublist)
571
+ sorted_parameter_circuit_dependencies_orig_indices.append(orig_index_sublist)
572
+
573
+ sorted_circuit_reps = []
574
+ sorted_circuit_lengths = []
575
+ for sublist in sorted_parameter_circuit_dependencies:
576
+ circuit_reps, circuit_lengths = _circuits_to_compare(sublist)
577
+ sorted_circuit_reps.append(circuit_reps)
578
+ sorted_circuit_lengths.append(circuit_lengths)
579
+
580
+ #Intuition: The sorted circuit lists should likely break into equivalence classes, wherein multiple
581
+ #parameters will have the same dependent circuits. This is because in typical models parameters
582
+ #appear in blocks corresponding to a particular gate label, and so most of the time it should be the
583
+ #case that the list fractures into all those circuits containing a particular label.
584
+ #This intuition probably breaks down for ImplicitOpModels with complicated layer rules for which
585
+ #the breaking into equivalence classes may have limited savings.
586
+ unique_parameter_circuit_dependency_classes = {}
587
+ for i, sublist in enumerate(sorted_circuit_reps):
588
+ if unique_parameter_circuit_dependency_classes.get(sublist, None) is None:
589
+ unique_parameter_circuit_dependency_classes[sublist] = [i]
590
+ else:
591
+ unique_parameter_circuit_dependency_classes[sublist].append(i)
592
+
593
+ self.unique_parameter_circuit_dependency_classes = unique_parameter_circuit_dependency_classes
594
+
595
+ #the keys of the dictionary already give the needed circuit rep lists for
596
+ #each class, also grab the appropriate list of length for each class.
597
+ sorted_circuit_lengths_by_class = [sorted_circuit_lengths[class_indices[0]]
598
+ for class_indices in unique_parameter_circuit_dependency_classes.values()]
599
+
600
+ #also need representatives fo the entries in sorted_parameter_circuit_dependencies for each class,
601
+ #and for sorted_parameter_circuit_dependencies_orig_indices
602
+ sorted_parameter_circuit_dependencies_by_class = [sorted_parameter_circuit_dependencies[class_indices[0]]
603
+ for class_indices in unique_parameter_circuit_dependency_classes.values()]
604
+ sorted_parameter_circuit_dependencies_orig_indices_by_class = [sorted_parameter_circuit_dependencies_orig_indices[class_indices[0]]
605
+ for class_indices in unique_parameter_circuit_dependency_classes.values()]
606
+
607
+ #now we can just do the calculation for each of these equivalence classes.
608
+
609
+ #get the cache hits for all of the parameter circuit dependency sublists
610
+ if max_cache_size is None or max_cache_size > 0:
611
+ cache_hits_by_class = []
612
+ #CACHE assessment pass: figure out what's worth keeping in the cache.
613
+ # In this pass, we cache *everything* and keep track of how many times each
614
+ # original index (after it's cached) is utilized as a prefix for another circuit.
615
+ # Not: this logic could be much better, e.g. computing a cost savings for each
616
+ # potentially-cached item and choosing the best ones, and proper accounting
617
+ # for chains of cached items.
618
+ for circuit_reps, circuit_lengths in zip(unique_parameter_circuit_dependency_classes.keys(),
619
+ sorted_circuit_lengths_by_class):
620
+ cache_hits_by_class.append(_cache_hits(circuit_reps, circuit_lengths))
621
+ else:
622
+ cache_hits_by_class = [None]*len(unique_parameter_circuit_dependency_classes)
623
+
624
+ #next construct a prefix table for each sublist.
625
+ table_contents_by_class = []
626
+ cache_size_by_class = []
627
+ for sublist, cache_hits, circuit_reps, circuit_lengths, orig_indices in zip(sorted_parameter_circuit_dependencies_by_class,
628
+ cache_hits_by_class,
629
+ unique_parameter_circuit_dependency_classes.keys(),
630
+ sorted_circuit_lengths_by_class,
631
+ sorted_parameter_circuit_dependencies_orig_indices_by_class):
632
+ table_contents, curCacheSize = _build_table(sublist, cache_hits,
633
+ max_cache_size, circuit_reps, circuit_lengths,
634
+ orig_indices)
635
+ table_contents_by_class.append(table_contents)
636
+ cache_size_by_class.append(curCacheSize)
637
+
638
+ #FUTURE: could perform a second pass, and if there is
639
+ # some threshold number of elements which share the
640
+ # *same* iStart and the same beginning of the
641
+ # 'remaining' part then add a new "extra" element
642
+ # (beyond the #circuits index) which computes
643
+ # the shared prefix and insert this into the eval
644
+ # order.
645
+
646
+ #map back from equivalence classes to by parameter.
647
+ table_contents_by_parameter = [None]*len(parameter_circuit_dependencies)
648
+ cache_size_by_parameter = [None]*len(parameter_circuit_dependencies)
649
+ for table_contents, cache_size, param_class in zip(table_contents_by_class, cache_size_by_class,
650
+ unique_parameter_circuit_dependency_classes.values()):
651
+ for idx in param_class:
652
+ table_contents_by_parameter[idx] = table_contents
653
+ cache_size_by_parameter[idx] = cache_size
654
+
655
+ self.contents_by_parameter = table_contents_by_parameter
656
+ self.cache_size_by_parameter = cache_size_by_parameter
657
+ self.parameter_circuit_dependencies = sorted_parameter_circuit_dependencies
658
+
659
+
660
+ #---------Helper Functions------------#
661
+
662
+ def _circuits_to_compare(sorted_circuits_to_evaluate):
663
+
664
+ bare_circuits = [cir.circuit_without_povm if isinstance(cir, _SeparatePOVMCircuit) else cir
665
+ for cir in sorted_circuits_to_evaluate]
666
+ distinct_line_labels = set([cir.line_labels for cir in bare_circuits])
667
+
668
+ circuit_lens = [None]*len(sorted_circuits_to_evaluate)
669
+ if len(distinct_line_labels) == 1:
670
+ circuit_reps = [None]*len(sorted_circuits_to_evaluate)
671
+ for i, cir in enumerate(bare_circuits):
672
+ circuit_reps[i] = cir.layertup
673
+ circuit_lens[i] = len(circuit_reps[i])
674
+ else:
675
+ circuit_reps = bare_circuits
676
+ for i, cir in enumerate(sorted_circuits_to_evaluate):
677
+ circuit_lens[i] = len(circuit_reps[i])
678
+
679
+ return tuple(circuit_reps), tuple(circuit_lens)
680
+
681
+ def _cache_hits(circuit_reps, circuit_lengths):
682
+
683
+ #CACHE assessment pass: figure out what's worth keeping in the cache.
684
+ # In this pass, we cache *everything* and keep track of how many times each
685
+ # original index (after it's cached) is utilized as a prefix for another circuit.
686
+ # Not: this logic could be much better, e.g. computing a cost savings for each
687
+ # potentially-cached item and choosing the best ones, and proper accounting
688
+ # for chains of cached items.
689
+
690
+ cacheIndices = [] # indices into circuits_to_evaluate of the results to cache
691
+ cache_hits = [0]*len(circuit_reps)
692
+
693
+ for i in range(len(circuit_reps)):
694
+ circuit = circuit_reps[i]
695
+ L = circuit_lengths[i] # can be a Circuit or a label tuple
696
+ for cached_index in reversed(cacheIndices):
697
+ candidate = circuit_reps[cached_index]
698
+ Lc = circuit_lengths[cached_index]
699
+ if L >= Lc > 0 and circuit[0:Lc] == candidate: # a cache hit!
700
+ cache_hits[cached_index] += 1
701
+ break # stop looking through cache
702
+ cacheIndices.append(i) # cache *everything* in this pass
703
+
704
+ return cache_hits
705
+
706
+ def _build_table(sorted_circuits_to_evaluate, cache_hits, max_cache_size, circuit_reps, circuit_lengths,
707
+ orig_indices):
708
+
709
+ # Build prefix table: construct list, only caching items with hits > 0 (up to max_cache_size)
710
+ cacheIndices = [] # indices into circuits_to_evaluate of the results to cache
711
+ table_contents = [None]*len(sorted_circuits_to_evaluate)
712
+ curCacheSize = 0
713
+ for j, (i, _) in zip(orig_indices,enumerate(sorted_circuits_to_evaluate)):
714
+
715
+ circuit_rep = circuit_reps[i]
716
+ L = circuit_lengths[i]
717
+
718
+ #find longest existing prefix for circuit by working backwards
719
+ # and finding the first string that *is* a prefix of this string
720
+ # (this will necessarily be the longest prefix, given the sorting)
721
+ for i_in_cache in range(curCacheSize - 1, -1, -1): # from curCacheSize-1 -> 0
722
+ candidate = circuit_reps[cacheIndices[i_in_cache]]
723
+ Lc = circuit_lengths[cacheIndices[i_in_cache]]
724
+ if L >= Lc > 0 and circuit_rep[0:Lc] == candidate: # ">=" allows for duplicates
725
+ iStart = i_in_cache # an index into the *cache*, not into circuits_to_evaluate
726
+ remaining = circuit_rep[Lc:] # *always* a SeparatePOVMCircuit or Circuit
727
+ break
728
+ else: # no break => no prefix
729
+ iStart = None
730
+ remaining = circuit_rep
731
+
732
+ # if/where this string should get stored in the cache
733
+ if (max_cache_size is None or curCacheSize < max_cache_size) and cache_hits[i]:
734
+ iCache = len(cacheIndices)
735
+ cacheIndices.append(i); curCacheSize += 1
736
+ else: # don't store in the cache
737
+ iCache = None
738
+
739
+ #Add instruction for computing this circuit
740
+ table_contents[i] = (j, iStart, remaining, iCache)
741
+
742
+ return table_contents, curCacheSize
743
+
744
+ #helper method for building a tree showing the connections between different circuits
745
+ #for the purposes of prefix-based evaluation.
746
+ def _build_prefix_tree(sorted_circuits_to_evaluate, circuit_reps, orig_indices):
747
+ #assume the input circuits have already been sorted by length.
748
+ circuit_tree = Tree()
749
+ for j, (i, _) in zip(orig_indices,enumerate(sorted_circuits_to_evaluate)):
750
+ circuit_rep = circuit_reps[i]
751
+ #the first layer should be a state preparation. If this isn't in a root in the
752
+ #tree add it.
753
+ root_node = circuit_tree.get_root_node(circuit_rep[0])
754
+ if root_node is None and len(circuit_rep)>0:
755
+ #cost is the number of propagations, so exclude the initial state prep
756
+ root_node = RootNode(circuit_rep[0], cost=0)
757
+ circuit_tree.add_root(root_node)
758
+
759
+ current_node = root_node
760
+ for layerlbl in circuit_reps[i][1:]:
761
+ child_node = current_node.get_child_node(layerlbl)
762
+ if child_node is None:
763
+ child_node = ChildNode(layerlbl, parent=current_node)
764
+ current_node = child_node
765
+ #when we get to the end of the circuit add a pointer on the
766
+ #final node to the original index of this circuit in the
767
+ #circuit list.
768
+ current_node.add_orig_index(j)
769
+
770
+ return circuit_tree
771
+
772
+
773
+ #----------------------Helper classes for managing circuit evaluation tree. --------------#
774
+ class TreeNode:
775
+ def __init__(self, value, children=None, orig_indices=None):
776
+ """
777
+ Parameters
778
+ ----------
779
+ value : any
780
+ The value to be stored in the node.
781
+
782
+ children : list, optional (default is None)
783
+ A list of child nodes. If None, initializes an empty list.
784
+
785
+ orig_indices : list, optional (default is None)
786
+ A list of original indices. If None, initializes an empty list.
787
+ """
788
+ self.value = value
789
+ self.children = [] if children is None else children
790
+ self.orig_indices = [] if orig_indices is None else orig_indices #make this a list to allow for duplicates
791
+
792
+ def add_child(self, child_node):
793
+ """
794
+ Add a child node to the current node.
795
+
796
+ Parameters
797
+ ----------
798
+ child_node : TreeNode
799
+ The child node to be added.
800
+ """
801
+
802
+ self.children.append(child_node)
803
+
804
+ def remove_child(self, child_node):
805
+ """
806
+ Remove a child node from the current node.
807
+
808
+ Parameters
809
+ ----------
810
+ child_node : TreeNode
811
+ The child node to be removed.
812
+ """
813
+ self.children = [child for child in self.children if child is not child_node]
814
+
815
+ def get_child_node(self, value):
816
+ """
817
+ Get the child node associated with the input value. If that node is not present, return None.
818
+
819
+ Parameters
820
+ ----------
821
+ value : any
822
+ The value to search for in the child nodes.
823
+
824
+ Returns
825
+ -------
826
+ TreeNode or None
827
+ The child node with the specified value, or None if not found.
828
+ """
829
+
830
+ for node in self.children:
831
+ if node.value == value:
832
+ return node
833
+ #if we haven't returned already it is because there wasn't a corresponding root,
834
+ #so return None
835
+ return None
836
+
837
+ def add_orig_index(self, value):
838
+ """
839
+ Add an original index to the node.
840
+
841
+ Parameters
842
+ ----------
843
+ value : int
844
+ The original index to be added.
845
+ """
846
+ self.orig_indices.append(value)
847
+
848
+ def traverse(self):
849
+ """
850
+ Traverse the tree in pre-order and return a list of node values.
851
+
852
+ Returns
853
+ -------
854
+ list
855
+ A list of node values in pre-order traversal.
856
+ """
857
+
858
+ nodes = []
859
+ stack = [self]
860
+ while stack:
861
+ node = stack.pop()
862
+ nodes.append(node.value)
863
+ stack.extend(reversed(node.children)) # Add children to stack in reverse order for pre-order traversal
864
+ return nodes
865
+
866
+ def get_descendants(self):
867
+ """
868
+ Get all descendant node values of the current node in pre-order traversal.
869
+
870
+ Returns
871
+ -------
872
+ list
873
+ A list of descendant node values.
874
+ """
875
+ descendants = []
876
+ stack = self.children[:]
877
+ while stack:
878
+ node = stack.pop()
879
+ descendants.append(node.value)
880
+ stack.extend(reversed(node.children)) # Add children to stack in reverse order for pre-order traversal
881
+ return descendants
882
+
883
+ def total_orig_indices(self):
884
+ """
885
+ Calculate the total number of orig_indices values for this node and all of its descendants.
886
+ """
887
+ total = len(self.orig_indices)
888
+ for child in self.get_descendants():
889
+ total += len(child.orig_indices)
890
+ return total
891
+
892
+ def print_tree(self, level=0, prefix=""):
893
+ """
894
+ Print the tree structure starting from the current node.
895
+
896
+ Parameters
897
+ ----------
898
+ level : int, optional (default 0)
899
+ The current level in the tree.
900
+ prefix : str, optional (default "")
901
+ The prefix for the current level.
902
+ """
903
+ connector = "├── " if level > 0 else ""
904
+ print(prefix + connector + str(self.value) +', ' + str(self.orig_indices))
905
+ for i, child in enumerate(self.children):
906
+ if i == len(self.children) - 1:
907
+ child.print_tree(level + 1, prefix + (" " if level > 0 else ""))
908
+ else:
909
+ child.print_tree(level + 1, prefix + ("│ " if level > 0 else ""))
910
+
911
+ #create a class for RootNodes that includes additional initial cost information.
912
+ class RootNode(TreeNode):
913
+ """
914
+ Class for representing a root node for a tree, along with the corresponding metadata
915
+ specific to root nodes.
916
+ """
917
+
918
+ def __init__(self, value, cost=0, tree=None, children=None, orig_indices=None):
919
+ """
920
+ Initialize a RootNode with a value, optional cost, optional tree, optional children, and optional original indices.
921
+
922
+ Parameters
923
+ ----------
924
+ value : any
925
+ The value to be stored in the node.
926
+ cost : int, optional (default is 0)
927
+ The initial cost associated with the root node.
928
+ tree : Tree, optional (default is None)
929
+ The tree to which this root node belongs.
930
+ children : list, optional (default is None)
931
+ A list of child nodes. If None, initializes an empty list.
932
+ orig_indices : list, optional (default is None)
933
+ A list of original indices. If None, initializes an empty list.
934
+ """
935
+ super().__init__(value, children, orig_indices)
936
+ self.cost = cost
937
+ self.tree = tree
938
+
939
+ class ChildNode(TreeNode):
940
+ """
941
+ Class for representing a child node for a tree, along with the corresponding metadata
942
+ specific to child nodes.
943
+ """
944
+ def __init__(self, value, parent=None, children=None, orig_indices=None):
945
+ """
946
+ Parameters
947
+ ----------
948
+ value : any
949
+ The value to be stored in the node.
950
+ parent : TreeNode, optional (default is None)
951
+ The parent node.
952
+ children : list, optional (default is None)
953
+ A list of child nodes. If None, initializes an empty list.
954
+ orig_indices : list, optional (default is None)
955
+ A list of original indices. If None, initializes an empty list.
956
+ """
957
+ super().__init__(value, children, orig_indices)
958
+ self.parent = parent
959
+ if parent is not None:
960
+ parent.add_child(self)
961
+
962
+ def get_ancestors(self):
963
+ """
964
+ Get all ancestor nodes of the current node up to the root node.
965
+
966
+ Returns
967
+ -------
968
+ list
969
+ A list of ancestor nodes.
970
+ """
971
+ ancestors = []
972
+ node = self
973
+ while node:
974
+ ancestors.append(node)
975
+ if isinstance(node, RootNode):
976
+ break
977
+ node = node.parent
978
+ return ancestors
979
+
980
+ def calculate_promotion_cost(self):
981
+ """
982
+ Calculate the cost of promoting this child node to a root node. This
983
+ corresponds to the sum of the cost of this node's current root, plus
984
+ the total number of ancestors (less the root).
985
+ """
986
+ ancestors = self.get_ancestors()
987
+ ancestor_count = len(ancestors) - 1
988
+ current_root = self.get_root()
989
+ current_root_cost = current_root.cost
990
+ return ancestor_count + current_root_cost
991
+
992
+ def promote_to_root(self):
993
+ """
994
+ Promote this child node to a root node, updating the tree structure accordingly.
995
+ """
996
+ # Calculate the cost (I know this is code duplication, but in this case
997
+ #we need the intermediate values as well).
998
+ ancestors = self.get_ancestors()
999
+ ancestor_count = len(ancestors) - 1
1000
+ current_root = self.get_root()
1001
+ current_root_cost = current_root.cost
1002
+ new_root_cost = ancestor_count + current_root_cost
1003
+
1004
+ # Remove this node from its parent's children
1005
+ if self.parent:
1006
+ self.parent.remove_child(self)
1007
+
1008
+ # Create a new RootNode
1009
+ ancestor_values = [ancestor.value for ancestor in reversed(ancestors)]
1010
+ if isinstance(ancestor_values[0], tuple):
1011
+ ancestor_values = list(ancestor_values[0]) + ancestor_values[1:]
1012
+ new_root_value = tuple(ancestor_values)
1013
+ new_root = RootNode(new_root_value, cost=new_root_cost, tree=current_root.tree, children=self.children,
1014
+ orig_indices=self.orig_indices)
1015
+
1016
+ # Update the children of the new RootNode
1017
+ for child in new_root.children:
1018
+ child.parent = new_root
1019
+
1020
+ # Add the new RootNode to the tree
1021
+ if new_root.tree:
1022
+ new_root.tree.add_root(new_root)
1023
+
1024
+ # Delete this ChildNode
1025
+ del self
1026
+
1027
+ def get_root(self):
1028
+ """
1029
+ Get the root node of the current node.
1030
+
1031
+ Returns
1032
+ -------
1033
+ RootNode
1034
+ The root node of the current node.
1035
+ """
1036
+ node = self
1037
+ while node.parent and not isinstance(node.parent, RootNode):
1038
+ node = node.parent
1039
+ return node.parent
1040
+
1041
+ class Tree:
1042
+ """
1043
+ Container class for storing a tree structure (technically a forest, as there
1044
+ can be multiple roots).
1045
+ """
1046
+ def __init__(self, roots=None):
1047
+ """
1048
+ Parameters
1049
+ ----------
1050
+ roots: list of RootNode, optional (default None)
1051
+ List of roots for this tree structure.
1052
+ """
1053
+ self.roots = []
1054
+ self.root_set = set(self.roots)
1055
+
1056
+ def get_root_node(self, value):
1057
+ """
1058
+ Get the root node associated with the input value. If that node is not present, return None.
1059
+
1060
+ Parameters
1061
+ ----------
1062
+ value : any
1063
+ The value to search for in the root nodes.
1064
+
1065
+ Returns
1066
+ -------
1067
+ RootNode or None
1068
+ The root node with the specified value, or None if not found.
1069
+ """
1070
+
1071
+ for node in self.roots:
1072
+ if node.value == value:
1073
+ return node
1074
+ #if we haven't returned already it is because there wasn't a corresponding root,
1075
+ #so return None
1076
+ return None
1077
+
1078
+ def add_root(self, root_node):
1079
+ """
1080
+ Add a root node to the tree.
1081
+
1082
+ Parameters
1083
+ ----------
1084
+ root_node : RootNode
1085
+ The root node to be added.
1086
+ """
1087
+
1088
+ root_node.tree = self
1089
+ self.roots.append(root_node)
1090
+ self.root_set.add(root_node)
1091
+
1092
+ def remove_root(self, root_node):
1093
+ """
1094
+ Remove a root node from the tree.
1095
+
1096
+ Parameters
1097
+ ----------
1098
+ root_node : RootNode
1099
+ The root node to be removed.
1100
+ """
1101
+
1102
+ root_node.tree = None
1103
+ self.roots = [root for root in self.roots if root is not root_node]
1104
+
1105
+ def total_orig_indices(self):
1106
+ """
1107
+ Calculate the total number of original indices for all root nodes and their descendants.
1108
+ """
1109
+ return sum([root.total_orig_indices() for root in self.roots])
1110
+
1111
+ def traverse(self):
1112
+ """
1113
+ Traverse the entire tree in pre-order and return a list of node values.
1114
+
1115
+ Returns
1116
+ -------
1117
+ list
1118
+ A list of node values in pre-order traversal.
1119
+ """
1120
+ nodes = []
1121
+ for root in self.roots:
1122
+ nodes.extend(root.traverse())
1123
+ return nodes
1124
+
1125
+ def count_nodes(self):
1126
+ """
1127
+ Count the total number of nodes in the tree.
1128
+ """
1129
+ count = 0
1130
+ stack = self.roots[:]
1131
+ while stack:
1132
+ node = stack.pop()
1133
+ count += 1
1134
+ stack.extend(node.children)
1135
+ return count
1136
+
1137
+ def print_tree(self):
1138
+ """
1139
+ Print the entire tree structure.
1140
+ """
1141
+ for root in self.roots:
1142
+ root.print_tree()
1143
+
1144
+ def calculate_cost(self):
1145
+ """
1146
+ Calculate the total cost of the tree, including root costs and promotion costs for child nodes.
1147
+ See `RootNode` and `ChildNode`.
1148
+ """
1149
+ total_cost = sum([root.cost for root in self.roots])
1150
+ total_nodes = self.count_nodes()
1151
+ total_child_nodes = total_nodes - len(self.roots)
1152
+ return total_cost + total_child_nodes
1153
+
1154
+ def to_networkx_graph(self):
1155
+ """
1156
+ Convert the tree to a NetworkX directed graph with node and edge attributes.
1157
+
1158
+ Returns
1159
+ -------
1160
+ networkx.DiGraph
1161
+ The NetworkX directed graph representation of the tree.
1162
+ """
1163
+ G = _nx.DiGraph()
1164
+ stack = [(None, root) for root in self.roots]
1165
+ insertion_order = 0
1166
+ while stack:
1167
+ parent, node = stack.pop()
1168
+ node_id = id(node)
1169
+ prop_cost = node.cost if isinstance(node, RootNode) else 1
1170
+ G.add_node(node_id, cost=len(node.orig_indices), orig_indices=tuple(node.orig_indices),
1171
+ label=node.value, prop_cost = prop_cost,
1172
+ insertion_order=insertion_order)
1173
+ insertion_order+=1
1174
+ if parent is not None:
1175
+ parent_id = id(parent)
1176
+ edge_cost = node.calculate_promotion_cost()
1177
+ G.add_edge(parent_id, node_id, promotion_cost=edge_cost)
1178
+ for child in node.children:
1179
+ stack.append((node, child))
1180
+
1181
+ #if there are multiple roots then add an additional virtual root node as the
1182
+ #parent for all of these roots to enable partitioning with later algorithms.
1183
+ if len(self.roots)>1:
1184
+ G.add_node('virtual_root', cost = 0, orig_indices=(), label = (), prop_cost=0, insertion_order=-1)
1185
+ for root in self.roots:
1186
+ G.add_edge('virtual_root', id(root), promotion_cost=0)
1187
+
1188
+ return G
1189
+
1190
+ #--------------- Tree Partitioning Algorithm Helpers (+NetworkX Utilities)-----------------#
1191
+
1192
+ def _draw_graph(G, node_label_key='label', edge_label_key='promotion_cost', figure_size=(10,10)):
1193
+ """
1194
+ Draw the NetworkX graph with node labels.
1195
+
1196
+ Parameters
1197
+ ----------
1198
+ G : networkx.Graph
1199
+ The networkx Graph object to draw.
1200
+
1201
+ node_label_key : str, optional (default 'label')
1202
+ Optional key for the node attribute to use for the node labels.
1203
+
1204
+ edge_label_key : str, optional (default 'cost')
1205
+ Optional key for the edge attribute to use for the edge labels.
1206
+
1207
+ figure_size : tuple of floats, optional (default (10,10))
1208
+ An optional size specifier passed into the matplotlib figure
1209
+ constructor to set the plot size.
1210
+ """
1211
+ plt.figure(figsize=figure_size)
1212
+ pos = _nx.nx_agraph.graphviz_layout(G, prog="dot", args="-Granksep=5 -Gnodesep=10")
1213
+ labels = _nx.get_node_attributes(G, node_label_key)
1214
+ _nx.draw(G, pos, labels=labels, with_labels=True, node_size=500, node_color='lightblue', font_size=6, font_weight='bold')
1215
+ edge_labels = _nx.get_edge_attributes(G, edge_label_key)
1216
+ _nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
1217
+ plt.show()
1218
+
1219
+ def _copy_networkx_graph(G):
1220
+ """
1221
+ Create a new independent copy of a NetworkX directed graph with node and edge attributes that
1222
+ match the original graph. Specialized to copying graphs with the known attributes set by the
1223
+ `to_networkx_graph` method of the `Tree` class.
1224
+
1225
+ Parameters
1226
+ ----------
1227
+ G : networkx.DiGraph
1228
+ The original NetworkX directed graph.
1229
+
1230
+ Returns
1231
+ -------
1232
+ networkx.DiGraph
1233
+ The new independent copy of the NetworkX directed graph.
1234
+ """
1235
+ new_G = _nx.DiGraph()
1236
+
1237
+ # Copy nodes with attributes
1238
+ for node, data in G.nodes(data=True):
1239
+ new_G.add_node(node, cost = data['cost'], orig_indices=data['orig_indices'],
1240
+ label= data['label'] , prop_cost = data['prop_cost'],
1241
+ insertion_order=data['insertion_order'])
1242
+
1243
+ # Copy edges with attributes
1244
+ for u, v, data in G.edges(data=True):
1245
+ new_G.add_edge(u, v, promotion_cost = data['promotion_cost'])
1246
+
1247
+ return new_G
1248
+
1249
+ def _find_root(tree):
1250
+ """
1251
+ Find the root node of a directed tree.
1252
+
1253
+ Parameters
1254
+ ----------
1255
+ tree : networkx.DiGraph
1256
+ The directed tree.
1257
+
1258
+ Returns
1259
+ -------
1260
+ networkx node corresponding to the root.
1261
+ """
1262
+
1263
+ # The root node will have no incoming edges
1264
+ for node in tree.nodes():
1265
+ if tree.in_degree(node) == 0:
1266
+ return node
1267
+ raise ValueError("The input graph is not a valid tree (no root found).")
1268
+
1269
+ def _compute_subtree_weights(tree, root, weight_key):
1270
+ """
1271
+ This function computes the total weight of each subtree in a directed tree.
1272
+ The weight of a subtree is defined as the sum of the weights of all nodes
1273
+ in that subtree, including the root of the subtree.
1274
+
1275
+ Parameters
1276
+ ----------
1277
+ tree : networkx.DiGraph
1278
+ The directed tree.
1279
+
1280
+ root: networkx node
1281
+ The root node of the tree.
1282
+
1283
+ weight_key : str
1284
+ A string corresponding to the node attribute to use as the weights.
1285
+
1286
+ Returns
1287
+ -------
1288
+ A dictionary where keys are nodes and values are the total weights of the subtrees rooted at those nodes.
1289
+ """
1290
+
1291
+ subtree_weights = {} # {node: 0 for node in tree.nodes()}
1292
+ stack = [root]
1293
+ visited = set()
1294
+
1295
+ # First pass: calculate the subtree weights in a bottom-up manner
1296
+ while stack:
1297
+ node = stack.pop()
1298
+ if node in visited:
1299
+ # All children have been processed, now process the node itself
1300
+ subtree_weight = tree.nodes[node][weight_key]
1301
+ for child in tree.successors(node):
1302
+ subtree_weight += subtree_weights[child]
1303
+ subtree_weights[node] = subtree_weight
1304
+ else:
1305
+ # Process the node after its children
1306
+ visited.add(node)
1307
+ stack.append(node)
1308
+ for child in tree.successors(node):
1309
+ if child not in visited:
1310
+ stack.append(child)
1311
+
1312
+ return subtree_weights
1313
+
1314
+ def _partition_levels(tree, root):
1315
+ """
1316
+ Partition the nodes of a rooted directed tree into levels based on their distance from the root.
1317
+
1318
+ Parameters
1319
+ ----------
1320
+ tree : networkx.DiGraph
1321
+ The directed tree.
1322
+ root : networkx node
1323
+ The root node of the tree.
1324
+
1325
+ Returns
1326
+ -------
1327
+ list of sets:
1328
+ A list where each set contains nodes that are equidistant from the root.
1329
+ """
1330
+ # Initialize a dictionary to store the level of each node
1331
+ levels = {}
1332
+ # Initialize a queue for BFS
1333
+ queue = _collections.deque([(root, 0)])
1334
+
1335
+ while queue:
1336
+ node, level = queue.popleft()
1337
+ if level not in levels:
1338
+ levels[level] = set()
1339
+ levels[level].add(node)
1340
+
1341
+ for child in tree.successors(node):
1342
+ queue.append((child, level + 1))
1343
+
1344
+ tree_nodes = tree.nodes
1345
+ # Convert the levels dictionary to a list of sets ordered by level
1346
+ sorted_levels = []
1347
+ for level in sorted(levels.keys()):
1348
+ # Sort nodes at each level by 'insertion_order' attribute
1349
+ sorted_nodes = sorted(levels[level], key=lambda node: tree_nodes[node]['insertion_order'])
1350
+ sorted_levels.append(sorted_nodes)
1351
+
1352
+ return sorted_levels
1353
+
1354
+
1355
+ def _partition_levels_and_compute_subtree_weights(tree, root, weight_key):
1356
+ """
1357
+ Partition the nodes of a rooted directed tree into levels based on their distance from the root
1358
+ and compute the total weight of each subtree.
1359
+
1360
+ Parameters
1361
+ ----------
1362
+ tree : networkx.DiGraph
1363
+ The directed tree.
1364
+ root : networkx node
1365
+ The root node of the tree.
1366
+ weight_key : str
1367
+ A string corresponding to the node attribute to use as the weights.
1368
+
1369
+ Returns
1370
+ -------
1371
+ tuple:
1372
+ - list of sets: A list where each set contains nodes that are equidistant from the root.
1373
+ - dict: A dictionary where keys are nodes and values are the total weights of the subtrees rooted at those nodes.
1374
+ """
1375
+ # Initialize a dictionary to store the level of each node
1376
+ levels = {}
1377
+ # Initialize a dictionary to store the subtree weights
1378
+ subtree_weights = {}
1379
+ # Initialize a queue for BFS
1380
+ queue = _collections.deque([(root, 0)])
1381
+ # Initialize a stack for DFS to compute subtree weights
1382
+ stack = []
1383
+ visited = set()
1384
+
1385
+ #I think this returns a view, so grab this ahead of time in case
1386
+ #there is overhead with that.
1387
+ tree_nodes = tree.nodes
1388
+
1389
+ #successors is kind of an expensive call and we use at least twice
1390
+ #per node, so let's just compute it once and cache in a dict.
1391
+ node_successors = {node: list(tree.successors(node)) for node in tree_nodes}
1392
+
1393
+ while queue:
1394
+ node, level = queue.popleft()
1395
+ if node not in visited:
1396
+ visited.add(node)
1397
+ if level not in levels:
1398
+ levels[level] = set()
1399
+ levels[level].add(node)
1400
+ stack.append(node)
1401
+ for child in node_successors[node]:
1402
+ queue.append((child, level + 1))
1403
+
1404
+ # Compute subtree weights in a bottom-up manner
1405
+ while stack:
1406
+ node = stack.pop()
1407
+ subtree_weight = tree_nodes[node][weight_key]
1408
+ for child in node_successors[node]:
1409
+ subtree_weight += subtree_weights[child]
1410
+ subtree_weights[node] = subtree_weight
1411
+
1412
+ # Convert the levels dictionary to a list of sets ordered by level
1413
+ sorted_levels = []
1414
+ for level in sorted(levels.keys()):
1415
+ # Sort nodes at each level by 'insertion_order' attribute
1416
+ sorted_nodes = sorted(levels[level], key=lambda node: tree_nodes[node]['insertion_order'])
1417
+ sorted_levels.append(sorted_nodes)
1418
+
1419
+ return sorted_levels, subtree_weights
1420
+
1421
+
1422
+ def _find_leaves(tree):
1423
+ """
1424
+ Find all leaf nodes in a directed tree.
1425
+
1426
+ Parameters
1427
+ ----------
1428
+ tree : networkx.DiGraph
1429
+ The directed tree.
1430
+
1431
+ Returns
1432
+ -------
1433
+ A list of leaf nodes.
1434
+ """
1435
+ leaf_nodes = set([node for node in tree.nodes() if tree.out_degree(node) == 0])
1436
+ return leaf_nodes
1437
+
1438
+ def _path_to_root(tree, node, root):
1439
+ """
1440
+ Return a list of nodes along the path from the given node to the root.
1441
+
1442
+ Parameters
1443
+ ----------
1444
+ tree : networkx.DiGraph
1445
+ The directed tree.
1446
+ node : networkx node
1447
+ The starting node.
1448
+ root : networkx node
1449
+ The root node of the tree.
1450
+
1451
+ Returns
1452
+ -------
1453
+ A list of nodes along the path from the given node to the root.
1454
+ """
1455
+ path = []
1456
+ current_node = node
1457
+
1458
+ while current_node != root:
1459
+ path.append(current_node)
1460
+ #note: for a tree structure there should be just one predecessor
1461
+ #so not worried about nondeterminism, if we every apply this to another
1462
+ #graph structure this needs to be reevaluated.
1463
+ predecessors = list(tree.predecessors(current_node))
1464
+ current_node = predecessors[0]
1465
+ path.append(root)
1466
+
1467
+ return path
1468
+
1469
+ def _get_subtree(tree, root):
1470
+ """
1471
+ Return a new graph corresponding to the subtree rooted at the given node.
1472
+
1473
+ Parameters
1474
+ ----------
1475
+ tree : networkx.DiGraph
1476
+ The directed tree.
1477
+
1478
+ root : networkx node
1479
+ The root node of the subtree.
1480
+
1481
+ Returns
1482
+ -------
1483
+ subtree : networkx.DiGraph
1484
+ A new directed graph corresponding to the subtree rooted at the given node.
1485
+ """
1486
+ # Create a new directed graph for the subtree
1487
+ subtree = _nx.DiGraph()
1488
+
1489
+ # Use a queue to perform BFS and add nodes and edges to the subtree
1490
+ queue = [root]
1491
+ while queue:
1492
+ node = queue.pop(0)
1493
+ subtree.add_node(node, **tree.nodes[node])
1494
+ for child in tree.successors(node):
1495
+ subtree.add_edge(node, child, **tree.edges[node, child])
1496
+ queue.append(child)
1497
+
1498
+ return subtree
1499
+
1500
+ def _collect_orig_indices(tree, root):
1501
+ """
1502
+ Collect all values of the 'orig_indices' node attributes in the subtree rooted at the given node.
1503
+ The 'orig_indices' values are tuples, and the function flattens these tuples into a single list.
1504
+
1505
+ Parameters
1506
+ ----------
1507
+ tree : networkx.DiGraph
1508
+ The directed tree.
1509
+
1510
+ root : networkx node
1511
+ The root node of the subtree.
1512
+
1513
+ Returns
1514
+ -------
1515
+ list
1516
+ A flattened list of all values of the 'orig_indices' node attributes in the subtree.
1517
+ """
1518
+ orig_indices_list = []
1519
+ queue = [root]
1520
+
1521
+ #TODO: See if this would be any faster with one of the dfs/bfs iterators in networkx
1522
+ while queue:
1523
+ node = queue.pop()
1524
+ orig_indices_list.extend(tree.nodes[node]['orig_indices'])
1525
+ for child in tree.successors(node):
1526
+ queue.append(child)
1527
+
1528
+ return sorted(orig_indices_list) #sort it to account for any nondeterministic traversal order.
1529
+
1530
+ def _process_node_km(node, tree, subtree_weights, cut_edges, max_weight, root, new_roots):
1531
+ """
1532
+ Helper function for Kundu-Misra algorithm. This function processes each node
1533
+ by cutting edges with the highest weight children until the node's subtree weight
1534
+ is below the maximum weight threshold, updating the subtree weights of any ancestors
1535
+ as needed.
1536
+ """
1537
+
1538
+ #if the subtree weight of this node is less than max weight we can stop right away
1539
+ #and avoid the sorting of the child weights.
1540
+ if subtree_weights[node]<=max_weight:
1541
+ return
1542
+
1543
+ tree_nodes = tree.nodes
1544
+ #otherwise we will sort the weights of the child nodes to get the heaviest weight ones.
1545
+ #sorting by insertion order to ensure determinism.
1546
+ weighted_children = [(child, subtree_weights[child]) for child in
1547
+ sorted(tree.successors(node), key=lambda node: tree_nodes[node]['insertion_order']) ]
1548
+ sorted_weighted_children = sorted(weighted_children, key = lambda x: x[1], reverse=True)
1549
+
1550
+ #get the path of nodes up to the root which need to have their weights updated upon edge removal.
1551
+ nodes_to_update = _path_to_root(tree, node, root)
1552
+
1553
+ #remove the weightiest children until the weight is below the maximum weight.
1554
+ removed_child_index = 0 #track the index of the child being removed.
1555
+ while subtree_weights[node]>max_weight:
1556
+ removed_child = sorted_weighted_children[removed_child_index][0]
1557
+ #add the edge to this child to the list of those cut.
1558
+ cut_edges.append((node, removed_child))
1559
+ new_roots.append(removed_child)
1560
+ removed_child_weight = subtree_weights[removed_child]
1561
+ #update the subtree weight of the current node and all parents up to the root.
1562
+ for node_to_update in nodes_to_update:
1563
+ subtree_weights[node_to_update]-= removed_child_weight
1564
+ #update index:
1565
+ removed_child_index+=1
1566
+
1567
+ def tree_partition_kundu_misra(tree, max_weight, weight_key='cost', test_leaves = True,
1568
+ return_levels_and_weights=False, precomp_levels = None,
1569
+ precomp_weights = None):
1570
+ """
1571
+ Algorithm for optimal minimum cardinality k-partition of tree (a partition
1572
+ of a tree into cluster of size at most k) based on a slightly less sophisticated
1573
+ implementation of the algorithm from "A Linear Tree Partitioning Algorithm"
1574
+ by Kundu and Misra (SIAM J. Comput. Vol. 6, No. 1, March 1977). Less sophisiticated
1575
+ because the strictly linear time implementation uses linear-time median estimation
1576
+ routine, while this implementation uses sorting (n log(n)-time), in practice it is
1577
+ likely that the highly-optimized C implementation of sorting would beat an uglier
1578
+ python implementation of median finding for most problem instances of interest anyhow.
1579
+
1580
+ Parameters
1581
+ ----------
1582
+ tree : networkx.DiGraph
1583
+ An input graph representing the directed tree to perform partitioning on.
1584
+
1585
+ max_weight : int
1586
+ Maximum node weight allowed for each partition.
1587
+
1588
+ weight_key : str, optional (default 'cost')
1589
+ An optional string denoting the node attribute label to use for node weights
1590
+ in partitioning.
1591
+
1592
+ test_leaves : bool, optional (default True)
1593
+ When True an initial test is performed to ensure that the weight of the leaves are all
1594
+ less than the maximum weight. Only turn off if you know for certain this is true.
1595
+
1596
+ return_levels_and_weights : bool, optional (default False)
1597
+ If True return the constructed tree level structure (the lists of nodes partitioned
1598
+ by distance from the root) and subtree weights.
1599
+
1600
+ precomp_levels : list of sets, optional (default None)
1601
+ A list where each set contains nodes that are equidistant from the root.
1602
+
1603
+ precomp_weights : dict, optional (default None)
1604
+ A dictionary where keys are nodes and values are the total weights of the subtrees rooted at those nodes.
1605
+
1606
+ Returns
1607
+ -------
1608
+ partitioned_tree : networkx.DiGraph
1609
+ A new DiGraph corresponding to the partitioned tree. I.e. a copy of the original
1610
+ tree with the requisite edge cuts performed.
1611
+
1612
+ cut_edges : list of tuples
1613
+ A list of the parent-child node pairs whose edges were cut in partitioning the tree.
1614
+
1615
+
1616
+ """
1617
+ #create a copy of the input tree:
1618
+ #tree = _copy_networkx_graph(tree)
1619
+
1620
+ cut_edges = [] #list of cut edges.
1621
+ new_roots = [] #list of the subtree root node in the partitioned tree
1622
+
1623
+ #find the root node of tree:
1624
+ root = _find_root(tree)
1625
+ new_roots.append(root)
1626
+
1627
+ tree_nodes = tree.nodes
1628
+
1629
+ if test_leaves:
1630
+ #find the leaves:
1631
+ leaves = _find_leaves(tree)
1632
+ #make sure that the weights of the leaves are all less than the maximum weight.
1633
+ msg = 'The maximum node weight for at least one leaf is greater than the maximum weight, no partition possible.'
1634
+ assert all([tree_nodes[leaf][weight_key]<=max_weight for leaf in leaves]), msg
1635
+
1636
+ #precompute a list of subtree weights which will be dynamically updated as we make cuts. Also
1637
+ #parition tree into levels based on distance from root.
1638
+ if precomp_levels is None and precomp_weights is None:
1639
+ tree_levels, subtree_weights = _partition_levels_and_compute_subtree_weights(tree, root, weight_key)
1640
+ else:
1641
+ tree_levels = precomp_levels if precomp_levels is not None else _partition_levels(tree, root)
1642
+ subtree_weights = precomp_weights.copy() if precomp_weights is not None else _compute_subtree_weights(tree, root, weight_key)
1643
+
1644
+ #the subtree_weights get modified in-place by _process_node_km, so create a copy for the return value.
1645
+ if return_levels_and_weights:
1646
+ subtree_weights_orig = subtree_weights.copy()
1647
+
1648
+ #begin processing the nodes level-by-level.
1649
+ for level in reversed(tree_levels):
1650
+ for node in level:
1651
+ _process_node_km(node, tree, subtree_weights, cut_edges, max_weight, root, new_roots)
1652
+
1653
+ #sort the new root nodes in case there are determinism issues
1654
+ new_roots = sorted(new_roots, key=lambda node: tree_nodes[node]['insertion_order'])
1655
+
1656
+ if return_levels_and_weights:
1657
+ return cut_edges, new_roots, tree_levels, subtree_weights_orig
1658
+ else:
1659
+ return cut_edges, new_roots
1660
+
1661
+ def _bisect_tree(tree, subtree_root, subtree_weights, weight_key, root_cost = 0, target_proportion = .5):
1662
+ #perform a bisection on the subtree. Loop through the tree beginning at the root,
1663
+ #and find as cheap as possible of an edge which when cut approximately bisects the tree based on cost.
1664
+
1665
+ heaviest_subtree_levels = _partition_levels(tree, subtree_root)
1666
+ new_subtree_cost = {}
1667
+
1668
+ new_subtree_cost[subtree_root] = subtree_weights[subtree_root]
1669
+ for i, level in enumerate(heaviest_subtree_levels[1:]): #skip the root.
1670
+ for node in level:
1671
+ #calculate the cost of a new subtree rooted at this node. This is the current cost
1672
+ #plus the current level plus the propagation cost of the current root.
1673
+ new_subtree_cost[node] = subtree_weights[node] + i + root_cost if weight_key == 'prop_cost' else subtree_weights[node]
1674
+
1675
+ #find the node that results in as close as possible to a bisection of the subtree
1676
+ #in terms of propagation cost.
1677
+ target_prop_cost = new_subtree_cost[subtree_root] * target_proportion
1678
+ closest_node = subtree_root
1679
+ closest_distance = new_subtree_cost[subtree_root]
1680
+ for node, cost in new_subtree_cost.items(): #since the nodes in each level are sorted this should be alright for determinism.
1681
+ current_distance = abs(cost - target_prop_cost)
1682
+ if current_distance < closest_distance:
1683
+ closest_distance = current_distance
1684
+ closest_node = node
1685
+ #we now have the node which when promoted to a root produces the tree closest to a bisection in terms of propagation
1686
+ #cost possible. Let's perform that bisection now.
1687
+ if closest_node is not subtree_root:
1688
+ #since a tree should only be one predecessor, so don't need to worry about determinism.
1689
+ cut_edge = (list(tree.predecessors(closest_node))[0], closest_node)
1690
+ return cut_edge, (new_subtree_cost[closest_node], subtree_weights[subtree_root] - subtree_weights[closest_node])
1691
+ else:
1692
+ return None, None
1693
+
1694
+ def _bisection_pass(partitioned_tree, cut_edges, new_roots, num_sub_tables, weight_key):
1695
+ partitioned_tree = _copy_networkx_graph(partitioned_tree)
1696
+ subtree_weights = [(root, _compute_subtree_weights(partitioned_tree, root, weight_key)) for root in new_roots]
1697
+ sorted_subtree_weights = sorted(subtree_weights, key=lambda x: x[1][x[0]], reverse=True)
1698
+
1699
+ #perform a bisection on the heaviest subtree. Loop through the tree beginning at the root,
1700
+ #and find as cheap as possible of an edge which when cut approximately bisects the tree based on cost.
1701
+ for i in range(len(sorted_subtree_weights)):
1702
+ heaviest_subtree_root = sorted_subtree_weights[i][0]
1703
+ heaviest_subtree_weights = sorted_subtree_weights[i][1]
1704
+ root_cost = partitioned_tree.nodes[heaviest_subtree_root][weight_key] if weight_key == 'prop_cost' else 0
1705
+ cut_edge, new_subtree_costs = _bisect_tree(partitioned_tree, heaviest_subtree_root, heaviest_subtree_weights, weight_key, root_cost)
1706
+ if cut_edge is not None:
1707
+ cut_edges.append(cut_edge)
1708
+ new_roots.append(cut_edge[1])
1709
+ #cut the prescribed edge.
1710
+ partitioned_tree.remove_edge(cut_edge[0], cut_edge[1])
1711
+ #check whether we need to continue paritioning subtrees.
1712
+ if len(new_roots) == num_sub_tables:
1713
+ break
1714
+ #sort the new root nodes in case there are determinism issues
1715
+ new_roots = sorted(new_roots, key=lambda node: partitioned_tree.nodes[node]['insertion_order'])
1716
+
1717
+ return partitioned_tree, new_roots, cut_edges
1718
+
1719
+ def _refinement_pass(partitioned_tree, roots, weight_key, imbalance_threshold=1.2, minimum_improvement_threshold = .1):
1720
+ #refine the partitioning to improve the balancing of the specified weights across the
1721
+ #subtrees.
1722
+ #start by recomputing the latest subtree weights and ranking them from heaviest to lightest.
1723
+ partitioned_tree = _copy_networkx_graph(partitioned_tree)
1724
+ subtree_weights = [(root, _compute_subtree_weights(partitioned_tree, root, weight_key)) for root in roots]
1725
+ sorted_subtree_weights = sorted(subtree_weights, key=lambda x: x[1][x[0]], reverse=True)
1726
+
1727
+ partitioned_tree_nodes = partitioned_tree.nodes
1728
+
1729
+ #Strategy: pair heaviest and lightest subtrees and identify the subtree in the heaviest that could be
1730
+ #snipped out and added to the lightest to bring their weights as close as possible.
1731
+ #Next do this for the second heaviest and second lightest, etc.
1732
+ #Only do so while the imbalance threshold, the ratio between the heaviest and lightest subtrees, is
1733
+ #above a specified threshold.
1734
+ heavy_light_pairs = _pair_elements(sorted_subtree_weights)
1735
+ heavy_light_pair_indices = _pair_elements(list(range(len(sorted_subtree_weights))))
1736
+ heavy_light_weights = [(sorted_subtree_weights[i][1][sorted_subtree_weights[i][0]], sorted_subtree_weights[j][1][sorted_subtree_weights[j][0]])
1737
+ for i,j in heavy_light_pair_indices]
1738
+ heavy_light_ratios = [weight_1/weight_2 for weight_1,weight_2 in heavy_light_weights]
1739
+
1740
+ heavy_light_pairs_to_balance = heavy_light_pairs if len(sorted_subtree_weights)%2==0 else heavy_light_pairs[0:-1]
1741
+ new_roots = []
1742
+ addl_cut_edges = []
1743
+ pair_iter = iter(range(len(heavy_light_pairs_to_balance)))
1744
+ for i in pair_iter:
1745
+ #if the ratio is above the threshold then try a rebalancing
1746
+ #step.
1747
+ if heavy_light_ratios[i] > imbalance_threshold:
1748
+ #calculate the fraction of the heavy tree that would be needed to bring the weight of the
1749
+ #lighter tree in line.
1750
+ root_cost = partitioned_tree_nodes[heavy_light_pairs[i][0][0]][weight_key] if weight_key == 'prop_cost' else 0
1751
+
1752
+ rebalancing_target_fraction = (.5*(heavy_light_weights[i][0] - heavy_light_weights[i][1]))/heavy_light_weights[i][0]
1753
+ cut_edge, new_subtree_weights =_bisect_tree(partitioned_tree, heavy_light_pairs[i][0][0], heavy_light_pairs[i][0][1],
1754
+ weight_key, root_cost = root_cost,
1755
+ target_proportion = rebalancing_target_fraction)
1756
+ #before applying the edge cut check whether the edge we found was close enough
1757
+ # to bring us below the threshold.
1758
+ if cut_edge is not None:
1759
+ new_light_tree_weight = new_subtree_weights[0] + heavy_light_weights[i][1]
1760
+ new_heavy_tree_weight = new_subtree_weights[1]
1761
+ new_heavy_light_ratio = new_heavy_tree_weight/new_light_tree_weight
1762
+ if new_heavy_light_ratio > imbalance_threshold and \
1763
+ (heavy_light_ratios[i] - new_heavy_light_ratio)<minimum_improvement_threshold:
1764
+ #We're only as good as the worst balancing, so if we are unable to
1765
+ #balance below the threshold and the improvement is below some minimum threshold
1766
+ #then we won't make an update and will terminate.
1767
+ #Maybe we should throw a warning too?
1768
+ #but it isn't clear whether that would just be confusing to end-users who wouldn't
1769
+ #know what was meant or if it was important.
1770
+ #also add the roots of any of the pairs we haven't yet processed.
1771
+ remaining_indices = [i] + [j for j in pair_iter]
1772
+ for idx in remaining_indices:
1773
+ new_roots.extend((heavy_light_pairs[idx][0][0], heavy_light_pairs[idx][1][0]))
1774
+ break
1775
+
1776
+ else:
1777
+ #append the original root of the heavy tree, and a tuple of roots for the light plus the
1778
+ #bisected part of the heavy.
1779
+ new_roots.append(heavy_light_pairs[i][0][0])
1780
+ new_roots.append((heavy_light_pairs[i][1][0], cut_edge[1]))
1781
+ addl_cut_edges.append(cut_edge)
1782
+ #apply the cut
1783
+ partitioned_tree.remove_edge(cut_edge[0],cut_edge[1])
1784
+ else:
1785
+ #if the cut edge is None append the original heavy and light roots.
1786
+ new_roots.extend((heavy_light_pairs[i][0][0], heavy_light_pairs[i][1][0]))
1787
+ #since we're pairing up subsequent pairs of heavy and light
1788
+ #elements, once we see one which is sufficiently balanced we
1789
+ #know the rest must be.
1790
+ else:
1791
+ remaining_indices = [i] + [j for j in pair_iter]
1792
+ for idx in remaining_indices:
1793
+ new_roots.extend((heavy_light_pairs[idx][0][0], heavy_light_pairs[idx][1][0]))
1794
+ break
1795
+
1796
+ #if the number of subtrees was odd to start we need to append on the median weight element which hasn't
1797
+ #been processed.
1798
+ if len(sorted_subtree_weights)%2!=0:
1799
+ new_roots.append(heavy_light_pairs[-1][0][0])
1800
+
1801
+ return partitioned_tree, new_roots, addl_cut_edges
1802
+
1803
+
1804
+ #helper function for pairing up heavy and light subtrees.
1805
+ def _pair_elements(lst):
1806
+ paired_list = []
1807
+ length = len(lst)
1808
+
1809
+ for i in range((length + 1) // 2):
1810
+ if i == length - i - 1:
1811
+ paired_list.append((lst[i], lst[i]))
1812
+ else:
1813
+ paired_list.append((lst[i], lst[length - i - 1]))
1814
+
1815
+ return paired_list
1816
+
1817
+