Trajectree 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +9 -9
  5. trajectree/fock_optics/outputs.py +10 -6
  6. trajectree/fock_optics/utils.py +9 -6
  7. trajectree/sequence/swap.py +5 -4
  8. trajectree/trajectory.py +5 -4
  9. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/METADATA +2 -3
  10. trajectree-0.0.3.dist-info/RECORD +16 -0
  11. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  12. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  13. trajectree/quimb/docs/conf.py +0 -158
  14. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  15. trajectree/quimb/quimb/__init__.py +0 -507
  16. trajectree/quimb/quimb/calc.py +0 -1491
  17. trajectree/quimb/quimb/core.py +0 -2279
  18. trajectree/quimb/quimb/evo.py +0 -712
  19. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  20. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  21. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  22. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  23. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  24. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  25. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  26. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  27. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  28. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  29. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  30. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  31. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  32. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  33. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  34. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  35. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  36. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  37. trajectree/quimb/quimb/gates.py +0 -36
  38. trajectree/quimb/quimb/gen/__init__.py +0 -2
  39. trajectree/quimb/quimb/gen/operators.py +0 -1167
  40. trajectree/quimb/quimb/gen/rand.py +0 -713
  41. trajectree/quimb/quimb/gen/states.py +0 -479
  42. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  43. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  44. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  45. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  46. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  47. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  48. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  49. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  50. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  51. trajectree/quimb/quimb/schematic.py +0 -1518
  52. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  53. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  54. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  55. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  56. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  57. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  58. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  59. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  60. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  61. trajectree/quimb/quimb/tensor/interface.py +0 -114
  62. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  63. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  64. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  65. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  66. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  67. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  68. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  69. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  70. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  71. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  72. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  74. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  75. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  76. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  77. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  78. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  79. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  80. trajectree/quimb/quimb/utils.py +0 -892
  81. trajectree/quimb/tests/__init__.py +0 -0
  82. trajectree/quimb/tests/test_accel.py +0 -501
  83. trajectree/quimb/tests/test_calc.py +0 -788
  84. trajectree/quimb/tests/test_core.py +0 -847
  85. trajectree/quimb/tests/test_evo.py +0 -565
  86. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  87. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  88. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  89. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  90. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  91. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  92. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  93. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  94. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  95. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  96. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  97. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  103. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  104. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  105. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  106. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  107. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  108. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  109. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  110. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  111. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  112. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  113. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  114. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  115. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  116. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  117. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  118. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  119. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  120. trajectree/quimb/tests/test_utils.py +0 -85
  121. trajectree-0.0.1.dist-info/RECORD +0 -126
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/WHEEL +0 -0
  123. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/licenses/LICENSE +0 -0
  124. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/top_level.txt +0 -0
