Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ import autoray as ar
2
+
3
+ import quimb.tensor as qtn
4
+ from quimb.utils import oset
5
+
6
+ from .bp_common import (
7
+ BeliefPropagationCommon,
8
+ combine_local_contractions,
9
+ create_lazy_community_edge_map,
10
+ )
11
+
12
+
13
+ class L1BP(BeliefPropagationCommon):
14
+ """Lazy 1-norm belief propagation. BP is run between groups of tensors
15
+ defined by ``site_tags``. The message updates are lazy contractions.
16
+
17
+ Parameters
18
+ ----------
19
+ tn : TensorNetwork
20
+ The tensor network to run BP on.
21
+ site_tags : sequence of str, optional
22
+ The tags identifying the sites in ``tn``, each tag forms a region,
23
+ which should not overlap. If the tensor network is structured, then
24
+ these are inferred automatically.
25
+ damping : float, optional
26
+ The damping parameter to use, defaults to no damping.
27
+ update : {'parallel', 'sequential'}, optional
28
+ Whether to update all messages in parallel or sequentially.
29
+ local_convergence : bool, optional
30
+ Whether to allow messages to locally converge - i.e. if all their
31
+ input messages have converged then stop updating them.
32
+ optimize : str or PathOptimizer, optional
33
+ The path optimizer to use when contracting the messages.
34
+ contract_opts
35
+ Other options supplied to ``cotengra.array_contract``.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ tn,
41
+ site_tags=None,
42
+ damping=0.0,
43
+ update="sequential",
44
+ local_convergence=True,
45
+ optimize="auto-hq",
46
+ message_init_function=None,
47
+ **contract_opts,
48
+ ):
49
+ self.backend = next(t.backend for t in tn)
50
+ self.damping = damping
51
+ self.local_convergence = local_convergence
52
+ self.update = update
53
+ self.optimize = optimize
54
+ self.contract_opts = contract_opts
55
+
56
+ if site_tags is None:
57
+ self.site_tags = tuple(tn.site_tags)
58
+ else:
59
+ self.site_tags = tuple(site_tags)
60
+
61
+ (
62
+ self.edges,
63
+ self.neighbors,
64
+ self.local_tns,
65
+ self.touch_map,
66
+ ) = create_lazy_community_edge_map(tn, site_tags)
67
+ self.touched = oset()
68
+
69
+ self._abs = ar.get_lib_fn(self.backend, "abs")
70
+ self._max = ar.get_lib_fn(self.backend, "max")
71
+ self._sum = ar.get_lib_fn(self.backend, "sum")
72
+ _real = ar.get_lib_fn(self.backend, "real")
73
+ _argmax = ar.get_lib_fn(self.backend, "argmax")
74
+ _reshape = ar.get_lib_fn(self.backend, "reshape")
75
+ self._norm = ar.get_lib_fn(self.backend, "linalg.norm")
76
+
77
+ def _normalize(x):
78
+
79
+ # sx = self._sum(x)
80
+ # sphase = sx / self._abs(sx)
81
+ # smag = self._norm(x)**0.5
82
+ # return x / (smag * sphase)
83
+
84
+ return x / self._sum(x)
85
+ # return x / self._norm(x)
86
+ # return x / self._max(x)
87
+ # fx = _reshape(x, (-1,))
88
+ # return x / fx[_argmax(self._abs(_real(fx)))]
89
+
90
+ def _distance(x, y):
91
+ return self._sum(self._abs(x - y))
92
+
93
+ self._normalize = _normalize
94
+ self._distance = _distance
95
+
96
+ # for each meta bond create initial messages
97
+ self.messages = {}
98
+ for pair, bix in self.edges.items():
99
+ # compute leftwards and rightwards messages
100
+ for i, j in (sorted(pair), sorted(pair, reverse=True)):
101
+ tn_i = self.local_tns[i]
102
+ # initial message just sums over dangling bonds
103
+
104
+ if message_init_function is None:
105
+ tm = tn_i.contract(
106
+ all,
107
+ output_inds=bix,
108
+ optimize=self.optimize,
109
+ drop_tags=True,
110
+ **self.contract_opts,
111
+ )
112
+ # normalize
113
+ tm.modify(apply=self._normalize)
114
+ else:
115
+ shape = tuple(tn_i.ind_size(ix) for ix in bix)
116
+ tm = qtn.Tensor(
117
+ data=message_init_function(shape),
118
+ inds=bix,
119
+ )
120
+
121
+ self.messages[i, j] = tm
122
+
123
+ # compute the contractions
124
+ self.contraction_tns = {}
125
+ for pair, bix in self.edges.items():
126
+ # for each meta bond compute left and right contractions
127
+ for i, j in (sorted(pair), sorted(pair, reverse=True)):
128
+ tn_i = self.local_tns[i].copy()
129
+ # attach incoming messages to dangling bonds
130
+ tks = [
131
+ self.messages[k, i] for k in self.neighbors[i] if k != j
132
+ ]
133
+ # virtual so we can modify messages tensors inplace
134
+ tn_i_to_j = qtn.TensorNetwork((tn_i, *tks), virtual=True)
135
+ self.contraction_tns[i, j] = tn_i_to_j
136
+
137
+ def iterate(self, tol=5e-6):
138
+ if (not self.local_convergence) or (not self.touched):
139
+ # assume if asked to iterate that we want to check all messages
140
+ self.touched.update(
141
+ pair for edge in self.edges for pair in (edge, edge[::-1])
142
+ )
143
+
144
+ ncheck = len(self.touched)
145
+ nconv = 0
146
+ max_mdiff = -1.0
147
+ new_touched = oset()
148
+
149
+ def _compute_m(key):
150
+ i, j = key
151
+ bix = self.edges[(i, j) if i < j else (j, i)]
152
+ tn_i_to_j = self.contraction_tns[i, j]
153
+ tm_new = tn_i_to_j.contract(
154
+ all,
155
+ output_inds=bix,
156
+ optimize=self.optimize,
157
+ **self.contract_opts,
158
+ )
159
+ return self._normalize(tm_new.data)
160
+
161
+ def _update_m(key, data):
162
+ nonlocal nconv, max_mdiff
163
+
164
+ tm = self.messages[key]
165
+
166
+ if callable(self.damping):
167
+ damping_m = self.damping()
168
+ data = (1 - damping_m) * data + damping_m * tm.data
169
+ elif self.damping != 0.0:
170
+ data = (1 - self.damping) * data + self.damping * tm.data
171
+
172
+ mdiff = float(self._distance(tm.data, data))
173
+
174
+ if mdiff > tol:
175
+ # mark touching messages for update
176
+ new_touched.update(self.touch_map[key])
177
+ else:
178
+ nconv += 1
179
+
180
+ max_mdiff = max(max_mdiff, mdiff)
181
+ tm.modify(data=data)
182
+
183
+ if self.update == "parallel":
184
+ new_data = {}
185
+ # compute all new messages
186
+ while self.touched:
187
+ key = self.touched.pop()
188
+ new_data[key] = _compute_m(key)
189
+ # insert all new messages
190
+ for key, data in new_data.items():
191
+ _update_m(key, data)
192
+
193
+ elif self.update == "sequential":
194
+ # compute each new message and immediately re-insert it
195
+ while self.touched:
196
+ key = self.touched.pop()
197
+ data = _compute_m(key)
198
+ _update_m(key, data)
199
+
200
+ self.touched = new_touched
201
+ return nconv, ncheck, max_mdiff
202
+
203
+ def contract(self, strip_exponent=False):
204
+ tvals = []
205
+ for site, tn_ic in self.local_tns.items():
206
+ if site in self.neighbors:
207
+ tval = qtn.tensor_contract(
208
+ *tn_ic,
209
+ *(self.messages[k, site] for k in self.neighbors[site]),
210
+ optimize=self.optimize,
211
+ **self.contract_opts,
212
+ )
213
+ else:
214
+ # site exists but has no neighbors
215
+ tval = tn_ic.contract(
216
+ all,
217
+ output_inds=(),
218
+ optimize=self.optimize,
219
+ **self.contract_opts,
220
+ )
221
+ tvals.append(tval)
222
+
223
+ mvals = []
224
+ for i, j in self.edges:
225
+ mval = qtn.tensor_contract(
226
+ self.messages[i, j],
227
+ self.messages[j, i],
228
+ optimize=self.optimize,
229
+ **self.contract_opts,
230
+ )
231
+ mvals.append(mval)
232
+
233
+ return combine_local_contractions(
234
+ tvals, mvals, self.backend, strip_exponent=strip_exponent
235
+ )
236
+
237
+ def normalize_messages(self):
238
+ """Normalize all messages such that for each bond `<m_i|m_j> = 1` and
239
+ `<m_i|m_i> = <m_j|m_j>` (but in general != 1).
240
+ """
241
+ for i, j in self.edges:
242
+ tmi = self.messages[i, j]
243
+ tmj = self.messages[j, i]
244
+ nij = abs(tmi @ tmj)**0.5
245
+ nii = (tmi @ tmi)**0.25
246
+ njj = (tmj @ tmj)**0.25
247
+ tmi /= (nij * nii / njj)
248
+ tmj /= (nij * njj / nii)
249
+
250
+
251
+ def contract_l1bp(
252
+ tn,
253
+ max_iterations=1000,
254
+ tol=5e-6,
255
+ site_tags=None,
256
+ damping=0.0,
257
+ update="sequential",
258
+ local_convergence=True,
259
+ optimize="auto-hq",
260
+ strip_exponent=False,
261
+ info=None,
262
+ progbar=False,
263
+ **contract_opts,
264
+ ):
265
+ """Estimate the contraction of ``tn`` using lazy 1-norm belief propagation.
266
+
267
+ Parameters
268
+ ----------
269
+ tn : TensorNetwork
270
+ The tensor network to contract.
271
+ max_iterations : int, optional
272
+ The maximum number of iterations to perform.
273
+ tol : float, optional
274
+ The convergence tolerance for messages.
275
+ site_tags : sequence of str, optional
276
+ The tags identifying the sites in ``tn``, each tag forms a region. If
277
+ the tensor network is structured, then these are inferred
278
+ automatically.
279
+ damping : float, optional
280
+ The damping parameter to use, defaults to no damping.
281
+ update : {'parallel', 'sequential'}, optional
282
+ Whether to update all messages in parallel or sequentially.
283
+ local_convergence : bool, optional
284
+ Whether to allow messages to locally converge - i.e. if all their
285
+ input messages have converged then stop updating them.
286
+ optimize : str or PathOptimizer, optional
287
+ The path optimizer to use when contracting the messages.
288
+ progbar : bool, optional
289
+ Whether to show a progress bar.
290
+ strip_exponent : bool, optional
291
+ Whether to strip the exponent from the final result. If ``True``
292
+ then the returned result is ``(mantissa, exponent)``.
293
+ info : dict, optional
294
+ If specified, update this dictionary with information about the
295
+ belief propagation run.
296
+ contract_opts
297
+ Other options supplied to ``cotengra.array_contract``.
298
+ """
299
+ bp = L1BP(
300
+ tn,
301
+ site_tags=site_tags,
302
+ damping=damping,
303
+ local_convergence=local_convergence,
304
+ update=update,
305
+ optimize=optimize,
306
+ **contract_opts,
307
+ )
308
+ bp.run(
309
+ max_iterations=max_iterations,
310
+ tol=tol,
311
+ info=info,
312
+ progbar=progbar,
313
+ )
314
+ return bp.contract(
315
+ strip_exponent=strip_exponent,
316
+ )