Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,537 @@
1
+ import math
2
+
3
+ import autoray as ar
4
+
5
+ import quimb.tensor as qtn
6
+ from quimb.utils import oset
7
+
8
+ from .bp_common import (
9
+ BeliefPropagationCommon,
10
+ combine_local_contractions,
11
+ create_lazy_community_edge_map,
12
+ )
13
+
14
+
15
+ class L2BP(BeliefPropagationCommon):
16
+ """Lazy (as in multiple uncontracted tensors per site) 2-norm (as in for
17
+ wavefunctions and operators) belief propagation.
18
+
19
+ Parameters
20
+ ----------
21
+ tn : TensorNetwork
22
+ The tensor network to form the 2-norm of and run BP on.
23
+ site_tags : sequence of str, optional
24
+ The tags identifying the sites in ``tn``, each tag forms a region,
25
+ which should not overlap. If the tensor network is structured, then
26
+ these are inferred automatically.
27
+ damping : float, optional
28
+ The damping parameter to use, defaults to no damping.
29
+ update : {'parallel', 'sequential'}, optional
30
+ Whether to update all messages in parallel or sequentially.
31
+ local_convergence : bool, optional
32
+ Whether to allow messages to locally converge - i.e. if all their
33
+ input messages have converged then stop updating them.
34
+ optimize : str or PathOptimizer, optional
35
+ The path optimizer to use when contracting the messages.
36
+ contract_opts
37
+ Other options supplied to ``cotengra.array_contract``.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ tn,
43
+ site_tags=None,
44
+ damping=0.0,
45
+ update="sequential",
46
+ local_convergence=True,
47
+ optimize="auto-hq",
48
+ **contract_opts,
49
+ ):
50
+ self.backend = next(t.backend for t in tn)
51
+ self.damping = damping
52
+ self.local_convergence = local_convergence
53
+ self.update = update
54
+ self.optimize = optimize
55
+ self.contract_opts = contract_opts
56
+
57
+ if site_tags is None:
58
+ self.site_tags = tuple(tn.site_tags)
59
+ else:
60
+ self.site_tags = tuple(site_tags)
61
+
62
+ (
63
+ self.edges,
64
+ self.neighbors,
65
+ self.local_tns,
66
+ self.touch_map,
67
+ ) = create_lazy_community_edge_map(tn, site_tags)
68
+ self.touched = oset()
69
+
70
+ _abs = ar.get_lib_fn(self.backend, "abs")
71
+ _sum = ar.get_lib_fn(self.backend, "sum")
72
+ _transpose = ar.get_lib_fn(self.backend, "transpose")
73
+ _conj = ar.get_lib_fn(self.backend, "conj")
74
+
75
+ def _normalize(x):
76
+ return x / _sum(x)
77
+
78
+ def _symmetrize(x):
79
+ N = ar.ndim(x)
80
+ perm = (*range(N // 2, N), *range(0, N // 2))
81
+ return x + _conj(_transpose(x, perm))
82
+
83
+ def _distance(x, y):
84
+ return _sum(_abs(x - y))
85
+
86
+ self._normalize = _normalize
87
+ self._symmetrize = _symmetrize
88
+ self._distance = _distance
89
+
90
+ # initialize messages
91
+ self.messages = {}
92
+
93
+ for pair, bix in self.edges.items():
94
+ cix = tuple(ix + "_l2bp*" for ix in bix)
95
+ remapper = dict(zip(bix, cix))
96
+ output_inds = cix + bix
97
+
98
+ # compute leftwards and righwards messages
99
+ for i, j in (sorted(pair), sorted(pair, reverse=True)):
100
+ tn_i = self.local_tns[i]
101
+ tn_i2 = tn_i & tn_i.conj().reindex_(remapper)
102
+ tm = tn_i2.contract(
103
+ all,
104
+ output_inds=output_inds,
105
+ optimize=self.optimize,
106
+ drop_tags=True,
107
+ **self.contract_opts,
108
+ )
109
+ tm.modify(apply=self._symmetrize)
110
+ tm.modify(apply=self._normalize)
111
+ self.messages[i, j] = tm
112
+
113
+ # initialize contractions
114
+ self.contraction_tns = {}
115
+ for pair, bix in self.edges.items():
116
+ for i, j in (sorted(pair), sorted(pair, reverse=True)):
117
+ # form the ket side and messages
118
+ tn_i_left = self.local_tns[i]
119
+ # get other incident nodes which aren't j
120
+ ks = [k for k in self.neighbors[i] if k != j]
121
+ tks = [self.messages[k, i] for k in ks]
122
+
123
+ # form the 'bra' side
124
+ tn_i_right = tn_i_left.conj()
125
+ # get the bonds that attach the bra to messages
126
+ outer_bix = {
127
+ ix for k in ks for ix in self.edges[tuple(sorted((k, i)))]
128
+ }
129
+ # need to reindex to join message bonds, and create bra outputs
130
+ remapper = {}
131
+ for ix in tn_i_right.ind_map:
132
+ if ix in bix:
133
+ # bra outputs
134
+ remapper[ix] = ix + "_l2bp**"
135
+ elif ix in outer_bix:
136
+ # messages connected
137
+ remapper[ix] = ix + "_l2bp*"
138
+ # remaining indices are either internal and will be mangled
139
+ # or global outer indices and will be contracted directly
140
+
141
+ tn_i_right.reindex_(remapper)
142
+
143
+ self.contraction_tns[i, j] = qtn.TensorNetwork(
144
+ (tn_i_left, *tks, tn_i_right), virtual=True
145
+ )
146
+
147
+ def iterate(self, tol=5e-6):
148
+ if (not self.local_convergence) or (not self.touched):
149
+ # assume if asked to iterate that we want to check all messages
150
+ self.touched.update(
151
+ pair for edge in self.edges for pair in (edge, edge[::-1])
152
+ )
153
+
154
+ ncheck = len(self.touched)
155
+ nconv = 0
156
+ max_mdiff = -1.0
157
+ new_touched = oset()
158
+
159
+ def _compute_m(key):
160
+ i, j = key
161
+ bix = self.edges[(i, j) if i < j else (j, i)]
162
+ cix = tuple(ix + "_l2bp**" for ix in bix)
163
+ output_inds = cix + bix
164
+
165
+ tn_i_to_j = self.contraction_tns[i, j]
166
+
167
+ tm_new = tn_i_to_j.contract(
168
+ all,
169
+ output_inds=output_inds,
170
+ drop_tags=True,
171
+ optimize=self.optimize,
172
+ **self.contract_opts,
173
+ )
174
+ tm_new.modify(apply=self._symmetrize)
175
+ tm_new.modify(apply=self._normalize)
176
+ return tm_new.data
177
+
178
+ def _update_m(key, data):
179
+ nonlocal nconv, max_mdiff
180
+
181
+ tm = self.messages[key]
182
+
183
+ if self.damping > 0.0:
184
+ data = (1 - self.damping) * data + self.damping * tm.data
185
+
186
+ try:
187
+ mdiff = float(self._distance(tm.data, data))
188
+ except (TypeError, ValueError):
189
+ # handle e.g. lazy arrays
190
+ mdiff = float("inf")
191
+
192
+ if mdiff > tol:
193
+ # mark touching messages for update
194
+ new_touched.update(self.touch_map[key])
195
+ else:
196
+ nconv += 1
197
+
198
+ max_mdiff = max(max_mdiff, mdiff)
199
+ tm.modify(data=data)
200
+
201
+ if self.update == "parallel":
202
+ new_data = {}
203
+ # compute all new messages
204
+ while self.touched:
205
+ key = self.touched.pop()
206
+ new_data[key] = _compute_m(key)
207
+ # insert all new messages
208
+ for key, data in new_data.items():
209
+ _update_m(key, data)
210
+
211
+ elif self.update == "sequential":
212
+ # compute each new message and immediately re-insert it
213
+ while self.touched:
214
+ key = self.touched.pop()
215
+ data = _compute_m(key)
216
+ _update_m(key, data)
217
+
218
+ self.touched = new_touched
219
+
220
+ return nconv, ncheck, max_mdiff
221
+
222
+ def normalize_messages(self):
223
+ """Normalize all messages such that for each bond `<m_i|m_j> = 1` and
224
+ `<m_i|m_i> = <m_j|m_j>` (but in general != 1).
225
+ """
226
+ for i, j in self.edges:
227
+ tmi = self.messages[i, j]
228
+ tmj = self.messages[j, i]
229
+ nij = (tmi @ tmj)**0.5
230
+ nii = (tmi @ tmi)**0.25
231
+ njj = (tmj @ tmj)**0.25
232
+ tmi /= (nij * nii / njj)
233
+ tmj /= (nij * njj / nii)
234
+
235
+ def contract(self, strip_exponent=False):
236
+ """Estimate the contraction of the norm squared using the current
237
+ messages.
238
+ """
239
+ tvals = []
240
+ for i, ket in self.local_tns.items():
241
+ # we allow missing keys here for tensors which are just
242
+ # disconnected but still appear in local_tns
243
+ ks = self.neighbors.get(i, ())
244
+ bix = [ix for k in ks for ix in self.edges[tuple(sorted((k, i)))]]
245
+ bra = ket.H.reindex_({ix: ix + "_l2bp*" for ix in bix})
246
+ tni = qtn.TensorNetwork(
247
+ (
248
+ ket,
249
+ *(self.messages[k, i] for k in ks),
250
+ bra,
251
+ )
252
+ )
253
+ tvals.append(
254
+ tni.contract(all, optimize=self.optimize, **self.contract_opts)
255
+ )
256
+
257
+ mvals = []
258
+ for i, j in self.edges:
259
+ mvals.append(
260
+ (self.messages[i, j] & self.messages[j, i]).contract(
261
+ all,
262
+ optimize=self.optimize,
263
+ **self.contract_opts,
264
+ )
265
+ )
266
+
267
+ return combine_local_contractions(
268
+ tvals, mvals, self.backend, strip_exponent=strip_exponent
269
+ )
270
+
271
+ def partial_trace(
272
+ self,
273
+ site,
274
+ normalized=True,
275
+ optimize="auto-hq",
276
+ ):
277
+ example_tn = next(tn for tn in self.local_tns.values())
278
+
279
+ site_tag = example_tn.site_tag(site)
280
+ ket_site_ind = example_tn.site_ind(site)
281
+
282
+ ks = self.neighbors[site_tag]
283
+ tn_rho_i = self.local_tns[site_tag].copy()
284
+ tn_bra_i = tn_rho_i.H
285
+
286
+ for k in ks:
287
+ tn_rho_i &= self.messages[k, site_tag]
288
+
289
+ outer_bix = {
290
+ ix for k in ks for ix in self.edges[tuple(sorted((k, site_tag)))]
291
+ }
292
+
293
+ ind_changes = {}
294
+ for ix in tn_bra_i.ind_map:
295
+ if ix == ket_site_ind:
296
+ # open up the site index
297
+ bra_site_ind = ix + "_l2bp**"
298
+ ind_changes[ix] = bra_site_ind
299
+ if ix in outer_bix:
300
+ # attach bra message indices
301
+ ind_changes[ix] = ix + "_l2bp*"
302
+ tn_bra_i.reindex_(ind_changes)
303
+
304
+ tn_rho_i &= tn_bra_i
305
+
306
+ rho_i = tn_rho_i.to_dense(
307
+ [ket_site_ind],
308
+ [bra_site_ind],
309
+ optimize=optimize,
310
+ **self.contract_opts,
311
+ )
312
+ if normalized:
313
+ rho_i = rho_i / ar.do("trace", rho_i)
314
+
315
+ return rho_i
316
+
317
+ def compress(
318
+ self,
319
+ tn,
320
+ max_bond=None,
321
+ cutoff=5e-6,
322
+ cutoff_mode="rsum2",
323
+ renorm=0,
324
+ lazy=False,
325
+ ):
326
+ """Compress the state ``tn``, assumed to matched this L2BP instance,
327
+ using the messages stored.
328
+ """
329
+ for (i, j), bix in self.edges.items():
330
+ tml = self.messages[i, j]
331
+ tmr = self.messages[j, i]
332
+
333
+ bix_sizes = [tml.ind_size(ix) for ix in bix]
334
+ dm = math.prod(bix_sizes)
335
+
336
+ ml = ar.reshape(tml.data, (dm, dm))
337
+ dl = self.local_tns[i].outer_size() // dm
338
+ Rl = qtn.decomp.squared_op_to_reduced_factor(
339
+ ml, dl, dm, right=True
340
+ )
341
+
342
+ mr = ar.reshape(tmr.data, (dm, dm)).T
343
+ dr = self.local_tns[j].outer_size() // dm
344
+ Rr = qtn.decomp.squared_op_to_reduced_factor(
345
+ mr, dm, dr, right=False
346
+ )
347
+
348
+ Pl, Pr = qtn.decomp.compute_oblique_projectors(
349
+ Rl,
350
+ Rr,
351
+ cutoff_mode=cutoff_mode,
352
+ renorm=renorm,
353
+ max_bond=max_bond,
354
+ cutoff=cutoff,
355
+ )
356
+
357
+ Pl = ar.do("reshape", Pl, (*bix_sizes, -1))
358
+ Pr = ar.do("reshape", Pr, (-1, *bix_sizes))
359
+
360
+ ltn = tn.select(i)
361
+ rtn = tn.select(j)
362
+
363
+ new_lix = [qtn.rand_uuid() for _ in bix]
364
+ new_rix = [qtn.rand_uuid() for _ in bix]
365
+ new_bix = [qtn.rand_uuid()]
366
+ ltn.reindex_(dict(zip(bix, new_lix)))
367
+ rtn.reindex_(dict(zip(bix, new_rix)))
368
+
369
+ # ... and insert the new projectors in place
370
+ tn |= qtn.Tensor(Pl, inds=new_lix + new_bix, tags=(i,))
371
+ tn |= qtn.Tensor(Pr, inds=new_bix + new_rix, tags=(j,))
372
+
373
+ if not lazy:
374
+ for st in self.site_tags:
375
+ try:
376
+ tn.contract_tags_(
377
+ st, optimize=self.optimize, **self.contract_opts
378
+ )
379
+ except KeyError:
380
+ pass
381
+
382
+ return tn
383
+
384
+
385
+ def contract_l2bp(
386
+ tn,
387
+ site_tags=None,
388
+ damping=0.0,
389
+ update="sequential",
390
+ local_convergence=True,
391
+ optimize="auto-hq",
392
+ max_iterations=1000,
393
+ tol=5e-6,
394
+ strip_exponent=False,
395
+ info=None,
396
+ progbar=False,
397
+ **contract_opts,
398
+ ):
399
+ """Estimate the norm squared of ``tn`` using lazy belief propagation.
400
+
401
+ Parameters
402
+ ----------
403
+ tn : TensorNetwork
404
+ The tensor network to estimate the norm squared of.
405
+ site_tags : sequence of str, optional
406
+ The tags identifying the sites in ``tn``, each tag forms a region.
407
+ damping : float, optional
408
+ The damping parameter to use, defaults to no damping.
409
+ update : {'parallel', 'sequential'}, optional
410
+ Whether to update all messages in parallel or sequentially.
411
+ local_convergence : bool, optional
412
+ Whether to allow messages to locally converge - i.e. if all their
413
+ input messages have converged then stop updating them.
414
+ optimize : str or PathOptimizer, optional
415
+ The contraction strategy to use.
416
+ max_iterations : int, optional
417
+ The maximum number of iterations to perform.
418
+ tol : float, optional
419
+ The convergence tolerance for messages.
420
+ strip_exponent : bool, optional
421
+ Whether to strip the exponent from the final result. If ``True``
422
+ then the returned result is ``(mantissa, exponent)``.
423
+ info : dict, optional
424
+ If specified, update this dictionary with information about the
425
+ belief propagation run.
426
+ progbar : bool, optional
427
+ Whether to show a progress bar.
428
+ contract_opts
429
+ Other options supplied to ``cotengra.array_contract``.
430
+ """
431
+ bp = L2BP(
432
+ tn,
433
+ site_tags=site_tags,
434
+ damping=damping,
435
+ update=update,
436
+ local_convergence=local_convergence,
437
+ optimize=optimize,
438
+ **contract_opts,
439
+ )
440
+ bp.run(
441
+ max_iterations=max_iterations,
442
+ tol=tol,
443
+ info=info,
444
+ progbar=progbar,
445
+ )
446
+ return bp.contract(strip_exponent=strip_exponent)
447
+
448
+
449
+ def compress_l2bp(
450
+ tn,
451
+ max_bond,
452
+ cutoff=0.0,
453
+ cutoff_mode="rsum2",
454
+ max_iterations=1000,
455
+ tol=5e-6,
456
+ site_tags=None,
457
+ damping=0.0,
458
+ update="sequential",
459
+ local_convergence=True,
460
+ optimize="auto-hq",
461
+ lazy=False,
462
+ inplace=False,
463
+ info=None,
464
+ progbar=False,
465
+ **contract_opts,
466
+ ):
467
+ """Compress ``tn`` using lazy belief propagation, producing a tensor
468
+ network with a single tensor per site.
469
+
470
+ Parameters
471
+ ----------
472
+ tn : TensorNetwork
473
+ The tensor network to form the 2-norm of, run BP on and then compress.
474
+ max_bond : int
475
+ The maximum bond dimension to compress to.
476
+ cutoff : float, optional
477
+ The cutoff to use when compressing.
478
+ cutoff_mode : int, optional
479
+ The cutoff mode to use when compressing.
480
+ max_iterations : int, optional
481
+ The maximum number of iterations to perform.
482
+ tol : float, optional
483
+ The convergence tolerance for messages.
484
+ site_tags : sequence of str, optional
485
+ The tags identifying the sites in ``tn``, each tag forms a region. If
486
+ the tensor network is structured, then these are inferred
487
+ automatically.
488
+ damping : float, optional
489
+ The damping parameter to use, defaults to no damping.
490
+ update : {'parallel', 'sequential'}, optional
491
+ Whether to update all messages in parallel or sequentially.
492
+ local_convergence : bool, optional
493
+ Whether to allow messages to locally converge - i.e. if all their
494
+ input messages have converged then stop updating them.
495
+ optimize : str or PathOptimizer, optional
496
+ The path optimizer to use when contracting the messages.
497
+ lazy : bool, optional
498
+ Whether to perform the compression lazily, i.e. to leave the computed
499
+ compression projectors uncontracted.
500
+ inplace : bool, optional
501
+ Whether to perform the compression inplace.
502
+ info : dict, optional
503
+ If specified, update this dictionary with information about the
504
+ belief propagation run.
505
+ progbar : bool, optional
506
+ Whether to show a progress bar.
507
+ contract_opts
508
+ Other options supplied to ``cotengra.array_contract``.
509
+
510
+ Returns
511
+ -------
512
+ TensorNetwork
513
+ """
514
+ tnc = tn if inplace else tn.copy()
515
+
516
+ bp = L2BP(
517
+ tnc,
518
+ site_tags=site_tags,
519
+ damping=damping,
520
+ update=update,
521
+ local_convergence=local_convergence,
522
+ optimize=optimize,
523
+ **contract_opts,
524
+ )
525
+ bp.run(
526
+ max_iterations=max_iterations,
527
+ tol=tol,
528
+ info=info,
529
+ progbar=progbar,
530
+ )
531
+ return bp.compress(
532
+ tnc,
533
+ max_bond=max_bond,
534
+ cutoff=cutoff,
535
+ cutoff_mode=cutoff_mode,
536
+ lazy=lazy,
537
+ )