Trajectree 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/trajectory.py +2 -2
  7. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/METADATA +2 -3
  8. trajectree-0.0.2.dist-info/RECORD +16 -0
  9. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  10. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  11. trajectree/quimb/docs/conf.py +0 -158
  12. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  13. trajectree/quimb/quimb/__init__.py +0 -507
  14. trajectree/quimb/quimb/calc.py +0 -1491
  15. trajectree/quimb/quimb/core.py +0 -2279
  16. trajectree/quimb/quimb/evo.py +0 -712
  17. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  18. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  19. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  20. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  21. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  22. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  23. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  24. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  25. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  26. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  27. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  28. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  29. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  30. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  31. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  32. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  33. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  34. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  35. trajectree/quimb/quimb/gates.py +0 -36
  36. trajectree/quimb/quimb/gen/__init__.py +0 -2
  37. trajectree/quimb/quimb/gen/operators.py +0 -1167
  38. trajectree/quimb/quimb/gen/rand.py +0 -713
  39. trajectree/quimb/quimb/gen/states.py +0 -479
  40. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  41. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  42. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  43. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  44. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  45. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  46. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  47. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  48. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  49. trajectree/quimb/quimb/schematic.py +0 -1518
  50. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  51. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  52. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  53. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  54. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  55. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  56. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  57. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  58. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  59. trajectree/quimb/quimb/tensor/interface.py +0 -114
  60. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  61. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  62. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  63. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  64. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  65. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  66. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  67. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  68. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  69. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  70. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  71. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  72. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  74. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  75. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  76. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  77. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  78. trajectree/quimb/quimb/utils.py +0 -892
  79. trajectree/quimb/tests/__init__.py +0 -0
  80. trajectree/quimb/tests/test_accel.py +0 -501
  81. trajectree/quimb/tests/test_calc.py +0 -788
  82. trajectree/quimb/tests/test_core.py +0 -847
  83. trajectree/quimb/tests/test_evo.py +0 -565
  84. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  85. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  86. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  87. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  88. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  89. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  90. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  91. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  92. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  93. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  94. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  95. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  103. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  104. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  105. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  106. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  107. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  108. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  109. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  110. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  111. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  112. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  113. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  114. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  115. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  116. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  117. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  118. trajectree/quimb/tests/test_utils.py +0 -85
  119. trajectree-0.0.1.dist-info/RECORD +0 -126
  120. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/top_level.txt +0 -0