@@ -1,194 +0,0 @@
1
- class RegionGraph:
2
- def __init__(self, regions=(), autocomplete=True):
3
- self.lookup = {}
4
- self.parents = {}
5
- self.children = {}
6
- self.counts = {}
7
- for region in regions:
8
- self.add_region(region)
9
- if autocomplete:
10
- self.autocomplete()
11
-
12
- @property
13
- def regions(self):
14
- return tuple(self.children)
15
-
16
- def neighbor_regions(self, region):
17
- """Get all regions that intersect with the given region."""
18
- region = frozenset(region)
19
-
20
- other_regions = set.union(*(self.lookup[node] for node in region))
21
- other_regions.discard(region)
22
- return other_regions
23
-
24
- def add_region(self, region):
25
- """Add a new region and update parent-child relationships.
26
-
27
- Parameters
28
- ----------
29
- region : Sequence[Hashable]
30
- The new region to add.
31
- """
32
- region = frozenset(region)
33
-
34
- if region in self.parents:
35
- # already added
36
- return
37
-
38
- # populate data structures
39
- self.parents[region] = set()
40
- self.children[region] = set()
41
- for node in region:
42
- # collect regions that contain nodes for fast neighbor lookup
43
- self.lookup.setdefault(node, set()).add(region)
44
-
45
- # add parent-child relationships
46
- for other in self.neighbor_regions(region):
47
- if region.issubset(other):
48
- self.parents[region].add(other)
49
- self.children[other].add(region)
50
- elif other.issubset(region):
51
- self.children[region].add(other)
52
- self.parents[other].add(region)
53
-
54
- # prune redundant parents and children
55
- children = sorted(self.children[region], key=len)
56
- for i, c in enumerate(children):
57
- if any(c.issubset(cc) for cc in children[i + 1 :]):
58
- # child is a subset of larger child -> remove
59
- self.children[region].remove(c)
60
- self.parents[c].remove(region)
61
-
62
- parents = sorted(self.parents[region], key=len, reverse=True)
63
- for i, p in enumerate(parents):
64
- if any(p.issuperset(pp) for pp in parents[i + 1 :]):
65
- # parent is a superset of smaller parent -> remove
66
- self.parents[region].remove(p)
67
- self.children[p].remove(region)
68
-
69
- self.counts.clear()
70
-
71
- def autocomplete(self):
72
- """Add all missing intersecting sub-regions."""
73
- for r in self.regions:
74
- for other in self.neighbor_regions(r):
75
- self.add_region(r & other)
76
-
77
- def autoextend(self, regions=None):
78
- """Extend this region graph upwards by adding in all pairwise unions of
79
- regions. If regions is specified, take this as one set of pairs.
80
- """
81
- if regions is None:
82
- regions = self.regions
83
-
84
- neighbors = {}
85
- for r in regions:
86
- for other in self.neighbor_regions(r):
87
- neighbors.setdefault(r, []).append(other)
88
-
89
- for r, others in neighbors.items():
90
- for other in others:
91
- self.add_region(r | other)
92
-
93
- def get_parents(self, region):
94
- """Get all ancestors that contain the given region, but do not contain
95
- any other regions that themselves contain the given region.
96
- """
97
- return self.parents[region]
98
-
99
- def get_children(self, region):
100
- """Get all regions that are contained by the given region, but are not
101
- contained by any other descendents of the given region.
102
- """
103
- return self.children[region]
104
-
105
- def get_ancestors(self, region):
106
- """Get all regions that contain the given region, not just direct
107
- parents.
108
- """
109
- seen = set()
110
- queue = [region]
111
- while queue:
112
- r = queue.pop()
113
- for rp in self.parents[r]:
114
- if rp not in seen:
115
- seen.add(rp)
116
- queue.append(rp)
117
- return seen
118
-
119
- def get_descendents(self, region):
120
- """Get all regions that are contained by the given region, not just
121
- direct children.
122
- """
123
- seen = set()
124
- queue = [region]
125
- while queue:
126
- r = queue.pop()
127
- for rc in self.children[r]:
128
- if rc not in seen:
129
- seen.add(rc)
130
- queue.append(rc)
131
- return seen
132
-
133
- def get_count(self, region):
134
- """Get the count of the given region, i.e. the correct weighting to
135
- apply when summing over all regions to avoid overcounting.
136
- """
137
- try:
138
- C = self.counts[region]
139
- except KeyError:
140
- # n.b. cache is cleared when any new region is added
141
- C = self.counts[region] = 1 - sum(
142
- self.get_count(a) for a in self.get_ancestors(region)
143
- )
144
- return C
145
-
146
- def get_total_count(self):
147
- return sum(map(self.get_count, self.regions))
148
-
149
- def get_level(self, region):
150
- """Get the level of the given region, i.e. the distance to an ancestor
151
- with no parents.
152
- """
153
- if not self.parents[region]:
154
- return 0
155
- else:
156
- return min(self.get_level(p) for p in self.parents[region]) - 1
157
-
158
- def draw(self, pos=None, a=20, scale=1.0, radius=0.1, **drawing_opts):
159
- from quimb.schematic import Drawing, hash_to_color
160
-
161
- if pos is None:
162
- pos = {node: node for node in self.lookup}
163
-
164
- def get_draw_pos(coo):
165
- return tuple(scale * s for s in pos[coo])
166
-
167
- sizes = {len(r) for r in self.regions}
168
- levelmap = {s: i for i, s in enumerate(sorted(sizes))}
169
-
170
- d = Drawing(a=a, **drawing_opts)
171
- for region in sorted(self.regions, key=len, reverse=True):
172
- # level = self.get_level(region)
173
- # level = len(region)
174
- level = levelmap[len(region)]
175
-
176
- coos = [(*get_draw_pos(coo), 2.0 * level) for coo in region]
177
-
178
- d.patch_around(
179
- coos,
180
- radius=radius,
181
- # edgecolor=hash_to_color(str(region)),
182
- facecolor=hash_to_color(str(region)),
183
- alpha=1 / 3,
184
- linestyle="",
185
- linewidth=3,
186
- )
187
-
188
- return d.fig, d.ax
189
-
190
- def __repr__(self):
191
- return (
192
- f"<RegionGraph(regions={len(self.regions)}, "
193
- f"total_count={self.get_total_count()})>"
194
- )
@@ -1,286 +0,0 @@
1
- """Implementation of arbitrary geometry wavefunction cluster update.
2
- """
3
- import functools
4
-
5
- from quimb.tensor.tensor_core import ensure_dict, bonds, bonds_size, do
6
- from quimb.tensor.tensor_arbgeom_tebd import SimpleUpdateGen
7
-
8
-
9
- def gate_inds_nn_fit(
10
- self,
11
- G,
12
- ind1,
13
- ind2,
14
- max_bond=None,
15
- method="als",
16
- pregauge=2,
17
- init_simple_guess=True,
18
- steps=10,
19
- fit_opts=None,
20
- contract_opts=None,
21
- inplace=False,
22
- ):
23
- """Gate two nearest neighbor outer indices, using full fitting of
24
- reduced tensors with respect to the environment. This is more accurate
25
- than a simple reduced gate when restricting the bond dimension.
26
-
27
- Parameters
28
- ----------
29
- G : array_like
30
- The gate to fit.
31
- ind1, ind2 : str
32
- The indices to gate.
33
- max_bond : int, optional
34
- The maximum bond dimension to use. If ``None``, use the maximum
35
- bond dimension that the tensors currently share.
36
- method : {'als', 'autodiff'}, optional
37
- The method to use for fitting.
38
- pregauge : int, optional
39
- How many times to locally canonize from the purified environment
40
- tensor to both the left and right reduced tensors.
41
- init_simple_guess : bool, optional
42
- Whether to use a 'simple update' guess for the initial guess. This
43
- can be quite high quality already if pregauging is used.
44
- steps : int, optional
45
- The number of steps to use for fitting, can be ``0`` in which case
46
- the initial guess is used, which in conjuction with the
47
- envinronment pregauging can still be quite high quality.
48
- inplace : bool, optional
49
- Whether to update the tensor network in place.
50
- contract_opts
51
- Supplied to
52
- :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract`.
53
- """
54
- fit_opts = ensure_dict(fit_opts)
55
- contract_opts = ensure_dict(contract_opts)
56
-
57
- ket = self.copy()
58
-
59
- # move indices onto shared bond so environment can be contracted
60
- ket.reduce_inds_onto_bond(
61
- ind1, ind2, combine=False, ndim_cutoff=0, tags="__REDUCED__"
62
- )
63
-
64
- ket.add_tag("__KET__")
65
- bra = ket.conj().retag_({"__KET__": "__BRA__"})
66
- norm = ket | bra
67
-
68
- # contract environment -> all but reduced tensors
69
- norm.contract_tags_("__REDUCED__", "!any", **contract_opts)
70
-
71
- # get tensors and bond names
72
- (tide,) = norm._get_tids_from_tags("__REDUCED__", "!any")
73
- (te,) = norm._tids_get(tide)
74
- (tid_k1,) = norm._get_tids_from_inds(ind1) & norm._get_tids_from_tags(
75
- "__KET__"
76
- )
77
- (tid_k2,) = norm._get_tids_from_inds(ind2) & norm._get_tids_from_tags(
78
- "__KET__"
79
- )
80
- (tid_b1,) = norm._get_tids_from_inds(ind1) & norm._get_tids_from_tags(
81
- "__BRA__"
82
- )
83
- (tid_b2,) = norm._get_tids_from_inds(ind2) & norm._get_tids_from_tags(
84
- "__BRA__"
85
- )
86
- tk1 = norm.tensor_map[tid_k1]
87
- tk2 = norm.tensor_map[tid_k2]
88
- tb1 = norm.tensor_map[tid_b1]
89
- tb2 = norm.tensor_map[tid_b2]
90
-
91
- (ix_ek1,) = bonds(te, tk1)
92
- (ix_ek2,) = bonds(te, tk2)
93
- (ix_eb1,) = bonds(te, tb1)
94
- (ix_eb2,) = bonds(te, tb2)
95
-
96
- if max_bond is None:
97
- max_bond = bonds_size(tk1, tk2)
98
-
99
- # split environment to get purification
100
- _, tek = te.split(
101
- left_inds=[ix_eb1, ix_eb2],
102
- right_inds=[ix_ek1, ix_ek2],
103
- method="svd",
104
- )
105
-
106
- # ┌────────────────┐
107
- # ┌─┴┐ ┌──┐ │ :
108
- # │k1├──────┤k2├───┐ │ : this is the purification
109
- # └─┬┘ └─┬┘ ├─┴┐ : we'll fit to gated version of itself
110
- # │ │ │ek│ :
111
- # │ind1 │ └─┬┘
112
- # │ ind2│ ┌─┴┐
113
- # │ │ │eb│
114
- # ┌─┴┐ ┌─┴┐ ├─┬┘
115
- # │b1├──────┤b2├───┘ │
116
- # └─┬┘ └──┘ │
117
- # └────────────────┘
118
-
119
- for _ in range(int(pregauge)):
120
- # perform some conditioning: local gauging from Q -> R tensors
121
- for ind, ix_ek, tk in [
122
- (ind1, ix_ek1, tk1),
123
- (ind2, ix_ek2, tk2),
124
- ]:
125
- R = tek.split(
126
- left_inds=None,
127
- right_inds=[ix_ek],
128
- method="qr",
129
- get="arrays",
130
- )[1]
131
-
132
- Rinv = do("linalg.inv", R)
133
- # get Q tensor in uncontracted TN
134
- (tidkq,) = ket._get_tids_from_inds(
135
- ix_ek
136
- ) - ket._get_tids_from_inds(ind)
137
- tkq = ket.tensor_map[tidkq]
138
- # need to keep environment and bare tensors in sync
139
- tkq.gate_(Rinv.T, ix_ek)
140
- tek.gate_(Rinv.T, ix_ek)
141
- tk.gate_(R, ix_ek)
142
-
143
- # form purification
144
- tnl = tek | tk1 | tk2
145
- # form gated purification
146
- tnl_target = tnl.gate_inds(G, (ind1, ind2), contract=True)
147
-
148
- if init_simple_guess:
149
- # maybe initialize with simple guess
150
- tnl.gate_inds_(
151
- G, (ind1, ind2), contract="split", max_bond=max_bond, cutoff=0.0
152
- )
153
-
154
- if steps:
155
- # perform the actual fitting, specifying only the reduced tensors
156
- tnl.fit_(
157
- tnl_target,
158
- tags=["__REDUCED__"],
159
- method=method,
160
- steps=steps,
161
- **fit_opts,
162
- )
163
-
164
- # re-absorb reduced factors
165
- ket.contract_ind(ix_ek1)
166
- ket.contract_ind(ix_ek2)
167
-
168
- # TN we will return
169
- new = self if inplace else self.copy()
170
- (tn1,) = new._inds_get(ind1)
171
- (tn2,) = new._inds_get(ind2)
172
-
173
- # permute to match original tensors
174
- (t1,) = ket._inds_get(ind1)
175
- (t2,) = ket._inds_get(ind2)
176
- t1.transpose_like_(tn1)
177
- t2.transpose_like_(tn2)
178
-
179
- tn1.modify(data=t1.data)
180
- tn2.modify(data=t2.data)
181
-
182
- return new
183
-
184
-
185
- gate_inds_nn_fit_ = functools.partialmethod(gate_inds_nn_fit, inplace=True)
186
-
187
- # TensorNetwork.gate_inds_nn_fit =gate_inds_nn_fit
188
- # TensorNetwork.gate_inds_nn_fit_ = gate_inds_nn_fit
189
-
190
-
191
- class ClusterUpdateNNGen(SimpleUpdateGen):
192
- """Cluster update for arbitrary geometry nearest neighbor hamiltonians.
193
- This keeps track of simple update style gauges, in order to approximately
194
- partial trace beyond ``cluster_radius`` and form an approximate environment
195
- for two nearest neighbor sites that be used to fit the gate with higher
196
- quality than simple update only.
197
- """
198
-
199
- def __init__(
200
- self,
201
- psi0,
202
- ham,
203
- tau=0.01,
204
- D=None,
205
- cluster_radius=1,
206
- cluster_fillin=0,
207
- gauge_smudge=1e-6,
208
- imag=True,
209
- gate_opts=None,
210
- ordering=None,
211
- second_order_reflect=False,
212
- compute_energy_every=None,
213
- compute_energy_final=True,
214
- compute_energy_opts=None,
215
- compute_energy_fn=None,
216
- compute_energy_per_site=False,
217
- callback=None,
218
- keep_best=False,
219
- progbar=True,
220
- ):
221
- super().__init__(
222
- psi0=psi0,
223
- ham=ham,
224
- tau=tau,
225
- D=D,
226
- imag=imag,
227
- gate_opts=gate_opts,
228
- ordering=ordering,
229
- second_order_reflect=second_order_reflect,
230
- compute_energy_every=compute_energy_every,
231
- compute_energy_final=compute_energy_final,
232
- compute_energy_opts=compute_energy_opts,
233
- compute_energy_fn=compute_energy_fn,
234
- compute_energy_per_site=compute_energy_per_site,
235
- callback=callback,
236
- keep_best=keep_best,
237
- progbar=progbar,
238
- )
239
- self.cluster_radius = cluster_radius
240
- self.cluster_fillin = cluster_fillin
241
- self.gauge_smudge = gauge_smudge
242
-
243
- # override some default TEBDGen gate_opts
244
- self.gate_opts.pop("cutoff")
245
- self.gate_opts.pop("contract")
246
-
247
- def gate(self, U, where):
248
- taga, tagb = self._psi.gen_tags_from_coos(where)
249
- inda, indb = self._psi.gen_inds_from_coos(where)
250
-
251
- # get the local cluster
252
- psi_local = self._psi.select_local(
253
- (taga, tagb),
254
- "any",
255
- max_distance=self.cluster_radius,
256
- fillin=self.cluster_fillin,
257
- virtual=True,
258
- )
259
-
260
- # temporarily gauge it with 'simple' gauges
261
- with psi_local.gauge_simple_temp(
262
- self.gauges,
263
- smudge=self.gauge_smudge,
264
- ):
265
- # fit the gate to the gauged local cluster
266
- gate_inds_nn_fit(
267
- psi_local,
268
- U,
269
- inda,
270
- indb,
271
- **self.gate_opts,
272
- )
273
-
274
- # update nearest gauges for the modified tensors
275
- self._psi.gauge_local_(
276
- (taga, tagb),
277
- "any",
278
- max_distance=1,
279
- method="simple",
280
- gauges=self.gauges,
281
- smudge=self.gauge_smudge,
282
- )
283
-
284
- # perform some conditioning
285
- self._psi.equalize_norms_(1.0)
286
- self._psi.exponent = 0.0