Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trajectree/__init__.py +3 -0
- trajectree/fock_optics/devices.py +1 -1
- trajectree/fock_optics/light_sources.py +2 -2
- trajectree/fock_optics/measurement.py +3 -3
- trajectree/fock_optics/utils.py +6 -6
- trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
- trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
- trajectree/quimb/docs/conf.py +158 -0
- trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
- trajectree/quimb/quimb/__init__.py +507 -0
- trajectree/quimb/quimb/calc.py +1491 -0
- trajectree/quimb/quimb/core.py +2279 -0
- trajectree/quimb/quimb/evo.py +712 -0
- trajectree/quimb/quimb/experimental/__init__.py +0 -0
- trajectree/quimb/quimb/experimental/autojittn.py +129 -0
- trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
- trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
- trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
- trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
- trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
- trajectree/quimb/quimb/experimental/schematic.py +7 -0
- trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
- trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
- trajectree/quimb/quimb/gates.py +36 -0
- trajectree/quimb/quimb/gen/__init__.py +2 -0
- trajectree/quimb/quimb/gen/operators.py +1167 -0
- trajectree/quimb/quimb/gen/rand.py +713 -0
- trajectree/quimb/quimb/gen/states.py +479 -0
- trajectree/quimb/quimb/linalg/__init__.py +6 -0
- trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
- trajectree/quimb/quimb/linalg/autoblock.py +258 -0
- trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
- trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
- trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
- trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
- trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
- trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
- trajectree/quimb/quimb/schematic.py +1518 -0
- trajectree/quimb/quimb/tensor/__init__.py +401 -0
- trajectree/quimb/quimb/tensor/array_ops.py +610 -0
- trajectree/quimb/quimb/tensor/circuit.py +4824 -0
- trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
- trajectree/quimb/quimb/tensor/contraction.py +336 -0
- trajectree/quimb/quimb/tensor/decomp.py +1255 -0
- trajectree/quimb/quimb/tensor/drawing.py +1646 -0
- trajectree/quimb/quimb/tensor/fitting.py +385 -0
- trajectree/quimb/quimb/tensor/geometry.py +583 -0
- trajectree/quimb/quimb/tensor/interface.py +114 -0
- trajectree/quimb/quimb/tensor/networking.py +1058 -0
- trajectree/quimb/quimb/tensor/optimize.py +1818 -0
- trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
- trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
- trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
- trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
- trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
- trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
- trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
- trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
- trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
- trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
- trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
- trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
- trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
- trajectree/quimb/quimb/utils.py +892 -0
- trajectree/quimb/tests/__init__.py +0 -0
- trajectree/quimb/tests/test_accel.py +501 -0
- trajectree/quimb/tests/test_calc.py +788 -0
- trajectree/quimb/tests/test_core.py +847 -0
- trajectree/quimb/tests/test_evo.py +565 -0
- trajectree/quimb/tests/test_gen/__init__.py +0 -0
- trajectree/quimb/tests/test_gen/test_operators.py +361 -0
- trajectree/quimb/tests/test_gen/test_rand.py +296 -0
- trajectree/quimb/tests/test_gen/test_states.py +261 -0
- trajectree/quimb/tests/test_linalg/__init__.py +0 -0
- trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
- trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
- trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
- trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
- trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
- trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
- trajectree/quimb/tests/test_tensor/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
- trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
- trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
- trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
- trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
- trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
- trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
- trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
- trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
- trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
- trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
- trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
- trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
- trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
- trajectree/quimb/tests/test_utils.py +85 -0
- trajectree/trajectory.py +2 -2
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
- trajectree-0.0.1.dist-info/RECORD +126 -0
- trajectree-0.0.0.dist-info/RECORD +0 -16
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import autoray as ar
|
|
4
|
+
|
|
5
|
+
import quimb.tensor as qtn
|
|
6
|
+
from quimb.utils import oset
|
|
7
|
+
|
|
8
|
+
from .bp_common import (
|
|
9
|
+
BeliefPropagationCommon,
|
|
10
|
+
combine_local_contractions,
|
|
11
|
+
create_lazy_community_edge_map,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class L2BP(BeliefPropagationCommon):
|
|
16
|
+
"""Lazy (as in multiple uncontracted tensors per site) 2-norm (as in for
|
|
17
|
+
wavefunctions and operators) belief propagation.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tn : TensorNetwork
|
|
22
|
+
The tensor network to form the 2-norm of and run BP on.
|
|
23
|
+
site_tags : sequence of str, optional
|
|
24
|
+
The tags identifying the sites in ``tn``, each tag forms a region,
|
|
25
|
+
which should not overlap. If the tensor network is structured, then
|
|
26
|
+
these are inferred automatically.
|
|
27
|
+
damping : float, optional
|
|
28
|
+
The damping parameter to use, defaults to no damping.
|
|
29
|
+
update : {'parallel', 'sequential'}, optional
|
|
30
|
+
Whether to update all messages in parallel or sequentially.
|
|
31
|
+
local_convergence : bool, optional
|
|
32
|
+
Whether to allow messages to locally converge - i.e. if all their
|
|
33
|
+
input messages have converged then stop updating them.
|
|
34
|
+
optimize : str or PathOptimizer, optional
|
|
35
|
+
The path optimizer to use when contracting the messages.
|
|
36
|
+
contract_opts
|
|
37
|
+
Other options supplied to ``cotengra.array_contract``.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
tn,
|
|
43
|
+
site_tags=None,
|
|
44
|
+
damping=0.0,
|
|
45
|
+
update="sequential",
|
|
46
|
+
local_convergence=True,
|
|
47
|
+
optimize="auto-hq",
|
|
48
|
+
**contract_opts,
|
|
49
|
+
):
|
|
50
|
+
self.backend = next(t.backend for t in tn)
|
|
51
|
+
self.damping = damping
|
|
52
|
+
self.local_convergence = local_convergence
|
|
53
|
+
self.update = update
|
|
54
|
+
self.optimize = optimize
|
|
55
|
+
self.contract_opts = contract_opts
|
|
56
|
+
|
|
57
|
+
if site_tags is None:
|
|
58
|
+
self.site_tags = tuple(tn.site_tags)
|
|
59
|
+
else:
|
|
60
|
+
self.site_tags = tuple(site_tags)
|
|
61
|
+
|
|
62
|
+
(
|
|
63
|
+
self.edges,
|
|
64
|
+
self.neighbors,
|
|
65
|
+
self.local_tns,
|
|
66
|
+
self.touch_map,
|
|
67
|
+
) = create_lazy_community_edge_map(tn, site_tags)
|
|
68
|
+
self.touched = oset()
|
|
69
|
+
|
|
70
|
+
_abs = ar.get_lib_fn(self.backend, "abs")
|
|
71
|
+
_sum = ar.get_lib_fn(self.backend, "sum")
|
|
72
|
+
_transpose = ar.get_lib_fn(self.backend, "transpose")
|
|
73
|
+
_conj = ar.get_lib_fn(self.backend, "conj")
|
|
74
|
+
|
|
75
|
+
def _normalize(x):
|
|
76
|
+
return x / _sum(x)
|
|
77
|
+
|
|
78
|
+
def _symmetrize(x):
|
|
79
|
+
N = ar.ndim(x)
|
|
80
|
+
perm = (*range(N // 2, N), *range(0, N // 2))
|
|
81
|
+
return x + _conj(_transpose(x, perm))
|
|
82
|
+
|
|
83
|
+
def _distance(x, y):
|
|
84
|
+
return _sum(_abs(x - y))
|
|
85
|
+
|
|
86
|
+
self._normalize = _normalize
|
|
87
|
+
self._symmetrize = _symmetrize
|
|
88
|
+
self._distance = _distance
|
|
89
|
+
|
|
90
|
+
# initialize messages
|
|
91
|
+
self.messages = {}
|
|
92
|
+
|
|
93
|
+
for pair, bix in self.edges.items():
|
|
94
|
+
cix = tuple(ix + "_l2bp*" for ix in bix)
|
|
95
|
+
remapper = dict(zip(bix, cix))
|
|
96
|
+
output_inds = cix + bix
|
|
97
|
+
|
|
98
|
+
# compute leftwards and righwards messages
|
|
99
|
+
for i, j in (sorted(pair), sorted(pair, reverse=True)):
|
|
100
|
+
tn_i = self.local_tns[i]
|
|
101
|
+
tn_i2 = tn_i & tn_i.conj().reindex_(remapper)
|
|
102
|
+
tm = tn_i2.contract(
|
|
103
|
+
all,
|
|
104
|
+
output_inds=output_inds,
|
|
105
|
+
optimize=self.optimize,
|
|
106
|
+
drop_tags=True,
|
|
107
|
+
**self.contract_opts,
|
|
108
|
+
)
|
|
109
|
+
tm.modify(apply=self._symmetrize)
|
|
110
|
+
tm.modify(apply=self._normalize)
|
|
111
|
+
self.messages[i, j] = tm
|
|
112
|
+
|
|
113
|
+
# initialize contractions
|
|
114
|
+
self.contraction_tns = {}
|
|
115
|
+
for pair, bix in self.edges.items():
|
|
116
|
+
for i, j in (sorted(pair), sorted(pair, reverse=True)):
|
|
117
|
+
# form the ket side and messages
|
|
118
|
+
tn_i_left = self.local_tns[i]
|
|
119
|
+
# get other incident nodes which aren't j
|
|
120
|
+
ks = [k for k in self.neighbors[i] if k != j]
|
|
121
|
+
tks = [self.messages[k, i] for k in ks]
|
|
122
|
+
|
|
123
|
+
# form the 'bra' side
|
|
124
|
+
tn_i_right = tn_i_left.conj()
|
|
125
|
+
# get the bonds that attach the bra to messages
|
|
126
|
+
outer_bix = {
|
|
127
|
+
ix for k in ks for ix in self.edges[tuple(sorted((k, i)))]
|
|
128
|
+
}
|
|
129
|
+
# need to reindex to join message bonds, and create bra outputs
|
|
130
|
+
remapper = {}
|
|
131
|
+
for ix in tn_i_right.ind_map:
|
|
132
|
+
if ix in bix:
|
|
133
|
+
# bra outputs
|
|
134
|
+
remapper[ix] = ix + "_l2bp**"
|
|
135
|
+
elif ix in outer_bix:
|
|
136
|
+
# messages connected
|
|
137
|
+
remapper[ix] = ix + "_l2bp*"
|
|
138
|
+
# remaining indices are either internal and will be mangled
|
|
139
|
+
# or global outer indices and will be contracted directly
|
|
140
|
+
|
|
141
|
+
tn_i_right.reindex_(remapper)
|
|
142
|
+
|
|
143
|
+
self.contraction_tns[i, j] = qtn.TensorNetwork(
|
|
144
|
+
(tn_i_left, *tks, tn_i_right), virtual=True
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def iterate(self, tol=5e-6):
|
|
148
|
+
if (not self.local_convergence) or (not self.touched):
|
|
149
|
+
# assume if asked to iterate that we want to check all messages
|
|
150
|
+
self.touched.update(
|
|
151
|
+
pair for edge in self.edges for pair in (edge, edge[::-1])
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
ncheck = len(self.touched)
|
|
155
|
+
nconv = 0
|
|
156
|
+
max_mdiff = -1.0
|
|
157
|
+
new_touched = oset()
|
|
158
|
+
|
|
159
|
+
def _compute_m(key):
|
|
160
|
+
i, j = key
|
|
161
|
+
bix = self.edges[(i, j) if i < j else (j, i)]
|
|
162
|
+
cix = tuple(ix + "_l2bp**" for ix in bix)
|
|
163
|
+
output_inds = cix + bix
|
|
164
|
+
|
|
165
|
+
tn_i_to_j = self.contraction_tns[i, j]
|
|
166
|
+
|
|
167
|
+
tm_new = tn_i_to_j.contract(
|
|
168
|
+
all,
|
|
169
|
+
output_inds=output_inds,
|
|
170
|
+
drop_tags=True,
|
|
171
|
+
optimize=self.optimize,
|
|
172
|
+
**self.contract_opts,
|
|
173
|
+
)
|
|
174
|
+
tm_new.modify(apply=self._symmetrize)
|
|
175
|
+
tm_new.modify(apply=self._normalize)
|
|
176
|
+
return tm_new.data
|
|
177
|
+
|
|
178
|
+
def _update_m(key, data):
|
|
179
|
+
nonlocal nconv, max_mdiff
|
|
180
|
+
|
|
181
|
+
tm = self.messages[key]
|
|
182
|
+
|
|
183
|
+
if self.damping > 0.0:
|
|
184
|
+
data = (1 - self.damping) * data + self.damping * tm.data
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
mdiff = float(self._distance(tm.data, data))
|
|
188
|
+
except (TypeError, ValueError):
|
|
189
|
+
# handle e.g. lazy arrays
|
|
190
|
+
mdiff = float("inf")
|
|
191
|
+
|
|
192
|
+
if mdiff > tol:
|
|
193
|
+
# mark touching messages for update
|
|
194
|
+
new_touched.update(self.touch_map[key])
|
|
195
|
+
else:
|
|
196
|
+
nconv += 1
|
|
197
|
+
|
|
198
|
+
max_mdiff = max(max_mdiff, mdiff)
|
|
199
|
+
tm.modify(data=data)
|
|
200
|
+
|
|
201
|
+
if self.update == "parallel":
|
|
202
|
+
new_data = {}
|
|
203
|
+
# compute all new messages
|
|
204
|
+
while self.touched:
|
|
205
|
+
key = self.touched.pop()
|
|
206
|
+
new_data[key] = _compute_m(key)
|
|
207
|
+
# insert all new messages
|
|
208
|
+
for key, data in new_data.items():
|
|
209
|
+
_update_m(key, data)
|
|
210
|
+
|
|
211
|
+
elif self.update == "sequential":
|
|
212
|
+
# compute each new message and immediately re-insert it
|
|
213
|
+
while self.touched:
|
|
214
|
+
key = self.touched.pop()
|
|
215
|
+
data = _compute_m(key)
|
|
216
|
+
_update_m(key, data)
|
|
217
|
+
|
|
218
|
+
self.touched = new_touched
|
|
219
|
+
|
|
220
|
+
return nconv, ncheck, max_mdiff
|
|
221
|
+
|
|
222
|
+
def normalize_messages(self):
|
|
223
|
+
"""Normalize all messages such that for each bond `<m_i|m_j> = 1` and
|
|
224
|
+
`<m_i|m_i> = <m_j|m_j>` (but in general != 1).
|
|
225
|
+
"""
|
|
226
|
+
for i, j in self.edges:
|
|
227
|
+
tmi = self.messages[i, j]
|
|
228
|
+
tmj = self.messages[j, i]
|
|
229
|
+
nij = (tmi @ tmj)**0.5
|
|
230
|
+
nii = (tmi @ tmi)**0.25
|
|
231
|
+
njj = (tmj @ tmj)**0.25
|
|
232
|
+
tmi /= (nij * nii / njj)
|
|
233
|
+
tmj /= (nij * njj / nii)
|
|
234
|
+
|
|
235
|
+
def contract(self, strip_exponent=False):
|
|
236
|
+
"""Estimate the contraction of the norm squared using the current
|
|
237
|
+
messages.
|
|
238
|
+
"""
|
|
239
|
+
tvals = []
|
|
240
|
+
for i, ket in self.local_tns.items():
|
|
241
|
+
# we allow missing keys here for tensors which are just
|
|
242
|
+
# disconnected but still appear in local_tns
|
|
243
|
+
ks = self.neighbors.get(i, ())
|
|
244
|
+
bix = [ix for k in ks for ix in self.edges[tuple(sorted((k, i)))]]
|
|
245
|
+
bra = ket.H.reindex_({ix: ix + "_l2bp*" for ix in bix})
|
|
246
|
+
tni = qtn.TensorNetwork(
|
|
247
|
+
(
|
|
248
|
+
ket,
|
|
249
|
+
*(self.messages[k, i] for k in ks),
|
|
250
|
+
bra,
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
tvals.append(
|
|
254
|
+
tni.contract(all, optimize=self.optimize, **self.contract_opts)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
mvals = []
|
|
258
|
+
for i, j in self.edges:
|
|
259
|
+
mvals.append(
|
|
260
|
+
(self.messages[i, j] & self.messages[j, i]).contract(
|
|
261
|
+
all,
|
|
262
|
+
optimize=self.optimize,
|
|
263
|
+
**self.contract_opts,
|
|
264
|
+
)
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
return combine_local_contractions(
|
|
268
|
+
tvals, mvals, self.backend, strip_exponent=strip_exponent
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
def partial_trace(
|
|
272
|
+
self,
|
|
273
|
+
site,
|
|
274
|
+
normalized=True,
|
|
275
|
+
optimize="auto-hq",
|
|
276
|
+
):
|
|
277
|
+
example_tn = next(tn for tn in self.local_tns.values())
|
|
278
|
+
|
|
279
|
+
site_tag = example_tn.site_tag(site)
|
|
280
|
+
ket_site_ind = example_tn.site_ind(site)
|
|
281
|
+
|
|
282
|
+
ks = self.neighbors[site_tag]
|
|
283
|
+
tn_rho_i = self.local_tns[site_tag].copy()
|
|
284
|
+
tn_bra_i = tn_rho_i.H
|
|
285
|
+
|
|
286
|
+
for k in ks:
|
|
287
|
+
tn_rho_i &= self.messages[k, site_tag]
|
|
288
|
+
|
|
289
|
+
outer_bix = {
|
|
290
|
+
ix for k in ks for ix in self.edges[tuple(sorted((k, site_tag)))]
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
ind_changes = {}
|
|
294
|
+
for ix in tn_bra_i.ind_map:
|
|
295
|
+
if ix == ket_site_ind:
|
|
296
|
+
# open up the site index
|
|
297
|
+
bra_site_ind = ix + "_l2bp**"
|
|
298
|
+
ind_changes[ix] = bra_site_ind
|
|
299
|
+
if ix in outer_bix:
|
|
300
|
+
# attach bra message indices
|
|
301
|
+
ind_changes[ix] = ix + "_l2bp*"
|
|
302
|
+
tn_bra_i.reindex_(ind_changes)
|
|
303
|
+
|
|
304
|
+
tn_rho_i &= tn_bra_i
|
|
305
|
+
|
|
306
|
+
rho_i = tn_rho_i.to_dense(
|
|
307
|
+
[ket_site_ind],
|
|
308
|
+
[bra_site_ind],
|
|
309
|
+
optimize=optimize,
|
|
310
|
+
**self.contract_opts,
|
|
311
|
+
)
|
|
312
|
+
if normalized:
|
|
313
|
+
rho_i = rho_i / ar.do("trace", rho_i)
|
|
314
|
+
|
|
315
|
+
return rho_i
|
|
316
|
+
|
|
317
|
+
def compress(
|
|
318
|
+
self,
|
|
319
|
+
tn,
|
|
320
|
+
max_bond=None,
|
|
321
|
+
cutoff=5e-6,
|
|
322
|
+
cutoff_mode="rsum2",
|
|
323
|
+
renorm=0,
|
|
324
|
+
lazy=False,
|
|
325
|
+
):
|
|
326
|
+
"""Compress the state ``tn``, assumed to matched this L2BP instance,
|
|
327
|
+
using the messages stored.
|
|
328
|
+
"""
|
|
329
|
+
for (i, j), bix in self.edges.items():
|
|
330
|
+
tml = self.messages[i, j]
|
|
331
|
+
tmr = self.messages[j, i]
|
|
332
|
+
|
|
333
|
+
bix_sizes = [tml.ind_size(ix) for ix in bix]
|
|
334
|
+
dm = math.prod(bix_sizes)
|
|
335
|
+
|
|
336
|
+
ml = ar.reshape(tml.data, (dm, dm))
|
|
337
|
+
dl = self.local_tns[i].outer_size() // dm
|
|
338
|
+
Rl = qtn.decomp.squared_op_to_reduced_factor(
|
|
339
|
+
ml, dl, dm, right=True
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
mr = ar.reshape(tmr.data, (dm, dm)).T
|
|
343
|
+
dr = self.local_tns[j].outer_size() // dm
|
|
344
|
+
Rr = qtn.decomp.squared_op_to_reduced_factor(
|
|
345
|
+
mr, dm, dr, right=False
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
Pl, Pr = qtn.decomp.compute_oblique_projectors(
|
|
349
|
+
Rl,
|
|
350
|
+
Rr,
|
|
351
|
+
cutoff_mode=cutoff_mode,
|
|
352
|
+
renorm=renorm,
|
|
353
|
+
max_bond=max_bond,
|
|
354
|
+
cutoff=cutoff,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
Pl = ar.do("reshape", Pl, (*bix_sizes, -1))
|
|
358
|
+
Pr = ar.do("reshape", Pr, (-1, *bix_sizes))
|
|
359
|
+
|
|
360
|
+
ltn = tn.select(i)
|
|
361
|
+
rtn = tn.select(j)
|
|
362
|
+
|
|
363
|
+
new_lix = [qtn.rand_uuid() for _ in bix]
|
|
364
|
+
new_rix = [qtn.rand_uuid() for _ in bix]
|
|
365
|
+
new_bix = [qtn.rand_uuid()]
|
|
366
|
+
ltn.reindex_(dict(zip(bix, new_lix)))
|
|
367
|
+
rtn.reindex_(dict(zip(bix, new_rix)))
|
|
368
|
+
|
|
369
|
+
# ... and insert the new projectors in place
|
|
370
|
+
tn |= qtn.Tensor(Pl, inds=new_lix + new_bix, tags=(i,))
|
|
371
|
+
tn |= qtn.Tensor(Pr, inds=new_bix + new_rix, tags=(j,))
|
|
372
|
+
|
|
373
|
+
if not lazy:
|
|
374
|
+
for st in self.site_tags:
|
|
375
|
+
try:
|
|
376
|
+
tn.contract_tags_(
|
|
377
|
+
st, optimize=self.optimize, **self.contract_opts
|
|
378
|
+
)
|
|
379
|
+
except KeyError:
|
|
380
|
+
pass
|
|
381
|
+
|
|
382
|
+
return tn
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def contract_l2bp(
|
|
386
|
+
tn,
|
|
387
|
+
site_tags=None,
|
|
388
|
+
damping=0.0,
|
|
389
|
+
update="sequential",
|
|
390
|
+
local_convergence=True,
|
|
391
|
+
optimize="auto-hq",
|
|
392
|
+
max_iterations=1000,
|
|
393
|
+
tol=5e-6,
|
|
394
|
+
strip_exponent=False,
|
|
395
|
+
info=None,
|
|
396
|
+
progbar=False,
|
|
397
|
+
**contract_opts,
|
|
398
|
+
):
|
|
399
|
+
"""Estimate the norm squared of ``tn`` using lazy belief propagation.
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
tn : TensorNetwork
|
|
404
|
+
The tensor network to estimate the norm squared of.
|
|
405
|
+
site_tags : sequence of str, optional
|
|
406
|
+
The tags identifying the sites in ``tn``, each tag forms a region.
|
|
407
|
+
damping : float, optional
|
|
408
|
+
The damping parameter to use, defaults to no damping.
|
|
409
|
+
update : {'parallel', 'sequential'}, optional
|
|
410
|
+
Whether to update all messages in parallel or sequentially.
|
|
411
|
+
local_convergence : bool, optional
|
|
412
|
+
Whether to allow messages to locally converge - i.e. if all their
|
|
413
|
+
input messages have converged then stop updating them.
|
|
414
|
+
optimize : str or PathOptimizer, optional
|
|
415
|
+
The contraction strategy to use.
|
|
416
|
+
max_iterations : int, optional
|
|
417
|
+
The maximum number of iterations to perform.
|
|
418
|
+
tol : float, optional
|
|
419
|
+
The convergence tolerance for messages.
|
|
420
|
+
strip_exponent : bool, optional
|
|
421
|
+
Whether to strip the exponent from the final result. If ``True``
|
|
422
|
+
then the returned result is ``(mantissa, exponent)``.
|
|
423
|
+
info : dict, optional
|
|
424
|
+
If specified, update this dictionary with information about the
|
|
425
|
+
belief propagation run.
|
|
426
|
+
progbar : bool, optional
|
|
427
|
+
Whether to show a progress bar.
|
|
428
|
+
contract_opts
|
|
429
|
+
Other options supplied to ``cotengra.array_contract``.
|
|
430
|
+
"""
|
|
431
|
+
bp = L2BP(
|
|
432
|
+
tn,
|
|
433
|
+
site_tags=site_tags,
|
|
434
|
+
damping=damping,
|
|
435
|
+
update=update,
|
|
436
|
+
local_convergence=local_convergence,
|
|
437
|
+
optimize=optimize,
|
|
438
|
+
**contract_opts,
|
|
439
|
+
)
|
|
440
|
+
bp.run(
|
|
441
|
+
max_iterations=max_iterations,
|
|
442
|
+
tol=tol,
|
|
443
|
+
info=info,
|
|
444
|
+
progbar=progbar,
|
|
445
|
+
)
|
|
446
|
+
return bp.contract(strip_exponent=strip_exponent)
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def compress_l2bp(
|
|
450
|
+
tn,
|
|
451
|
+
max_bond,
|
|
452
|
+
cutoff=0.0,
|
|
453
|
+
cutoff_mode="rsum2",
|
|
454
|
+
max_iterations=1000,
|
|
455
|
+
tol=5e-6,
|
|
456
|
+
site_tags=None,
|
|
457
|
+
damping=0.0,
|
|
458
|
+
update="sequential",
|
|
459
|
+
local_convergence=True,
|
|
460
|
+
optimize="auto-hq",
|
|
461
|
+
lazy=False,
|
|
462
|
+
inplace=False,
|
|
463
|
+
info=None,
|
|
464
|
+
progbar=False,
|
|
465
|
+
**contract_opts,
|
|
466
|
+
):
|
|
467
|
+
"""Compress ``tn`` using lazy belief propagation, producing a tensor
|
|
468
|
+
network with a single tensor per site.
|
|
469
|
+
|
|
470
|
+
Parameters
|
|
471
|
+
----------
|
|
472
|
+
tn : TensorNetwork
|
|
473
|
+
The tensor network to form the 2-norm of, run BP on and then compress.
|
|
474
|
+
max_bond : int
|
|
475
|
+
The maximum bond dimension to compress to.
|
|
476
|
+
cutoff : float, optional
|
|
477
|
+
The cutoff to use when compressing.
|
|
478
|
+
cutoff_mode : int, optional
|
|
479
|
+
The cutoff mode to use when compressing.
|
|
480
|
+
max_iterations : int, optional
|
|
481
|
+
The maximum number of iterations to perform.
|
|
482
|
+
tol : float, optional
|
|
483
|
+
The convergence tolerance for messages.
|
|
484
|
+
site_tags : sequence of str, optional
|
|
485
|
+
The tags identifying the sites in ``tn``, each tag forms a region. If
|
|
486
|
+
the tensor network is structured, then these are inferred
|
|
487
|
+
automatically.
|
|
488
|
+
damping : float, optional
|
|
489
|
+
The damping parameter to use, defaults to no damping.
|
|
490
|
+
update : {'parallel', 'sequential'}, optional
|
|
491
|
+
Whether to update all messages in parallel or sequentially.
|
|
492
|
+
local_convergence : bool, optional
|
|
493
|
+
Whether to allow messages to locally converge - i.e. if all their
|
|
494
|
+
input messages have converged then stop updating them.
|
|
495
|
+
optimize : str or PathOptimizer, optional
|
|
496
|
+
The path optimizer to use when contracting the messages.
|
|
497
|
+
lazy : bool, optional
|
|
498
|
+
Whether to perform the compression lazily, i.e. to leave the computed
|
|
499
|
+
compression projectors uncontracted.
|
|
500
|
+
inplace : bool, optional
|
|
501
|
+
Whether to perform the compression inplace.
|
|
502
|
+
info : dict, optional
|
|
503
|
+
If specified, update this dictionary with information about the
|
|
504
|
+
belief propagation run.
|
|
505
|
+
progbar : bool, optional
|
|
506
|
+
Whether to show a progress bar.
|
|
507
|
+
contract_opts
|
|
508
|
+
Other options supplied to ``cotengra.array_contract``.
|
|
509
|
+
|
|
510
|
+
Returns
|
|
511
|
+
-------
|
|
512
|
+
TensorNetwork
|
|
513
|
+
"""
|
|
514
|
+
tnc = tn if inplace else tn.copy()
|
|
515
|
+
|
|
516
|
+
bp = L2BP(
|
|
517
|
+
tnc,
|
|
518
|
+
site_tags=site_tags,
|
|
519
|
+
damping=damping,
|
|
520
|
+
update=update,
|
|
521
|
+
local_convergence=local_convergence,
|
|
522
|
+
optimize=optimize,
|
|
523
|
+
**contract_opts,
|
|
524
|
+
)
|
|
525
|
+
bp.run(
|
|
526
|
+
max_iterations=max_iterations,
|
|
527
|
+
tol=tol,
|
|
528
|
+
info=info,
|
|
529
|
+
progbar=progbar,
|
|
530
|
+
)
|
|
531
|
+
return bp.compress(
|
|
532
|
+
tnc,
|
|
533
|
+
max_bond=max_bond,
|
|
534
|
+
cutoff=cutoff,
|
|
535
|
+
cutoff_mode=cutoff_mode,
|
|
536
|
+
lazy=lazy,
|
|
537
|
+
)
|