@@ -1,571 +0,0 @@
1
- """Hyper dense belief propagation for arbitrary `quimb` tensor networks. This
2
- is the classic 1-norm version of belief propagation, which treats the tensor
3
- network directly as a factor graph. Messages are processed one at a time.
4
-
5
- TODO:
6
-
7
- - [ ] implement 'touching', so that only necessary messages are updated
8
- - [ ] implement sequential update
9
-
10
- """
11
- import autoray as ar
12
- import quimb.tensor as qtn
13
-
14
- from .bp_common import (
15
- BeliefPropagationCommon,
16
- compute_all_index_marginals_from_messages,
17
- contract_hyper_messages,
18
- initialize_hyper_messages,
19
- prod,
20
- )
21
-
22
-
23
- def compute_all_hyperind_messages_prod(ms, smudge_factor=1e-12):
24
- """Given set of messages ``ms`` incident to a single index, compute the
25
- corresponding next output messages, using the 'product' implementation.
26
- """
27
- if len(ms) == 2:
28
- # shortcut for 2 messages
29
- return [ms[1], ms[0]]
30
-
31
- x = prod(ms)
32
- return [x / (m + smudge_factor) for m in ms]
33
-
34
-
35
- def compute_all_hyperind_messages_tree(ms):
36
- """Given set of messages ``ms`` incident to a single index, compute the
37
- corresponding next output messages, using the 'tree' implementation.
38
- """
39
- ndim = len(ms)
40
- if len(ms) == 2:
41
- # shortcut for 2 messages
42
- return [ms[1], ms[0]]
43
-
44
- mouts = [None for _ in range(ndim)]
45
- queue = [(tuple(range(ndim)), 1, ms)]
46
-
47
- while queue:
48
- js, x, ms = queue.pop()
49
-
50
- ndim = len(ms)
51
- if ndim == 1:
52
- # reached single message
53
- mouts[js[0]] = x
54
- continue
55
- elif ndim == 2:
56
- # shortcut for 2 messages left
57
- mouts[js[0]] = x * ms[1]
58
- mouts[js[1]] = ms[0] * x
59
- continue
60
-
61
- # else split in two and contract each half
62
- k = ndim // 2
63
- jl, jr = js[:k], js[k:]
64
- ml, mr = ms[:k], ms[k:]
65
-
66
- # contract the right messages to get new left array
67
- xl = prod((*mr, x))
68
-
69
- # contract the left messages to get new right array
70
- xr = prod((*ml, x))
71
-
72
- # add the queue for possible further halving
73
- queue.append((jl, xl, ml))
74
- queue.append((jr, xr, mr))
75
-
76
- return mouts
77
-
78
-
79
- def compute_all_tensor_messages_shortcuts(x, ms, ndim):
80
- if ndim == 2:
81
- # shortcut for 2 messages
82
- return [x @ ms[1], ms[0] @ x]
83
- elif ndim == 1:
84
- # shortcut for single message
85
- return [x]
86
- elif ndim == 0:
87
- # shortcut for no messages
88
- return []
89
-
90
-
91
- def compute_all_tensor_messages_prod(
92
- x,
93
- ms,
94
- backend=None,
95
- smudge_factor=1e-12,
96
- ):
97
- """Given set of messages ``ms`` incident to tensor with data ``x``, compute
98
- the corresponding next output messages, using the 'prod' implementation.
99
- """
100
- ndim = len(ms)
101
- if ndim <= 2:
102
- return compute_all_tensor_messages_shortcuts(x, ms, ndim)
103
-
104
- js = tuple(range(ndim))
105
-
106
- mx = qtn.array_contract(
107
- arrays=(x, *ms), inputs=(js, *((j,) for j in js)), output=js
108
- )
109
- mouts = []
110
-
111
- for j, g in enumerate(ms):
112
- mouts.append(
113
- qtn.array_contract(
114
- arrays=(mx, 1 / (g + smudge_factor)),
115
- inputs=(js, (j,)),
116
- output=(j,),
117
- backend=backend,
118
- )
119
- )
120
-
121
- return mouts
122
-
123
-
124
- def compute_all_tensor_messages_tree(x, ms, backend=None):
125
- """Given set of messages ``ms`` incident to tensor with data ``x``, compute
126
- the corresponding next output messages, using the 'tree' implementation.
127
- """
128
- ndim = len(ms)
129
- if ndim <= 2:
130
- return compute_all_tensor_messages_shortcuts(x, ms, ndim)
131
-
132
- mouts = [None for _ in range(ndim)]
133
- queue = [(tuple(range(ndim)), x, ms)]
134
-
135
- while queue:
136
- js, x, ms = queue.pop()
137
-
138
- ndim = len(ms)
139
- if ndim == 1:
140
- # reached single message
141
- mouts[js[0]] = x
142
- continue
143
- elif ndim == 2:
144
- # shortcut for 2 messages left
145
- mouts[js[0]] = x @ ms[1]
146
- mouts[js[1]] = ms[0] @ x
147
- continue
148
-
149
- # else split in two and contract each half
150
- k = ndim // 2
151
- jl, jr = js[:k], js[k:]
152
- ml, mr = ms[:k], ms[k:]
153
-
154
- # contract the right messages to get new left array
155
- xl = qtn.array_contract(
156
- arrays=(x, *mr),
157
- inputs=(js, *((j,) for j in jr)),
158
- output=jl,
159
- backend=backend,
160
- )
161
-
162
- # contract the left messages to get new right array
163
- xr = qtn.array_contract(
164
- arrays=(x, *ml),
165
- inputs=(js, *((j,) for j in jl)),
166
- output=jr,
167
- backend=backend,
168
- )
169
-
170
- # add the queue for possible further halving
171
- queue.append((jl, xl, ml))
172
- queue.append((jr, xr, mr))
173
-
174
- return mouts
175
-
176
-
177
- def iterate_belief_propagation_basic(
178
- tn,
179
- messages,
180
- damping=None,
181
- smudge_factor=1e-12,
182
- tol=None,
183
- ):
184
- """Run a single iteration of belief propagation. This is the basic version
185
- that does not vectorize contractions.
186
-
187
- Parameters
188
- ----------
189
- tn : TensorNetwork
190
- The tensor network to run BP on.
191
- messages : dict
192
- The current messages. For every index and tensor id pair, there should
193
- be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
194
- smudge_factor : float, optional
195
- A small number to add to the denominator of messages to avoid division
196
- by zero. Note when this happens the numerator will also be zero.
197
-
198
- Returns
199
- -------
200
- new_messages : dict
201
- The new messages.
202
- """
203
- backend = ar.infer_backend(next(iter(messages.values())))
204
-
205
- # _sum = ar.get_lib_fn(backend, "sum")
206
- # n.b. at small sizes python sum is faster than numpy sum
207
- _sum = ar.get_lib_fn(backend, "sum")
208
- # _max = ar.get_lib_fn(backend, "max")
209
- _abs = ar.get_lib_fn(backend, "abs")
210
-
211
- def _normalize_and_insert(k, m, max_dm):
212
- # normalize and insert
213
- m = m / _sum(m)
214
-
215
- old_m = messages[k]
216
-
217
- if damping is not None:
218
- # mix old and new
219
- m = damping * old_m + (1 - damping) * m
220
-
221
- # compare to the old messages
222
- dm = _sum(_abs(m - old_m))
223
- max_dm = max(dm, max_dm)
224
-
225
- # set and return the max diff so far
226
- messages[k] = m
227
- return max_dm
228
-
229
- max_dm = 0.0
230
-
231
- # hyper index messages
232
- for ix, tids in tn.ind_map.items():
233
- ms = compute_all_hyperind_messages_prod(
234
- [messages[tid, ix] for tid in tids], smudge_factor
235
- )
236
- for tid, m in zip(tids, ms):
237
- max_dm = _normalize_and_insert((ix, tid), m, max_dm)
238
-
239
- # tensor messages
240
- for tid, t in tn.tensor_map.items():
241
- inds = t.inds
242
- ms = compute_all_tensor_messages_tree(
243
- t.data,
244
- [messages[ix, tid] for ix in inds],
245
- )
246
- for ix, m in zip(inds, ms):
247
- max_dm = _normalize_and_insert((tid, ix), m, max_dm)
248
-
249
- return messages, max_dm
250
-
251
-
252
- class HD1BP(BeliefPropagationCommon):
253
- """Object interface for hyper, dense, 1-norm belief propagation. This is
254
- standard belief propagation in tensor network form.
255
-
256
- Parameters
257
- ----------
258
- tn : TensorNetwork
259
- The tensor network to run BP on.
260
- messages : dict, optional
261
- Initial messages to use, if not given then uniform messages are used.
262
- smudge_factor : float, optional
263
- A small number to add to the denominator of messages to avoid division
264
- by zero. Note when this happens the numerator will also be zero.
265
- """
266
-
267
- def __init__(
268
- self,
269
- tn,
270
- messages=None,
271
- damping=None,
272
- smudge_factor=1e-12,
273
- ):
274
- self.tn = tn
275
- self.backend = next(t.backend for t in tn)
276
- self.smudge_factor = smudge_factor
277
- self.damping = damping
278
- if messages is None:
279
- messages = initialize_hyper_messages(
280
- tn, smudge_factor=smudge_factor
281
- )
282
- self.messages = messages
283
-
284
- def iterate(self, **kwargs):
285
- self.messages, max_dm = iterate_belief_propagation_basic(
286
- self.tn,
287
- self.messages,
288
- damping=self.damping,
289
- smudge_factor=self.smudge_factor,
290
- **kwargs,
291
- )
292
- return None, None, max_dm
293
-
294
- def get_gauged_tn(self):
295
- """Assuming the supplied tensor network has no hyper or dangling
296
- indices, gauge it by inserting the BP-approximated transfer matrix
297
- eigenvectors, which may be complex. The BP-contraction of this gauged
298
- network is then simply the product of zeroth entries of each tensor.
299
- """
300
- tng = self.tn.copy()
301
- for ind, tids in self.tn.ind_map.items():
302
- tida, tidb = tids
303
- ka = (ind, tida)
304
- kb = (ind, tidb)
305
- ma = self.messages[ka]
306
- mb = self.messages[kb]
307
-
308
- el, ev = ar.do('linalg.eig', ar.do('outer', ma, mb))
309
- k = ar.do('argsort', -ar.do('abs', el))
310
- ev = ev[:, k]
311
- Uinv = ev
312
- U = ar.do('linalg.inv', ev)
313
- tng._insert_gauge_tids(U, tida, tidb, Uinv)
314
- return tng
315
-
316
- def contract(self, strip_exponent=False):
317
- """Estimate the total contraction, i.e. the exponential of the 'Bethe
318
- free entropy'.
319
- """
320
- return contract_hyper_messages(
321
- self.tn,
322
- self.messages,
323
- strip_exponent=strip_exponent,
324
- backend=self.backend,
325
- )
326
-
327
-
328
- def contract_hd1bp(
329
- tn,
330
- messages=None,
331
- max_iterations=1000,
332
- tol=5e-6,
333
- damping=0.0,
334
- smudge_factor=1e-12,
335
- strip_exponent=False,
336
- info=None,
337
- progbar=False,
338
- ):
339
- """Estimate the contraction of ``tn`` with hyper, vectorized, 1-norm
340
- belief propagation, via the exponential of the Bethe free entropy.
341
-
342
- Parameters
343
- ----------
344
- tn : TensorNetwork
345
- The tensor network to run BP on, can have hyper indices.
346
- messages : dict, optional
347
- Initial messages to use, if not given then uniform messages are used.
348
- max_iterations : int, optional
349
- The maximum number of iterations to perform.
350
- tol : float, optional
351
- The convergence tolerance for messages.
352
- damping : float, optional
353
- The damping factor to use, 0.0 means no damping.
354
- smudge_factor : float, optional
355
- A small number to add to the denominator of messages to avoid division
356
- by zero. Note when this happens the numerator will also be zero.
357
- strip_exponent : bool, optional
358
- Whether to strip the exponent from the final result. If ``True``
359
- then the returned result is ``(mantissa, exponent)``.
360
- info : dict, optional
361
- If specified, update this dictionary with information about the
362
- belief propagation run.
363
- progbar : bool, optional
364
- Whether to show a progress bar.
365
-
366
- Returns
367
- -------
368
- scalar or (scalar, float)
369
- """
370
- bp = HD1BP(
371
- tn,
372
- messages=messages,
373
- damping=damping,
374
- smudge_factor=smudge_factor,
375
- )
376
- bp.run(
377
- max_iterations=max_iterations,
378
- tol=tol,
379
- info=info,
380
- progbar=progbar,
381
- )
382
- return bp.contract(strip_exponent=strip_exponent)
383
-
384
-
385
- def run_belief_propagation_hd1bp(
386
- tn,
387
- messages=None,
388
- max_iterations=1000,
389
- tol=5e-6,
390
- damping=0.0,
391
- smudge_factor=1e-12,
392
- info=None,
393
- progbar=False,
394
- ):
395
- """Run belief propagation on a tensor network until it converges. This
396
- is the basic version that does not vectorize contractions.
397
-
398
- Parameters
399
- ----------
400
- tn : TensorNetwork
401
- The tensor network to run BP on.
402
- messages : dict, optional
403
- The current messages. For every index and tensor id pair, there should
404
- be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
405
- If not given, then messages are initialized as uniform.
406
- max_iterations : int, optional
407
- The maximum number of iterations to run for.
408
- tol : float, optional
409
- The convergence tolerance.
410
- smudge_factor : float, optional
411
- A small number to add to the denominator of messages to avoid division
412
- by zero. Note when this happens the numerator will also be zero.
413
- info : dict, optional
414
- If specified, update this dictionary with information about the
415
- belief propagation run.
416
- progbar : bool, optional
417
- Whether to show a progress bar.
418
-
419
- Returns
420
- -------
421
- messages : dict
422
- The final messages.
423
- converged : bool
424
- Whether the algorithm converged.
425
- """
426
- bp = HD1BP(
427
- tn, messages=messages, damping=damping, smudge_factor=smudge_factor
428
- )
429
- bp.run(max_iterations=max_iterations, tol=tol, info=info, progbar=progbar)
430
- return bp.messages, bp.converged
431
-
432
-
433
- def sample_hd1bp(
434
- tn,
435
- messages=None,
436
- output_inds=None,
437
- max_iterations=1000,
438
- tol=1e-2,
439
- damping=0.0,
440
- smudge_factor=1e-12,
441
- bias=False,
442
- seed=None,
443
- progbar=False,
444
- ):
445
- """Sample all indices of a tensor network using repeated belief propagation
446
- runs and decimation.
447
-
448
- Parameters
449
- ----------
450
- tn : TensorNetwork
451
- The tensor network to sample.
452
- messages : dict, optional
453
- The current messages. For every index and tensor id pair, there should
454
- be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
455
- If not given, then messages are initialized as uniform.
456
- output_inds : sequence of str, optional
457
- The indices to sample. If not given, then all indices are sampled.
458
- max_iterations : int, optional
459
- The maximum number of iterations for each message passing run.
460
- tol : float, optional
461
- The convergence tolerance for each message passing run.
462
- smudge_factor : float, optional
463
- A small number to add to each message to avoid zeros. Making this large
464
- is similar to adding a temperature, which can aid convergence but
465
- likely produces less accurate marginals.
466
- bias : bool or float, optional
467
- Whether to bias the sampling towards the largest marginal. If ``False``
468
- (the default), then indices are sampled proportional to their
469
- marginals. If ``True``, then each index is 'sampled' to be its largest
470
- weight value always. If a float, then the local probability
471
- distribution is raised to this power before sampling.
472
- thread_pool : bool, int or ThreadPoolExecutor, optional
473
- Whether to use a thread pool for parallelization. If an integer, then
474
- this is the number of threads to use. If ``True``, then the number of
475
- threads is set to the number of cores. If a ``ThreadPoolExecutor``,
476
- then this is used directly.
477
- seed : int, optional
478
- A random seed to use for the sampling.
479
- progbar : bool, optional
480
- Whether to show a progress bar.
481
-
482
- Returns
483
- -------
484
- config : dict[str, int]
485
- The sample configuration, mapping indices to values.
486
- tn_config : TensorNetwork
487
- The tensor network with all index values (or just those in
488
- `output_inds` if supllied) selected. Contracting this tensor network
489
- (which will just be a sequence of scalars if all index values have been
490
- sampled) gives the weight of the sample, e.g. should be 1 for a SAT
491
- problem and valid assignment.
492
- omega : float
493
- The probability of choosing this sample (i.e. product of marginal
494
- values). Useful possibly for importance sampling.
495
- """
496
- import numpy as np
497
-
498
- rng = np.random.default_rng(seed)
499
-
500
- tn_config = tn.copy()
501
-
502
- if messages is None:
503
- messages = initialize_hyper_messages(tn_config)
504
-
505
- if output_inds is None:
506
- output_inds = tn_config.ind_map.keys()
507
- output_inds = set(output_inds)
508
-
509
- config = {}
510
- omega = 1.0
511
-
512
- if progbar:
513
- import tqdm
514
-
515
- pbar = tqdm.tqdm(total=len(output_inds))
516
- else:
517
- pbar = None
518
-
519
- while output_inds:
520
- messages, _ = run_belief_propagation_hd1bp(
521
- tn_config,
522
- messages,
523
- max_iterations=max_iterations,
524
- tol=tol,
525
- damping=damping,
526
- smudge_factor=smudge_factor,
527
- progbar=True,
528
- )
529
-
530
- marginals = compute_all_index_marginals_from_messages(
531
- tn_config, messages
532
- )
533
-
534
- # choose most peaked marginal
535
- ix, p = max(
536
- (m for m in marginals.items() if m[0] in output_inds),
537
- key=lambda ix_p: max(ix_p[1]),
538
- )
539
-
540
- if bias is False:
541
- # sample the value according to the marginal
542
- v = rng.choice(np.arange(p.size), p=p)
543
- elif bias is True:
544
- v = np.argmax(p)
545
- # in some sense omega is really 1.0 here
546
- else:
547
- # bias towards larger marginals by raising to a power
548
- p = p**bias
549
- p /= np.sum(p)
550
- v = np.random.choice(np.arange(p.size), p=p)
551
-
552
- omega *= p[v]
553
- config[ix] = v
554
-
555
- # clean up messages
556
- for tid in tn_config.ind_map[ix]:
557
- del messages[ix, tid]
558
- del messages[tid, ix]
559
-
560
- # remove index
561
- tn_config.isel_({ix: v})
562
- output_inds.remove(ix)
563
-
564
- if progbar:
565
- pbar.update(1)
566
- pbar.set_description(f"{ix}->{v}", refresh=False)
567
-
568
- if progbar:
569
- pbar.close()
570
-
571
- return config, tn_config, omega