Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,775 @@
1
+ """Hyper, vectorized, 1-norm, belief propagation.
2
+ """
3
+
4
+ import autoray as ar
5
+
6
+ from quimb.tensor.contraction import array_contract
7
+ from .bp_common import (
8
+ BeliefPropagationCommon,
9
+ compute_all_index_marginals_from_messages,
10
+ contract_hyper_messages,
11
+ initialize_hyper_messages,
12
+ maybe_get_thread_pool,
13
+ )
14
+
15
+
16
+ def initialize_messages_batched(tn, messages=None):
17
+ """Initialize batched messages for belief propagation, as the uniform
18
+ distribution.
19
+ """
20
+ if messages is None:
21
+ messages = initialize_hyper_messages(tn)
22
+
23
+ backend = ar.infer_backend(next(iter(messages.values())))
24
+ _stack = ar.get_lib_fn(backend, "stack")
25
+ _array = ar.get_lib_fn(backend, "array")
26
+
27
+ # prepare index messages
28
+ batched_inputs_m = {}
29
+ input_locs_m = {}
30
+ output_locs_m = {}
31
+ for ix, tids in tn.ind_map.items():
32
+ rank = len(tids)
33
+ try:
34
+ batch = batched_inputs_m[rank]
35
+ except KeyError:
36
+ batch = batched_inputs_m[rank] = [[] for _ in range(rank)]
37
+
38
+ for i, tid in enumerate(tids):
39
+ batch_i = batch[i]
40
+ # position in the stack
41
+ b = len(batch_i)
42
+ input_locs_m[tid, ix] = (rank, i, b)
43
+ output_locs_m[ix, tid] = (rank, i, b)
44
+ batch_i.append(messages[tid, ix])
45
+
46
+ # prepare tensor messages
47
+ batched_tensors = {}
48
+ batched_inputs_t = {}
49
+ input_locs_t = {}
50
+ output_locs_t = {}
51
+ for tid, t in tn.tensor_map.items():
52
+ rank = t.ndim
53
+ if rank == 0:
54
+ continue
55
+
56
+ try:
57
+ batch = batched_inputs_t[rank]
58
+ batch_t = batched_tensors[rank]
59
+ except KeyError:
60
+ batch = batched_inputs_t[rank] = [[] for _ in range(rank)]
61
+ batch_t = batched_tensors[rank] = []
62
+
63
+ for i, ix in enumerate(t.inds):
64
+ batch_i = batch[i]
65
+ # position in the stack
66
+ b = len(batch_i)
67
+ input_locs_t[ix, tid] = (rank, i, b)
68
+ output_locs_t[tid, ix] = (rank, i, b)
69
+ batch_i.append(messages[ix, tid])
70
+
71
+ batch_t.append(t.data)
72
+
73
+ # stack messages in into single arrays
74
+ for batched_inputs in (batched_inputs_m, batched_inputs_t):
75
+ for key, batch in batched_inputs.items():
76
+ batched_inputs[key] = _stack(
77
+ tuple(_stack(batch_i) for batch_i in batch)
78
+ )
79
+ for rank, tensors in batched_tensors.items():
80
+ batched_tensors[rank] = _stack(tensors)
81
+
82
+ # make numeric masks for updating output to input messages
83
+ masks_m = {}
84
+ masks_t = {}
85
+ for masks, input_locs, output_locs in [
86
+ (masks_m, input_locs_m, output_locs_t),
87
+ (masks_t, input_locs_t, output_locs_m),
88
+ ]:
89
+ for pair in input_locs:
90
+ (ranki, ii, bi) = input_locs[pair]
91
+ (ranko, io, bo) = output_locs[pair]
92
+ key = (ranki, ranko)
93
+ try:
94
+ maskin, maskout = masks[key]
95
+ except KeyError:
96
+ maskin, maskout = masks[key] = [], []
97
+ maskin.append([ii, bi])
98
+ maskout.append([io, bo])
99
+
100
+ for key, (maskin, maskout) in masks.items():
101
+ masks[key] = _array(maskin), _array(maskout)
102
+
103
+ return (
104
+ batched_inputs_m,
105
+ batched_inputs_t,
106
+ batched_tensors,
107
+ input_locs_m,
108
+ input_locs_t,
109
+ masks_m,
110
+ masks_t,
111
+ )
112
+
113
+
114
+ def _compute_all_hyperind_messages_tree_batched(bm):
115
+ """ """
116
+ ndim = len(bm)
117
+
118
+ if ndim == 2:
119
+ # shortcut for 'bonds', which just swap places
120
+ return ar.do("flip", bm, (0,))
121
+
122
+ backend = ar.infer_backend(bm)
123
+ _prod = ar.get_lib_fn(backend, "prod")
124
+ _empty_like = ar.get_lib_fn(backend, "empty_like")
125
+
126
+ bmo = _empty_like(bm)
127
+ queue = [(tuple(range(ndim)), 1, bm)]
128
+
129
+ while queue:
130
+ js, x, bm = queue.pop()
131
+
132
+ ndim = len(bm)
133
+ if ndim == 1:
134
+ # reached single message
135
+ bmo[js[0]] = x
136
+ continue
137
+ elif ndim == 2:
138
+ # shortcut for 2 messages left
139
+ bmo[js[0]] = x * bm[1]
140
+ bmo[js[1]] = bm[0] * x
141
+ continue
142
+
143
+ # else split in two and contract each half
144
+ k = ndim // 2
145
+ jl, jr = js[:k], js[k:]
146
+ bml, bmr = bm[:k], bm[k:]
147
+
148
+ # contract the right messages to get new left array
149
+ xl = x * _prod(bmr, axis=0)
150
+
151
+ # contract the left messages to get new right array
152
+ xr = _prod(bml, axis=0) * x
153
+
154
+ # add the queue for possible further halving
155
+ queue.append((jl, xl, bml))
156
+ queue.append((jr, xr, bmr))
157
+
158
+ return bmo
159
+
160
+
161
+ def _compute_all_hyperind_messages_prod_batched(bm, smudge_factor=1e-12):
162
+ """ """
163
+ backend = ar.infer_backend(bm)
164
+ _prod = ar.get_lib_fn(backend, "prod")
165
+ _reshape = ar.get_lib_fn(backend, "reshape")
166
+
167
+ ndim = len(bm)
168
+ if ndim == 2:
169
+ # shortcut for 'bonds', which just swap
170
+ return ar.do("flip", bm, (0,))
171
+
172
+ combined = _prod(bm, axis=0)
173
+ return _reshape(combined, (1, *ar.shape(combined))) / (bm + smudge_factor)
174
+
175
+
176
+ def _compute_all_tensor_messages_tree_batched(bx, bm):
177
+ """Compute all output messages for a stacked tensor and messages."""
178
+ backend = ar.infer_backend_multi(bx, bm)
179
+ _stack = ar.get_lib_fn(backend, "stack")
180
+
181
+ ndim = len(bm)
182
+ mouts = [None for _ in range(ndim)]
183
+ queue = [(tuple(range(ndim)), bx, bm)]
184
+
185
+ while queue:
186
+ js, bx, bm = queue.pop()
187
+
188
+ ndim = len(bm)
189
+ if ndim == 1:
190
+ # reached single message
191
+ mouts[js[0]] = bx
192
+ continue
193
+ elif ndim == 2:
194
+ # shortcut for 2 messages left
195
+ mouts[js[0]] = array_contract(
196
+ arrays=(bx, bm[1]),
197
+ inputs=(("X", "a", "b"), ("X", "b")),
198
+ output=("X", "a"),
199
+ backend=backend,
200
+ )
201
+ mouts[js[1]] = array_contract(
202
+ arrays=(bm[0], bx),
203
+ inputs=(("X", "a"), ("X", "a", "b")),
204
+ output=("X", "b"),
205
+ backend=backend,
206
+ )
207
+ continue
208
+
209
+ # else split in two and contract each half
210
+ k = ndim // 2
211
+ jl, jr = js[:k], js[k:]
212
+ ml, mr = bm[:k], bm[k:]
213
+
214
+ # contract the right messages to get new left array
215
+ xl = array_contract(
216
+ arrays=(bx, *(mr[i] for i in range(mr.shape[0]))),
217
+ inputs=((-1, *js), *((-1, j) for j in jr)),
218
+ output=(-1, *jl),
219
+ backend=backend,
220
+ )
221
+
222
+ # contract the left messages to get new right array
223
+ xr = array_contract(
224
+ arrays=(bx, *(ml[i] for i in range(ml.shape[0]))),
225
+ inputs=((-1, *js), *((-1, j) for j in jl)),
226
+ output=(-1, *jr),
227
+ backend=backend,
228
+ )
229
+
230
+ # add the queue for possible further halving
231
+ queue.append((jl, xl, ml))
232
+ queue.append((jr, xr, mr))
233
+
234
+ return _stack(tuple(mouts))
235
+
236
+
237
+ def _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor=1e-12):
238
+ backend = ar.infer_backend_multi(bx, bm)
239
+ _einsum = ar.get_lib_fn(backend, "einsum")
240
+ _stack = ar.get_lib_fn(backend, "stack")
241
+
242
+ ndim = len(bm)
243
+ x_inds = (-1, *range(ndim))
244
+ m_inds = [(-1, i) for i in range(ndim)]
245
+ bmx = array_contract(
246
+ arrays=(bx, *bm),
247
+ inputs=(x_inds, *m_inds),
248
+ output=x_inds,
249
+ )
250
+
251
+ bminv = 1 / (bm + smudge_factor)
252
+
253
+ mouts = []
254
+ for i in range(ndim):
255
+ # sum all but ith index, apply inverse gate to that
256
+ mouts.append(
257
+ array_contract(
258
+ arrays=(bmx, bminv[i]),
259
+ inputs=(x_inds, m_inds[i]),
260
+ output=m_inds[i],
261
+ )
262
+ )
263
+
264
+ return _stack(mouts)
265
+
266
+
267
+ def _compute_output_single_t(
268
+ bm,
269
+ bx,
270
+ _reshape,
271
+ _sum,
272
+ smudge_factor=1e-12,
273
+ ):
274
+ # tensor messages
275
+ bmo = _compute_all_tensor_messages_tree_batched(bx, bm)
276
+ # bmo = _compute_all_tensor_messages_prod_batched(bx, bm, smudge_factor)
277
+ # normalize
278
+ bmo /= _reshape(_sum(bmo, axis=-1), (*ar.shape(bmo)[:-1], 1))
279
+ return bmo
280
+
281
+
282
+ def _compute_output_single_m(bm, _reshape, _sum, smudge_factor=1e-12):
283
+ # index messages
284
+ # bmo = _compute_all_hyperind_messages_tree_batched(bm)
285
+ bmo = _compute_all_hyperind_messages_prod_batched(bm, smudge_factor)
286
+ # normalize
287
+ bmo /= _reshape(_sum(bmo, axis=-1), (*ar.shape(bmo)[:-1], 1))
288
+ return bmo
289
+
290
+
291
+ def _compute_outputs_batched(
292
+ batched_inputs,
293
+ batched_tensors=None,
294
+ smudge_factor=1e-12,
295
+ _pool=None,
296
+ ):
297
+ """Given stacked messsages and tensors, compute stacked output messages."""
298
+ backend = ar.infer_backend(next(iter(batched_inputs.values())))
299
+ _sum = ar.get_lib_fn(backend, "sum")
300
+ _reshape = ar.get_lib_fn(backend, "reshape")
301
+
302
+ if batched_tensors is not None:
303
+ # tensor messages
304
+ f = _compute_output_single_t
305
+ f_args = {
306
+ rank: (bm, batched_tensors[rank], _reshape, _sum, smudge_factor)
307
+ for rank, bm in batched_inputs.items()
308
+ }
309
+ else:
310
+ # index messages
311
+ f = _compute_output_single_m
312
+ f_args = {
313
+ rank: (bm, _reshape, _sum, smudge_factor)
314
+ for rank, bm in batched_inputs.items()
315
+ }
316
+
317
+ batched_outputs = {}
318
+ if _pool is None:
319
+ # sequential process
320
+ for rank, args in f_args.items():
321
+ batched_outputs[rank] = f(*args)
322
+ else:
323
+ # parallel process
324
+ for rank, args in f_args.items():
325
+ batched_outputs[rank] = _pool.submit(f, *args)
326
+ for key, fut in batched_outputs.items():
327
+ batched_outputs[key] = fut.result()
328
+
329
+ return batched_outputs
330
+
331
+
332
+ def _update_output_to_input_single_batched(
333
+ bi,
334
+ bo,
335
+ maskin,
336
+ maskout,
337
+ _max,
338
+ _sum,
339
+ _abs,
340
+ damping=0.0,
341
+ ):
342
+ # do a vectorized update
343
+ select_in = (maskin[:, 0], maskin[:, 1], slice(None))
344
+ select_out = (maskout[:, 0], maskout[:, 1], slice(None))
345
+ bim = bi[select_in]
346
+ bom = bo[select_out]
347
+
348
+ if damping > 0.0:
349
+ bim = (1 - damping) * bom + damping * bim
350
+
351
+ # first check the change
352
+ dm = _max(_sum(_abs(bim - bom), axis=-1))
353
+
354
+ # update the input
355
+ bi[select_in] = bom
356
+
357
+ return dm
358
+
359
+
360
+ def _update_outputs_to_inputs_batched(
361
+ batched_inputs, batched_outputs, masks, damping=0.0, _pool=None
362
+ ):
363
+ """Update the stacked input messages from the stacked output messages."""
364
+ backend = ar.infer_backend(next(iter(batched_outputs.values())))
365
+ _max = ar.get_lib_fn(backend, "max")
366
+ _sum = ar.get_lib_fn(backend, "sum")
367
+ _abs = ar.get_lib_fn(backend, "abs")
368
+
369
+ f = _update_output_to_input_single_batched
370
+ f_args = (
371
+ (
372
+ batched_inputs[ranki],
373
+ batched_outputs[ranko],
374
+ maskin,
375
+ maskout,
376
+ _max,
377
+ _sum,
378
+ _abs,
379
+ damping,
380
+ )
381
+ for (ranki, ranko), (maskin, maskout) in masks.items()
382
+ )
383
+
384
+ if _pool is None:
385
+ # sequential process
386
+ dms = (f(*args) for args in f_args)
387
+ else:
388
+ # parallel process
389
+ futs = [_pool.submit(f, *args) for args in f_args]
390
+ dms = (fut.result() for fut in futs)
391
+
392
+ return max(dms)
393
+
394
+
395
+ def _extract_messages_from_inputs_batched(
396
+ batched_inputs_m,
397
+ batched_inputs_t,
398
+ input_locs_m,
399
+ input_locs_t,
400
+ ):
401
+ """Get all messages as a dict from the batch stacked input form."""
402
+ messages = {}
403
+ for pair, (rank, i, b) in input_locs_m.items():
404
+ messages[pair] = batched_inputs_m[rank][i, b, :]
405
+ for pair, (rank, i, b) in input_locs_t.items():
406
+ messages[pair] = batched_inputs_t[rank][i, b, :]
407
+ return messages
408
+
409
+
410
+ def iterate_belief_propagation_batched(
411
+ batched_inputs_m,
412
+ batched_inputs_t,
413
+ batched_tensors,
414
+ masks_m,
415
+ masks_t,
416
+ smudge_factor=1e-12,
417
+ damping=0.0,
418
+ _pool=None,
419
+ ):
420
+ """ """
421
+ # compute tensor messages
422
+ batched_outputs_t = _compute_outputs_batched(
423
+ batched_inputs=batched_inputs_t,
424
+ batched_tensors=batched_tensors,
425
+ smudge_factor=smudge_factor,
426
+ _pool=_pool,
427
+ )
428
+ # update the index input messages
429
+ t_max_dm = _update_outputs_to_inputs_batched(
430
+ batched_inputs_m,
431
+ batched_outputs_t,
432
+ masks_m,
433
+ damping=damping,
434
+ _pool=_pool,
435
+ )
436
+
437
+ # compute index messages
438
+ batched_outputs_m = _compute_outputs_batched(
439
+ batched_inputs=batched_inputs_m,
440
+ batched_tensors=None,
441
+ smudge_factor=smudge_factor,
442
+ _pool=_pool,
443
+ )
444
+ # update the tensor input messages
445
+ m_max_dm = _update_outputs_to_inputs_batched(
446
+ batched_inputs_t,
447
+ batched_outputs_m,
448
+ masks_t,
449
+ damping=damping,
450
+ _pool=_pool,
451
+ )
452
+ return batched_inputs_m, batched_inputs_t, max(t_max_dm, m_max_dm)
453
+
454
+
455
+ class HV1BP(BeliefPropagationCommon):
456
+ """Object interface for hyper, vectorized, 1-norm, belief propagation. This
457
+ is the fast version of belief propagation possible when there are many,
458
+ small, matching tensor sizes.
459
+
460
+ Parameters
461
+ ----------
462
+ tn : TensorNetwork
463
+ The tensor network to run BP on.
464
+ messages : dict, optional
465
+ Initial messages to use, if not given then uniform messages are used.
466
+ smudge_factor : float, optional
467
+ A small number to add to the denominator of messages to avoid division
468
+ by zero. Note when this happens the numerator will also be zero.
469
+ thread_pool : bool or int, optional
470
+ Whether to use a thread pool for parallelization, if ``True`` use the
471
+ default number of threads, if an integer use that many threads.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ tn,
477
+ messages=None,
478
+ smudge_factor=1e-12,
479
+ damping=0.0,
480
+ thread_pool=False,
481
+ ):
482
+ self.tn = tn
483
+ self.backend = next(t.backend for t in tn)
484
+ self.smudge_factor = smudge_factor
485
+ self.damping = damping
486
+ self.pool = maybe_get_thread_pool(thread_pool)
487
+ (
488
+ self.batched_inputs_m,
489
+ self.batched_inputs_t,
490
+ self.batched_tensors,
491
+ self.input_locs_m,
492
+ self.input_locs_t,
493
+ self.masks_m,
494
+ self.masks_t,
495
+ ) = initialize_messages_batched(tn, messages)
496
+
497
+ def iterate(self, **kwargs):
498
+ (
499
+ self.batched_inputs_m,
500
+ self.batched_inputs_t,
501
+ max_dm,
502
+ ) = iterate_belief_propagation_batched(
503
+ self.batched_inputs_m,
504
+ self.batched_inputs_t,
505
+ self.batched_tensors,
506
+ self.masks_m,
507
+ self.masks_t,
508
+ damping=self.damping,
509
+ smudge_factor=self.smudge_factor,
510
+ _pool=self.pool,
511
+ )
512
+ return None, None, max_dm
513
+
514
+ def get_messages(self):
515
+ """Get messages in individual form from the batched stacks."""
516
+ return _extract_messages_from_inputs_batched(
517
+ self.batched_inputs_m,
518
+ self.batched_inputs_t,
519
+ self.input_locs_m,
520
+ self.input_locs_t,
521
+ )
522
+
523
+ def contract(self, strip_exponent=False):
524
+ return contract_hyper_messages(
525
+ self.tn,
526
+ self.get_messages(),
527
+ strip_exponent=strip_exponent,
528
+ backend=self.backend,
529
+ )
530
+
531
+
532
+ def contract_hv1bp(
533
+ tn,
534
+ messages=None,
535
+ max_iterations=1000,
536
+ tol=5e-6,
537
+ smudge_factor=1e-12,
538
+ damping=0.0,
539
+ strip_exponent=False,
540
+ info=None,
541
+ progbar=False,
542
+ ):
543
+ """Estimate the contraction of ``tn`` with hyper, vectorized, 1-norm
544
+ belief propagation, via the exponential of the Bethe free entropy.
545
+
546
+ Parameters
547
+ ----------
548
+ tn : TensorNetwork
549
+ The tensor network to run BP on, can have hyper indices.
550
+ messages : dict, optional
551
+ Initial messages to use, if not given then uniform messages are used.
552
+ max_iterations : int, optional
553
+ The maximum number of iterations to perform.
554
+ tol : float, optional
555
+ The convergence tolerance for messages.
556
+ smudge_factor : float, optional
557
+ A small number to add to the denominator of messages to avoid division
558
+ by zero. Note when this happens the numerator will also be zero.
559
+ damping : float, optional
560
+ The damping factor to use, 0.0 means no damping.
561
+ strip_exponent : bool, optional
562
+ Whether to strip the exponent from the final result. If ``True``
563
+ then the returned result is ``(mantissa, exponent)``.
564
+ info : dict, optional
565
+ If specified, update this dictionary with information about the
566
+ belief propagation run.
567
+ progbar : bool, optional
568
+ Whether to show a progress bar.
569
+
570
+ Returns
571
+ -------
572
+ scalar or (scalar, float)
573
+ """
574
+ bp = HV1BP(
575
+ tn,
576
+ messages=messages,
577
+ damping=damping,
578
+ smudge_factor=smudge_factor,
579
+ )
580
+ bp.run(
581
+ max_iterations=max_iterations,
582
+ tol=tol,
583
+ info=info,
584
+ progbar=progbar,
585
+ )
586
+ return bp.contract(strip_exponent=strip_exponent)
587
+
588
+
589
+ def run_belief_propagation_hv1bp(
590
+ tn,
591
+ messages=None,
592
+ max_iterations=1000,
593
+ tol=5e-6,
594
+ damping=0.0,
595
+ smudge_factor=1e-12,
596
+ info=None,
597
+ progbar=False,
598
+ ):
599
+ """Run belief propagation on a tensor network until it converges.
600
+
601
+ Parameters
602
+ ----------
603
+ tn : TensorNetwork
604
+ The tensor network to run BP on.
605
+ messages : dict, optional
606
+ The current messages. For every index and tensor id pair, there should
607
+ be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
608
+ If not given, then messages are initialized as uniform.
609
+ max_iterations : int, optional
610
+ The maximum number of iterations to run for.
611
+ tol : float, optional
612
+ The convergence tolerance.
613
+ damping : float, optional
614
+ The damping factor to use, 0.0 means no damping.
615
+ smudge_factor : float, optional
616
+ A small number to add to the denominator of messages to avoid division
617
+ by zero. Note when this happens the numerator will also be zero.
618
+ info : dict, optional
619
+ If specified, update this dictionary with information about the
620
+ belief propagation run.
621
+ progbar : bool, optional
622
+ Whether to show a progress bar.
623
+
624
+ Returns
625
+ -------
626
+ messages : dict
627
+ The final messages.
628
+ converged : bool
629
+ Whether the algorithm converged.
630
+ """
631
+ bp = HV1BP(
632
+ tn, messages=messages, damping=damping, smudge_factor=smudge_factor
633
+ )
634
+ bp.run(max_iterations=max_iterations, tol=tol, info=info, progbar=progbar)
635
+ return bp.get_messages(), bp.converged
636
+
637
+
638
+ def sample_hv1bp(
639
+ tn,
640
+ messages=None,
641
+ output_inds=None,
642
+ max_iterations=1000,
643
+ tol=1e-2,
644
+ damping=0.0,
645
+ smudge_factor=1e-12,
646
+ bias=False,
647
+ seed=None,
648
+ progbar=False,
649
+ ):
650
+ """Sample all indices of a tensor network using repeated belief propagation
651
+ runs and decimation.
652
+
653
+ Parameters
654
+ ----------
655
+ tn : TensorNetwork
656
+ The tensor network to sample.
657
+ messages : dict, optional
658
+ The current messages. For every index and tensor id pair, there should
659
+ be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
660
+ If not given, then messages are initialized as uniform.
661
+ output_inds : sequence of str, optional
662
+ The indices to sample. If not given, then all indices are sampled.
663
+ max_iterations : int, optional
664
+ The maximum number of iterations for each message passing run.
665
+ tol : float, optional
666
+ The convergence tolerance for each message passing run.
667
+ smudge_factor : float, optional
668
+ A small number to add to each message to avoid zeros. Making this large
669
+ is similar to adding a temperature, which can aid convergence but
670
+ likely produces less accurate marginals.
671
+ bias : bool or float, optional
672
+ Whether to bias the sampling towards the largest marginal. If ``False``
673
+ (the default), then indices are sampled proportional to their
674
+ marginals. If ``True``, then each index is 'sampled' to be its largest
675
+ weight value always. If a float, then the local probability
676
+ distribution is raised to this power before sampling.
677
+ thread_pool : bool, int or ThreadPoolExecutor, optional
678
+ Whether to use a thread pool for parallelization. If an integer, then
679
+ this is the number of threads to use. If ``True``, then the number of
680
+ threads is set to the number of cores. If a ``ThreadPoolExecutor``,
681
+ then this is used directly.
682
+ seed : int, optional
683
+ A random seed to use for the sampling.
684
+ progbar : bool, optional
685
+ Whether to show a progress bar.
686
+
687
+ Returns
688
+ -------
689
+ config : dict[str, int]
690
+ The sample configuration, mapping indices to values.
691
+ tn_config : TensorNetwork
692
+ The tensor network with all index values (or just those in
693
+ `output_inds` if supllied) selected. Contracting this tensor network
694
+ (which will just be a sequence of scalars if all index values have been
695
+ sampled) gives the weight of the sample, e.g. should be 1 for a SAT
696
+ problem and valid assignment.
697
+ omega : float
698
+ The probability of choosing this sample (i.e. product of marginal
699
+ values). Useful possibly for importance sampling.
700
+ """
701
+ import numpy as np
702
+
703
+ rng = np.random.default_rng(seed)
704
+
705
+ tn_config = tn.copy()
706
+
707
+ if messages is None:
708
+ messages = initialize_hyper_messages(tn_config)
709
+
710
+ if output_inds is None:
711
+ output_inds = tn_config.ind_map.keys()
712
+ output_inds = set(output_inds)
713
+
714
+ config = {}
715
+ omega = 1.0
716
+
717
+ if progbar:
718
+ import tqdm
719
+
720
+ pbar = tqdm.tqdm(total=len(output_inds))
721
+ else:
722
+ pbar = None
723
+
724
+ while output_inds:
725
+ messages, _ = run_belief_propagation_hv1bp(
726
+ tn_config,
727
+ messages,
728
+ max_iterations=max_iterations,
729
+ tol=tol,
730
+ damping=damping,
731
+ smudge_factor=smudge_factor,
732
+ )
733
+
734
+ marginals = compute_all_index_marginals_from_messages(
735
+ tn_config, messages
736
+ )
737
+
738
+ # choose most peaked marginal
739
+ ix, p = max(
740
+ (m for m in marginals.items() if m[0] in output_inds),
741
+ key=lambda ix_p: max(ix_p[1]),
742
+ )
743
+
744
+ if bias is False:
745
+ # sample the value according to the marginal
746
+ v = rng.choice(np.arange(p.size), p=p)
747
+ elif bias is True:
748
+ v = np.argmax(p)
749
+ # in some sense omega is really 1.0 here
750
+ else:
751
+ # bias towards larger marginals by raising to a power
752
+ p = p**bias
753
+ p /= np.sum(p)
754
+ v = np.random.choice(np.arange(p.size), p=p)
755
+
756
+ omega *= p[v]
757
+ config[ix] = v
758
+
759
+ # clean up messages
760
+ for tid in tn_config.ind_map[ix]:
761
+ del messages[ix, tid]
762
+ del messages[tid, ix]
763
+
764
+ # remove index
765
+ tn_config.isel_({ix: v})
766
+ output_inds.remove(ix)
767
+
768
+ if progbar:
769
+ pbar.update(1)
770
+ pbar.set_description(f"{ix}->{v}", refresh=False)
771
+
772
+ if progbar:
773
+ pbar.close()
774
+
775
+ return config, tn_config, omega