Trajectree 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trajectree/__init__.py +0 -3
- trajectree/fock_optics/devices.py +1 -1
- trajectree/fock_optics/light_sources.py +2 -2
- trajectree/fock_optics/measurement.py +3 -3
- trajectree/fock_optics/utils.py +6 -6
- trajectree/trajectory.py +2 -2
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/METADATA +2 -3
- trajectree-0.0.2.dist-info/RECORD +16 -0
- trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
- trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
- trajectree/quimb/docs/conf.py +0 -158
- trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
- trajectree/quimb/quimb/__init__.py +0 -507
- trajectree/quimb/quimb/calc.py +0 -1491
- trajectree/quimb/quimb/core.py +0 -2279
- trajectree/quimb/quimb/evo.py +0 -712
- trajectree/quimb/quimb/experimental/__init__.py +0 -0
- trajectree/quimb/quimb/experimental/autojittn.py +0 -129
- trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
- trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
- trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
- trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
- trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
- trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
- trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
- trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
- trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
- trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
- trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
- trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
- trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
- trajectree/quimb/quimb/experimental/schematic.py +0 -7
- trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
- trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
- trajectree/quimb/quimb/gates.py +0 -36
- trajectree/quimb/quimb/gen/__init__.py +0 -2
- trajectree/quimb/quimb/gen/operators.py +0 -1167
- trajectree/quimb/quimb/gen/rand.py +0 -713
- trajectree/quimb/quimb/gen/states.py +0 -479
- trajectree/quimb/quimb/linalg/__init__.py +0 -6
- trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
- trajectree/quimb/quimb/linalg/autoblock.py +0 -258
- trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
- trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
- trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
- trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
- trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
- trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
- trajectree/quimb/quimb/schematic.py +0 -1518
- trajectree/quimb/quimb/tensor/__init__.py +0 -401
- trajectree/quimb/quimb/tensor/array_ops.py +0 -610
- trajectree/quimb/quimb/tensor/circuit.py +0 -4824
- trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
- trajectree/quimb/quimb/tensor/contraction.py +0 -336
- trajectree/quimb/quimb/tensor/decomp.py +0 -1255
- trajectree/quimb/quimb/tensor/drawing.py +0 -1646
- trajectree/quimb/quimb/tensor/fitting.py +0 -385
- trajectree/quimb/quimb/tensor/geometry.py +0 -583
- trajectree/quimb/quimb/tensor/interface.py +0 -114
- trajectree/quimb/quimb/tensor/networking.py +0 -1058
- trajectree/quimb/quimb/tensor/optimize.py +0 -1818
- trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
- trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
- trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
- trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
- trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
- trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
- trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
- trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
- trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
- trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
- trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
- trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
- trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
- trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
- trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
- trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
- trajectree/quimb/quimb/utils.py +0 -892
- trajectree/quimb/tests/__init__.py +0 -0
- trajectree/quimb/tests/test_accel.py +0 -501
- trajectree/quimb/tests/test_calc.py +0 -788
- trajectree/quimb/tests/test_core.py +0 -847
- trajectree/quimb/tests/test_evo.py +0 -565
- trajectree/quimb/tests/test_gen/__init__.py +0 -0
- trajectree/quimb/tests/test_gen/test_operators.py +0 -361
- trajectree/quimb/tests/test_gen/test_rand.py +0 -296
- trajectree/quimb/tests/test_gen/test_states.py +0 -261
- trajectree/quimb/tests/test_linalg/__init__.py +0 -0
- trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
- trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
- trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
- trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
- trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
- trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
- trajectree/quimb/tests/test_tensor/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
- trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
- trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
- trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
- trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
- trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
- trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
- trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
- trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
- trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
- trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
- trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
- trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
- trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
- trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
- trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
- trajectree/quimb/tests/test_utils.py +0 -85
- trajectree-0.0.1.dist-info/RECORD +0 -126
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/WHEEL +0 -0
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/licenses/LICENSE +0 -0
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/top_level.txt +0 -0
|
@@ -1,775 +0,0 @@
|
|
|
1
|
-
"""Hyper, vectorized, 1-norm, belief propagation.
|
|
2
|
-
"""
|
|
3
|
-
|
|
4
|
-
import autoray as ar
|
|
5
|
-
|
|
6
|
-
from quimb.tensor.contraction import array_contract
|
|
7
|
-
from .bp_common import (
|
|
8
|
-
BeliefPropagationCommon,
|
|
9
|
-
compute_all_index_marginals_from_messages,
|
|
10
|
-
contract_hyper_messages,
|
|
11
|
-
initialize_hyper_messages,
|
|
12
|
-
maybe_get_thread_pool,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def initialize_messages_batched(tn, messages=None):
|
|
17
|
-
"""Initialize batched messages for belief propagation, as the uniform
|
|
18
|
-
distribution.
|
|
19
|
-
"""
|
|
20
|
-
if messages is None:
|
|
21
|
-
messages = initialize_hyper_messages(tn)
|
|
22
|
-
|
|
23
|
-
backend = ar.infer_backend(next(iter(messages.values())))
|
|
24
|
-
_stack = ar.get_lib_fn(backend, "stack")
|
|
25
|
-
_array = ar.get_lib_fn(backend, "array")
|
|
26
|
-
|
|
27
|
-
# prepare index messages
|
|
28
|
-
batched_inputs_m = {}
|
|
29
|
-
input_locs_m = {}
|
|
30
|
-
output_locs_m = {}
|
|
31
|
-
for ix, tids in tn.ind_map.items():
|
|
32
|
-
rank = len(tids)
|
|
33
|
-
try:
|
|
34
|
-
batch = batched_inputs_m[rank]
|
|
35
|
-
except KeyError:
|
|
36
|
-
batch = batched_inputs_m[rank] = [[] for _ in range(rank)]
|
|
37
|
-
|
|
38
|
-
for i, tid in enumerate(tids):
|
|
39
|
-
batch_i = batch[i]
|
|
40
|
-
# position in the stack
|
|
41
|
-
b = len(batch_i)
|
|
42
|
-
input_locs_m[tid, ix] = (rank, i, b)
|
|
43
|
-
output_locs_m[ix, tid] = (rank, i, b)
|
|
44
|
-
batch_i.append(messages[tid, ix])
|
|
45
|
-
|
|
46
|
-
# prepare tensor messages
|
|
47
|
-
batched_tensors = {}
|
|
48
|
-
batched_inputs_t = {}
|
|
49
|
-
input_locs_t = {}
|
|
50
|
-
output_locs_t = {}
|
|
51
|
-
for tid, t in tn.tensor_map.items():
|
|
52
|
-
rank = t.ndim
|
|
53
|
-
if rank == 0:
|
|
54
|
-
continue
|
|
55
|
-
|
|
56
|
-
try:
|
|
57
|
-
batch = batched_inputs_t[rank]
|
|
58
|
-
batch_t = batched_tensors[rank]
|
|
59
|
-
except KeyError:
|
|
60
|
-
batch = batched_inputs_t[rank] = [[] for _ in range(rank)]
|
|
61
|
-
batch_t = batched_tensors[rank] = []
|
|
62
|
-
|
|
63
|
-
for i, ix in enumerate(t.inds):
|
|
64
|
-
batch_i = batch[i]
|
|
65
|
-
# position in the stack
|
|
66
|
-
b = len(batch_i)
|
|
67
|
-
input_locs_t[ix, tid] = (rank, i, b)
|
|
68
|
-
output_locs_t[tid, ix] = (rank, i, b)
|
|
69
|
-
batch_i.append(messages[ix, tid])
|
|
70
|
-
|
|
71
|
-
batch_t.append(t.data)
|
|
72
|
-
|
|
73
|
-
# stack messages in into single arrays
|
|
74
|
-
for batched_inputs in (batched_inputs_m, batched_inputs_t):
|
|
75
|
-
for key, batch in batched_inputs.items():
|
|
76
|
-
batched_inputs[key] = _stack(
|
|
77
|
-
tuple(_stack(batch_i) for batch_i in batch)
|
|
78
|
-
)
|
|
79
|
-
for rank, tensors in batched_tensors.items():
|
|
80
|
-
batched_tensors[rank] = _stack(tensors)
|
|
81
|
-
|
|
82
|
-
# make numeric masks for updating output to input messages
|
|
83
|
-
masks_m = {}
|
|
84
|
-
masks_t = {}
|
|
85
|
-
for masks, input_locs, output_locs in [
|
|
86
|
-
(masks_m, input_locs_m, output_locs_t),
|
|
87
|
-
(masks_t, input_locs_t, output_locs_m),
|
|
88
|
-
]:
|
|
89
|
-
for pair in input_locs:
|
|
90
|
-
(ranki, ii, bi) = input_locs[pair]
|
|
91
|
-
(ranko, io, bo) = output_locs[pair]
|
|
92
|
-
key = (ranki, ranko)
|
|
93
|
-
try:
|
|
94
|
-
maskin, maskout = masks[key]
|
|
95
|
-
except KeyError:
|
|
96
|
-
maskin, maskout = masks[key] = [], []
|
|
97
|
-
maskin.append([ii, bi])
|
|
98
|
-
maskout.append([io, bo])
|
|
99
|
-
|
|
100
|
-
for key, (maskin, maskout) in masks.items():
|
|
101
|
-
masks[key] = _array(maskin), _array(maskout)
|
|
102
|
-
|
|
103
|
-
return (
|
|
104
|
-
batched_inputs_m,
|
|
105
|
-
batched_inputs_t,
|
|
106
|
-
batched_tensors,
|
|
107
|
-
input_locs_m,
|
|
108
|
-
input_locs_t,
|
|
109
|
-
masks_m,
|
|
110
|
-
masks_t,
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def _compute_all_hyperind_messages_tree_batched(bm):
|
|
115
|
-
""" """
|
|
116
|
-
ndim = len(bm)
|
|
117
|
-
|
|
118
|
-
if ndim == 2:
|
|
119
|
-
# shortcut for 'bonds', which just swap places
|
|
120
|
-
return ar.do("flip", bm, (0,))
|
|
121
|
-
|
|
122
|
-
backend = ar.infer_backend(bm)
|
|
123
|
-
_prod = ar.get_lib_fn(backend, "prod")
|
|
124
|
-
_empty_like = ar.get_lib_fn(backend, "empty_like")
|
|
125
|
-
|
|
126
|
-
bmo = _empty_like(bm)
|
|
127
|
-
queue = [(tuple(range(ndim)), 1, bm)]
|
|
128
|
-
|
|
129
|
-
while queue:
|
|
130
|
-
js, x, bm = queue.pop()
|
|
131
|
-
|
|
132
|
-
ndim = len(bm)
|
|
133
|
-
if ndim == 1:
|
|
134
|
-
# reached single message
|
|
135
|
-
bmo[js[0]] = x
|
|
136
|
-
continue
|
|
137
|
-
elif ndim == 2:
|
|
138
|
-
# shortcut for 2 messages left
|
|
139
|
-
bmo[js[0]] = x * bm[1]
|
|
140
|
-
bmo[js[1]] = bm[0] * x
|
|
141
|
-
continue
|
|
142
|
-
|
|
143
|
-
# else split in two and contract each half
|
|
144
|
-
k = ndim // 2
|
|
145
|
-
jl, jr = js[:k], js[k:]
|
|
146
|
-
bml, bmr = bm[:k], bm[k:]
|
|
147
|
-
|
|
148
|
-
# contract the right messages to get new left array
|
|
149
|
-
xl = x * _prod(bmr, axis=0)
|
|
150
|
-
|
|
151
|
-
# contract the left messages to get new right array
|
|
152
|
-
xr = _prod(bml, axis=0) * x
|
|
153
|
-
|
|
154
|
-
# add the queue for possible further halving
|
|
155
|
-
queue.append((jl, xl, bml))
|
|
156
|
-
queue.append((jr, xr, bmr))
|
|
157
|
-
|
|
158
|
-
return bmo
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def _compute_all_hyperind_messages_prod_batched(bm, smudge_factor=1e-12):
|
|
162
|
-
""" """
|
|
163
|
-
backend = ar.infer_backend(bm)
|
|
164
|
-
_prod = ar.get_lib_fn(backend, "prod")
|
|
165
|
-
_reshape = ar.get_lib_fn(backend, "reshape")
|
|
166
|
-
|
|
167
|
-
ndim = len(bm)
|
|
168
|
-
if ndim == 2:
|
|
169
|
-
# shortcut for 'bonds', which just swap
|
|
170
|
-
return ar.do("flip", bm, (0,))
|
|
171
|
-
|
|
172
|
-
combined = _prod(bm, axis=0)
|
|
173
|
-
return _reshape(combined, (1, *ar.shape(combined))) / (bm + smudge_factor)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
def _compute_all_tensor_messages_tree_batched(bx, bm):
|
|
177
|
-
"""Compute all output messages for a stacked tensor and messages."""
|
|
178
|
-
backend = ar.infer_backend_multi(bx, bm)
|
|
179
|
-
_stack = ar.get_lib_fn(backend, "stack")
|
|
180
|
-
|
|
181
|
-
ndim = len(bm)
|
|
182
|
-
mouts = [None for _ in range(ndim)]
|
|
183
|
-
queue = [(tuple(range(ndim)), bx, bm)]
|
|
184
|
-
|
|
185
|
-
while queue:
|
|
186
|
-
js, bx, bm = queue.pop()
|
|
187
|
-
|
|
188
|
-
ndim = len(bm)
|
|
189
|
-
if ndim == 1:
|
|
190
|
-
# reached single message
|
|
191
|
-
mouts[js[0]] = bx
|
|
192
|
-
continue
|
|
193
|
-
elif ndim == 2:
|
|
194
|
-
# shortcut for 2 messages left
|
|
195
|
-
mouts[js[0]] = array_contract(
|
|
196
|
-
arrays=(bx, bm[1]),
|
|
197
|
-
inputs=(("X", "a", "b"), ("X", "b")),
|
|
198
|
-
output=("X", "a"),
|
|
199
|
-
backend=backend,
|
|
200
|
-
)
|
|
201
|
-
mouts[js[1]] = array_contract(
|
|
202
|
-
arrays=(bm[0], bx),
|
|
203
|
-
inputs=(("X", "a"), ("X", "a", "b")),
|
|
204
|
-
output=("X", "b"),
|
|
205
|
-
backend=backend,
|
|
206
|
-
)
|
|
207
|
-
continue
|
|
208
|
-
|
|
209
|
-
# else split in two and contract each half
|
|
210
|
-
k = ndim // 2
|
|
211
|
-
jl, jr = js[:k], js[k:]
|
|
212
|
-
ml, mr = bm[:k], bm[k:]
|
|
213
|
-
|
|
214
|
-
# contract the right messages to get new left array
|
|
215
|
-
xl = array_contract(
|
|
216
|
-
arrays=(bx, *(mr[i] for i in range(mr.shape[0]))),
|
|
217
|
-
inputs=((-1, *js), *((-1, j) for j in jr)),
|
|
218
|
-
output=(-1, *jl),
|
|
219
|
-
backend=backend,
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
# contract the left messages to get new right array
|
|
223
|
-
xr = array_contract(
|
|
224
|
-
arrays=(bx, *(ml[i] for i in range(ml.shape[0]))),
|
|
225
|
-
inputs=((-1, *js), *((-1, j) for j in jl)),
|
|
226
|
-
output=(-1, *jr),
|
|
227
|
-
backend=backend,
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
# add the queue for possible further halving
|
|
231
|
-
queue.append((jl, xl, ml))
|
|
232
|
-
queue.append((jr, xr, mr))
|
|
233
|
-
|
|
234
|
-
return _stack(tuple(mouts))
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
def _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor=1e-12):
|
|
238
|
-
backend = ar.infer_backend_multi(bx, bm)
|
|
239
|
-
_einsum = ar.get_lib_fn(backend, "einsum")
|
|
240
|
-
_stack = ar.get_lib_fn(backend, "stack")
|
|
241
|
-
|
|
242
|
-
ndim = len(bm)
|
|
243
|
-
x_inds = (-1, *range(ndim))
|
|
244
|
-
m_inds = [(-1, i) for i in range(ndim)]
|
|
245
|
-
bmx = array_contract(
|
|
246
|
-
arrays=(bx, *bm),
|
|
247
|
-
inputs=(x_inds, *m_inds),
|
|
248
|
-
output=x_inds,
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
bminv = 1 / (bm + smudge_factor)
|
|
252
|
-
|
|
253
|
-
mouts = []
|
|
254
|
-
for i in range(ndim):
|
|
255
|
-
# sum all but ith index, apply inverse gate to that
|
|
256
|
-
mouts.append(
|
|
257
|
-
array_contract(
|
|
258
|
-
arrays=(bmx, bminv[i]),
|
|
259
|
-
inputs=(x_inds, m_inds[i]),
|
|
260
|
-
output=m_inds[i],
|
|
261
|
-
)
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
return _stack(mouts)
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
def _compute_output_single_t(
|
|
268
|
-
bm,
|
|
269
|
-
bx,
|
|
270
|
-
_reshape,
|
|
271
|
-
_sum,
|
|
272
|
-
smudge_factor=1e-12,
|
|
273
|
-
):
|
|
274
|
-
# tensor messages
|
|
275
|
-
bmo = _compute_all_tensor_messages_tree_batched(bx, bm)
|
|
276
|
-
# bmo = _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor)
|
|
277
|
-
# normalize
|
|
278
|
-
bmo /= _reshape(_sum(bmo, axis=-1), (*ar.shape(bmo)[:-1], 1))
|
|
279
|
-
return bmo
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def _compute_output_single_m(bm, _reshape, _sum, smudge_factor=1e-12):
|
|
283
|
-
# index messages
|
|
284
|
-
# bmo = _compute_all_hyperind_messages_tree_batched(bm)
|
|
285
|
-
bmo = _compute_all_hyperind_messages_prod_batched(bm, smudge_factor)
|
|
286
|
-
# normalize
|
|
287
|
-
bmo /= _reshape(_sum(bmo, axis=-1), (*ar.shape(bmo)[:-1], 1))
|
|
288
|
-
return bmo
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
def _compute_outputs_batched(
|
|
292
|
-
batched_inputs,
|
|
293
|
-
batched_tensors=None,
|
|
294
|
-
smudge_factor=1e-12,
|
|
295
|
-
_pool=None,
|
|
296
|
-
):
|
|
297
|
-
"""Given stacked messsages and tensors, compute stacked output messages."""
|
|
298
|
-
backend = ar.infer_backend(next(iter(batched_inputs.values())))
|
|
299
|
-
_sum = ar.get_lib_fn(backend, "sum")
|
|
300
|
-
_reshape = ar.get_lib_fn(backend, "reshape")
|
|
301
|
-
|
|
302
|
-
if batched_tensors is not None:
|
|
303
|
-
# tensor messages
|
|
304
|
-
f = _compute_output_single_t
|
|
305
|
-
f_args = {
|
|
306
|
-
rank: (bm, batched_tensors[rank], _reshape, _sum, smudge_factor)
|
|
307
|
-
for rank, bm in batched_inputs.items()
|
|
308
|
-
}
|
|
309
|
-
else:
|
|
310
|
-
# index messages
|
|
311
|
-
f = _compute_output_single_m
|
|
312
|
-
f_args = {
|
|
313
|
-
rank: (bm, _reshape, _sum, smudge_factor)
|
|
314
|
-
for rank, bm in batched_inputs.items()
|
|
315
|
-
}
|
|
316
|
-
|
|
317
|
-
batched_outputs = {}
|
|
318
|
-
if _pool is None:
|
|
319
|
-
# sequential process
|
|
320
|
-
for rank, args in f_args.items():
|
|
321
|
-
batched_outputs[rank] = f(*args)
|
|
322
|
-
else:
|
|
323
|
-
# parallel process
|
|
324
|
-
for rank, args in f_args.items():
|
|
325
|
-
batched_outputs[rank] = _pool.submit(f, *args)
|
|
326
|
-
for key, fut in batched_outputs.items():
|
|
327
|
-
batched_outputs[key] = fut.result()
|
|
328
|
-
|
|
329
|
-
return batched_outputs
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
def _update_output_to_input_single_batched(
|
|
333
|
-
bi,
|
|
334
|
-
bo,
|
|
335
|
-
maskin,
|
|
336
|
-
maskout,
|
|
337
|
-
_max,
|
|
338
|
-
_sum,
|
|
339
|
-
_abs,
|
|
340
|
-
damping=0.0,
|
|
341
|
-
):
|
|
342
|
-
# do a vectorized update
|
|
343
|
-
select_in = (maskin[:, 0], maskin[:, 1], slice(None))
|
|
344
|
-
select_out = (maskout[:, 0], maskout[:, 1], slice(None))
|
|
345
|
-
bim = bi[select_in]
|
|
346
|
-
bom = bo[select_out]
|
|
347
|
-
|
|
348
|
-
if damping > 0.0:
|
|
349
|
-
bim = (1 - damping) * bom + damping * bim
|
|
350
|
-
|
|
351
|
-
# first check the change
|
|
352
|
-
dm = _max(_sum(_abs(bim - bom), axis=-1))
|
|
353
|
-
|
|
354
|
-
# update the input
|
|
355
|
-
bi[select_in] = bom
|
|
356
|
-
|
|
357
|
-
return dm
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
def _update_outputs_to_inputs_batched(
|
|
361
|
-
batched_inputs, batched_outputs, masks, damping=0.0, _pool=None
|
|
362
|
-
):
|
|
363
|
-
"""Update the stacked input messages from the stacked output messages."""
|
|
364
|
-
backend = ar.infer_backend(next(iter(batched_outputs.values())))
|
|
365
|
-
_max = ar.get_lib_fn(backend, "max")
|
|
366
|
-
_sum = ar.get_lib_fn(backend, "sum")
|
|
367
|
-
_abs = ar.get_lib_fn(backend, "abs")
|
|
368
|
-
|
|
369
|
-
f = _update_output_to_input_single_batched
|
|
370
|
-
f_args = (
|
|
371
|
-
(
|
|
372
|
-
batched_inputs[ranki],
|
|
373
|
-
batched_outputs[ranko],
|
|
374
|
-
maskin,
|
|
375
|
-
maskout,
|
|
376
|
-
_max,
|
|
377
|
-
_sum,
|
|
378
|
-
_abs,
|
|
379
|
-
damping,
|
|
380
|
-
)
|
|
381
|
-
for (ranki, ranko), (maskin, maskout) in masks.items()
|
|
382
|
-
)
|
|
383
|
-
|
|
384
|
-
if _pool is None:
|
|
385
|
-
# sequential process
|
|
386
|
-
dms = (f(*args) for args in f_args)
|
|
387
|
-
else:
|
|
388
|
-
# parallel process
|
|
389
|
-
futs = [_pool.submit(f, *args) for args in f_args]
|
|
390
|
-
dms = (fut.result() for fut in futs)
|
|
391
|
-
|
|
392
|
-
return max(dms)
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
def _extract_messages_from_inputs_batched(
|
|
396
|
-
batched_inputs_m,
|
|
397
|
-
batched_inputs_t,
|
|
398
|
-
input_locs_m,
|
|
399
|
-
input_locs_t,
|
|
400
|
-
):
|
|
401
|
-
"""Get all messages as a dict from the batch stacked input form."""
|
|
402
|
-
messages = {}
|
|
403
|
-
for pair, (rank, i, b) in input_locs_m.items():
|
|
404
|
-
messages[pair] = batched_inputs_m[rank][i, b, :]
|
|
405
|
-
for pair, (rank, i, b) in input_locs_t.items():
|
|
406
|
-
messages[pair] = batched_inputs_t[rank][i, b, :]
|
|
407
|
-
return messages
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
def iterate_belief_propagation_batched(
|
|
411
|
-
batched_inputs_m,
|
|
412
|
-
batched_inputs_t,
|
|
413
|
-
batched_tensors,
|
|
414
|
-
masks_m,
|
|
415
|
-
masks_t,
|
|
416
|
-
smudge_factor=1e-12,
|
|
417
|
-
damping=0.0,
|
|
418
|
-
_pool=None,
|
|
419
|
-
):
|
|
420
|
-
""" """
|
|
421
|
-
# compute tensor messages
|
|
422
|
-
batched_outputs_t = _compute_outputs_batched(
|
|
423
|
-
batched_inputs=batched_inputs_t,
|
|
424
|
-
batched_tensors=batched_tensors,
|
|
425
|
-
smudge_factor=smudge_factor,
|
|
426
|
-
_pool=_pool,
|
|
427
|
-
)
|
|
428
|
-
# update the index input messages
|
|
429
|
-
t_max_dm = _update_outputs_to_inputs_batched(
|
|
430
|
-
batched_inputs_m,
|
|
431
|
-
batched_outputs_t,
|
|
432
|
-
masks_m,
|
|
433
|
-
damping=damping,
|
|
434
|
-
_pool=_pool,
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
# compute index messages
|
|
438
|
-
batched_outputs_m = _compute_outputs_batched(
|
|
439
|
-
batched_inputs=batched_inputs_m,
|
|
440
|
-
batched_tensors=None,
|
|
441
|
-
smudge_factor=smudge_factor,
|
|
442
|
-
_pool=_pool,
|
|
443
|
-
)
|
|
444
|
-
# update the tensor input messages
|
|
445
|
-
m_max_dm = _update_outputs_to_inputs_batched(
|
|
446
|
-
batched_inputs_t,
|
|
447
|
-
batched_outputs_m,
|
|
448
|
-
masks_t,
|
|
449
|
-
damping=damping,
|
|
450
|
-
_pool=_pool,
|
|
451
|
-
)
|
|
452
|
-
return batched_inputs_m, batched_inputs_t, max(t_max_dm, m_max_dm)
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
class HV1BP(BeliefPropagationCommon):
|
|
456
|
-
"""Object interface for hyper, vectorized, 1-norm, belief propagation. This
|
|
457
|
-
is the fast version of belief propagation possible when there are many,
|
|
458
|
-
small, matching tensor sizes.
|
|
459
|
-
|
|
460
|
-
Parameters
|
|
461
|
-
----------
|
|
462
|
-
tn : TensorNetwork
|
|
463
|
-
The tensor network to run BP on.
|
|
464
|
-
messages : dict, optional
|
|
465
|
-
Initial messages to use, if not given then uniform messages are used.
|
|
466
|
-
smudge_factor : float, optional
|
|
467
|
-
A small number to add to the denominator of messages to avoid division
|
|
468
|
-
by zero. Note when this happens the numerator will also be zero.
|
|
469
|
-
thread_pool : bool or int, optional
|
|
470
|
-
Whether to use a thread pool for parallelization, if ``True`` use the
|
|
471
|
-
default number of threads, if an integer use that many threads.
|
|
472
|
-
"""
|
|
473
|
-
|
|
474
|
-
def __init__(
|
|
475
|
-
self,
|
|
476
|
-
tn,
|
|
477
|
-
messages=None,
|
|
478
|
-
smudge_factor=1e-12,
|
|
479
|
-
damping=0.0,
|
|
480
|
-
thread_pool=False,
|
|
481
|
-
):
|
|
482
|
-
self.tn = tn
|
|
483
|
-
self.backend = next(t.backend for t in tn)
|
|
484
|
-
self.smudge_factor = smudge_factor
|
|
485
|
-
self.damping = damping
|
|
486
|
-
self.pool = maybe_get_thread_pool(thread_pool)
|
|
487
|
-
(
|
|
488
|
-
self.batched_inputs_m,
|
|
489
|
-
self.batched_inputs_t,
|
|
490
|
-
self.batched_tensors,
|
|
491
|
-
self.input_locs_m,
|
|
492
|
-
self.input_locs_t,
|
|
493
|
-
self.masks_m,
|
|
494
|
-
self.masks_t,
|
|
495
|
-
) = initialize_messages_batched(tn, messages)
|
|
496
|
-
|
|
497
|
-
def iterate(self, **kwargs):
|
|
498
|
-
(
|
|
499
|
-
self.batched_inputs_m,
|
|
500
|
-
self.batched_inputs_t,
|
|
501
|
-
max_dm,
|
|
502
|
-
) = iterate_belief_propagation_batched(
|
|
503
|
-
self.batched_inputs_m,
|
|
504
|
-
self.batched_inputs_t,
|
|
505
|
-
self.batched_tensors,
|
|
506
|
-
self.masks_m,
|
|
507
|
-
self.masks_t,
|
|
508
|
-
damping=self.damping,
|
|
509
|
-
smudge_factor=self.smudge_factor,
|
|
510
|
-
_pool=self.pool,
|
|
511
|
-
)
|
|
512
|
-
return None, None, max_dm
|
|
513
|
-
|
|
514
|
-
def get_messages(self):
|
|
515
|
-
"""Get messages in individual form from the batched stacks."""
|
|
516
|
-
return _extract_messages_from_inputs_batched(
|
|
517
|
-
self.batched_inputs_m,
|
|
518
|
-
self.batched_inputs_t,
|
|
519
|
-
self.input_locs_m,
|
|
520
|
-
self.input_locs_t,
|
|
521
|
-
)
|
|
522
|
-
|
|
523
|
-
def contract(self, strip_exponent=False):
|
|
524
|
-
return contract_hyper_messages(
|
|
525
|
-
self.tn,
|
|
526
|
-
self.get_messages(),
|
|
527
|
-
strip_exponent=strip_exponent,
|
|
528
|
-
backend=self.backend,
|
|
529
|
-
)
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
def contract_hv1bp(
|
|
533
|
-
tn,
|
|
534
|
-
messages=None,
|
|
535
|
-
max_iterations=1000,
|
|
536
|
-
tol=5e-6,
|
|
537
|
-
smudge_factor=1e-12,
|
|
538
|
-
damping=0.0,
|
|
539
|
-
strip_exponent=False,
|
|
540
|
-
info=None,
|
|
541
|
-
progbar=False,
|
|
542
|
-
):
|
|
543
|
-
"""Estimate the contraction of ``tn`` with hyper, vectorized, 1-norm
|
|
544
|
-
belief propagation, via the exponential of the Bethe free entropy.
|
|
545
|
-
|
|
546
|
-
Parameters
|
|
547
|
-
----------
|
|
548
|
-
tn : TensorNetwork
|
|
549
|
-
The tensor network to run BP on, can have hyper indices.
|
|
550
|
-
messages : dict, optional
|
|
551
|
-
Initial messages to use, if not given then uniform messages are used.
|
|
552
|
-
max_iterations : int, optional
|
|
553
|
-
The maximum number of iterations to perform.
|
|
554
|
-
tol : float, optional
|
|
555
|
-
The convergence tolerance for messages.
|
|
556
|
-
smudge_factor : float, optional
|
|
557
|
-
A small number to add to the denominator of messages to avoid division
|
|
558
|
-
by zero. Note when this happens the numerator will also be zero.
|
|
559
|
-
damping : float, optional
|
|
560
|
-
The damping factor to use, 0.0 means no damping.
|
|
561
|
-
strip_exponent : bool, optional
|
|
562
|
-
Whether to strip the exponent from the final result. If ``True``
|
|
563
|
-
then the returned result is ``(mantissa, exponent)``.
|
|
564
|
-
info : dict, optional
|
|
565
|
-
If specified, update this dictionary with information about the
|
|
566
|
-
belief propagation run.
|
|
567
|
-
progbar : bool, optional
|
|
568
|
-
Whether to show a progress bar.
|
|
569
|
-
|
|
570
|
-
Returns
|
|
571
|
-
-------
|
|
572
|
-
scalar or (scalar, float)
|
|
573
|
-
"""
|
|
574
|
-
bp = HV1BP(
|
|
575
|
-
tn,
|
|
576
|
-
messages=messages,
|
|
577
|
-
damping=damping,
|
|
578
|
-
smudge_factor=smudge_factor,
|
|
579
|
-
)
|
|
580
|
-
bp.run(
|
|
581
|
-
max_iterations=max_iterations,
|
|
582
|
-
tol=tol,
|
|
583
|
-
info=info,
|
|
584
|
-
progbar=progbar,
|
|
585
|
-
)
|
|
586
|
-
return bp.contract(strip_exponent=strip_exponent)
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
def run_belief_propagation_hv1bp(
|
|
590
|
-
tn,
|
|
591
|
-
messages=None,
|
|
592
|
-
max_iterations=1000,
|
|
593
|
-
tol=5e-6,
|
|
594
|
-
damping=0.0,
|
|
595
|
-
smudge_factor=1e-12,
|
|
596
|
-
info=None,
|
|
597
|
-
progbar=False,
|
|
598
|
-
):
|
|
599
|
-
"""Run belief propagation on a tensor network until it converges.
|
|
600
|
-
|
|
601
|
-
Parameters
|
|
602
|
-
----------
|
|
603
|
-
tn : TensorNetwork
|
|
604
|
-
The tensor network to run BP on.
|
|
605
|
-
messages : dict, optional
|
|
606
|
-
The current messages. For every index and tensor id pair, there should
|
|
607
|
-
be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
|
|
608
|
-
If not given, then messages are initialized as uniform.
|
|
609
|
-
max_iterations : int, optional
|
|
610
|
-
The maximum number of iterations to run for.
|
|
611
|
-
tol : float, optional
|
|
612
|
-
The convergence tolerance.
|
|
613
|
-
damping : float, optional
|
|
614
|
-
The damping factor to use, 0.0 means no damping.
|
|
615
|
-
smudge_factor : float, optional
|
|
616
|
-
A small number to add to the denominator of messages to avoid division
|
|
617
|
-
by zero. Note when this happens the numerator will also be zero.
|
|
618
|
-
info : dict, optional
|
|
619
|
-
If specified, update this dictionary with information about the
|
|
620
|
-
belief propagation run.
|
|
621
|
-
progbar : bool, optional
|
|
622
|
-
Whether to show a progress bar.
|
|
623
|
-
|
|
624
|
-
Returns
|
|
625
|
-
-------
|
|
626
|
-
messages : dict
|
|
627
|
-
The final messages.
|
|
628
|
-
converged : bool
|
|
629
|
-
Whether the algorithm converged.
|
|
630
|
-
"""
|
|
631
|
-
bp = HV1BP(
|
|
632
|
-
tn, messages=messages, damping=damping, smudge_factor=smudge_factor
|
|
633
|
-
)
|
|
634
|
-
bp.run(max_iterations=max_iterations, tol=tol, info=info, progbar=progbar)
|
|
635
|
-
return bp.get_messages(), bp.converged
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
def sample_hv1bp(
|
|
639
|
-
tn,
|
|
640
|
-
messages=None,
|
|
641
|
-
output_inds=None,
|
|
642
|
-
max_iterations=1000,
|
|
643
|
-
tol=1e-2,
|
|
644
|
-
damping=0.0,
|
|
645
|
-
smudge_factor=1e-12,
|
|
646
|
-
bias=False,
|
|
647
|
-
seed=None,
|
|
648
|
-
progbar=False,
|
|
649
|
-
):
|
|
650
|
-
"""Sample all indices of a tensor network using repeated belief propagation
|
|
651
|
-
runs and decimation.
|
|
652
|
-
|
|
653
|
-
Parameters
|
|
654
|
-
----------
|
|
655
|
-
tn : TensorNetwork
|
|
656
|
-
The tensor network to sample.
|
|
657
|
-
messages : dict, optional
|
|
658
|
-
The current messages. For every index and tensor id pair, there should
|
|
659
|
-
be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
|
|
660
|
-
If not given, then messages are initialized as uniform.
|
|
661
|
-
output_inds : sequence of str, optional
|
|
662
|
-
The indices to sample. If not given, then all indices are sampled.
|
|
663
|
-
max_iterations : int, optional
|
|
664
|
-
The maximum number of iterations for each message passing run.
|
|
665
|
-
tol : float, optional
|
|
666
|
-
The convergence tolerance for each message passing run.
|
|
667
|
-
smudge_factor : float, optional
|
|
668
|
-
A small number to add to each message to avoid zeros. Making this large
|
|
669
|
-
is similar to adding a temperature, which can aid convergence but
|
|
670
|
-
likely produces less accurate marginals.
|
|
671
|
-
bias : bool or float, optional
|
|
672
|
-
Whether to bias the sampling towards the largest marginal. If ``False``
|
|
673
|
-
(the default), then indices are sampled proportional to their
|
|
674
|
-
marginals. If ``True``, then each index is 'sampled' to be its largest
|
|
675
|
-
weight value always. If a float, then the local probability
|
|
676
|
-
distribution is raised to this power before sampling.
|
|
677
|
-
thread_pool : bool, int or ThreadPoolExecutor, optional
|
|
678
|
-
Whether to use a thread pool for parallelization. If an integer, then
|
|
679
|
-
this is the number of threads to use. If ``True``, then the number of
|
|
680
|
-
threads is set to the number of cores. If a ``ThreadPoolExecutor``,
|
|
681
|
-
then this is used directly.
|
|
682
|
-
seed : int, optional
|
|
683
|
-
A random seed to use for the sampling.
|
|
684
|
-
progbar : bool, optional
|
|
685
|
-
Whether to show a progress bar.
|
|
686
|
-
|
|
687
|
-
Returns
|
|
688
|
-
-------
|
|
689
|
-
config : dict[str, int]
|
|
690
|
-
The sample configuration, mapping indices to values.
|
|
691
|
-
tn_config : TensorNetwork
|
|
692
|
-
The tensor network with all index values (or just those in
|
|
693
|
-
`output_inds` if supllied) selected. Contracting this tensor network
|
|
694
|
-
(which will just be a sequence of scalars if all index values have been
|
|
695
|
-
sampled) gives the weight of the sample, e.g. should be 1 for a SAT
|
|
696
|
-
problem and valid assignment.
|
|
697
|
-
omega : float
|
|
698
|
-
The probability of choosing this sample (i.e. product of marginal
|
|
699
|
-
values). Useful possibly for importance sampling.
|
|
700
|
-
"""
|
|
701
|
-
import numpy as np
|
|
702
|
-
|
|
703
|
-
rng = np.random.default_rng(seed)
|
|
704
|
-
|
|
705
|
-
tn_config = tn.copy()
|
|
706
|
-
|
|
707
|
-
if messages is None:
|
|
708
|
-
messages = initialize_hyper_messages(tn_config)
|
|
709
|
-
|
|
710
|
-
if output_inds is None:
|
|
711
|
-
output_inds = tn_config.ind_map.keys()
|
|
712
|
-
output_inds = set(output_inds)
|
|
713
|
-
|
|
714
|
-
config = {}
|
|
715
|
-
omega = 1.0
|
|
716
|
-
|
|
717
|
-
if progbar:
|
|
718
|
-
import tqdm
|
|
719
|
-
|
|
720
|
-
pbar = tqdm.tqdm(total=len(output_inds))
|
|
721
|
-
else:
|
|
722
|
-
pbar = None
|
|
723
|
-
|
|
724
|
-
while output_inds:
|
|
725
|
-
messages, _ = run_belief_propagation_hv1bp(
|
|
726
|
-
tn_config,
|
|
727
|
-
messages,
|
|
728
|
-
max_iterations=max_iterations,
|
|
729
|
-
tol=tol,
|
|
730
|
-
damping=damping,
|
|
731
|
-
smudge_factor=smudge_factor,
|
|
732
|
-
)
|
|
733
|
-
|
|
734
|
-
marginals = compute_all_index_marginals_from_messages(
|
|
735
|
-
tn_config, messages
|
|
736
|
-
)
|
|
737
|
-
|
|
738
|
-
# choose most peaked marginal
|
|
739
|
-
ix, p = max(
|
|
740
|
-
(m for m in marginals.items() if m[0] in output_inds),
|
|
741
|
-
key=lambda ix_p: max(ix_p[1]),
|
|
742
|
-
)
|
|
743
|
-
|
|
744
|
-
if bias is False:
|
|
745
|
-
# sample the value according to the marginal
|
|
746
|
-
v = rng.choice(np.arange(p.size), p=p)
|
|
747
|
-
elif bias is True:
|
|
748
|
-
v = np.argmax(p)
|
|
749
|
-
# in some sense omega is really 1.0 here
|
|
750
|
-
else:
|
|
751
|
-
# bias towards larger marginals by raising to a power
|
|
752
|
-
p = p**bias
|
|
753
|
-
p /= np.sum(p)
|
|
754
|
-
v = np.random.choice(np.arange(p.size), p=p)
|
|
755
|
-
|
|
756
|
-
omega *= p[v]
|
|
757
|
-
config[ix] = v
|
|
758
|
-
|
|
759
|
-
# clean up messages
|
|
760
|
-
for tid in tn_config.ind_map[ix]:
|
|
761
|
-
del messages[ix, tid]
|
|
762
|
-
del messages[tid, ix]
|
|
763
|
-
|
|
764
|
-
# remove index
|
|
765
|
-
tn_config.isel_({ix: v})
|
|
766
|
-
output_inds.remove(ix)
|
|
767
|
-
|
|
768
|
-
if progbar:
|
|
769
|
-
pbar.update(1)
|
|
770
|
-
pbar.set_description(f"{ix}->{v}", refresh=False)
|
|
771
|
-
|
|
772
|
-
if progbar:
|
|
773
|
-
pbar.close()
|
|
774
|
-
|
|
775
|
-
return config, tn_config, omega
|