Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1138 @@
1
+ """Tools for performing TEBD like algorithms on arbitrary lattices."""
2
+
3
+ import collections
4
+ import itertools
5
+ import random
6
+ from collections.abc import Iterable
7
+
8
+ from autoray import do, to_numpy
9
+
10
+ from ..core import eye, kron, qarray
11
+ from ..utils import (
12
+ ExponentialGeometricRollingDiffMean,
13
+ default_to_neutral_style,
14
+ ensure_dict,
15
+ )
16
+ from ..utils import (
17
+ progbar as Progbar,
18
+ )
19
+ from .drawing import get_colors, get_positions
20
+ from .tensor_core import Tensor
21
+
22
+
23
+ def edge_coloring(
24
+ edges,
25
+ strategy="smallest_last",
26
+ interchange=True,
27
+ group=True,
28
+ ):
29
+ """Generate an edge coloring for the graph given by ``edges``, using
30
+ ``networkx.coloring.greedy_color``.
31
+
32
+ Parameters
33
+ ----------
34
+ edges : sequence[tuple[hashable, hashable]]
35
+ The edges of the graph.
36
+ strategy : str or callable, optional
37
+ The strategy to use for coloring the edges. Can be:
38
+
39
+ - 'largest_first'
40
+ - 'smallest_last'
41
+ - 'random_sequential'
42
+ ...
43
+
44
+ interchange : bool, optional
45
+ Whether to use the interchange heuristic. Usually generates better
46
+ colorings but can be slower.
47
+ group : bool, optional
48
+ Whether to group the edges by color or return a flat list.
49
+ """
50
+ import networkx as nx
51
+
52
+ # find vertex coloring of line graph
53
+ G = nx.Graph(tuple(edges))
54
+ edge_colors = nx.coloring.greedy_color(
55
+ nx.line_graph(G), strategy, interchange=interchange
56
+ )
57
+
58
+ # group the edges by color
59
+ coloring = {}
60
+ for edge, color in edge_colors.items():
61
+ coloring.setdefault(color, []).append(edge)
62
+
63
+ if group:
64
+ return tuple(
65
+ tuple(tuple(tuple(sorted(edge)) for edge in coloring[color]))
66
+ for color in sorted(coloring)
67
+ )
68
+ else:
69
+ # flatten sorted groups
70
+ return tuple(
71
+ tuple(sorted(edge))
72
+ for color in sorted(coloring)
73
+ for edge in coloring[color]
74
+ )
75
+
76
+
77
+ class LocalHamGen:
78
+ """Representation of a local hamiltonian defined on a general graph. This
79
+ combines all two site and one site terms into a single interaction per
80
+ lattice pair, and caches operations on the terms such as getting their
81
+ exponential. The sites (nodes) should be hashable and comparable.
82
+
83
+ Parameters
84
+ ----------
85
+ H2 : dict[tuple[node], array_like]
86
+ The interaction terms, with each key being an tuple of nodes defining
87
+ an edge and each value the local hamilotonian term for those two nodes.
88
+ H1 : array_like or dict[node, array_like], optional
89
+ The one site term(s). If a single array is given, assume to be the
90
+ default onsite term for all terms. If a dict is supplied,
91
+ the keys should represent specific coordinates like
92
+ ``(i, j)`` with the values the array representing the local term for
93
+ that site. A default term for all remaining sites can still be supplied
94
+ with the key ``None``.
95
+
96
+ Attributes
97
+ ----------
98
+ terms : dict[tuple, array_like]
99
+ The total effective local term for each interaction (with single site
100
+ terms appropriately absorbed). Each key is a pair of coordinates
101
+ ``site_a, site_b`` with ``site_a < site_b``.
102
+ """
103
+
104
+ def __init__(self, H2, H1=None):
105
+ # caches for not repeating operations / duplicating tensors
106
+ self._op_cache = collections.defaultdict(dict)
107
+
108
+ self.terms = dict(H2)
109
+
110
+ # convert qarrays (mostly useful for working with jax)
111
+ for key, X in self.terms.items():
112
+ if isinstance(X, qarray):
113
+ self.terms[key] = self._convert_from_qarray_cached(X)
114
+
115
+ self.sites = tuple(
116
+ sorted(set(itertools.chain.from_iterable(self.terms)))
117
+ )
118
+
119
+ # first combine terms to ensure coo1 < coo2
120
+ for where in tuple(filter(bool, self.terms)):
121
+ coo1, coo2 = where
122
+ if coo1 < coo2:
123
+ continue
124
+
125
+ # pop and flip the term
126
+ X12 = self._flip_cached(self.terms.pop(where))
127
+
128
+ # add to, or create, term with flipped coos
129
+ new_where = coo2, coo1
130
+ if new_where in self.terms:
131
+ self.terms[new_where] = self._add_cached(
132
+ self.terms[new_where], X12
133
+ )
134
+ else:
135
+ self.terms[new_where] = X12
136
+
137
+ # make a directory of which single sites are covered by which terms
138
+ # - to merge them into later
139
+ self._sites_to_covering_terms = collections.defaultdict(list)
140
+ for where in self.terms:
141
+ site_a, site_b = where
142
+ self._sites_to_covering_terms[site_a].append(where)
143
+ self._sites_to_covering_terms[site_b].append(where)
144
+
145
+ # parse one site terms
146
+ if H1 is None:
147
+ H1s = dict()
148
+ elif hasattr(H1, "shape"):
149
+ # set a default site term
150
+ H1s = {None: H1}
151
+ else:
152
+ H1s = dict(H1)
153
+
154
+ # convert qarrays (mostly useful for working with jax)
155
+ for key, X in H1s.items():
156
+ if isinstance(X, qarray):
157
+ H1s[key] = self._convert_from_qarray_cached(X)
158
+
159
+ # possibly set the default single site term
160
+ default_H1 = H1s.pop(None, None)
161
+ if default_H1 is not None:
162
+ for site in self.sites:
163
+ H1s.setdefault(site, default_H1)
164
+
165
+ # now absorb the single site terms evenly into the two site terms
166
+ for site, H in H1s.items():
167
+ # get interacting terms which cover the site
168
+ pairs = self._sites_to_covering_terms[site]
169
+ num_pairs = len(pairs)
170
+ if num_pairs == 0:
171
+ raise ValueError(
172
+ f"There are no two site terms to add this single site "
173
+ f"term to - site {site} is not coupled to anything."
174
+ )
175
+
176
+ # merge the single site term in equal parts into all covering pairs
177
+ H_tensoreds = (self._op_id_cached(H), self._id_op_cached(H))
178
+ for pair in pairs:
179
+ H_tensored = H_tensoreds[pair.index(site)]
180
+ self.terms[pair] = self._add_cached(
181
+ self.terms[pair], self._div_cached(H_tensored, num_pairs)
182
+ )
183
+
184
+ @property
185
+ def nsites(self):
186
+ """The number of sites in the system."""
187
+ return len(self.sites)
188
+
189
+ def items(self):
190
+ """Iterate over all terms in the hamiltonian. This is mostly for
191
+ convenient compatibility with ``compute_local_expectation``.
192
+ """
193
+ return self.terms.items()
194
+
195
+ def _convert_from_qarray_cached(self, x):
196
+ cache = self._op_cache["convert_from_qarray"]
197
+ key = id(x)
198
+ if key not in cache:
199
+ cache[key] = x.toarray()
200
+ return cache[key]
201
+
202
+ def _flip_cached(self, x):
203
+ cache = self._op_cache["flip"]
204
+ key = id(x)
205
+ if key not in cache:
206
+ d = int(x.size ** (1 / 4))
207
+ xf = do("reshape", x, (d, d, d, d))
208
+ xf = do("transpose", xf, (1, 0, 3, 2))
209
+ xf = do("reshape", xf, (d * d, d * d))
210
+ cache[key] = xf
211
+ return cache[key]
212
+
213
+ def _add_cached(self, x, y):
214
+ cache = self._op_cache["add"]
215
+ key = (id(x), id(y))
216
+ if key not in cache:
217
+ cache[key] = x + y
218
+ return cache[key]
219
+
220
+ def _div_cached(self, x, y):
221
+ cache = self._op_cache["div"]
222
+ key = (id(x), y)
223
+ if key not in cache:
224
+ cache[key] = x / y
225
+ return cache[key]
226
+
227
+ def _op_id_cached(self, x):
228
+ cache = self._op_cache["op_id"]
229
+ key = id(x)
230
+ if key not in cache:
231
+ xn = to_numpy(x)
232
+ d = int(xn.size**0.5)
233
+ Id = eye(d, dtype=xn.dtype)
234
+ XI = do("array", kron(xn, Id), like=x)
235
+ cache[key] = XI
236
+ return cache[key]
237
+
238
+ def _id_op_cached(self, x):
239
+ cache = self._op_cache["id_op"]
240
+ key = id(x)
241
+ if key not in cache:
242
+ xn = to_numpy(x)
243
+ d = int(xn.size**0.5)
244
+ Id = eye(d, dtype=xn.dtype)
245
+ IX = do("array", kron(Id, xn), like=x)
246
+ cache[key] = IX
247
+ return cache[key]
248
+
249
+ def _expm_cached(self, G, x):
250
+ cache = self._op_cache["expm"]
251
+ key = (id(G), x)
252
+ if key not in cache:
253
+ ndim_G = do("ndim", G)
254
+ need_to_reshape = ndim_G != 2
255
+ if need_to_reshape:
256
+ shape_orig = do("shape", G)
257
+ G = do(
258
+ "fuse",
259
+ G,
260
+ range(0, ndim_G // 2),
261
+ range(ndim_G // 2, ndim_G),
262
+ )
263
+
264
+ U = do("linalg.expm", G * x)
265
+
266
+ if need_to_reshape:
267
+ U = do("reshape", U, shape_orig)
268
+
269
+ cache[key] = U
270
+
271
+ return cache[key]
272
+
273
+ def get_gate(self, where):
274
+ """Get the local term for pair ``where``, cached."""
275
+ return self.terms[tuple(sorted(where))]
276
+
277
+ def get_gate_expm(self, where, x):
278
+ """Get the local term for pair ``where``, matrix exponentiated by
279
+ ``x``, and cached.
280
+ """
281
+ return self._expm_cached(self.get_gate(where), x)
282
+
283
+ def apply_to_arrays(self, fn):
284
+ """Apply the function ``fn`` to all the arrays representing terms."""
285
+ for k, x in self.terms.items():
286
+ self.terms[k] = fn(x)
287
+
288
+ def get_auto_ordering(self, order="sort", **kwargs):
289
+ """Get an ordering of the terms to use with TEBD, for example. The
290
+ default is to sort the coordinates then greedily group them into
291
+ commuting sets.
292
+
293
+ Parameters
294
+ ----------
295
+ order : {'sort', None, 'random', str}
296
+ How to order the terms *before* greedily grouping them into
297
+ commuting (non-coordinate overlapping) sets:
298
+
299
+ - ``'sort'`` will sort the coordinate pairs first.
300
+ - ``None`` will use the current order of terms which should
301
+ match the order they were supplied to this ``LocalHam2D``
302
+ instance.
303
+ - ``'random'`` will randomly shuffle the coordinate pairs
304
+ before grouping them - *not* the same as returning a
305
+ completely random order.
306
+ - ``'random-ungrouped'`` will randomly shuffle the coordinate
307
+ pairs but *not* group them at all with respect to
308
+ commutation.
309
+
310
+ Any other option will be passed as a strategy to
311
+ :func:`networkx.coloring.greedy_color` to generate the ordering.
312
+
313
+ Returns
314
+ -------
315
+ list[tuple[node]]
316
+ Sequence of coordinate pairs.
317
+ """
318
+ if order is None:
319
+ pairs = self.terms
320
+ elif order == "sort":
321
+ pairs = sorted(self.terms)
322
+ elif order == "random":
323
+ pairs = list(self.terms)
324
+ random.shuffle(pairs)
325
+ elif order == "random-ungrouped":
326
+ pairs = list(self.terms)
327
+ random.shuffle(pairs)
328
+ return pairs
329
+ else:
330
+ return edge_coloring(self.terms, order, group=False, **kwargs)
331
+
332
+ pairs = {x: None for x in pairs}
333
+
334
+ cover = set()
335
+ ordering = list()
336
+ while pairs:
337
+ for pair in tuple(pairs):
338
+ ij1, ij2 = pair
339
+ if (ij1 not in cover) and (ij2 not in cover):
340
+ ordering.append(pair)
341
+ pairs.pop(pair)
342
+ cover.add(ij1)
343
+ cover.add(ij2)
344
+ cover.clear()
345
+
346
+ return ordering
347
+
348
+ def __repr__(self):
349
+ s = "<LocalHamGen(nsites={}, num_terms={})>"
350
+ return s.format(self.nsites, len(self.terms))
351
+
352
+ @default_to_neutral_style
353
+ def draw(
354
+ self,
355
+ ordering="sort",
356
+ show_norm=True,
357
+ figsize=None,
358
+ fontsize=8,
359
+ legend=True,
360
+ ax=None,
361
+ **kwargs,
362
+ ):
363
+ """Plot this Hamiltonian as a network.
364
+
365
+ Parameters
366
+ ----------
367
+ ordering : {'sort', None, 'random'}, optional
368
+ An ordering of the termns, or an argument to be supplied to
369
+ :meth:`quimb.tensor.tensor_arbgeom_tebd.LocalHamGen.get_auto_ordering`
370
+ to generate this automatically.
371
+ show_norm : bool, optional
372
+ Show the norm of each term as edge labels.
373
+ figsize : None or tuple[int], optional
374
+ Size of the figure, defaults to size of Hamiltonian.
375
+ fontsize : int, optional
376
+ Font size for norm labels.
377
+ legend : bool, optional
378
+ Whether to show the legend of which terms are in which group.
379
+ ax : None or matplotlib.Axes, optional
380
+ Add to a existing set of axes.
381
+ """
382
+ import matplotlib.pyplot as plt
383
+ import networkx as nx
384
+
385
+ if figsize is None:
386
+ L = self.nsites**0.5 + 1
387
+ figsize = (L, L)
388
+
389
+ ax_supplied = ax is not None
390
+ if not ax_supplied:
391
+ fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)
392
+ ax.axis("off")
393
+ ax.set_aspect("equal")
394
+ else:
395
+ fig = None
396
+
397
+ if ordering is None or isinstance(ordering, str):
398
+ ordering = self.get_auto_ordering(ordering, **kwargs)
399
+
400
+ G = nx.Graph()
401
+ seen = set()
402
+ n = 0
403
+ edge_labels = dict()
404
+ for where in ordering:
405
+ site_a, site_b = where
406
+ if (site_a in seen) or (site_b in seen):
407
+ # start a new group
408
+ seen = {site_a, site_b}
409
+ n += 1
410
+ else:
411
+ seen.add(site_a)
412
+ seen.add(site_b)
413
+
414
+ nrm = do("linalg.norm", self.terms[where])
415
+ edge_labels[where] = f"{nrm:.2f}"
416
+ G.add_edge(site_a, site_b, norm=nrm, group=n)
417
+
418
+ num_groups = n + 1
419
+ colors = get_colors(range(num_groups))
420
+
421
+ pos = get_positions(None, G)
422
+
423
+ # do the plotting
424
+ nx.draw_networkx_edges(
425
+ G,
426
+ pos=pos,
427
+ width=tuple(2 * x[2]["norm"] ** 0.5 for x in G.edges(data=True)),
428
+ edge_color=tuple(
429
+ colors[x[2]["group"]] for x in G.edges(data=True)
430
+ ),
431
+ alpha=0.8,
432
+ ax=ax,
433
+ )
434
+ nx.draw_networkx_edge_labels(
435
+ G,
436
+ pos=pos,
437
+ edge_labels=edge_labels,
438
+ font_size=fontsize,
439
+ font_color=(0.5, 0.5, 0.5),
440
+ bbox=dict(alpha=0),
441
+ ax=ax,
442
+ )
443
+ nx.draw_networkx_labels(
444
+ G,
445
+ pos,
446
+ font_color=(0.2, 0.2, 0.2),
447
+ font_size=fontsize,
448
+ font_weight="bold",
449
+ ax=ax,
450
+ )
451
+
452
+ # create legend
453
+ if legend:
454
+ handles = []
455
+ for color in colors.values():
456
+ handles += [
457
+ plt.Line2D(
458
+ [0],
459
+ [0],
460
+ marker="o",
461
+ color=color,
462
+ linestyle="",
463
+ markersize=10,
464
+ )
465
+ ]
466
+
467
+ lbls = [f"Group {i + 1}" for i in range(num_groups)]
468
+
469
+ ax.legend(
470
+ handles,
471
+ lbls,
472
+ ncol=max(round(len(handles) / 20), 1),
473
+ loc="center left",
474
+ bbox_to_anchor=(1, 0.5),
475
+ )
476
+
477
+ return fig, ax
478
+
479
+ graph = draw
480
+
481
+
482
+ class TEBDGen:
483
+ """Generic class for performing time evolving block decimation on an
484
+ arbitrary graph, i.e. applying the exponential of a Hamiltonian using
485
+ a product formula that involves applying local exponentiated gates only.
486
+ """
487
+
488
+ def __init__(
489
+ self,
490
+ psi0,
491
+ ham,
492
+ tau=0.01,
493
+ D=None,
494
+ cutoff=1e-10,
495
+ imag=True,
496
+ gate_opts=None,
497
+ ordering=None,
498
+ second_order_reflect=False,
499
+ compute_energy_every=None,
500
+ compute_energy_final=True,
501
+ compute_energy_opts=None,
502
+ compute_energy_fn=None,
503
+ compute_energy_per_site=False,
504
+ tol=None,
505
+ callback=None,
506
+ keep_best=False,
507
+ progbar=True,
508
+ ):
509
+ self.imag = imag
510
+ if not imag:
511
+ raise NotImplementedError("Real time evolution not tested yet.")
512
+
513
+ self.state = psi0
514
+ self.ham = ham
515
+ self.progbar = progbar
516
+ self.callback = callback
517
+
518
+ # default time step to use
519
+ self.tau = tau
520
+ self.last_tau = 0.0
521
+
522
+ # parse gate application options
523
+ if D is None:
524
+ D = self._psi.max_bond()
525
+ self.gate_opts = ensure_dict(gate_opts)
526
+ self.gate_opts["max_bond"] = D
527
+ self.gate_opts.setdefault("cutoff", cutoff)
528
+ self.gate_opts.setdefault("contract", "reduce-split")
529
+
530
+ # parse energy computation options
531
+ self.compute_energy_opts = ensure_dict(compute_energy_opts)
532
+
533
+ self.compute_energy_every = compute_energy_every
534
+ self.compute_energy_final = compute_energy_final
535
+ self.compute_energy_fn = compute_energy_fn
536
+ self.compute_energy_per_site = bool(compute_energy_per_site)
537
+ self.tol = tol
538
+
539
+ if ordering is None:
540
+
541
+ def dynamic_random():
542
+ return self.ham.get_auto_ordering("random_sequential")
543
+
544
+ self.ordering = dynamic_random
545
+ elif isinstance(ordering, str):
546
+ self.ordering = self.ham.get_auto_ordering(ordering)
547
+ elif callable(ordering):
548
+ self.ordering = ordering
549
+ else:
550
+ self.ordering = tuple(ordering)
551
+
552
+ self.second_order_reflect = second_order_reflect
553
+
554
+ # storage
555
+ self._n = 0
556
+ self.its = []
557
+ self.taus = []
558
+ self.energies = []
559
+ self.energy_diffs = []
560
+ self.egrdm = ExponentialGeometricRollingDiffMean()
561
+
562
+ self.keep_best = bool(keep_best)
563
+ self.best = dict(energy=float("inf"), state=None, it=None)
564
+ self.stop = False
565
+
566
+ def sweep(self, tau):
567
+ r"""Perform a full sweep of gates at every pair.
568
+
569
+ .. math::
570
+
571
+ \psi \rightarrow \prod_{\{ij\}} \exp(-\tau H_{ij}) \psi
572
+
573
+ """
574
+ if callable(self.ordering):
575
+ ordering = self.ordering()
576
+ else:
577
+ ordering = self.ordering
578
+
579
+ if self.second_order_reflect:
580
+ ordering = tuple(ordering) + tuple(reversed(ordering))
581
+ factor = 2.0
582
+ else:
583
+ factor = 1.0
584
+
585
+ layer = set()
586
+
587
+ for where in ordering:
588
+ if any(coo in layer for coo in where):
589
+ # starting a new non-commuting layer
590
+ self.postlayer()
591
+ layer = set(where)
592
+ else:
593
+ # add to the current layer
594
+ layer.update(where)
595
+
596
+ if callable(tau):
597
+ self.last_tau = tau(where)
598
+ else:
599
+ self.last_tau = tau
600
+
601
+ G = self.ham.get_gate_expm(where, -self.last_tau / factor)
602
+
603
+ self.gate(G, where)
604
+
605
+ self.postlayer()
606
+
607
+ def _set_progbar_description(self, pbar):
608
+ desc = f"n={self._n}, tau={float(self.last_tau):.2g}"
609
+ if getattr(self, "gauge_diffs", None):
610
+ desc += f", max|dS|={self.gauge_diffs[-1]:.2g}"
611
+ if self.energies:
612
+ desc += f", energy~{float(self.energies[-1]):.6g}"
613
+ pbar.set_description(desc)
614
+
615
+ def evolve(self, steps, tau=None, progbar=None):
616
+ """Evolve the state with the local Hamiltonian for ``steps`` steps with
617
+ time step ``tau``.
618
+ """
619
+ if tau is not None:
620
+ if isinstance(tau, Iterable):
621
+ taus = itertools.chain(tau, itertools.repeat(tau[-1]))
622
+ else:
623
+ self.tau = tau
624
+ taus = itertools.repeat(tau)
625
+
626
+ if progbar is None:
627
+ progbar = self.progbar
628
+
629
+ pbar = Progbar(total=steps, disable=not progbar)
630
+
631
+ try:
632
+ for i, tau in zip(range(steps), taus):
633
+ # anything required by both energy and sweep
634
+ self.presweep(i)
635
+
636
+ # possibly compute the energy
637
+ should_compute_energy = bool(self.compute_energy_every) and (
638
+ i % self.compute_energy_every == 0
639
+ )
640
+ if should_compute_energy:
641
+ self._check_energy()
642
+ self._set_progbar_description(pbar)
643
+
644
+ # check for convergence
645
+ self.stop = (self.tol is not None) and (
646
+ self.energy_diffs[-1] < self.tol
647
+ )
648
+
649
+ if self.stop:
650
+ # maybe stop pre sweep
651
+ self.stop = False
652
+ break
653
+
654
+ # actually perform the gates
655
+ self.sweep(tau)
656
+ self.postsweep(i)
657
+
658
+ self._n += 1
659
+ pbar.update()
660
+ self._set_progbar_description(pbar)
661
+
662
+ if self.callback is not None:
663
+ if self.callback(self):
664
+ break
665
+
666
+ if self.stop:
667
+ # maybe stop post sweep
668
+ self.stop = False
669
+ break
670
+
671
+ # possibly compute the energy
672
+ if self.compute_energy_final:
673
+ self._check_energy()
674
+ self._set_progbar_description(pbar)
675
+
676
+ except KeyboardInterrupt:
677
+ # allow the user to interupt early
678
+ pass
679
+ finally:
680
+ pbar.close()
681
+
682
+ @property
683
+ def state(self):
684
+ """Return a copy of the current state."""
685
+ return self.get_state()
686
+
687
+ @state.setter
688
+ def state(self, psi):
689
+ self.set_state(psi)
690
+
691
+ @property
692
+ def n(self):
693
+ """The number of sweeps performed."""
694
+ return self._n
695
+
696
+ @property
697
+ def D(self):
698
+ """The maximum bond dimension."""
699
+ return self.gate_opts["max_bond"]
700
+
701
+ @D.setter
702
+ def D(self, value):
703
+ """The maximum bond dimension."""
704
+ self.gate_opts["max_bond"] = round(value)
705
+
706
+ def _check_energy(self):
707
+ """Logic for maybe computing the energy if needed."""
708
+ if self.its and (self._n == self.its[-1]):
709
+ # only compute if haven't already
710
+ return self.energies[-1]
711
+
712
+ if self.compute_energy_fn is not None:
713
+ en = self.compute_energy_fn(self)
714
+ else:
715
+ en = self.compute_energy()
716
+
717
+ if self.compute_energy_per_site:
718
+ en = en / self.ham.nsites
719
+
720
+ self.its.append(self._n)
721
+ self.taus.append(float(self.last_tau))
722
+
723
+ # update the energy and possibly the best state
724
+ self.energies.append(float(en))
725
+ if self.keep_best and en < self.best["energy"]:
726
+ self.best["energy"] = en
727
+ self.best["state"] = self.state
728
+ self.best["it"] = self._n
729
+
730
+ # update the energy difference mean and possibly marked converged
731
+ self.egrdm.update(float(en))
732
+ self.energy_diffs.append(self.egrdm.value)
733
+
734
+ if self.tol is not None:
735
+ self.stop = self.energy_diffs[-1] < self.tol
736
+
737
+ return self.energies[-1]
738
+
739
+ @property
740
+ def energy(self):
741
+ """Return the energy of current state, computing it only if necessary."""
742
+ return self._check_energy()
743
+
744
+ # ------- abstract methods that subclasses might want to override ------- #
745
+
746
+ def get_state(self):
747
+ """The default method for retrieving the current state - simply a copy.
748
+ Subclasses can override this to perform additional transformations.
749
+ """
750
+ return self._psi.copy()
751
+
752
+ def set_state(self, psi):
753
+ """The default method for setting the current state - simply a copy.
754
+ Subclasses can override this to perform additional transformations.
755
+ """
756
+ self._psi = psi.copy()
757
+
758
+ def presweep(self, i):
759
+ """Perform any computations required before the sweep (and energy
760
+ computation). For the basic TEBD update is nothing.
761
+ """
762
+ pass
763
+
764
+ def postlayer(self):
765
+ """Perform any computations required after each layer of commuting
766
+ gates. For the basic update this is nothing.
767
+ """
768
+ pass
769
+
770
+ def postsweep(self, i):
771
+ """Perform any computations required after the sweep (but before
772
+ the energy computation). For the basic update this is nothing.
773
+ """
774
+ pass
775
+
776
+ def gate(self, U, where):
777
+ """Perform single gate ``U`` at coordinate pair ``where``. This is the
778
+ the most common method to override.
779
+ """
780
+ self._psi.gate_(U, where, **self.gate_opts)
781
+
782
+ def compute_energy(self):
783
+ """Compute and return the energy of the current state. Subclasses can
784
+ override this with a custom method to compute the energy.
785
+ """
786
+ return self._psi.compute_local_expectation_cluster(
787
+ terms=self.ham.terms, **self.compute_energy_opts
788
+ )
789
+
790
+ @default_to_neutral_style
791
+ def plot(
792
+ self,
793
+ zoom="auto",
794
+ xscale="symlog",
795
+ xscale_linthresh=20,
796
+ color_energy=(0.0, 0.5, 1.0),
797
+ color_gauge_diff=(1.0, 0.5, 0.0),
798
+ hlines=(),
799
+ figsize=(8, 4),
800
+ ):
801
+ """Plot an overview of the evolution of the energy and gauge diffs.
802
+
803
+ Parameters
804
+ ----------
805
+ zoom : int or 'auto', optional
806
+ The number of iterations to zoom in on, or 'auto' to automatically
807
+ choose a reasonable zoom level.
808
+ xscale : {'linear', 'log', 'symlog'}, optional
809
+ The x-axis scale, for the upper plot of the entire evolution.
810
+ xscale_linthresh : float, optional
811
+ The linear threshold for the upper symlog scale.
812
+ color_energy : str or tuple, optional
813
+ The color to use for the energy plot.
814
+ color_gauge_diff : str or tuple, optional
815
+ The color to use for the gauge diff plot.
816
+ hlines : dict, optional
817
+ Add horizontal lines to the plot, with keys as labels and values
818
+ as the y-values.
819
+ figsize : tuple, optional
820
+ The size of the figure.
821
+
822
+ Returns
823
+ -------
824
+ fig, axs : matplotlib.Figure, tuple[matplotlib.Axes]
825
+ """
826
+ import matplotlib.pyplot as plt
827
+ import numpy as np
828
+ from matplotlib.ticker import ScalarFormatter
829
+ from matplotlib.colors import hsv_to_rgb
830
+
831
+ def set_axis_color(ax, which, color):
832
+ ax.spines[which].set_visible(True)
833
+ ax.spines[which].set_color(color)
834
+ ax.yaxis.label.set_color(color)
835
+ ax.tick_params(axis="y", colors=color, which="both")
836
+
837
+ x_en = np.array(self.its)
838
+ y_en = np.array(self.energies)
839
+ x_gd = np.arange(1, len(self.gauge_diffs) + 1)
840
+ y_gd = np.array(self.gauge_diffs)
841
+
842
+ if zoom is not None:
843
+ if zoom == "auto":
844
+ zoom = min(200, self.n // 2)
845
+ nz = self.n - zoom
846
+
847
+ fig, axs = plt.subplots(nrows=2, figsize=figsize)
848
+
849
+ # plotted zoomed out
850
+ # energy
851
+ axl = axs[0]
852
+ axl.plot(x_en, y_en, marker="|", color=color_energy)
853
+ axl.set_xscale(xscale, linthresh=xscale_linthresh)
854
+ axl.set_ylabel("Energy")
855
+ axl.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
856
+ set_axis_color(axl, "left", color_energy)
857
+ # gauge diff
858
+ axr = axl.twinx()
859
+ axr.plot(
860
+ x_gd,
861
+ y_gd,
862
+ linestyle="--",
863
+ color=color_gauge_diff,
864
+ )
865
+ axr.set_ylabel("Max gauge diff")
866
+ axr.set_yscale("log")
867
+ set_axis_color(axr, "right", color_gauge_diff)
868
+
869
+ axl.axvline(
870
+ nz,
871
+ color=(0.5, 0.5, 0.5, 0.5),
872
+ linestyle="-",
873
+ linewidth=1,
874
+ )
875
+
876
+ # plotted zoomed in
877
+ # energy
878
+ iz = min(range(len(x_en)), key=lambda i: x_en[i] < nz)
879
+ axl = axs[1]
880
+ axl.plot(x_en[iz:], y_en[iz:], marker="|", color=color_energy)
881
+ axl.set_ylabel("Energy")
882
+ axl.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
883
+ set_axis_color(axl, "left", color_energy)
884
+ axl.set_xlabel("Iteration")
885
+ # gauge diff
886
+ iz = min(range(len(x_gd)), key=lambda i: x_gd[i] < nz)
887
+ axr = axl.twinx()
888
+ axr.plot(
889
+ x_gd[iz:],
890
+ y_gd[iz:],
891
+ linestyle="--",
892
+ color=color_gauge_diff,
893
+ )
894
+ axr.set_ylabel("Max gauge diff")
895
+ axr.set_yscale("log")
896
+ set_axis_color(axr, "right", color_gauge_diff)
897
+
898
+ if hlines:
899
+ hlines = dict(hlines)
900
+ for i, (label, value) in enumerate(hlines.items()):
901
+ color = hsv_to_rgb([(0.45 - (0.08 * i)) % 1.0, 0.7, 0.6])
902
+ axs[0].axhline(value, color=color, ls=":", label=label)
903
+ axs[1].axhline(value, color=color, ls=":", label=label)
904
+ axs[0].text(
905
+ 1, value, label, color=color, va="bottom", ha="left"
906
+ )
907
+ axs[1].text(
908
+ nz, value, label, color=color, va="bottom", ha="left"
909
+ )
910
+
911
+ return fig, axs
912
+
913
+ def __repr__(self):
914
+ s = "<{}(n={}, tau={}, D={})>"
915
+ return s.format(self.__class__.__name__, self.n, self.tau, self.D)
916
+
917
+
918
+ class SimpleUpdateGen(TEBDGen):
919
+ """Simple update for arbitrary geometry hamiltonians.
920
+
921
+ Parameters
922
+ ----------
923
+ psi0 : TensorNetworkGenVector
924
+ The initial state.
925
+ ham : LocalHamGen
926
+ The local hamiltonian.
927
+ tau : float, optional
928
+ The default time step to use.
929
+ D : int, optional
930
+ The maximum bond dimension, by default the current maximum bond of
931
+ ``psi0``.
932
+ cutoff : float, optional
933
+ The singular value cutoff to use when applying gates.
934
+ imag : bool, optional
935
+ Whether to evolve in imaginary time (default) or real time.
936
+ gate_opts : dict, optional
937
+ Other options to supply to the gate application method,
938
+ :meth:`quimb.tensor.tensor_arbgeom.TensorNetworkGenVector.gate_simple_`.
939
+ ordering : None, str or callable, optional
940
+ The ordering of the terms to apply, by default this will be determined
941
+ automatically.
942
+ second_order_reflect : bool, optional
943
+ Whether to use a second order Trotter decomposition by reflecting the
944
+ ordering.
945
+ compute_energy_every : int, optional
946
+ Compute the energy every this many steps.
947
+ compute_energy_final : bool, optional
948
+ Whether to compute the energy at the end.
949
+ compute_energy_opts : dict, optional
950
+ Options to supply to the energy computation method,
951
+ :func:`quimb.tensor.tensor_arbgeom.TensorNetworkGenVector.compute_local_expectation_cluster`.
952
+ compute_energy_fn : callable, optional
953
+ A custom function to compute the energy, with signature ``fn(su)``,
954
+ where ``su`` is this instance.
955
+ compute_energy_per_site : bool, optional
956
+ Whether to compute the energy per site.
957
+ tol : float, optional
958
+ If not ``None``, stop when either energy difference falls below this
959
+ value, or maximum singluar value changes fall below this value.
960
+ equilibrate_every : int, optional
961
+ Equilibrate the gauges every this many steps.
962
+ equilibrate_start : bool, optional
963
+ Whether to equilibrate the gauges at the start, regardless of
964
+ ``equilibrate_every``.
965
+ equilibrate_opts : dict, optional
966
+ Default options to supply to the gauge equilibration method, see
967
+ :meth:`quimb.tensor.tensor_core.TensorNetwork.gauge_all_simple`. By
968
+ default `max_iterations` is set to 100 and `tol` to 1e-3.
969
+ callback : callable, optional
970
+ A function to call after each step, with signature ``fn(su)``.
971
+ keep_best : bool, optional
972
+ Whether to keep track of the best state and energy.
973
+ progbar : bool, optional
974
+ Whether to show a progress bar during evolution.
975
+ """
976
+
977
+ def __init__(
978
+ self,
979
+ psi0,
980
+ ham,
981
+ tau=0.01,
982
+ D=None,
983
+ cutoff=1e-10,
984
+ imag=True,
985
+ gate_opts=None,
986
+ ordering=None,
987
+ second_order_reflect=False,
988
+ compute_energy_every=None,
989
+ compute_energy_final=True,
990
+ compute_energy_opts=None,
991
+ compute_energy_fn=None,
992
+ compute_energy_per_site=False,
993
+ tol=None,
994
+ equilibrate_every=0,
995
+ equilibrate_start=True,
996
+ equilibrate_opts=None,
997
+ callback=None,
998
+ keep_best=False,
999
+ progbar=True,
1000
+ ):
1001
+ self.equilibrate_every = equilibrate_every
1002
+ self.equilibrate_start = bool(equilibrate_start)
1003
+ self.equilibrate_opts = equilibrate_opts or {}
1004
+ self.equilibrate_opts.setdefault("max_iterations", 100)
1005
+ self.equilibrate_opts.setdefault("tol", 1e-3)
1006
+
1007
+ self.gauges_prev = None
1008
+ self.gauge_diffs = []
1009
+
1010
+ return super().__init__(
1011
+ psi0,
1012
+ ham,
1013
+ tau=tau,
1014
+ D=D,
1015
+ cutoff=cutoff,
1016
+ imag=imag,
1017
+ gate_opts=gate_opts,
1018
+ ordering=ordering,
1019
+ second_order_reflect=second_order_reflect,
1020
+ compute_energy_every=compute_energy_every,
1021
+ compute_energy_final=compute_energy_final,
1022
+ compute_energy_opts=compute_energy_opts,
1023
+ compute_energy_fn=compute_energy_fn,
1024
+ compute_energy_per_site=compute_energy_per_site,
1025
+ tol=tol,
1026
+ callback=callback,
1027
+ keep_best=keep_best,
1028
+ progbar=progbar,
1029
+ )
1030
+
1031
+ def gate(self, G, where):
1032
+ """Application of a single gate ``G`` at ``where``."""
1033
+ self._psi.gate_simple_(G, where, gauges=self.gauges, **self.gate_opts)
1034
+
1035
+ if self.equilibrate_every == "gate":
1036
+ tags = [self._psi.site_tag(x) for x in where]
1037
+ tids = self._psi._get_tids_from_tags(tags, "any")
1038
+ self.equilibrate(touched_tids=tids)
1039
+
1040
+ def equilibrate(self, **kwargs):
1041
+ """Equilibrate the gauges with the current state (like evolving with
1042
+ tau=0).
1043
+ """
1044
+ # allow overriding of default options
1045
+ kwargs = {**self.equilibrate_opts, **kwargs}
1046
+ self._psi.gauge_all_simple_(gauges=self.gauges, **kwargs)
1047
+
1048
+ def postlayer(self):
1049
+ """Performed after each layer of commuting gates."""
1050
+ if self.equilibrate_every == "layer":
1051
+ self.equilibrate()
1052
+
1053
+ def postsweep(self, i):
1054
+ """Performed after every full sweep."""
1055
+ should_equilibrate = (
1056
+ # str settings are equilibrated elsewhere
1057
+ (not isinstance(self.equilibrate_every, str))
1058
+ and (self.equilibrate_every > 0)
1059
+ and (i % self.equilibrate_every == 0)
1060
+ )
1061
+
1062
+ if should_equilibrate:
1063
+ self.equilibrate()
1064
+
1065
+ # check gauges for convergence / progbar
1066
+ if self.gauges_prev is not None:
1067
+ sdiffs = []
1068
+ for k, g in self.gauges.items():
1069
+ g_prev = self.gauges_prev[k]
1070
+ try:
1071
+ sdiff = do("linalg.norm", g - g_prev)
1072
+ except ValueError:
1073
+ # gauge has changed size
1074
+ sdiff = 1.0
1075
+ sdiffs.append(sdiff)
1076
+
1077
+ max_sdiff = max(sdiffs)
1078
+ self.gauge_diffs.append(max_sdiff)
1079
+
1080
+ if self.tol is not None and (max_sdiff < self.tol):
1081
+ self.stop = True
1082
+
1083
+ self.gauges_prev = self.gauges.copy()
1084
+
1085
+ def normalize(self):
1086
+ """Normalize the state and simple gauges."""
1087
+ self._psi.normalize_simple(self.gauges)
1088
+
1089
+ def compute_energy(self):
1090
+ """Default estimate of the energy."""
1091
+ return self._psi.compute_local_expectation_cluster(
1092
+ terms=self.ham.terms,
1093
+ gauges=self.gauges,
1094
+ **self.compute_energy_opts,
1095
+ )
1096
+
1097
+ def get_state(self, absorb_gauges=True):
1098
+ """Return the current state, possibly absorbing the gauges.
1099
+
1100
+ Parameters
1101
+ ----------
1102
+ absorb_gauges : bool or "return", optional
1103
+ Whether to absorb the gauges into the state or not. If `True`, a
1104
+ standard PEPS is returned with the gauges absorbed. If `False``,
1105
+ the gauges are added to the tensor network but uncontracted. If
1106
+ "return", the gauges are returned separately.
1107
+
1108
+ Returns
1109
+ -------
1110
+ psi : TensorNetwork
1111
+ The current state.
1112
+ gauges : dict
1113
+ The current gauges, if ``absorb_gauges == "return"``.
1114
+ """
1115
+ psi = self._psi.copy()
1116
+
1117
+ if absorb_gauges == "return":
1118
+ return psi, self.gauges.copy()
1119
+
1120
+ if absorb_gauges:
1121
+ psi.gauge_simple_insert(self.gauges)
1122
+ else:
1123
+ for ix, g in self.gauges.items():
1124
+ psi |= Tensor(g, inds=[ix])
1125
+
1126
+ return psi
1127
+
1128
+ def set_state(self, psi, gauges=None):
1129
+ """Set the current state and possibly the gauges."""
1130
+ self._psi = psi.copy()
1131
+ if gauges is None:
1132
+ self.gauges = {}
1133
+ self._psi.gauge_all_simple_(max_iterations=1, gauges=self.gauges)
1134
+ else:
1135
+ self.gauges = dict(gauges)
1136
+
1137
+ if self.equilibrate_start:
1138
+ self.equilibrate()