Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,583 @@
1
+ """Some functions for generating the edges of a lattice."""
2
+
3
+ import itertools
4
+ import random
5
+
6
+
7
+ def sort_unique(edges):
8
+ """Make sure there are no duplicate edges and that for each
9
+ ``coo_a < coo_b``.
10
+ """
11
+ return tuple(
12
+ sorted(tuple(sorted(edge)) for edge in set(map(frozenset, edges)))
13
+ )
14
+
15
+
16
+ # ----------------------------------- 1D ------------------------------------ #
17
+
18
+
19
+ def edges_1d_chain(L, cyclic=False):
20
+ """Return the graph edges of a finite 1D chain lattice.
21
+
22
+ Parameters
23
+ ----------
24
+ L : int
25
+ The number of cells.
26
+ cyclic : bool, optional
27
+ Whether to use periodic boundary conditions.
28
+
29
+ Returns
30
+ -------
31
+ edges : list[(int, int)]
32
+ """
33
+ edges = []
34
+ for i in range(L):
35
+ if i < L - 1 or cyclic:
36
+ edges.append((i, (i + 1) % L))
37
+
38
+ return sort_unique(edges)
39
+
40
+
41
+ # ----------------------------------- 2D ------------------------------------ #
42
+
43
+
44
+ def check_2d(coo, Lx, Ly, cyclic):
45
+ """Check ``coo`` in inbounds for a maybe cyclic 2D lattice."""
46
+ x, y = coo
47
+ if (not cyclic) and not ((0 <= x < Lx) and (0 <= y < Ly)):
48
+ return
49
+ return (x % Lx, y % Ly)
50
+
51
+
52
+ def edges_2d_square(Lx, Ly, cyclic=False, cells=None):
53
+ """Return the graph edges of a finite 2D square lattice. The nodes
54
+ (sites) are labelled like ``(i, j)``.
55
+
56
+ Parameters
57
+ ----------
58
+ Lx : int
59
+ The number of cells along the x-direction.
60
+ Ly : int
61
+ The number of cells along the y-direction.
62
+ cyclic : bool, optional
63
+ Whether to use periodic boundary conditions.
64
+ cells : list, optional
65
+ A list of cells to use. If not given the cells used are
66
+ ``itertools.product(range(Lx), range(Ly))``.
67
+
68
+ Returns
69
+ -------
70
+ edges : list[((int, int), (int, int))]
71
+ """
72
+ if cells is None:
73
+ cells = itertools.product(range(Lx), range(Ly))
74
+
75
+ edges = []
76
+ for i, j in cells:
77
+ for coob in [(i, j + 1), (i + 1, j)]:
78
+ coob = check_2d(coob, Lx, Ly, cyclic)
79
+ if coob:
80
+ edges.append(((i, j), coob))
81
+
82
+ return sort_unique(edges)
83
+
84
+
85
+ def edges_2d_hexagonal(Lx, Ly, cyclic=False, cells=None):
86
+ """Return the graph edges of a finite 2D hexagonal lattice. There are two
87
+ sites per cell, and note the cells do not form a square tiling. The nodes
88
+ (sites) are labelled like ``(i, j, s)`` for ``s`` in ``'AB'``.
89
+
90
+ Parameters
91
+ ----------
92
+ Lx : int
93
+ The number of cells along the x-direction.
94
+ Ly : int
95
+ The number of cells along the y-direction.
96
+ cyclic : bool, optional
97
+ Whether to use periodic boundary conditions.
98
+ cells : list, optional
99
+ A list of cells to use. If not given the cells used are
100
+ ``itertools.product(range(Lx), range(Ly))``.
101
+
102
+ Returns
103
+ -------
104
+ edges : list[((int, int, str), (int, int, str))]
105
+ """
106
+ if cells is None:
107
+ cells = itertools.product(range(Lx), range(Ly))
108
+
109
+ edges = []
110
+ for i, j in cells:
111
+ for *coob, lbl in [
112
+ (i, j, "B"),
113
+ (i, j - 1, "B"),
114
+ (i - 1, j, "B"),
115
+ ]:
116
+ coob = check_2d(coob, Lx, Ly, cyclic)
117
+ if coob:
118
+ edges.append(((i, j, "A"), (*coob, lbl)))
119
+
120
+ for *coob, lbl in [
121
+ (i, j, "A"),
122
+ (i, j + 1, "A"),
123
+ (i + 1, j, "A"),
124
+ ]:
125
+ coob = check_2d(coob, Lx, Ly, cyclic)
126
+ if coob:
127
+ edges.append(((i, j, "B"), (*coob, lbl)))
128
+
129
+ return sort_unique(edges)
130
+
131
+
132
+ def edges_2d_triangular(Lx, Ly, cyclic=False, cells=None):
133
+ """Return the graph edges of a finite 2D triangular lattice. There is a
134
+ single site per cell, and note the cells do not form a square tiling.
135
+ The nodes (sites) are labelled like ``(i, j)``.
136
+
137
+ Parameters
138
+ ----------
139
+ Parameters
140
+ ----------
141
+ Lx : int
142
+ The number of cells along the x-direction.
143
+ Ly : int
144
+ The number of cells along the y-direction.
145
+ cyclic : bool, optional
146
+ Whether to use periodic boundary conditions.
147
+ cells : list, optional
148
+ A list of cells to use. If not given the cells used are
149
+ ``itertools.product(range(Lx), range(Ly))``.
150
+
151
+ Returns
152
+ -------
153
+ edges : list[((int, int), (int, int))]
154
+ """
155
+ if cells is None:
156
+ cells = itertools.product(range(Lx), range(Ly))
157
+
158
+ edges = []
159
+ for i, j in cells:
160
+ for coob in [(i, j + 1), (i + 1, j), (i + 1, j - 1)]:
161
+ coob = check_2d(coob, Lx, Ly, cyclic)
162
+ if coob:
163
+ edges.append(((i, j), coob))
164
+
165
+ return sort_unique(edges)
166
+
167
+
168
+ def edges_2d_triangular_rectangular(Lx, Ly, cyclic=False, cells=None):
169
+ """Return the graph edges of a finite 2D triangular lattice tiled in a
170
+ rectangular geometry. There are two sites per rectangular cell. The nodes
171
+ (sites) are labelled like ``(i, j, s)`` for ``s`` in ``'AB'``.
172
+
173
+ Parameters
174
+ ----------
175
+ Lx : int
176
+ The number of cells along the x-direction.
177
+ Ly : int
178
+ The number of cells along the y-direction.
179
+ cyclic : bool, optional
180
+ Whether to use periodic boundary conditions.
181
+ cells : list, optional
182
+ A list of cells to use. If not given the cells used are
183
+ ``itertools.product(range(Lx), range(Ly))``.
184
+
185
+ Returns
186
+ -------
187
+ edges : list[((int, int, s), (int, int, s))]
188
+ """
189
+ if cells is None:
190
+ cells = itertools.product(range(Lx), range(Ly))
191
+
192
+ edges = []
193
+ for i, j in cells:
194
+ for *coob, lbl in [
195
+ (i, j, "B"),
196
+ (i, j - 1, "B"),
197
+ (i, j + 1, "A"),
198
+ ]:
199
+ coob = check_2d(coob, Lx, Ly, cyclic)
200
+ if coob:
201
+ edges.append(((i, j, "A"), (*coob, lbl)))
202
+
203
+ for *coob, lbl in [
204
+ (i + 1, j, "A"),
205
+ (i, j + 1, "B"),
206
+ (i + 1, j + 1, "A"),
207
+ ]:
208
+ coob = check_2d(coob, Lx, Ly, cyclic)
209
+ if coob:
210
+ edges.append(((i, j, "B"), (*coob, lbl)))
211
+
212
+ return sort_unique(edges)
213
+
214
+
215
+ def edges_2d_kagome(Lx, Ly, cyclic=False, cells=None):
216
+ """Return the graph edges of a finite 2D kagome lattice. There are
217
+ three sites per cell, and note the cells do not form a square tiling. The
218
+ nodes (sites) are labelled like ``(i, j, s)`` for ``s`` in ``'ABC'``.
219
+
220
+ Parameters
221
+ ----------
222
+ Lx : int
223
+ The number of cells along the x-direction.
224
+ Ly : int
225
+ The number of cells along the y-direction.
226
+ cyclic : bool, optional
227
+ Whether to use periodic boundary conditions.
228
+ cells : list, optional
229
+ A list of cells to use. If not given the cells used are
230
+ ``itertools.product(range(Lx), range(Ly))``.
231
+
232
+ Returns
233
+ -------
234
+ edges : list[((int, int, str), (int, int, str))]
235
+ """
236
+ if cells is None:
237
+ cells = itertools.product(range(Lx), range(Ly))
238
+
239
+ edges = []
240
+ for i, j in cells:
241
+ for *coob, lbl in [
242
+ (i, j, "B"),
243
+ (i, j - 1, "B"),
244
+ (i, j, "C"),
245
+ (i - 1, j, "C"),
246
+ ]:
247
+ coob = check_2d(coob, Lx, Ly, cyclic)
248
+ if coob:
249
+ edges.append(((i, j, "A"), (*coob, lbl)))
250
+
251
+ for *coob, lbl in [
252
+ (i, j, "C"),
253
+ (i - 1, j + 1, "C"),
254
+ (i, j, "A"),
255
+ (i, j + 1, "A"),
256
+ ]:
257
+ coob = check_2d(coob, Lx, Ly, cyclic)
258
+ if coob:
259
+ edges.append(((i, j, "B"), (*coob, lbl)))
260
+
261
+ for *coob, lbl in [
262
+ (i, j, "A"),
263
+ (i + 1, j, "A"),
264
+ (i, j, "B"),
265
+ (i + 1, j - 1, "B"),
266
+ ]:
267
+ coob = check_2d(coob, Lx, Ly, cyclic)
268
+ if coob:
269
+ edges.append(((i, j, "C"), (*coob, lbl)))
270
+
271
+ return sort_unique(edges)
272
+
273
+
274
+ # ----------------------------------- 3D ------------------------------------ #
275
+
276
+
277
+ def check_3d(coo, Lx, Ly, Lz, cyclic):
278
+ """Check ``coo`` in inbounds for a maybe cyclic 3D lattice."""
279
+ x, y, z = coo
280
+ OBC = not cyclic
281
+ inbounds = (0 <= x < Lx) and (0 <= y < Ly) and (0 <= z < Lz)
282
+ if OBC and not inbounds:
283
+ return
284
+ return (x % Lx, y % Ly, z % Lz)
285
+
286
+
287
+ def edges_3d_cubic(Lx, Ly, Lz, cyclic=False, cells=None):
288
+ """Return the graph edges of a finite 3D cubic lattice. The nodes
289
+ (sites) are labelled like ``(i, j, k)``.
290
+
291
+ Parameters
292
+ ----------
293
+ Lx : int
294
+ The number of cells along the x-direction.
295
+ Ly : int
296
+ The number of cells along the y-direction.
297
+ Lz : int
298
+ The number of cells along the z-direction.
299
+ cyclic : bool, optional
300
+ Whether to use periodic boundary conditions.
301
+ cells : list, optional
302
+ A list of cells to use. If not given the cells used are
303
+ ``itertools.product(range(Lx), range(Ly), range(Lz))``.
304
+
305
+ Returns
306
+ -------
307
+ edges : list[((int, int, int), (int, int, int))]
308
+ """
309
+ if cells is None:
310
+ cells = itertools.product(range(Lx), range(Ly), range(Lz))
311
+
312
+ edges = []
313
+ for i, j, k in cells:
314
+ for coob in [(i, j, k + 1), (i, j + 1, k), (i + 1, j, k)]:
315
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
316
+ if coob:
317
+ edges.append(((i, j, k), coob))
318
+
319
+ return sort_unique(edges)
320
+
321
+
322
+ def edges_3d_pyrochlore(Lx, Ly, Lz, cyclic=False, cells=None):
323
+ """Return the graph edges of a finite 3D pyorchlore lattice. There are
324
+ four sites per cell, and note the cells do not form a cubic tiling. The
325
+ nodes (sites) are labelled like ``(i, j, k, s)`` for ``s`` in ``'ABCD'``.
326
+
327
+ Parameters
328
+ ----------
329
+ Lx : int
330
+ The number of cells along the x-direction.
331
+ Ly : int
332
+ The number of cells along the y-direction.
333
+ Lz : int
334
+ The number of cells along the z-direction.
335
+ cyclic : bool, optional
336
+ Whether to use periodic boundary conditions.
337
+ cells : list, optional
338
+ A list of cells to use. If not given the cells used are
339
+ ``itertools.product(range(Lx), range(Ly), range(Lz))``.
340
+
341
+ Returns
342
+ -------
343
+ edges : list[((int, int, int, str), (int, int, int, str))]
344
+ """
345
+ if cells is None:
346
+ cells = itertools.product(range(Lx), range(Ly), range(Lz))
347
+
348
+ edges = []
349
+ for i, j, k in cells:
350
+ for *coob, lbl in [
351
+ (i, j, k, "B"),
352
+ (i, j - 1, k, "B"),
353
+ (i, j, k, "C"),
354
+ (i - 1, j, k, "C"),
355
+ (i, j, k, "D"),
356
+ (i, j, k - 1, "D"),
357
+ ]:
358
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
359
+ if coob:
360
+ edges.append(((i, j, k, "A"), (*coob, lbl)))
361
+
362
+ for *coob, lbl in [
363
+ (i, j, k, "C"),
364
+ (i - 1, j + 1, k, "C"),
365
+ (i, j, k, "D"),
366
+ (i, j + 1, k - 1, "D"),
367
+ (i, j, k, "A"),
368
+ (i, j + 1, k, "A"),
369
+ ]:
370
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
371
+ if coob:
372
+ edges.append(((i, j, k, "B"), (*coob, lbl)))
373
+
374
+ for *coob, lbl in [
375
+ (i, j, k, "D"),
376
+ (i + 1, j, k - 1, "D"),
377
+ (i, j, k, "A"),
378
+ (i + 1, j, k, "A"),
379
+ (i, j, k, "B"),
380
+ (i + 1, j - 1, k, "B"),
381
+ ]:
382
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
383
+ if coob:
384
+ edges.append(((i, j, k, "C"), (*coob, lbl)))
385
+
386
+ for *coob, lbl in [
387
+ (i, j, k, "A"),
388
+ (i, j, k + 1, "A"),
389
+ (i, j, k, "B"),
390
+ (i, j - 1, k + 1, "B"),
391
+ (i, j, k, "C"),
392
+ (i - 1, j, k + 1, "C"),
393
+ ]:
394
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
395
+ if coob:
396
+ edges.append(((i, j, k, "D"), (*coob, lbl)))
397
+
398
+ return sort_unique(edges)
399
+
400
+
401
+ def edges_3d_diamond(Lx, Ly, Lz, cyclic=False, cells=None):
402
+ """Return the graph edges of a finite 3D diamond lattice. There are
403
+ two sites per cell, and note the cells do not form a cubic tiling. The
404
+ nodes (sites) are labelled like ``(i, j, k, s)`` for ``s`` in ``'AB'``.
405
+
406
+ Parameters
407
+ ----------
408
+ Lx : int
409
+ The number of cells along the x-direction.
410
+ Ly : int
411
+ The number of cells along the y-direction.
412
+ Lz : int
413
+ The number of cells along the z-direction.
414
+ cyclic : bool, optional
415
+ Whether to use periodic boundary conditions.
416
+ cells : list, optional
417
+ A list of cells to use. If not given the cells used are
418
+ ``itertools.product(range(Lx), range(Ly), range(Lz))``.
419
+
420
+ Returns
421
+ -------
422
+ edges : list[((int, int, int, str), (int, int, int, str))]
423
+ """
424
+ if cells is None:
425
+ cells = itertools.product(range(Lx), range(Ly), range(Lz))
426
+
427
+ edges = []
428
+ for i, j, k in cells:
429
+ for *coob, lbl in [
430
+ (i, j, k, "B"),
431
+ ]:
432
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
433
+ if coob:
434
+ edges.append(((i, j, k, "A"), (*coob, lbl)))
435
+
436
+ for *coob, lbl in [
437
+ (i, j, k + 1, "A"),
438
+ (i, j + 1, k, "A"),
439
+ (i + 1, j, k, "A"),
440
+ ]:
441
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
442
+ if coob:
443
+ edges.append(((i, j, k, "B"), (*coob, lbl)))
444
+
445
+ return sort_unique(edges)
446
+
447
+
448
+ def edges_3d_diamond_cubic(Lx, Ly, Lz, cyclic=False, cells=None):
449
+ """Return the graph edges of a finite 3D diamond lattice tiled in a cubic
450
+ geometry. There are eight sites per cubic cell. The nodes (sites) are
451
+ labelled like ``(i, j, k, s)`` for ``s`` in ``'ABCDEFGH'``.
452
+
453
+ Parameters
454
+ ----------
455
+ Lx : int
456
+ The number of cells along the x-direction.
457
+ Ly : int
458
+ The number of cells along the y-direction.
459
+ Lz : int
460
+ The number of cells along the z-direction.
461
+ cyclic : bool, optional
462
+ Whether to use periodic boundary conditions.
463
+ cells : list, optional
464
+ A list of cells to use. If not given the cells used are
465
+ ``itertools.product(range(Lx), range(Ly), range(Lz))``.
466
+
467
+ Returns
468
+ -------
469
+ edges : list[((int, int, int, str), (int, int, int, str))]
470
+ """
471
+
472
+ if cells is None:
473
+ cells = itertools.product(range(Lx), range(Ly), range(Lz))
474
+
475
+ edges = []
476
+ for i, j, k in cells:
477
+ for *coob, lbl in [
478
+ (i, j, k, "E"),
479
+ ]:
480
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
481
+ if coob:
482
+ edges.append(((i, j, k, "A"), (*coob, lbl)))
483
+
484
+ for *coob, lbl in [
485
+ (i, j, k, "E"),
486
+ (i, j, k, "F"),
487
+ ]:
488
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
489
+ if coob:
490
+ edges.append(((i, j, k, "B"), (*coob, lbl)))
491
+
492
+ for *coob, lbl in [
493
+ (i, j, k, "E"),
494
+ (i, j, k, "G"),
495
+ ]:
496
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
497
+ if coob:
498
+ edges.append(((i, j, k, "C"), (*coob, lbl)))
499
+
500
+ for *coob, lbl in [
501
+ (i, j, k, "E"),
502
+ (i, j, k, "H"),
503
+ ]:
504
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
505
+ if coob:
506
+ edges.append(((i, j, k, "D"), (*coob, lbl)))
507
+
508
+ for *coob, lbl in []:
509
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
510
+ if coob:
511
+ edges.append(((i, j, k, "E"), (*coob, lbl)))
512
+
513
+ for *coob, lbl in [
514
+ (i, j + 1, k, "C"),
515
+ (i + 1, j, k, "D"),
516
+ ]:
517
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
518
+ if coob:
519
+ edges.append(((i, j, k, "F"), (*coob, lbl)))
520
+
521
+ for *coob, lbl in [
522
+ (i + 1, j, k + 1, "A"),
523
+ (i, j, k + 1, "B"),
524
+ (i + 1, j, k, "D"),
525
+ ]:
526
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
527
+ if coob:
528
+ edges.append(((i, j, k, "G"), (*coob, lbl)))
529
+
530
+ for *coob, lbl in [
531
+ (i, j + 1, k + 1, "A"),
532
+ (i, j, k + 1, "B"),
533
+ (i, j + 1, k, "C"),
534
+ ]:
535
+ coob = check_3d(coob, Lx, Ly, Lz, cyclic)
536
+ if coob:
537
+ edges.append(((i, j, k, "H"), (*coob, lbl)))
538
+
539
+ return sort_unique(edges)
540
+
541
+
542
+ def edges_tree_rand(n, max_degree=None, seed=None):
543
+ """Return a random tree with ``n`` nodes. This a convenience function for
544
+ testing purposes and the trees generated are not guaranteed to be uniformly
545
+ random (for that see ``networkx.random_labeled_tree``).
546
+
547
+ Parameters
548
+ ----------
549
+ n : int
550
+ The number of nodes.
551
+ max_degree : int, optional
552
+ The maximum degree of the nodes. For example ``max_degree=3`` means
553
+ generate a binary tree.
554
+ seed : int, optional
555
+ The random seed.
556
+
557
+ Returns
558
+ -------
559
+ edges : list[(int, int)]
560
+ """
561
+ rng = random.Random(seed)
562
+ edges = []
563
+
564
+ if max_degree is None:
565
+ nodes = [0]
566
+ for i in range(1, n):
567
+ ib = rng.choice(nodes)
568
+ nodes.append(i)
569
+ edges.append((ib, i))
570
+ else:
571
+ degrees = {0: 0}
572
+ for i in range(1, n):
573
+ ib = rng.choice(list(degrees))
574
+ edges.append((ib, i))
575
+
576
+ degrees[i] = 1
577
+ if degrees[ib] + 1 == max_degree:
578
+ # node is finished
579
+ degrees.pop(ib)
580
+ else:
581
+ degrees[ib] += 1
582
+
583
+ return edges
@@ -0,0 +1,114 @@
1
+ """Tools for interfacing the tensor and tensor network objects with other
2
+ libraries.
3
+ """
4
+
5
+ import functools
6
+
7
+ from ..utils import tree_map
8
+ from .tensor_core import Tensor, TensorNetwork
9
+
10
+
11
+ class Placeholder:
12
+ __slots__ = ("shape",)
13
+
14
+ def __init__(self, x):
15
+ self.shape = getattr(x, "shape", None)
16
+
17
+ def __repr__(self):
18
+ return f"Placeholder(shape={self.shape})"
19
+
20
+
21
+ def pack(obj):
22
+ """Take a tensor or tensor network like object and return a skeleton needed
23
+ to reconstruct it, and a pytree of raw parameters.
24
+
25
+ Parameters
26
+ ----------
27
+ obj : Tensor, TensorNetwork, or similar
28
+ Something that has ``copy``, ``set_params``, and ``get_params``
29
+ methods.
30
+
31
+ Returns
32
+ -------
33
+ params : pytree
34
+ A pytree of raw parameter arrays.
35
+ skeleton : Tensor, TensorNetwork, or similar
36
+ A copy of ``obj`` with all references to the original data removed.
37
+ """
38
+ try:
39
+ skeleton = obj.copy()
40
+ params = skeleton.get_params()
41
+ placeholders = tree_map(Placeholder, params)
42
+ skeleton.set_params(placeholders)
43
+ except AttributeError:
44
+ # assume it's a raw array
45
+ params = obj
46
+ skeleton = Placeholder(obj)
47
+ return params, skeleton
48
+
49
+
50
+ def unpack(params, skeleton):
51
+ """Take a skeleton of a tensor or tensor network like object and a pytree
52
+ of raw parameters and return a new reconstructed object with those
53
+ parameters inserted.
54
+
55
+ Parameters
56
+ ----------
57
+ params : pytree
58
+ A pytree of raw parameter arrays, with the same structure as the
59
+ output of ``skeleton.get_params()``.
60
+ skeleton : Tensor, TensorNetwork, or similar
61
+ Something that has ``copy``, ``set_params``, and ``get_params``
62
+ methods.
63
+
64
+ Returns
65
+ -------
66
+ obj : Tensor, TensorNetwork, or similar
67
+ A copy of ``skeleton`` with parameters inserted.
68
+ """
69
+ try:
70
+ obj = skeleton.copy()
71
+ obj.set_params(params)
72
+ except AttributeError:
73
+ # assume it's a raw array
74
+ obj = params
75
+ return obj
76
+
77
+
78
+ # -------------------------------- jax -------------------------------------- #
79
+
80
+
81
+ _JAX_REGISTERED_TN_CLASSES = set()
82
+
83
+
84
+ def jax_pack(obj):
85
+ # jax requires the top level children to be a tuple
86
+ params, aux = pack(obj)
87
+ children = (params,)
88
+ return children, aux
89
+
90
+
91
+ def jax_unpack(aux, children):
92
+ # jax also flips the return order from above
93
+ (params,) = children
94
+ return unpack(params, aux)
95
+
96
+
97
+ def jax_register_pytree():
98
+ import jax
99
+
100
+ queue = [Tensor, TensorNetwork]
101
+ while queue:
102
+ cls = queue.pop()
103
+ if cls not in _JAX_REGISTERED_TN_CLASSES:
104
+ jax.tree_util.register_pytree_node(cls, jax_pack, jax_unpack)
105
+ _JAX_REGISTERED_TN_CLASSES.add(cls)
106
+ queue.extend(cls.__subclasses__())
107
+
108
+
109
+ @functools.lru_cache(1)
110
+ def get_jax():
111
+ import jax
112
+
113
+ jax_register_pytree()
114
+ return jax