Trajectree 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +9 -9
  5. trajectree/fock_optics/outputs.py +10 -6
  6. trajectree/fock_optics/utils.py +9 -6
  7. trajectree/sequence/swap.py +5 -4
  8. trajectree/trajectory.py +5 -4
  9. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/METADATA +2 -3
  10. trajectree-0.0.3.dist-info/RECORD +16 -0
  11. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  12. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  13. trajectree/quimb/docs/conf.py +0 -158
  14. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  15. trajectree/quimb/quimb/__init__.py +0 -507
  16. trajectree/quimb/quimb/calc.py +0 -1491
  17. trajectree/quimb/quimb/core.py +0 -2279
  18. trajectree/quimb/quimb/evo.py +0 -712
  19. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  20. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  21. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  22. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  23. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  24. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  25. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  26. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  27. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  28. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  29. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  30. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  31. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  32. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  33. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  34. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  35. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  36. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  37. trajectree/quimb/quimb/gates.py +0 -36
  38. trajectree/quimb/quimb/gen/__init__.py +0 -2
  39. trajectree/quimb/quimb/gen/operators.py +0 -1167
  40. trajectree/quimb/quimb/gen/rand.py +0 -713
  41. trajectree/quimb/quimb/gen/states.py +0 -479
  42. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  43. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  44. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  45. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  46. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  47. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  48. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  49. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  50. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  51. trajectree/quimb/quimb/schematic.py +0 -1518
  52. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  53. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  54. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  55. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  56. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  57. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  58. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  59. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  60. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  61. trajectree/quimb/quimb/tensor/interface.py +0 -114
  62. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  63. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  64. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  65. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  66. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  67. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  68. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  69. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  70. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  71. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  72. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  74. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  75. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  76. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  77. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  78. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  79. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  80. trajectree/quimb/quimb/utils.py +0 -892
  81. trajectree/quimb/tests/__init__.py +0 -0
  82. trajectree/quimb/tests/test_accel.py +0 -501
  83. trajectree/quimb/tests/test_calc.py +0 -788
  84. trajectree/quimb/tests/test_core.py +0 -847
  85. trajectree/quimb/tests/test_evo.py +0 -565
  86. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  87. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  88. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  89. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  90. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  91. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  92. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  93. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  94. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  95. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  96. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  97. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  103. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  104. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  105. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  106. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  107. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  108. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  109. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  110. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  111. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  112. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  113. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  114. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  115. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  116. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  117. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  118. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  119. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  120. trajectree/quimb/tests/test_utils.py +0 -85
  121. trajectree-0.0.1.dist-info/RECORD +0 -126
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/WHEEL +0 -0
  123. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/licenses/LICENSE +0 -0
  124. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/top_level.txt +0 -0
