Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trajectree/__init__.py +3 -0
- trajectree/fock_optics/devices.py +1 -1
- trajectree/fock_optics/light_sources.py +2 -2
- trajectree/fock_optics/measurement.py +3 -3
- trajectree/fock_optics/utils.py +6 -6
- trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
- trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
- trajectree/quimb/docs/conf.py +158 -0
- trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
- trajectree/quimb/quimb/__init__.py +507 -0
- trajectree/quimb/quimb/calc.py +1491 -0
- trajectree/quimb/quimb/core.py +2279 -0
- trajectree/quimb/quimb/evo.py +712 -0
- trajectree/quimb/quimb/experimental/__init__.py +0 -0
- trajectree/quimb/quimb/experimental/autojittn.py +129 -0
- trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
- trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
- trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
- trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
- trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
- trajectree/quimb/quimb/experimental/schematic.py +7 -0
- trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
- trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
- trajectree/quimb/quimb/gates.py +36 -0
- trajectree/quimb/quimb/gen/__init__.py +2 -0
- trajectree/quimb/quimb/gen/operators.py +1167 -0
- trajectree/quimb/quimb/gen/rand.py +713 -0
- trajectree/quimb/quimb/gen/states.py +479 -0
- trajectree/quimb/quimb/linalg/__init__.py +6 -0
- trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
- trajectree/quimb/quimb/linalg/autoblock.py +258 -0
- trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
- trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
- trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
- trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
- trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
- trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
- trajectree/quimb/quimb/schematic.py +1518 -0
- trajectree/quimb/quimb/tensor/__init__.py +401 -0
- trajectree/quimb/quimb/tensor/array_ops.py +610 -0
- trajectree/quimb/quimb/tensor/circuit.py +4824 -0
- trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
- trajectree/quimb/quimb/tensor/contraction.py +336 -0
- trajectree/quimb/quimb/tensor/decomp.py +1255 -0
- trajectree/quimb/quimb/tensor/drawing.py +1646 -0
- trajectree/quimb/quimb/tensor/fitting.py +385 -0
- trajectree/quimb/quimb/tensor/geometry.py +583 -0
- trajectree/quimb/quimb/tensor/interface.py +114 -0
- trajectree/quimb/quimb/tensor/networking.py +1058 -0
- trajectree/quimb/quimb/tensor/optimize.py +1818 -0
- trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
- trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
- trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
- trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
- trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
- trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
- trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
- trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
- trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
- trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
- trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
- trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
- trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
- trajectree/quimb/quimb/utils.py +892 -0
- trajectree/quimb/tests/__init__.py +0 -0
- trajectree/quimb/tests/test_accel.py +501 -0
- trajectree/quimb/tests/test_calc.py +788 -0
- trajectree/quimb/tests/test_core.py +847 -0
- trajectree/quimb/tests/test_evo.py +565 -0
- trajectree/quimb/tests/test_gen/__init__.py +0 -0
- trajectree/quimb/tests/test_gen/test_operators.py +361 -0
- trajectree/quimb/tests/test_gen/test_rand.py +296 -0
- trajectree/quimb/tests/test_gen/test_states.py +261 -0
- trajectree/quimb/tests/test_linalg/__init__.py +0 -0
- trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
- trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
- trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
- trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
- trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
- trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
- trajectree/quimb/tests/test_tensor/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
- trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
- trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
- trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
- trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
- trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
- trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
- trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
- trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
- trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
- trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
- trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
- trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
- trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
- trajectree/quimb/tests/test_utils.py +85 -0
- trajectree/trajectory.py +2 -2
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
- trajectree-0.0.1.dist-info/RECORD +126 -0
- trajectree-0.0.0.dist-info/RECORD +0 -16
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,653 @@
|
|
|
1
|
+
import autoray as ar
|
|
2
|
+
|
|
3
|
+
import quimb.tensor as qtn
|
|
4
|
+
from quimb.utils import oset
|
|
5
|
+
|
|
6
|
+
from .bp_common import (
|
|
7
|
+
BeliefPropagationCommon,
|
|
8
|
+
combine_local_contractions,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class D2BP(BeliefPropagationCommon):
|
|
13
|
+
"""Dense (as in one tensor per site) 2-norm (as in for wavefunctions and
|
|
14
|
+
operators) belief propagation. Allows messages reuse. This version assumes
|
|
15
|
+
no hyper indices (i.e. a standard PEPS like tensor network).
|
|
16
|
+
|
|
17
|
+
Potential use cases for D2BP and a PEPS like tensor network are:
|
|
18
|
+
|
|
19
|
+
- globally compressing it from bond dimension ``D`` to ``D'``
|
|
20
|
+
- eagerly applying gates and locally compressing back to ``D``
|
|
21
|
+
- sampling configurations
|
|
22
|
+
- estimating the norm of the tensor network
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
tn : TensorNetwork
|
|
28
|
+
The tensor network to form the 2-norm of and run BP on.
|
|
29
|
+
messages : dict[(str, int), array_like], optional
|
|
30
|
+
The initial messages to use, effectively defaults to all ones if not
|
|
31
|
+
specified.
|
|
32
|
+
output_inds : set[str], optional
|
|
33
|
+
The indices to consider as output (dangling) indices of the tn.
|
|
34
|
+
Computed automatically if not specified.
|
|
35
|
+
optimize : str or PathOptimizer, optional
|
|
36
|
+
The path optimizer to use when contracting the messages.
|
|
37
|
+
damping : float, optional
|
|
38
|
+
The damping factor to use, 0.0 means no damping.
|
|
39
|
+
update : {'parallel', 'sequential'}, optional
|
|
40
|
+
Whether to update all messages in parallel or sequentially.
|
|
41
|
+
local_convergence : bool, optional
|
|
42
|
+
Whether to allow messages to locally converge - i.e. if all their
|
|
43
|
+
input messages have converged then stop updating them.
|
|
44
|
+
contract_opts
|
|
45
|
+
Other options supplied to ``cotengra.array_contract``.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
tn,
|
|
51
|
+
messages=None,
|
|
52
|
+
output_inds=None,
|
|
53
|
+
optimize="auto-hq",
|
|
54
|
+
damping=0.0,
|
|
55
|
+
update="sequential",
|
|
56
|
+
local_convergence=True,
|
|
57
|
+
**contract_opts,
|
|
58
|
+
):
|
|
59
|
+
from quimb.tensor.contraction import array_contract_expression
|
|
60
|
+
|
|
61
|
+
self.tn = tn
|
|
62
|
+
self.contract_opts = contract_opts
|
|
63
|
+
self.contract_opts.setdefault("optimize", optimize)
|
|
64
|
+
self.damping = damping
|
|
65
|
+
self.local_convergence = local_convergence
|
|
66
|
+
self.update = update
|
|
67
|
+
|
|
68
|
+
if output_inds is None:
|
|
69
|
+
self.output_inds = set(self.tn.outer_inds())
|
|
70
|
+
else:
|
|
71
|
+
self.output_inds = set(output_inds)
|
|
72
|
+
|
|
73
|
+
self.backend = next(t.backend for t in tn)
|
|
74
|
+
_abs = ar.get_lib_fn(self.backend, "abs")
|
|
75
|
+
_sum = ar.get_lib_fn(self.backend, "sum")
|
|
76
|
+
|
|
77
|
+
def _normalize(x):
|
|
78
|
+
return x / _sum(x)
|
|
79
|
+
|
|
80
|
+
def _distance(x, y):
|
|
81
|
+
return _sum(_abs(x - y))
|
|
82
|
+
|
|
83
|
+
self._normalize = _normalize
|
|
84
|
+
self._distance = _distance
|
|
85
|
+
|
|
86
|
+
if messages is None:
|
|
87
|
+
self.messages = {}
|
|
88
|
+
else:
|
|
89
|
+
self.messages = messages
|
|
90
|
+
|
|
91
|
+
# record which messages touch each others, for efficient updates
|
|
92
|
+
self.touch_map = {}
|
|
93
|
+
self.touched = oset()
|
|
94
|
+
self.exprs = {}
|
|
95
|
+
|
|
96
|
+
# populate any messages
|
|
97
|
+
for ix, tids in self.tn.ind_map.items():
|
|
98
|
+
if ix in self.output_inds:
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
tida, tidb = tids
|
|
102
|
+
jx = ix + "*"
|
|
103
|
+
ta, tb = self.tn._tids_get(tida, tidb)
|
|
104
|
+
|
|
105
|
+
for tid, t, t_in in ((tida, ta, tb), (tidb, tb, ta)):
|
|
106
|
+
this_touchmap = []
|
|
107
|
+
for nx in t.inds:
|
|
108
|
+
if nx in self.output_inds or nx == ix:
|
|
109
|
+
continue
|
|
110
|
+
# where this message will be sent on to
|
|
111
|
+
(tidn,) = (n for n in self.tn.ind_map[nx] if n != tid)
|
|
112
|
+
this_touchmap.append((nx, tidn))
|
|
113
|
+
self.touch_map[ix, tid] = this_touchmap
|
|
114
|
+
|
|
115
|
+
if (ix, tid) not in self.messages:
|
|
116
|
+
tm = (t_in.reindex({ix: jx}).conj_() @ t_in).data
|
|
117
|
+
self.messages[ix, tid] = self._normalize(tm.data)
|
|
118
|
+
|
|
119
|
+
# for efficiency setup all the contraction expressions ahead of time
|
|
120
|
+
for ix, tids in self.tn.ind_map.items():
|
|
121
|
+
if ix in self.output_inds:
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
for tida, tidb in (sorted(tids), sorted(tids, reverse=True)):
|
|
125
|
+
ta = self.tn.tensor_map[tida]
|
|
126
|
+
kix = ta.inds
|
|
127
|
+
bix = tuple(
|
|
128
|
+
i if i in self.output_inds else i + "*" for i in kix
|
|
129
|
+
)
|
|
130
|
+
inputs = [kix, bix]
|
|
131
|
+
data = [ta.data, ta.data.conj()]
|
|
132
|
+
shapes = [ta.shape, ta.shape]
|
|
133
|
+
for i in kix:
|
|
134
|
+
if (i != ix) and i not in self.output_inds:
|
|
135
|
+
inputs.append((i + "*", i))
|
|
136
|
+
data.append((i, tida))
|
|
137
|
+
shapes.append(self.messages[i, tida].shape)
|
|
138
|
+
|
|
139
|
+
expr = array_contract_expression(
|
|
140
|
+
inputs=inputs,
|
|
141
|
+
output=(ix + "*", ix),
|
|
142
|
+
shapes=shapes,
|
|
143
|
+
**self.contract_opts,
|
|
144
|
+
)
|
|
145
|
+
self.exprs[ix, tidb] = expr, data
|
|
146
|
+
|
|
147
|
+
def update_touched_from_tids(self, *tids):
|
|
148
|
+
"""Specify that the messages for the given ``tids`` have changed."""
|
|
149
|
+
for tid in tids:
|
|
150
|
+
t = self.tn.tensor_map[tid]
|
|
151
|
+
for ix in t.inds:
|
|
152
|
+
if ix in self.output_inds:
|
|
153
|
+
continue
|
|
154
|
+
(ntid,) = (n for n in self.tn.ind_map[ix] if n != tid)
|
|
155
|
+
self.touched.add((ix, ntid))
|
|
156
|
+
|
|
157
|
+
def update_touched_from_tags(self, tags, which="any"):
|
|
158
|
+
"""Specify that the messages for the messages touching ``tags`` have
|
|
159
|
+
changed.
|
|
160
|
+
"""
|
|
161
|
+
tids = self.tn._get_tids_from_tags(tags, which)
|
|
162
|
+
self.update_touched_from_tids(*tids)
|
|
163
|
+
|
|
164
|
+
def update_touched_from_inds(self, inds, which="any"):
|
|
165
|
+
"""Specify that the messages for the messages touching ``inds`` have
|
|
166
|
+
changed.
|
|
167
|
+
"""
|
|
168
|
+
tids = self.tn._get_tids_from_inds(inds, which)
|
|
169
|
+
self.update_touched_from_tids(*tids)
|
|
170
|
+
|
|
171
|
+
def iterate(self, tol=5e-6):
|
|
172
|
+
"""Perform a single iteration of dense 2-norm belief propagation."""
|
|
173
|
+
|
|
174
|
+
if (not self.local_convergence) or (not self.touched):
|
|
175
|
+
# assume if asked to iterate that we want to check all messages
|
|
176
|
+
self.touched.update(self.exprs.keys())
|
|
177
|
+
|
|
178
|
+
ncheck = len(self.touched)
|
|
179
|
+
nconv = 0
|
|
180
|
+
max_mdiff = -1.0
|
|
181
|
+
new_touched = oset()
|
|
182
|
+
|
|
183
|
+
def _compute_m(key):
|
|
184
|
+
expr, data = self.exprs[key]
|
|
185
|
+
m = expr(*data[:2], *(self.messages[mkey] for mkey in data[2:]))
|
|
186
|
+
# enforce hermiticity and normalize
|
|
187
|
+
return self._normalize(m + ar.dag(m))
|
|
188
|
+
|
|
189
|
+
def _update_m(key, new_m):
|
|
190
|
+
nonlocal nconv, max_mdiff
|
|
191
|
+
|
|
192
|
+
old_m = self.messages[key]
|
|
193
|
+
if self.damping > 0.0:
|
|
194
|
+
new_m = self._normalize(
|
|
195
|
+
self.damping * old_m + (1 - self.damping) * new_m
|
|
196
|
+
)
|
|
197
|
+
try:
|
|
198
|
+
mdiff = float(self._distance(old_m, new_m))
|
|
199
|
+
except (TypeError, ValueError):
|
|
200
|
+
# handle e.g. lazy arrays
|
|
201
|
+
mdiff = float("inf")
|
|
202
|
+
if mdiff > tol:
|
|
203
|
+
# mark touching messages for update
|
|
204
|
+
new_touched.update(self.touch_map[key])
|
|
205
|
+
else:
|
|
206
|
+
nconv += 1
|
|
207
|
+
max_mdiff = max(max_mdiff, mdiff)
|
|
208
|
+
self.messages[key] = new_m
|
|
209
|
+
|
|
210
|
+
if self.update == "parallel":
|
|
211
|
+
new_messages = {}
|
|
212
|
+
# compute all new messages
|
|
213
|
+
while self.touched:
|
|
214
|
+
key = self.touched.pop()
|
|
215
|
+
new_messages[key] = _compute_m(key)
|
|
216
|
+
# insert all new messages
|
|
217
|
+
for key, new_m in new_messages.items():
|
|
218
|
+
_update_m(key, new_m)
|
|
219
|
+
|
|
220
|
+
elif self.update == "sequential":
|
|
221
|
+
# compute each new message and immediately re-insert it
|
|
222
|
+
while self.touched:
|
|
223
|
+
key = self.touched.pop()
|
|
224
|
+
new_m = _compute_m(key)
|
|
225
|
+
_update_m(key, new_m)
|
|
226
|
+
|
|
227
|
+
self.touched = new_touched
|
|
228
|
+
|
|
229
|
+
return nconv, ncheck, max_mdiff
|
|
230
|
+
|
|
231
|
+
def compute_marginal(self, ind):
|
|
232
|
+
"""Compute the marginal for the index ``ind``."""
|
|
233
|
+
(tid,) = self.tn.ind_map[ind]
|
|
234
|
+
t = self.tn.tensor_map[tid]
|
|
235
|
+
|
|
236
|
+
arrays = [t.data, ar.do("conj", t.data)]
|
|
237
|
+
k_input = []
|
|
238
|
+
b_input = []
|
|
239
|
+
m_inputs = []
|
|
240
|
+
for j, jx in enumerate(t.inds, 1):
|
|
241
|
+
k_input.append(j)
|
|
242
|
+
|
|
243
|
+
if jx == ind:
|
|
244
|
+
# output index -> take diagonal
|
|
245
|
+
output = (j,)
|
|
246
|
+
b_input.append(j)
|
|
247
|
+
else:
|
|
248
|
+
try:
|
|
249
|
+
# partial trace with message
|
|
250
|
+
m = self.messages[jx, tid]
|
|
251
|
+
arrays.append(m)
|
|
252
|
+
b_input.append(-j)
|
|
253
|
+
m_inputs.append((-j, j))
|
|
254
|
+
except KeyError:
|
|
255
|
+
# direct partial trace
|
|
256
|
+
b_input.append(j)
|
|
257
|
+
|
|
258
|
+
p = qtn.array_contract(
|
|
259
|
+
arrays,
|
|
260
|
+
inputs=(tuple(k_input), tuple(b_input), *m_inputs),
|
|
261
|
+
output=output,
|
|
262
|
+
**self.contract_opts,
|
|
263
|
+
)
|
|
264
|
+
p = ar.do("real", p)
|
|
265
|
+
return p / ar.do("sum", p)
|
|
266
|
+
|
|
267
|
+
def contract(self, strip_exponent=False):
|
|
268
|
+
"""Estimate the total contraction, i.e. the 2-norm.
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
strip_exponent : bool, optional
|
|
273
|
+
Whether to strip the exponent from the final result. If ``True``
|
|
274
|
+
then the returned result is ``(mantissa, exponent)``.
|
|
275
|
+
|
|
276
|
+
Returns
|
|
277
|
+
-------
|
|
278
|
+
scalar or (scalar, float)
|
|
279
|
+
"""
|
|
280
|
+
tvals = []
|
|
281
|
+
|
|
282
|
+
for tid, t in self.tn.tensor_map.items():
|
|
283
|
+
arrays = [t.data, ar.do("conj", t.data)]
|
|
284
|
+
k_input = []
|
|
285
|
+
b_input = []
|
|
286
|
+
m_inputs = []
|
|
287
|
+
for i, ix in enumerate(t.inds, 1):
|
|
288
|
+
k_input.append(i)
|
|
289
|
+
if ix in self.output_inds:
|
|
290
|
+
b_input.append(i)
|
|
291
|
+
else:
|
|
292
|
+
b_input.append(-i)
|
|
293
|
+
m_inputs.append((-i, i))
|
|
294
|
+
arrays.append(self.messages[ix, tid])
|
|
295
|
+
|
|
296
|
+
inputs = (tuple(k_input), tuple(b_input), *m_inputs)
|
|
297
|
+
output = ()
|
|
298
|
+
tval = qtn.array_contract(
|
|
299
|
+
arrays, inputs, output, **self.contract_opts
|
|
300
|
+
)
|
|
301
|
+
tvals.append(tval)
|
|
302
|
+
|
|
303
|
+
mvals = []
|
|
304
|
+
for ix, tids in self.tn.ind_map.items():
|
|
305
|
+
if ix in self.output_inds:
|
|
306
|
+
continue
|
|
307
|
+
tida, tidb = tids
|
|
308
|
+
ml = self.messages[ix, tidb]
|
|
309
|
+
mr = self.messages[ix, tida]
|
|
310
|
+
mval = qtn.array_contract(
|
|
311
|
+
(ml, mr), ((1, 2), (1, 2)), (), **self.contract_opts
|
|
312
|
+
)
|
|
313
|
+
mvals.append(mval)
|
|
314
|
+
|
|
315
|
+
return combine_local_contractions(
|
|
316
|
+
tvals, mvals, self.backend, strip_exponent=strip_exponent
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
def compress(
|
|
320
|
+
self,
|
|
321
|
+
max_bond,
|
|
322
|
+
cutoff=0.0,
|
|
323
|
+
cutoff_mode=4,
|
|
324
|
+
renorm=0,
|
|
325
|
+
inplace=False,
|
|
326
|
+
):
|
|
327
|
+
"""Compress the initial tensor network using the current messages."""
|
|
328
|
+
tn = self.tn if inplace else self.tn.copy()
|
|
329
|
+
|
|
330
|
+
for ix, tids in tn.ind_map.items():
|
|
331
|
+
if len(tids) != 2:
|
|
332
|
+
continue
|
|
333
|
+
tida, tidb = tids
|
|
334
|
+
|
|
335
|
+
# messages are left and right factors squared already
|
|
336
|
+
ta = tn.tensor_map[tida]
|
|
337
|
+
dm = ta.ind_size(ix)
|
|
338
|
+
dl = ta.size // dm
|
|
339
|
+
ml = self.messages[ix, tidb]
|
|
340
|
+
Rl = qtn.decomp.squared_op_to_reduced_factor(
|
|
341
|
+
ml, dl, dm, right=True
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
tb = tn.tensor_map[tidb]
|
|
345
|
+
dr = tb.size // dm
|
|
346
|
+
mr = self.messages[ix, tida].T
|
|
347
|
+
Rr = qtn.decomp.squared_op_to_reduced_factor(
|
|
348
|
+
mr, dm, dr, right=False
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# compute the compressors
|
|
352
|
+
Pl, Pr = qtn.decomp.compute_oblique_projectors(
|
|
353
|
+
Rl,
|
|
354
|
+
Rr,
|
|
355
|
+
max_bond=max_bond,
|
|
356
|
+
cutoff=cutoff,
|
|
357
|
+
cutoff_mode=cutoff_mode,
|
|
358
|
+
renorm=renorm,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# contract the compressors into the tensors
|
|
362
|
+
tn.tensor_map[tida].gate_(Pl.T, ix)
|
|
363
|
+
tn.tensor_map[tidb].gate_(Pr, ix)
|
|
364
|
+
|
|
365
|
+
# update messages with projections
|
|
366
|
+
if inplace:
|
|
367
|
+
new_Ra = Rl @ Pl
|
|
368
|
+
new_Rb = Pr @ Rr
|
|
369
|
+
self.messages[ix, tidb] = ar.dag(new_Ra) @ new_Ra
|
|
370
|
+
self.messages[ix, tida] = new_Rb @ ar.dag(new_Rb)
|
|
371
|
+
|
|
372
|
+
return tn
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def contract_d2bp(
|
|
376
|
+
tn,
|
|
377
|
+
messages=None,
|
|
378
|
+
output_inds=None,
|
|
379
|
+
optimize="auto-hq",
|
|
380
|
+
damping=0.0,
|
|
381
|
+
update="sequential",
|
|
382
|
+
local_convergence=True,
|
|
383
|
+
max_iterations=1000,
|
|
384
|
+
tol=5e-6,
|
|
385
|
+
strip_exponent=False,
|
|
386
|
+
info=None,
|
|
387
|
+
progbar=False,
|
|
388
|
+
**contract_opts,
|
|
389
|
+
):
|
|
390
|
+
"""Estimate the norm squared of ``tn`` using dense 2-norm belief
|
|
391
|
+
propagation.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
tn : TensorNetwork
|
|
396
|
+
The tensor network to form the 2-norm of and run BP on.
|
|
397
|
+
messages : dict[(str, int), array_like], optional
|
|
398
|
+
The initial messages to use, effectively defaults to all ones if not
|
|
399
|
+
specified.
|
|
400
|
+
max_iterations : int, optional
|
|
401
|
+
The maximum number of iterations to perform.
|
|
402
|
+
tol : float, optional
|
|
403
|
+
The convergence tolerance for messages.
|
|
404
|
+
output_inds : set[str], optional
|
|
405
|
+
The indices to consider as output (dangling) indices of the tn.
|
|
406
|
+
Computed automatically if not specified.
|
|
407
|
+
optimize : str or PathOptimizer, optional
|
|
408
|
+
The path optimizer to use when contracting the messages.
|
|
409
|
+
damping : float, optional
|
|
410
|
+
The damping parameter to use, defaults to no damping.
|
|
411
|
+
update : {'parallel', 'sequential'}, optional
|
|
412
|
+
Whether to update all messages in parallel or sequentially.
|
|
413
|
+
local_convergence : bool, optional
|
|
414
|
+
Whether to allow messages to locally converge - i.e. if all their
|
|
415
|
+
input messages have converged then stop updating them.
|
|
416
|
+
strip_exponent : bool, optional
|
|
417
|
+
Whether to strip the exponent from the final result. If ``True``
|
|
418
|
+
then the returned result is ``(mantissa, exponent)``.
|
|
419
|
+
info : dict, optional
|
|
420
|
+
If specified, update this dictionary with information about the
|
|
421
|
+
belief propagation run.
|
|
422
|
+
progbar : bool, optional
|
|
423
|
+
Whether to show a progress bar.
|
|
424
|
+
contract_opts
|
|
425
|
+
Other options supplied to ``cotengra.array_contract``.
|
|
426
|
+
|
|
427
|
+
Returns
|
|
428
|
+
-------
|
|
429
|
+
scalar or (scalar, float)
|
|
430
|
+
"""
|
|
431
|
+
bp = D2BP(
|
|
432
|
+
tn,
|
|
433
|
+
messages=messages,
|
|
434
|
+
output_inds=output_inds,
|
|
435
|
+
optimize=optimize,
|
|
436
|
+
damping=damping,
|
|
437
|
+
local_convergence=local_convergence,
|
|
438
|
+
update=update,
|
|
439
|
+
**contract_opts,
|
|
440
|
+
)
|
|
441
|
+
bp.run(
|
|
442
|
+
max_iterations=max_iterations,
|
|
443
|
+
tol=tol,
|
|
444
|
+
info=info,
|
|
445
|
+
progbar=progbar,
|
|
446
|
+
)
|
|
447
|
+
return bp.contract(strip_exponent=strip_exponent)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def compress_d2bp(
|
|
451
|
+
tn,
|
|
452
|
+
max_bond,
|
|
453
|
+
cutoff=0.0,
|
|
454
|
+
cutoff_mode="rsum2",
|
|
455
|
+
renorm=0,
|
|
456
|
+
messages=None,
|
|
457
|
+
output_inds=None,
|
|
458
|
+
optimize="auto-hq",
|
|
459
|
+
damping=0.0,
|
|
460
|
+
update="sequential",
|
|
461
|
+
local_convergence=True,
|
|
462
|
+
max_iterations=1000,
|
|
463
|
+
tol=5e-6,
|
|
464
|
+
inplace=False,
|
|
465
|
+
info=None,
|
|
466
|
+
progbar=False,
|
|
467
|
+
**contract_opts,
|
|
468
|
+
):
|
|
469
|
+
"""Compress the tensor network ``tn`` using dense 2-norm belief
|
|
470
|
+
propagation.
|
|
471
|
+
|
|
472
|
+
Parameters
|
|
473
|
+
----------
|
|
474
|
+
tn : TensorNetwork
|
|
475
|
+
The tensor network to form the 2-norm of, run BP on and then compress.
|
|
476
|
+
max_bond : int
|
|
477
|
+
The maximum bond dimension to compress to.
|
|
478
|
+
cutoff : float, optional
|
|
479
|
+
The cutoff to use when compressing.
|
|
480
|
+
cutoff_mode : int, optional
|
|
481
|
+
The cutoff mode to use when compressing.
|
|
482
|
+
messages : dict[(str, int), array_like], optional
|
|
483
|
+
The initial messages to use, effectively defaults to all ones if not
|
|
484
|
+
specified.
|
|
485
|
+
max_iterations : int, optional
|
|
486
|
+
The maximum number of iterations to perform.
|
|
487
|
+
tol : float, optional
|
|
488
|
+
The convergence tolerance for messages.
|
|
489
|
+
output_inds : set[str], optional
|
|
490
|
+
The indices to consider as output (dangling) indices of the tn.
|
|
491
|
+
Computed automatically if not specified.
|
|
492
|
+
optimize : str or PathOptimizer, optional
|
|
493
|
+
The path optimizer to use when contracting the messages.
|
|
494
|
+
damping : float, optional
|
|
495
|
+
The damping parameter to use, defaults to no damping.
|
|
496
|
+
update : {'parallel', 'sequential'}, optional
|
|
497
|
+
Whether to update all messages in parallel or sequentially.
|
|
498
|
+
local_convergence : bool, optional
|
|
499
|
+
Whether to allow messages to locally converge - i.e. if all their
|
|
500
|
+
input messages have converged then stop updating them.
|
|
501
|
+
inplace : bool, optional
|
|
502
|
+
Whether to perform the compression inplace.
|
|
503
|
+
info : dict, optional
|
|
504
|
+
If specified, update this dictionary with information about the
|
|
505
|
+
belief propagation run.
|
|
506
|
+
progbar : bool, optional
|
|
507
|
+
Whether to show a progress bar.
|
|
508
|
+
contract_opts
|
|
509
|
+
Other options supplied to ``cotengra.array_contract``.
|
|
510
|
+
|
|
511
|
+
Returns
|
|
512
|
+
-------
|
|
513
|
+
TensorNetwork
|
|
514
|
+
"""
|
|
515
|
+
bp = D2BP(
|
|
516
|
+
tn,
|
|
517
|
+
messages=messages,
|
|
518
|
+
output_inds=output_inds,
|
|
519
|
+
optimize=optimize,
|
|
520
|
+
damping=damping,
|
|
521
|
+
update=update,
|
|
522
|
+
local_convergence=local_convergence,
|
|
523
|
+
**contract_opts,
|
|
524
|
+
)
|
|
525
|
+
bp.run(
|
|
526
|
+
max_iterations=max_iterations,
|
|
527
|
+
tol=tol,
|
|
528
|
+
info=info,
|
|
529
|
+
progbar=progbar,
|
|
530
|
+
)
|
|
531
|
+
return bp.compress(
|
|
532
|
+
max_bond=max_bond,
|
|
533
|
+
cutoff=cutoff,
|
|
534
|
+
cutoff_mode=cutoff_mode,
|
|
535
|
+
renorm=renorm,
|
|
536
|
+
inplace=inplace,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def sample_d2bp(
|
|
541
|
+
tn,
|
|
542
|
+
output_inds=None,
|
|
543
|
+
messages=None,
|
|
544
|
+
max_iterations=100,
|
|
545
|
+
tol=1e-2,
|
|
546
|
+
bias=None,
|
|
547
|
+
seed=None,
|
|
548
|
+
local_convergence=True,
|
|
549
|
+
progbar=False,
|
|
550
|
+
**contract_opts,
|
|
551
|
+
):
|
|
552
|
+
"""Sample a configuration from ``tn`` using dense 2-norm belief
|
|
553
|
+
propagation.
|
|
554
|
+
|
|
555
|
+
Parameters
|
|
556
|
+
----------
|
|
557
|
+
tn : TensorNetwork
|
|
558
|
+
The tensor network to sample from.
|
|
559
|
+
output_inds : set[str], optional
|
|
560
|
+
Which indices to sample.
|
|
561
|
+
messages : dict[(str, int), array_like], optional
|
|
562
|
+
The initial messages to use, effectively defaults to all ones if not
|
|
563
|
+
specified.
|
|
564
|
+
max_iterations : int, optional
|
|
565
|
+
The maximum number of iterations to perform, per marginal.
|
|
566
|
+
tol : float, optional
|
|
567
|
+
The convergence tolerance for messages.
|
|
568
|
+
bias : float, optional
|
|
569
|
+
Bias the sampling towards more locally likely bit-strings. This is
|
|
570
|
+
done by raising the probability of each bit-string to this power.
|
|
571
|
+
seed : int, optional
|
|
572
|
+
A random seed for reproducibility.
|
|
573
|
+
local_convergence : bool, optional
|
|
574
|
+
Whether to allow messages to locally converge - i.e. if all their
|
|
575
|
+
input messages have converged then stop updating them.
|
|
576
|
+
progbar : bool, optional
|
|
577
|
+
Whether to show a progress bar.
|
|
578
|
+
contract_opts
|
|
579
|
+
Other options supplied to ``cotengra.array_contract``.
|
|
580
|
+
|
|
581
|
+
Returns
|
|
582
|
+
-------
|
|
583
|
+
config : dict[str, int]
|
|
584
|
+
The sampled configuration, a mapping of output indices to values.
|
|
585
|
+
tn_config : TensorNetwork
|
|
586
|
+
The tensor network with the sampled configuration applied.
|
|
587
|
+
omega : float
|
|
588
|
+
The BP probability of the sampled configuration.
|
|
589
|
+
"""
|
|
590
|
+
import numpy as np
|
|
591
|
+
|
|
592
|
+
if output_inds is None:
|
|
593
|
+
output_inds = tn.outer_inds()
|
|
594
|
+
|
|
595
|
+
rng = np.random.default_rng(seed)
|
|
596
|
+
config = {}
|
|
597
|
+
omega = 1.0
|
|
598
|
+
|
|
599
|
+
tn = tn.copy()
|
|
600
|
+
bp = D2BP(
|
|
601
|
+
tn,
|
|
602
|
+
messages=messages,
|
|
603
|
+
local_convergence=local_convergence,
|
|
604
|
+
**contract_opts,
|
|
605
|
+
)
|
|
606
|
+
bp.run(max_iterations=max_iterations, tol=tol)
|
|
607
|
+
|
|
608
|
+
marginals = dict.fromkeys(output_inds)
|
|
609
|
+
|
|
610
|
+
if progbar:
|
|
611
|
+
import tqdm
|
|
612
|
+
|
|
613
|
+
pbar = tqdm.tqdm(total=len(marginals))
|
|
614
|
+
else:
|
|
615
|
+
pbar = None
|
|
616
|
+
|
|
617
|
+
while marginals:
|
|
618
|
+
for ix in marginals:
|
|
619
|
+
marginals[ix] = bp.compute_marginal(ix)
|
|
620
|
+
|
|
621
|
+
ix, p = max(marginals.items(), key=lambda x: max(x[1]))
|
|
622
|
+
p = ar.to_numpy(p)
|
|
623
|
+
|
|
624
|
+
if bias is not None:
|
|
625
|
+
# bias distribution towards more locally likely bit-strings
|
|
626
|
+
p = p**bias
|
|
627
|
+
p /= np.sum(p)
|
|
628
|
+
|
|
629
|
+
v = rng.choice([0, 1], p=p)
|
|
630
|
+
config[ix] = v
|
|
631
|
+
del marginals[ix]
|
|
632
|
+
|
|
633
|
+
tids = tuple(tn.ind_map[ix])
|
|
634
|
+
tn.isel_({ix: v})
|
|
635
|
+
|
|
636
|
+
omega *= p[v]
|
|
637
|
+
if progbar:
|
|
638
|
+
pbar.update(1)
|
|
639
|
+
pbar.set_description(f"{ix}->{v}", refresh=False)
|
|
640
|
+
|
|
641
|
+
bp = D2BP(
|
|
642
|
+
tn,
|
|
643
|
+
messages=bp.messages,
|
|
644
|
+
local_convergence=local_convergence,
|
|
645
|
+
**contract_opts,
|
|
646
|
+
)
|
|
647
|
+
bp.update_touched_from_tids(*tids)
|
|
648
|
+
bp.run(tol=tol, max_iterations=max_iterations)
|
|
649
|
+
|
|
650
|
+
if progbar:
|
|
651
|
+
pbar.close()
|
|
652
|
+
|
|
653
|
+
return config, tn, omega
|