Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,653 @@
1
+ import autoray as ar
2
+
3
+ import quimb.tensor as qtn
4
+ from quimb.utils import oset
5
+
6
+ from .bp_common import (
7
+ BeliefPropagationCommon,
8
+ combine_local_contractions,
9
+ )
10
+
11
+
12
+ class D2BP(BeliefPropagationCommon):
13
+ """Dense (as in one tensor per site) 2-norm (as in for wavefunctions and
14
+ operators) belief propagation. Allows messages reuse. This version assumes
15
+ no hyper indices (i.e. a standard PEPS like tensor network).
16
+
17
+ Potential use cases for D2BP and a PEPS like tensor network are:
18
+
19
+ - globally compressing it from bond dimension ``D`` to ``D'``
20
+ - eagerly applying gates and locally compressing back to ``D``
21
+ - sampling configurations
22
+ - estimating the norm of the tensor network
23
+
24
+
25
+ Parameters
26
+ ----------
27
+ tn : TensorNetwork
28
+ The tensor network to form the 2-norm of and run BP on.
29
+ messages : dict[(str, int), array_like], optional
30
+ The initial messages to use, effectively defaults to all ones if not
31
+ specified.
32
+ output_inds : set[str], optional
33
+ The indices to consider as output (dangling) indices of the tn.
34
+ Computed automatically if not specified.
35
+ optimize : str or PathOptimizer, optional
36
+ The path optimizer to use when contracting the messages.
37
+ damping : float, optional
38
+ The damping factor to use, 0.0 means no damping.
39
+ update : {'parallel', 'sequential'}, optional
40
+ Whether to update all messages in parallel or sequentially.
41
+ local_convergence : bool, optional
42
+ Whether to allow messages to locally converge - i.e. if all their
43
+ input messages have converged then stop updating them.
44
+ contract_opts
45
+ Other options supplied to ``cotengra.array_contract``.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ tn,
51
+ messages=None,
52
+ output_inds=None,
53
+ optimize="auto-hq",
54
+ damping=0.0,
55
+ update="sequential",
56
+ local_convergence=True,
57
+ **contract_opts,
58
+ ):
59
+ from quimb.tensor.contraction import array_contract_expression
60
+
61
+ self.tn = tn
62
+ self.contract_opts = contract_opts
63
+ self.contract_opts.setdefault("optimize", optimize)
64
+ self.damping = damping
65
+ self.local_convergence = local_convergence
66
+ self.update = update
67
+
68
+ if output_inds is None:
69
+ self.output_inds = set(self.tn.outer_inds())
70
+ else:
71
+ self.output_inds = set(output_inds)
72
+
73
+ self.backend = next(t.backend for t in tn)
74
+ _abs = ar.get_lib_fn(self.backend, "abs")
75
+ _sum = ar.get_lib_fn(self.backend, "sum")
76
+
77
+ def _normalize(x):
78
+ return x / _sum(x)
79
+
80
+ def _distance(x, y):
81
+ return _sum(_abs(x - y))
82
+
83
+ self._normalize = _normalize
84
+ self._distance = _distance
85
+
86
+ if messages is None:
87
+ self.messages = {}
88
+ else:
89
+ self.messages = messages
90
+
91
+ # record which messages touch each others, for efficient updates
92
+ self.touch_map = {}
93
+ self.touched = oset()
94
+ self.exprs = {}
95
+
96
+ # populate any messages
97
+ for ix, tids in self.tn.ind_map.items():
98
+ if ix in self.output_inds:
99
+ continue
100
+
101
+ tida, tidb = tids
102
+ jx = ix + "*"
103
+ ta, tb = self.tn._tids_get(tida, tidb)
104
+
105
+ for tid, t, t_in in ((tida, ta, tb), (tidb, tb, ta)):
106
+ this_touchmap = []
107
+ for nx in t.inds:
108
+ if nx in self.output_inds or nx == ix:
109
+ continue
110
+ # where this message will be sent on to
111
+ (tidn,) = (n for n in self.tn.ind_map[nx] if n != tid)
112
+ this_touchmap.append((nx, tidn))
113
+ self.touch_map[ix, tid] = this_touchmap
114
+
115
+ if (ix, tid) not in self.messages:
116
+ tm = (t_in.reindex({ix: jx}).conj_() @ t_in).data
117
+ self.messages[ix, tid] = self._normalize(tm.data)
118
+
119
+ # for efficiency setup all the contraction expressions ahead of time
120
+ for ix, tids in self.tn.ind_map.items():
121
+ if ix in self.output_inds:
122
+ continue
123
+
124
+ for tida, tidb in (sorted(tids), sorted(tids, reverse=True)):
125
+ ta = self.tn.tensor_map[tida]
126
+ kix = ta.inds
127
+ bix = tuple(
128
+ i if i in self.output_inds else i + "*" for i in kix
129
+ )
130
+ inputs = [kix, bix]
131
+ data = [ta.data, ta.data.conj()]
132
+ shapes = [ta.shape, ta.shape]
133
+ for i in kix:
134
+ if (i != ix) and i not in self.output_inds:
135
+ inputs.append((i + "*", i))
136
+ data.append((i, tida))
137
+ shapes.append(self.messages[i, tida].shape)
138
+
139
+ expr = array_contract_expression(
140
+ inputs=inputs,
141
+ output=(ix + "*", ix),
142
+ shapes=shapes,
143
+ **self.contract_opts,
144
+ )
145
+ self.exprs[ix, tidb] = expr, data
146
+
147
+ def update_touched_from_tids(self, *tids):
148
+ """Specify that the messages for the given ``tids`` have changed."""
149
+ for tid in tids:
150
+ t = self.tn.tensor_map[tid]
151
+ for ix in t.inds:
152
+ if ix in self.output_inds:
153
+ continue
154
+ (ntid,) = (n for n in self.tn.ind_map[ix] if n != tid)
155
+ self.touched.add((ix, ntid))
156
+
157
+ def update_touched_from_tags(self, tags, which="any"):
158
+ """Specify that the messages for the messages touching ``tags`` have
159
+ changed.
160
+ """
161
+ tids = self.tn._get_tids_from_tags(tags, which)
162
+ self.update_touched_from_tids(*tids)
163
+
164
+ def update_touched_from_inds(self, inds, which="any"):
165
+ """Specify that the messages for the messages touching ``inds`` have
166
+ changed.
167
+ """
168
+ tids = self.tn._get_tids_from_inds(inds, which)
169
+ self.update_touched_from_tids(*tids)
170
+
171
+ def iterate(self, tol=5e-6):
172
+ """Perform a single iteration of dense 2-norm belief propagation."""
173
+
174
+ if (not self.local_convergence) or (not self.touched):
175
+ # assume if asked to iterate that we want to check all messages
176
+ self.touched.update(self.exprs.keys())
177
+
178
+ ncheck = len(self.touched)
179
+ nconv = 0
180
+ max_mdiff = -1.0
181
+ new_touched = oset()
182
+
183
+ def _compute_m(key):
184
+ expr, data = self.exprs[key]
185
+ m = expr(*data[:2], *(self.messages[mkey] for mkey in data[2:]))
186
+ # enforce hermiticity and normalize
187
+ return self._normalize(m + ar.dag(m))
188
+
189
+ def _update_m(key, new_m):
190
+ nonlocal nconv, max_mdiff
191
+
192
+ old_m = self.messages[key]
193
+ if self.damping > 0.0:
194
+ new_m = self._normalize(
195
+ self.damping * old_m + (1 - self.damping) * new_m
196
+ )
197
+ try:
198
+ mdiff = float(self._distance(old_m, new_m))
199
+ except (TypeError, ValueError):
200
+ # handle e.g. lazy arrays
201
+ mdiff = float("inf")
202
+ if mdiff > tol:
203
+ # mark touching messages for update
204
+ new_touched.update(self.touch_map[key])
205
+ else:
206
+ nconv += 1
207
+ max_mdiff = max(max_mdiff, mdiff)
208
+ self.messages[key] = new_m
209
+
210
+ if self.update == "parallel":
211
+ new_messages = {}
212
+ # compute all new messages
213
+ while self.touched:
214
+ key = self.touched.pop()
215
+ new_messages[key] = _compute_m(key)
216
+ # insert all new messages
217
+ for key, new_m in new_messages.items():
218
+ _update_m(key, new_m)
219
+
220
+ elif self.update == "sequential":
221
+ # compute each new message and immediately re-insert it
222
+ while self.touched:
223
+ key = self.touched.pop()
224
+ new_m = _compute_m(key)
225
+ _update_m(key, new_m)
226
+
227
+ self.touched = new_touched
228
+
229
+ return nconv, ncheck, max_mdiff
230
+
231
+ def compute_marginal(self, ind):
232
+ """Compute the marginal for the index ``ind``."""
233
+ (tid,) = self.tn.ind_map[ind]
234
+ t = self.tn.tensor_map[tid]
235
+
236
+ arrays = [t.data, ar.do("conj", t.data)]
237
+ k_input = []
238
+ b_input = []
239
+ m_inputs = []
240
+ for j, jx in enumerate(t.inds, 1):
241
+ k_input.append(j)
242
+
243
+ if jx == ind:
244
+ # output index -> take diagonal
245
+ output = (j,)
246
+ b_input.append(j)
247
+ else:
248
+ try:
249
+ # partial trace with message
250
+ m = self.messages[jx, tid]
251
+ arrays.append(m)
252
+ b_input.append(-j)
253
+ m_inputs.append((-j, j))
254
+ except KeyError:
255
+ # direct partial trace
256
+ b_input.append(j)
257
+
258
+ p = qtn.array_contract(
259
+ arrays,
260
+ inputs=(tuple(k_input), tuple(b_input), *m_inputs),
261
+ output=output,
262
+ **self.contract_opts,
263
+ )
264
+ p = ar.do("real", p)
265
+ return p / ar.do("sum", p)
266
+
267
+ def contract(self, strip_exponent=False):
268
+ """Estimate the total contraction, i.e. the 2-norm.
269
+
270
+ Parameters
271
+ ----------
272
+ strip_exponent : bool, optional
273
+ Whether to strip the exponent from the final result. If ``True``
274
+ then the returned result is ``(mantissa, exponent)``.
275
+
276
+ Returns
277
+ -------
278
+ scalar or (scalar, float)
279
+ """
280
+ tvals = []
281
+
282
+ for tid, t in self.tn.tensor_map.items():
283
+ arrays = [t.data, ar.do("conj", t.data)]
284
+ k_input = []
285
+ b_input = []
286
+ m_inputs = []
287
+ for i, ix in enumerate(t.inds, 1):
288
+ k_input.append(i)
289
+ if ix in self.output_inds:
290
+ b_input.append(i)
291
+ else:
292
+ b_input.append(-i)
293
+ m_inputs.append((-i, i))
294
+ arrays.append(self.messages[ix, tid])
295
+
296
+ inputs = (tuple(k_input), tuple(b_input), *m_inputs)
297
+ output = ()
298
+ tval = qtn.array_contract(
299
+ arrays, inputs, output, **self.contract_opts
300
+ )
301
+ tvals.append(tval)
302
+
303
+ mvals = []
304
+ for ix, tids in self.tn.ind_map.items():
305
+ if ix in self.output_inds:
306
+ continue
307
+ tida, tidb = tids
308
+ ml = self.messages[ix, tidb]
309
+ mr = self.messages[ix, tida]
310
+ mval = qtn.array_contract(
311
+ (ml, mr), ((1, 2), (1, 2)), (), **self.contract_opts
312
+ )
313
+ mvals.append(mval)
314
+
315
+ return combine_local_contractions(
316
+ tvals, mvals, self.backend, strip_exponent=strip_exponent
317
+ )
318
+
319
+ def compress(
320
+ self,
321
+ max_bond,
322
+ cutoff=0.0,
323
+ cutoff_mode=4,
324
+ renorm=0,
325
+ inplace=False,
326
+ ):
327
+ """Compress the initial tensor network using the current messages."""
328
+ tn = self.tn if inplace else self.tn.copy()
329
+
330
+ for ix, tids in tn.ind_map.items():
331
+ if len(tids) != 2:
332
+ continue
333
+ tida, tidb = tids
334
+
335
+ # messages are left and right factors squared already
336
+ ta = tn.tensor_map[tida]
337
+ dm = ta.ind_size(ix)
338
+ dl = ta.size // dm
339
+ ml = self.messages[ix, tidb]
340
+ Rl = qtn.decomp.squared_op_to_reduced_factor(
341
+ ml, dl, dm, right=True
342
+ )
343
+
344
+ tb = tn.tensor_map[tidb]
345
+ dr = tb.size // dm
346
+ mr = self.messages[ix, tida].T
347
+ Rr = qtn.decomp.squared_op_to_reduced_factor(
348
+ mr, dm, dr, right=False
349
+ )
350
+
351
+ # compute the compressors
352
+ Pl, Pr = qtn.decomp.compute_oblique_projectors(
353
+ Rl,
354
+ Rr,
355
+ max_bond=max_bond,
356
+ cutoff=cutoff,
357
+ cutoff_mode=cutoff_mode,
358
+ renorm=renorm,
359
+ )
360
+
361
+ # contract the compressors into the tensors
362
+ tn.tensor_map[tida].gate_(Pl.T, ix)
363
+ tn.tensor_map[tidb].gate_(Pr, ix)
364
+
365
+ # update messages with projections
366
+ if inplace:
367
+ new_Ra = Rl @ Pl
368
+ new_Rb = Pr @ Rr
369
+ self.messages[ix, tidb] = ar.dag(new_Ra) @ new_Ra
370
+ self.messages[ix, tida] = new_Rb @ ar.dag(new_Rb)
371
+
372
+ return tn
373
+
374
+
375
+ def contract_d2bp(
376
+ tn,
377
+ messages=None,
378
+ output_inds=None,
379
+ optimize="auto-hq",
380
+ damping=0.0,
381
+ update="sequential",
382
+ local_convergence=True,
383
+ max_iterations=1000,
384
+ tol=5e-6,
385
+ strip_exponent=False,
386
+ info=None,
387
+ progbar=False,
388
+ **contract_opts,
389
+ ):
390
+ """Estimate the norm squared of ``tn`` using dense 2-norm belief
391
+ propagation.
392
+
393
+ Parameters
394
+ ----------
395
+ tn : TensorNetwork
396
+ The tensor network to form the 2-norm of and run BP on.
397
+ messages : dict[(str, int), array_like], optional
398
+ The initial messages to use, effectively defaults to all ones if not
399
+ specified.
400
+ max_iterations : int, optional
401
+ The maximum number of iterations to perform.
402
+ tol : float, optional
403
+ The convergence tolerance for messages.
404
+ output_inds : set[str], optional
405
+ The indices to consider as output (dangling) indices of the tn.
406
+ Computed automatically if not specified.
407
+ optimize : str or PathOptimizer, optional
408
+ The path optimizer to use when contracting the messages.
409
+ damping : float, optional
410
+ The damping parameter to use, defaults to no damping.
411
+ update : {'parallel', 'sequential'}, optional
412
+ Whether to update all messages in parallel or sequentially.
413
+ local_convergence : bool, optional
414
+ Whether to allow messages to locally converge - i.e. if all their
415
+ input messages have converged then stop updating them.
416
+ strip_exponent : bool, optional
417
+ Whether to strip the exponent from the final result. If ``True``
418
+ then the returned result is ``(mantissa, exponent)``.
419
+ info : dict, optional
420
+ If specified, update this dictionary with information about the
421
+ belief propagation run.
422
+ progbar : bool, optional
423
+ Whether to show a progress bar.
424
+ contract_opts
425
+ Other options supplied to ``cotengra.array_contract``.
426
+
427
+ Returns
428
+ -------
429
+ scalar or (scalar, float)
430
+ """
431
+ bp = D2BP(
432
+ tn,
433
+ messages=messages,
434
+ output_inds=output_inds,
435
+ optimize=optimize,
436
+ damping=damping,
437
+ local_convergence=local_convergence,
438
+ update=update,
439
+ **contract_opts,
440
+ )
441
+ bp.run(
442
+ max_iterations=max_iterations,
443
+ tol=tol,
444
+ info=info,
445
+ progbar=progbar,
446
+ )
447
+ return bp.contract(strip_exponent=strip_exponent)
448
+
449
+
450
+ def compress_d2bp(
451
+ tn,
452
+ max_bond,
453
+ cutoff=0.0,
454
+ cutoff_mode="rsum2",
455
+ renorm=0,
456
+ messages=None,
457
+ output_inds=None,
458
+ optimize="auto-hq",
459
+ damping=0.0,
460
+ update="sequential",
461
+ local_convergence=True,
462
+ max_iterations=1000,
463
+ tol=5e-6,
464
+ inplace=False,
465
+ info=None,
466
+ progbar=False,
467
+ **contract_opts,
468
+ ):
469
+ """Compress the tensor network ``tn`` using dense 2-norm belief
470
+ propagation.
471
+
472
+ Parameters
473
+ ----------
474
+ tn : TensorNetwork
475
+ The tensor network to form the 2-norm of, run BP on and then compress.
476
+ max_bond : int
477
+ The maximum bond dimension to compress to.
478
+ cutoff : float, optional
479
+ The cutoff to use when compressing.
480
+ cutoff_mode : int, optional
481
+ The cutoff mode to use when compressing.
482
+ messages : dict[(str, int), array_like], optional
483
+ The initial messages to use, effectively defaults to all ones if not
484
+ specified.
485
+ max_iterations : int, optional
486
+ The maximum number of iterations to perform.
487
+ tol : float, optional
488
+ The convergence tolerance for messages.
489
+ output_inds : set[str], optional
490
+ The indices to consider as output (dangling) indices of the tn.
491
+ Computed automatically if not specified.
492
+ optimize : str or PathOptimizer, optional
493
+ The path optimizer to use when contracting the messages.
494
+ damping : float, optional
495
+ The damping parameter to use, defaults to no damping.
496
+ update : {'parallel', 'sequential'}, optional
497
+ Whether to update all messages in parallel or sequentially.
498
+ local_convergence : bool, optional
499
+ Whether to allow messages to locally converge - i.e. if all their
500
+ input messages have converged then stop updating them.
501
+ inplace : bool, optional
502
+ Whether to perform the compression inplace.
503
+ info : dict, optional
504
+ If specified, update this dictionary with information about the
505
+ belief propagation run.
506
+ progbar : bool, optional
507
+ Whether to show a progress bar.
508
+ contract_opts
509
+ Other options supplied to ``cotengra.array_contract``.
510
+
511
+ Returns
512
+ -------
513
+ TensorNetwork
514
+ """
515
+ bp = D2BP(
516
+ tn,
517
+ messages=messages,
518
+ output_inds=output_inds,
519
+ optimize=optimize,
520
+ damping=damping,
521
+ update=update,
522
+ local_convergence=local_convergence,
523
+ **contract_opts,
524
+ )
525
+ bp.run(
526
+ max_iterations=max_iterations,
527
+ tol=tol,
528
+ info=info,
529
+ progbar=progbar,
530
+ )
531
+ return bp.compress(
532
+ max_bond=max_bond,
533
+ cutoff=cutoff,
534
+ cutoff_mode=cutoff_mode,
535
+ renorm=renorm,
536
+ inplace=inplace,
537
+ )
538
+
539
+
540
+ def sample_d2bp(
541
+ tn,
542
+ output_inds=None,
543
+ messages=None,
544
+ max_iterations=100,
545
+ tol=1e-2,
546
+ bias=None,
547
+ seed=None,
548
+ local_convergence=True,
549
+ progbar=False,
550
+ **contract_opts,
551
+ ):
552
+ """Sample a configuration from ``tn`` using dense 2-norm belief
553
+ propagation.
554
+
555
+ Parameters
556
+ ----------
557
+ tn : TensorNetwork
558
+ The tensor network to sample from.
559
+ output_inds : set[str], optional
560
+ Which indices to sample.
561
+ messages : dict[(str, int), array_like], optional
562
+ The initial messages to use, effectively defaults to all ones if not
563
+ specified.
564
+ max_iterations : int, optional
565
+ The maximum number of iterations to perform, per marginal.
566
+ tol : float, optional
567
+ The convergence tolerance for messages.
568
+ bias : float, optional
569
+ Bias the sampling towards more locally likely bit-strings. This is
570
+ done by raising the probability of each bit-string to this power.
571
+ seed : int, optional
572
+ A random seed for reproducibility.
573
+ local_convergence : bool, optional
574
+ Whether to allow messages to locally converge - i.e. if all their
575
+ input messages have converged then stop updating them.
576
+ progbar : bool, optional
577
+ Whether to show a progress bar.
578
+ contract_opts
579
+ Other options supplied to ``cotengra.array_contract``.
580
+
581
+ Returns
582
+ -------
583
+ config : dict[str, int]
584
+ The sampled configuration, a mapping of output indices to values.
585
+ tn_config : TensorNetwork
586
+ The tensor network with the sampled configuration applied.
587
+ omega : float
588
+ The BP probability of the sampled configuration.
589
+ """
590
+ import numpy as np
591
+
592
+ if output_inds is None:
593
+ output_inds = tn.outer_inds()
594
+
595
+ rng = np.random.default_rng(seed)
596
+ config = {}
597
+ omega = 1.0
598
+
599
+ tn = tn.copy()
600
+ bp = D2BP(
601
+ tn,
602
+ messages=messages,
603
+ local_convergence=local_convergence,
604
+ **contract_opts,
605
+ )
606
+ bp.run(max_iterations=max_iterations, tol=tol)
607
+
608
+ marginals = dict.fromkeys(output_inds)
609
+
610
+ if progbar:
611
+ import tqdm
612
+
613
+ pbar = tqdm.tqdm(total=len(marginals))
614
+ else:
615
+ pbar = None
616
+
617
+ while marginals:
618
+ for ix in marginals:
619
+ marginals[ix] = bp.compute_marginal(ix)
620
+
621
+ ix, p = max(marginals.items(), key=lambda x: max(x[1]))
622
+ p = ar.to_numpy(p)
623
+
624
+ if bias is not None:
625
+ # bias distribution towards more locally likely bit-strings
626
+ p = p**bias
627
+ p /= np.sum(p)
628
+
629
+ v = rng.choice([0, 1], p=p)
630
+ config[ix] = v
631
+ del marginals[ix]
632
+
633
+ tids = tuple(tn.ind_map[ix])
634
+ tn.isel_({ix: v})
635
+
636
+ omega *= p[v]
637
+ if progbar:
638
+ pbar.update(1)
639
+ pbar.set_description(f"{ix}->{v}", refresh=False)
640
+
641
+ bp = D2BP(
642
+ tn,
643
+ messages=bp.messages,
644
+ local_convergence=local_convergence,
645
+ **contract_opts,
646
+ )
647
+ bp.update_touched_from_tids(*tids)
648
+ bp.run(tol=tol, max_iterations=max_iterations)
649
+
650
+ if progbar:
651
+ pbar.close()
652
+
653
+ return config, tn, omega