@@ -1,571 +0,0 @@
1
- """Hyper dense belief propagation for arbitrary `quimb` tensor networks. This
2
- is the classic 1-norm version of belief propagation, which treats the tensor
3
- network directly as a factor graph. Messages are processed one at a time.
4
-
5
- TODO:
6
-
7
- - [ ] implement 'touching', so that only necessary messages are updated
8
- - [ ] implement sequential update
9
-
10
- """
11
- import autoray as ar
12
- import quimb.tensor as qtn
13
-
14
- from .bp_common import (
15
- BeliefPropagationCommon,
16
- compute_all_index_marginals_from_messages,
17
- contract_hyper_messages,
18
- initialize_hyper_messages,
19
- prod,
20
- )
21
-
22
-
23
- def compute_all_hyperind_messages_prod(ms, smudge_factor=1e-12):
24
- """Given set of messages ``ms`` incident to a single index, compute the
25
- corresponding next output messages, using the 'product' implementation.
26
- """
27
- if len(ms) == 2:
28
- # shortcut for 2 messages
29
- return [ms[1], ms[0]]
30
-
31
- x = prod(ms)
32
- return [x / (m + smudge_factor) for m in ms]
33
-
34
-
35
- def compute_all_hyperind_messages_tree(ms):
36
- """Given set of messages ``ms`` incident to a single index, compute the
37
- corresponding next output messages, using the 'tree' implementation.
38
- """
39
- ndim = len(ms)
40
- if len(ms) == 2:
41
- # shortcut for 2 messages
42
- return [ms[1], ms[0]]
43
-
44
- mouts = [None for _ in range(ndim)]
45
- queue = [(tuple(range(ndim)), 1, ms)]
46
-
47
- while queue:
48
- js, x, ms = queue.pop()
49
-
50
- ndim = len(ms)
51
- if ndim == 1:
52
- # reached single message
53
- mouts[js[0]] = x
54
- continue
55
- elif ndim == 2:
56
- # shortcut for 2 messages left
57
- mouts[js[0]] = x * ms[1]
58
- mouts[js[1]] = ms[0] * x
59
- continue
60
-
61
- # else split in two and contract each half
62
- k = ndim // 2
63
- jl, jr = js[:k], js[k:]
64
- ml, mr = ms[:k], ms[k:]
65
-
66
- # contract the right messages to get new left array
67
- xl = prod((*mr, x))
68
-
69
- # contract the left messages to get new right array
70
- xr = prod((*ml, x))
71
-
72
- # add the queue for possible further halving
73
- queue.append((jl, xl, ml))
74
- queue.append((jr, xr, mr))
75
-
76
- return mouts
77
-
78
-
79
- def compute_all_tensor_messages_shortcuts(x, ms, ndim):
80
- if ndim == 2:
81
- # shortcut for 2 messages
82
- return [x @ ms[1], ms[0] @ x]
83
- elif ndim == 1:
84
- # shortcut for single message
85
- return [x]
86
- elif ndim == 0:
87
- # shortcut for no messages
88
- return []
89
-
90
-
91
- def compute_all_tensor_messages_prod(
92
- x,
93
- ms,
94
- backend=None,
95
- smudge_factor=1e-12,
96
- ):
97
- """Given set of messages ``ms`` incident to tensor with data ``x``, compute
98
- the corresponding next output messages, using the 'prod' implementation.
99
- """
100
- ndim = len(ms)
101
- if ndim <= 2:
102
- return compute_all_tensor_messages_shortcuts(x, ms, ndim)
103
-
104
- js = tuple(range(ndim))
105
-
106
- mx = qtn.array_contract(
107
- arrays=(x, *ms), inputs=(js, *((j,) for j in js)), output=js
108
- )
109
- mouts = []
110
-
111
- for j, g in enumerate(ms):
112
- mouts.append(
113
- qtn.array_contract(
114
- arrays=(mx, 1 / (g + smudge_factor)),
115
- inputs=(js, (j,)),
116
- output=(j,),
117
- backend=backend,
118
- )
119
- )
120
-
121
- return mouts
122
-
123
-
124
- def compute_all_tensor_messages_tree(x, ms, backend=None):
125
- """Given set of messages ``ms`` incident to tensor with data ``x``, compute
126
- the corresponding next output messages, using the 'tree' implementation.
127
- """
128
- ndim = len(ms)
129
- if ndim <= 2:
130
- return compute_all_tensor_messages_shortcuts(x, ms, ndim)
131
-
132
- mouts = [None for _ in range(ndim)]
133
- queue = [(tuple(range(ndim)), x, ms)]
134
-
135
- while queue:
136
- js, x, ms = queue.pop()
137
-
138
- ndim = len(ms)
139
- if ndim == 1:
140
- # reached single message
141
- mouts[js[0]] = x
142
- continue
143
- elif ndim == 2:
144
- # shortcut for 2 messages left
145
- mouts[js[0]] = x @ ms[1]
146
- mouts[js[1]] = ms[0] @ x
147
- continue
148
-
149
- # else split in two and contract each half
150
- k = ndim // 2
151
- jl, jr = js[:k], js[k:]
152
- ml, mr = ms[:k], ms[k:]
153
-
154
- # contract the right messages to get new left array
155
- xl = qtn.array_contract(
156
- arrays=(x, *mr),
157
- inputs=(js, *((j,) for j in jr)),
158
- output=jl,
159
- backend=backend,
160
- )
161
-
162
- # contract the left messages to get new right array
163
- xr = qtn.array_contract(
164
- arrays=(x, *ml),
165
- inputs=(js, *((j,) for j in jl)),
166
- output=jr,
167
- backend=backend,
168
- )
169
-
170
- # add the queue for possible further halving
171
- queue.append((jl, xl, ml))
172
- queue.append((jr, xr, mr))
173
-
174
- return mouts
175
-
176
-
177
- def iterate_belief_propagation_basic(
178
- tn,
179
- messages,
180
- damping=None,
181
- smudge_factor=1e-12,
182
- tol=None,
183
- ):
184
- """Run a single iteration of belief propagation. This is the basic version
185
- that does not vectorize contractions.
186
-
187
- Parameters
188
- ----------
189
- tn : TensorNetwork
190
- The tensor network to run BP on.
191
- messages : dict
192
- The current messages. For every index and tensor id pair, there should
193
- be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
194
- smudge_factor : float, optional
195
- A small number to add to the denominator of messages to avoid division
196
- by zero. Note when this happens the numerator will also be zero.
197
-
198
- Returns
199
- -------
200
- new_messages : dict
201
- The new messages.
202
- """
203
- backend = ar.infer_backend(next(iter(messages.values())))
204
-
205
- # _sum = ar.get_lib_fn(backend, "sum")
206
- # n.b. at small sizes python sum is faster than numpy sum
207
- _sum = ar.get_lib_fn(backend, "sum")
208
- # _max = ar.get_lib_fn(backend, "max")
209
- _abs = ar.get_lib_fn(backend, "abs")
210
-
211
- def _normalize_and_insert(k, m, max_dm):
212
- # normalize and insert
213
- m = m / _sum(m)
214
-
215
- old_m = messages[k]
216
-
217
- if damping is not None:
218
- # mix old and new
219
- m = damping * old_m + (1 - damping) * m
220
-
221
- # compare to the old messages
222
- dm = _sum(_abs(m - old_m))
223
- max_dm = max(dm, max_dm)
224
-
225
- # set and return the max diff so far
226
- messages[k] = m
227
- return max_dm
228
-
229
- max_dm = 0.0
230
-
231
- # hyper index messages
232
- for ix, tids in tn.ind_map.items():
233
- ms = compute_all_hyperind_messages_prod(
234
- [messages[tid, ix] for tid in tids], smudge_factor
235
- )
236
- for tid, m in zip(tids, ms):
237
- max_dm = _normalize_and_insert((ix, tid), m, max_dm)
238
-
239
- # tensor messages
240
- for tid, t in tn.tensor_map.items():
241
- inds = t.inds
242
- ms = compute_all_tensor_messages_tree(
243
- t.data,
244
- [messages[ix, tid] for ix in inds],
245
- )
246
- for ix, m in zip(inds, ms):
247
- max_dm = _normalize_and_insert((tid, ix), m, max_dm)
248
-
249
- return messages, max_dm
250
-
251
-
252
- class HD1BP(BeliefPropagationCommon):
253
- """Object interface for hyper, dense, 1-norm belief propagation. This is
254
- standard belief propagation in tensor network form.
255
-
256
- Parameters
257
- ----------
258
- tn : TensorNetwork
259
- The tensor network to run BP on.
260
- messages : dict, optional
261
- Initial messages to use, if not given then uniform messages are used.
262
- smudge_factor : float, optional
263
- A small number to add to the denominator of messages to avoid division
264
- by zero. Note when this happens the numerator will also be zero.
265
- """
266
-
267
- def __init__(
268
- self,
269
- tn,
270
- messages=None,
271
- damping=None,
272
- smudge_factor=1e-12,
273
- ):
274
- self.tn = tn
275
- self.backend = next(t.backend for t in tn)
276
- self.smudge_factor = smudge_factor
277
- self.damping = damping
278
- if messages is None:
279
- messages = initialize_hyper_messages(
280
- tn, smudge_factor=smudge_factor
281
- )
282
- self.messages = messages
283
-
284
- def iterate(self, **kwargs):
285
- self.messages, max_dm = iterate_belief_propagation_basic(
286
- self.tn,
287
- self.messages,
288
- damping=self.damping,
289
- smudge_factor=self.smudge_factor,
290
- **kwargs,
291
- )
292
- return None, None, max_dm
293
-
294
- def get_gauged_tn(self):
295
- """Assuming the supplied tensor network has no hyper or dangling
296
- indices, gauge it by inserting the BP-approximated transfer matrix
297
- eigenvectors, which may be complex. The BP-contraction of this gauged
298
- network is then simply the product of zeroth entries of each tensor.
299
- """
300
- tng = self.tn.copy()
301
- for ind, tids in self.tn.ind_map.items():
302
- tida, tidb = tids
303
- ka = (ind, tida)
304
- kb = (ind, tidb)
305
- ma = self.messages[ka]
306
- mb = self.messages[kb]
307
-
308
- el, ev = ar.do('linalg.eig', ar.do('outer', ma, mb))
309
- k = ar.do('argsort', -ar.do('abs', el))
310
- ev = ev[:, k]
311
- Uinv = ev
312
- U = ar.do('linalg.inv', ev)
313
- tng._insert_gauge_tids(U, tida, tidb, Uinv)
314
- return tng
315
-
316
- def contract(self, strip_exponent=False):
317
- """Estimate the total contraction, i.e. the exponential of the 'Bethe
318
- free entropy'.
319
- """
320
- return contract_hyper_messages(
321
- self.tn,
322
- self.messages,
323
- strip_exponent=strip_exponent,
324
- backend=self.backend,
325
- )
326
-
327
-
328
- def contract_hd1bp(
329
- tn,
330
- messages=None,
331
- max_iterations=1000,
332
- tol=5e-6,
333
- damping=0.0,
334
- smudge_factor=1e-12,
335
- strip_exponent=False,
336
- info=None,
337
- progbar=False,
338
- ):
339
- """Estimate the contraction of ``tn`` with hyper, vectorized, 1-norm
340
- belief propagation, via the exponential of the Bethe free entropy.
341
-
342
- Parameters
343
- ----------
344
- tn : TensorNetwork
345
- The tensor network to run BP on, can have hyper indices.
346
- messages : dict, optional
347
- Initial messages to use, if not given then uniform messages are used.
348
- max_iterations : int, optional
349
- The maximum number of iterations to perform.
350
- tol : float, optional
351
- The convergence tolerance for messages.
352
- damping : float, optional
353
- The damping factor to use, 0.0 means no damping.
354
- smudge_factor : float, optional
355
- A small number to add to the denominator of messages to avoid division
356
- by zero. Note when this happens the numerator will also be zero.
357
- strip_exponent : bool, optional
358
- Whether to strip the exponent from the final result. If ``True``
359
- then the returned result is ``(mantissa, exponent)``.
360
- info : dict, optional
361
- If specified, update this dictionary with information about the
362
- belief propagation run.
363
- progbar : bool, optional
364
- Whether to show a progress bar.
365
-
366
- Returns
367
- -------
368
- scalar or (scalar, float)
369
- """
370
- bp = HD1BP(
371
- tn,
372
- messages=messages,
373
- damping=damping,
374
- smudge_factor=smudge_factor,
375
- )
376
- bp.run(
377
- max_iterations=max_iterations,
378
- tol=tol,
379
- info=info,
380
- progbar=progbar,
381
- )
382
- return bp.contract(strip_exponent=strip_exponent)
383
-
384
-
385
- def run_belief_propagation_hd1bp(
386
- tn,
387
- messages=None,
388
- max_iterations=1000,
389
- tol=5e-6,
390
- damping=0.0,
391
- smudge_factor=1e-12,
392
- info=None,
393
- progbar=False,
394
- ):
395
- """Run belief propagation on a tensor network until it converges. This
396
- is the basic version that does not vectorize contractions.
397
-
398
- Parameters
399
- ----------
400
- tn : TensorNetwork
401
- The tensor network to run BP on.
402
- messages : dict, optional
403
- The current messages. For every index and tensor id pair, there should
404
- be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
405
- If not given, then messages are initialized as uniform.
406
- max_iterations : int, optional
407
- The maximum number of iterations to run for.
408
- tol : float, optional
409
- The convergence tolerance.
410
- smudge_factor : float, optional
411
- A small number to add to the denominator of messages to avoid division
412
- by zero. Note when this happens the numerator will also be zero.
413
- info : dict, optional
414
- If specified, update this dictionary with information about the
415
- belief propagation run.
416
- progbar : bool, optional
417
- Whether to show a progress bar.
418
-
419
- Returns
420
- -------
421
- messages : dict
422
- The final messages.
423
- converged : bool
424
- Whether the algorithm converged.
425
- """
426
- bp = HD1BP(
427
- tn, messages=messages, damping=damping, smudge_factor=smudge_factor
428
- )
429
- bp.run(max_iterations=max_iterations, tol=tol, info=info, progbar=progbar)
430
- return bp.messages, bp.converged
431
-
432
-
433
- def sample_hd1bp(
434
- tn,
435
- messages=None,
436
- output_inds=None,
437
- max_iterations=1000,
438
- tol=1e-2,
439
- damping=0.0,
440
- smudge_factor=1e-12,
441
- bias=False,
442
- seed=None,
443
- progbar=False,
444
- ):
445
- """Sample all indices of a tensor network using repeated belief propagation
446
- runs and decimation.
447
-
448
- Parameters
449
- ----------
450
- tn : TensorNetwork
451
- The tensor network to sample.
452
- messages : dict, optional
453
- The current messages. For every index and tensor id pair, there should
454
- be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
455
- If not given, then messages are initialized as uniform.
456
- output_inds : sequence of str, optional
457
- The indices to sample. If not given, then all indices are sampled.
458
- max_iterations : int, optional
459
- The maximum number of iterations for each message passing run.
460
- tol : float, optional
461
- The convergence tolerance for each message passing run.
462
- smudge_factor : float, optional
463
- A small number to add to each message to avoid zeros. Making this large
464
- is similar to adding a temperature, which can aid convergence but
465
- likely produces less accurate marginals.
466
- bias : bool or float, optional
467
- Whether to bias the sampling towards the largest marginal. If ``False``
468
- (the default), then indices are sampled proportional to their
469
- marginals. If ``True``, then each index is 'sampled' to be its largest
470
- weight value always. If a float, then the local probability
471
- distribution is raised to this power before sampling.
472
- thread_pool : bool, int or ThreadPoolExecutor, optional
473
- Whether to use a thread pool for parallelization. If an integer, then
474
- this is the number of threads to use. If ``True``, then the number of
475
- threads is set to the number of cores. If a ``ThreadPoolExecutor``,
476
- then this is used directly.
477
- seed : int, optional
478
- A random seed to use for the sampling.
479
- progbar : bool, optional
480
- Whether to show a progress bar.
481
-
482
- Returns
483
- -------
484
- config : dict[str, int]
485
- The sample configuration, mapping indices to values.
486
- tn_config : TensorNetwork
487
- The tensor network with all index values (or just those in
488
- `output_inds` if supllied) selected. Contracting this tensor network
489
- (which will just be a sequence of scalars if all index values have been
490
- sampled) gives the weight of the sample, e.g. should be 1 for a SAT
491
- problem and valid assignment.
492
- omega : float
493
- The probability of choosing this sample (i.e. product of marginal
494
- values). Useful possibly for importance sampling.
495
- """
496
- import numpy as np
497
-
498
- rng = np.random.default_rng(seed)
499
-
500
- tn_config = tn.copy()
501
-
502
- if messages is None:
503
- messages = initialize_hyper_messages(tn_config)
504
-
505
- if output_inds is None:
506
- output_inds = tn_config.ind_map.keys()
507
- output_inds = set(output_inds)
508
-
509
- config = {}
510
- omega = 1.0
511
-
512
- if progbar:
513
- import tqdm
514
-
515
- pbar = tqdm.tqdm(total=len(output_inds))
516
- else:
517
- pbar = None
518
-
519
- while output_inds:
520
- messages, _ = run_belief_propagation_hd1bp(
521
- tn_config,
522
- messages,
523
- max_iterations=max_iterations,
524
- tol=tol,
525
- damping=damping,
526
- smudge_factor=smudge_factor,
527
- progbar=True,
528
- )
529
-
530
- marginals = compute_all_index_marginals_from_messages(
531
- tn_config, messages
532
- )
533
-
534
- # choose most peaked marginal
535
- ix, p = max(
536
- (m for m in marginals.items() if m[0] in output_inds),
537
- key=lambda ix_p: max(ix_p[1]),
538
- )
539
-
540
- if bias is False:
541
- # sample the value according to the marginal
542
- v = rng.choice(np.arange(p.size), p=p)
543
- elif bias is True:
544
- v = np.argmax(p)
545
- # in some sense omega is really 1.0 here
546
- else:
547
- # bias towards larger marginals by raising to a power
548
- p = p**bias
549
- p /= np.sum(p)
550
- v = np.random.choice(np.arange(p.size), p=p)
551
-
552
- omega *= p[v]
553
- config[ix] = v
554
-
555
- # clean up messages
556
- for tid in tn_config.ind_map[ix]:
557
- del messages[ix, tid]
558
- del messages[tid, ix]
559
-
560
- # remove index
561
- tn_config.isel_({ix: v})
562
- output_inds.remove(ix)
563
-
564
- if progbar:
565
- pbar.update(1)
566
- pbar.set_description(f"{ix}->{v}", refresh=False)
567
-
568
- if progbar:
569
- pbar.close()
570
-
571
- return config, tn_config, omega