Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +3 -0
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
  7. trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
  8. trajectree/quimb/docs/conf.py +158 -0
  9. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
  10. trajectree/quimb/quimb/__init__.py +507 -0
  11. trajectree/quimb/quimb/calc.py +1491 -0
  12. trajectree/quimb/quimb/core.py +2279 -0
  13. trajectree/quimb/quimb/evo.py +712 -0
  14. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  15. trajectree/quimb/quimb/experimental/autojittn.py +129 -0
  16. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
  17. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
  18. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
  19. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
  20. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
  21. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
  22. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
  23. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
  24. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
  25. trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
  26. trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
  27. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
  28. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
  29. trajectree/quimb/quimb/experimental/schematic.py +7 -0
  30. trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
  31. trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
  32. trajectree/quimb/quimb/gates.py +36 -0
  33. trajectree/quimb/quimb/gen/__init__.py +2 -0
  34. trajectree/quimb/quimb/gen/operators.py +1167 -0
  35. trajectree/quimb/quimb/gen/rand.py +713 -0
  36. trajectree/quimb/quimb/gen/states.py +479 -0
  37. trajectree/quimb/quimb/linalg/__init__.py +6 -0
  38. trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
  39. trajectree/quimb/quimb/linalg/autoblock.py +258 -0
  40. trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
  41. trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
  42. trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
  43. trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
  44. trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
  45. trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
  46. trajectree/quimb/quimb/schematic.py +1518 -0
  47. trajectree/quimb/quimb/tensor/__init__.py +401 -0
  48. trajectree/quimb/quimb/tensor/array_ops.py +610 -0
  49. trajectree/quimb/quimb/tensor/circuit.py +4824 -0
  50. trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
  51. trajectree/quimb/quimb/tensor/contraction.py +336 -0
  52. trajectree/quimb/quimb/tensor/decomp.py +1255 -0
  53. trajectree/quimb/quimb/tensor/drawing.py +1646 -0
  54. trajectree/quimb/quimb/tensor/fitting.py +385 -0
  55. trajectree/quimb/quimb/tensor/geometry.py +583 -0
  56. trajectree/quimb/quimb/tensor/interface.py +114 -0
  57. trajectree/quimb/quimb/tensor/networking.py +1058 -0
  58. trajectree/quimb/quimb/tensor/optimize.py +1818 -0
  59. trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
  60. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
  61. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
  62. trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
  63. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
  64. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
  65. trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
  66. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
  67. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
  68. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
  69. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
  70. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
  71. trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
  72. trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
  73. trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
  74. trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
  75. trajectree/quimb/quimb/utils.py +892 -0
  76. trajectree/quimb/tests/__init__.py +0 -0
  77. trajectree/quimb/tests/test_accel.py +501 -0
  78. trajectree/quimb/tests/test_calc.py +788 -0
  79. trajectree/quimb/tests/test_core.py +847 -0
  80. trajectree/quimb/tests/test_evo.py +565 -0
  81. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  82. trajectree/quimb/tests/test_gen/test_operators.py +361 -0
  83. trajectree/quimb/tests/test_gen/test_rand.py +296 -0
  84. trajectree/quimb/tests/test_gen/test_states.py +261 -0
  85. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  86. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
  87. trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
  88. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
  89. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
  90. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
  91. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
  92. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  93. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  94. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
  95. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
  100. trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
  101. trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
  102. trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
  103. trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
  104. trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
  105. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
  106. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
  107. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
  108. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
  109. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
  110. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
  111. trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
  112. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
  113. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
  114. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
  115. trajectree/quimb/tests/test_utils.py +85 -0
  116. trajectree/trajectory.py +2 -2
  117. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
  118. trajectree-0.0.1.dist-info/RECORD +126 -0
  119. trajectree-0.0.0.dist-info/RECORD +0 -16
  120. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,610 @@
1
+ """Backend agnostic array operations."""
2
+
3
+ import functools
4
+ import itertools
5
+
6
+ import numpy
7
+ from autoray import (
8
+ compose,
9
+ do,
10
+ get_dtype_name,
11
+ get_lib_fn,
12
+ infer_backend,
13
+ reshape,
14
+ )
15
+
16
+ from ..core import njit, qarray
17
+ from ..linalg.base_linalg import norm_fro_dense
18
+ from ..utils import compose as fn_compose
19
+
20
+
21
+ def asarray(array):
22
+ """Maybe convert data for a tensor to use. If ``array`` already has a
23
+ ``.shape`` attribute, i.e. looks like an array, it is left as-is. Else the
24
+ elements are inspected to see which libraries' array constructor should be
25
+ used, defaulting to ``numpy`` if everything is builtin or numpy numbers.
26
+ """
27
+ if isinstance(array, (numpy.matrix, qarray)):
28
+ # if numpy make sure array not subclass
29
+ return numpy.asarray(array)
30
+
31
+ if hasattr(array, "shape"):
32
+ # otherwise don't touch things which are already array like
33
+ return array
34
+
35
+ # else we some kind of possibly nested python iterable -> inspect items
36
+ backends = set()
37
+
38
+ def _nd_py_iter(x):
39
+ if isinstance(x, str):
40
+ # handle recursion error
41
+ return x
42
+
43
+ backend = infer_backend(x)
44
+ if backend != "builtins":
45
+ # don't iterate any non-builtin containers
46
+ backends.add(backend)
47
+ return x
48
+
49
+ # is some kind of python container or element -> iterate or return
50
+ try:
51
+ return list(_nd_py_iter(sub) for sub in x)
52
+ except TypeError:
53
+ return x
54
+
55
+ nested_tup = _nd_py_iter(array)
56
+
57
+ # numpy and builtin elements treat as basic
58
+ backends -= {"builtins", "numpy"}
59
+ if not backends:
60
+ backend = "numpy"
61
+ else:
62
+ (backend,) = backends
63
+
64
+ return do("array", nested_tup, like=backend)
65
+
66
+
67
+ _blocksparselookup = {}
68
+
69
+
70
+ def isblocksparse(x):
71
+ """Check if `x` is a block-sparse array. Cached on class for speed."""
72
+ try:
73
+ return _blocksparselookup[x.__class__]
74
+ except KeyError:
75
+ # XXX: make this a more established interface
76
+ isbs = hasattr(x, "align_axes")
77
+ _blocksparselookup[x.__class__] = isbs
78
+ return isbs
79
+
80
+
81
+ _fermioniclookup = {}
82
+
83
+
84
+ def isfermionic(x):
85
+ """Check if `x` is a fermionic array. Cached on class for speed."""
86
+ try:
87
+ return _fermioniclookup[x.__class__]
88
+ except KeyError:
89
+ # XXX: make this a more established interface
90
+ isf = hasattr(x, "phase_flip")
91
+ _fermioniclookup[x.__class__] = isf
92
+ return isf
93
+
94
+
95
+ @functools.lru_cache(2**14)
96
+ def calc_fuse_perm_and_shape(shape, axes_groups):
97
+ ndim = len(shape)
98
+
99
+ # which group does each axis appear in, if any
100
+ num_groups = len(axes_groups)
101
+ ax2group = {ax: g for g, axes in enumerate(axes_groups) for ax in axes}
102
+
103
+ # the permutation will be the same for every block: precalculate
104
+ # n.b. all new groups will be inserted at the *first fused axis*
105
+ position = min(g for gax in axes_groups for g in gax)
106
+ axes_before = tuple(
107
+ ax for ax in range(position) if ax2group.setdefault(ax, None) is None
108
+ )
109
+ axes_after = tuple(
110
+ ax
111
+ for ax in range(position, ndim)
112
+ if ax2group.setdefault(ax, None) is None
113
+ )
114
+ perm = (*axes_before, *(ax for g in axes_groups for ax in g), *axes_after)
115
+
116
+ # track where each axis will be in the new array
117
+ new_axes = {ax: ax for ax in axes_before}
118
+ for i, g in enumerate(axes_groups):
119
+ for ax in g:
120
+ new_axes[ax] = position + i
121
+ for i, ax in enumerate(axes_after):
122
+ new_axes[ax] = position + num_groups + i
123
+ new_ndim = len(axes_before) + num_groups + len(axes_after)
124
+
125
+ new_shape = [1] * new_ndim
126
+ for i, d in enumerate(shape):
127
+ g = ax2group[i]
128
+ new_ax = new_axes[i]
129
+ if g is None:
130
+ # not fusing, new value is just copied
131
+ new_shape[new_ax] = d
132
+ else:
133
+ # fusing: need to accumulate
134
+ new_shape[new_ax] *= d
135
+
136
+ if all(i == ax for i, ax in enumerate(perm)):
137
+ # no need to transpose
138
+ perm = None
139
+
140
+ new_shape = tuple(new_shape)
141
+ if shape == new_shape:
142
+ # no need to reshape
143
+ new_shape = None
144
+
145
+ return perm, new_shape
146
+
147
+
148
+ @compose
149
+ def fuse(x, *axes_groups, backend=None):
150
+ """Fuse the give group or groups of axes. The new fused axes will be
151
+ inserted at the minimum index of any fused axis (even if it is not in
152
+ the first group). For example, ``fuse(x, [5, 3], [7, 2, 6])`` will
153
+ produce an array with axes like::
154
+
155
+ groups inserted at axis 2, removed beyond that.
156
+ ......<--
157
+ (0, 1, g0, g1, 4, 8, ...)
158
+ | |
159
+ | g1=(7, 2, 6)
160
+ g0=(5, 3)
161
+
162
+ Parameters
163
+ ----------
164
+ axes_groups : sequence of sequences of int
165
+ The axes to fuse. Each group of axes will be fused into a single
166
+ axis.
167
+ """
168
+ if backend is None:
169
+ backend = infer_backend(x)
170
+ _transpose = get_lib_fn(backend, "transpose")
171
+ _reshape = get_lib_fn(backend, "reshape")
172
+
173
+ axes_groups = tuple(map(tuple, axes_groups))
174
+ if not any(axes_groups):
175
+ return x
176
+
177
+ shape = tuple(map(int, x.shape))
178
+ perm, new_shape = calc_fuse_perm_and_shape(shape, axes_groups)
179
+
180
+ if perm is not None:
181
+ x = _transpose(x, perm)
182
+ if new_shape is not None:
183
+ x = _reshape(x, new_shape)
184
+
185
+ return x
186
+
187
+
188
+ def ndim(array):
189
+ """The number of dimensions of an array."""
190
+ try:
191
+ return array.ndim
192
+ except AttributeError:
193
+ return len(array.shape)
194
+
195
+
196
+ @compose
197
+ def multiply_diagonal(x, v, axis, backend=None):
198
+ """Multiply v into x as if contracting in a diagonal matrix."""
199
+ newshape = tuple((-1 if i == axis else 1) for i in range(ndim(x)))
200
+ v_broadcast = do("reshape", v, newshape, like=backend)
201
+ return x * v_broadcast
202
+
203
+
204
+ @compose
205
+ def align_axes(*arrays, axes, backend=None):
206
+ """Prepare a set of arrays that should be contractible along ``axes``.
207
+
208
+ For example, block symmetric arrays need to have aligned sectors prior to
209
+ fusing.
210
+ """
211
+ # default implementation is nothing
212
+ return arrays
213
+
214
+
215
+ # ------------- miscelleneous other backend agnostic functions -------------- #
216
+
217
+
218
+ def iscomplex(x):
219
+ """Does ``x`` have a complex dtype?"""
220
+ if infer_backend(x) == "builtins":
221
+ return isinstance(x, complex)
222
+ return "complex" in get_dtype_name(x)
223
+
224
+
225
+ @compose
226
+ def norm_fro(x):
227
+ """The frobenius norm of an array."""
228
+ try:
229
+ return do("linalg.norm", reshape(x, (-1,)))
230
+ except AttributeError:
231
+ return do("sum", do("abs", x) ** 2) ** 0.5
232
+
233
+
234
+ norm_fro.register("numpy", norm_fro_dense)
235
+
236
+
237
+ def sensibly_scale(x):
238
+ """Take an array and scale it *very* roughly such that random tensor
239
+ networks consisting of such arrays do not have gigantic norms.
240
+ """
241
+ return x / norm_fro(x) ** (1.5 / ndim(x))
242
+
243
+
244
+ @njit
245
+ def _numba_find_diag_axes(x, atol=1e-12): # pragma: no cover
246
+ """Numba-compiled array diagonal axis finder.
247
+
248
+ Parameters
249
+ ----------
250
+ x : numpy.ndarray
251
+ The array to search for diagonal axes.
252
+ atol : float
253
+ The tolerance with which to compare to zero.
254
+
255
+ Returns
256
+ -------
257
+ diag_axes : set[tuple[int]]
258
+ The set of pairs of axes which are diagonal.
259
+ """
260
+
261
+ # create the set of pairs of matching size axes
262
+ diag_axes = set()
263
+ for d1 in range(x.ndim - 1):
264
+ for d2 in range(d1 + 1, x.ndim):
265
+ if x.shape[d1] == x.shape[d2]:
266
+ diag_axes.add((d1, d2))
267
+
268
+ # enumerate through every array entry, eagerly invalidating axis pairs
269
+ for index, val in numpy.ndenumerate(x):
270
+ for d1, d2 in list(diag_axes):
271
+ if (index[d1] != index[d2]) and (abs(val) > atol):
272
+ diag_axes.remove((d1, d2))
273
+
274
+ # all pairs invalid, nothing left to do
275
+ if len(diag_axes) == 0:
276
+ break
277
+
278
+ return diag_axes
279
+
280
+
281
+ def find_diag_axes(x, atol=1e-12):
282
+ """Try and find a pair of axes of ``x`` in which it is diagonal.
283
+
284
+ Parameters
285
+ ----------
286
+ x : array-like
287
+ The array to search.
288
+ atol : float, optional
289
+ Tolerance with which to compare to zero.
290
+
291
+ Returns
292
+ -------
293
+ tuple[int] or None
294
+ The two axes if found else None.
295
+
296
+ Examples
297
+ --------
298
+
299
+ >>> x = np.array([[[1, 0], [0, 2]],
300
+ ... [[3, 0], [0, 4]]])
301
+ >>> find_diag_axes(x)
302
+ (1, 2)
303
+
304
+ Which means we can reduce ``x`` without loss of information to:
305
+
306
+ >>> np.einsum('abb->ab', x)
307
+ array([[1, 2],
308
+ [3, 4]])
309
+
310
+ """
311
+ shape = x.shape
312
+ if len(shape) < 2:
313
+ return None
314
+
315
+ backend = infer_backend(x)
316
+ zero = do("zeros", (), like=x)
317
+
318
+ # use numba-accelerated version for numpy arrays
319
+ if backend == "numpy":
320
+ diag_axes = _numba_find_diag_axes(x, atol=atol)
321
+ if diag_axes:
322
+ # make it determinstic
323
+ return min(diag_axes)
324
+ return None
325
+ indxrs = do("indices", shape, like=backend)
326
+
327
+ for i, j in itertools.combinations(range(len(shape)), 2):
328
+ if shape[i] != shape[j]:
329
+ continue
330
+ if do(
331
+ "allclose",
332
+ x[indxrs[i] != indxrs[j]],
333
+ zero,
334
+ atol=atol,
335
+ like=backend,
336
+ ):
337
+ return (i, j)
338
+ return None
339
+
340
+
341
+ @njit
342
+ def _numba_find_antidiag_axes(x, atol=1e-12): # pragma: no cover
343
+ """Numba-compiled array antidiagonal axis finder.
344
+
345
+ Parameters
346
+ ----------
347
+ x : numpy.ndarray
348
+ The array to search for anti-diagonal axes.
349
+ atol : float
350
+ The tolerance with which to compare to zero.
351
+
352
+ Returns
353
+ -------
354
+ antidiag_axes : set[tuple[int]]
355
+ The set of pairs of axes which are anti-diagonal.
356
+ """
357
+
358
+ # create the set of pairs of matching size axes
359
+ antidiag_axes = set()
360
+ for i in range(x.ndim - 1):
361
+ for j in range(i + 1, x.ndim):
362
+ if x.shape[i] == x.shape[j]:
363
+ antidiag_axes.add((i, j))
364
+
365
+ # enumerate through every array entry, eagerly invalidating axis pairs
366
+ for index, val in numpy.ndenumerate(x):
367
+ for i, j in list(antidiag_axes):
368
+ d = x.shape[i]
369
+ if (index[i] != d - 1 - index[j]) and (abs(val) > atol):
370
+ antidiag_axes.remove((i, j))
371
+
372
+ # all pairs invalid, nothing left to do
373
+ if len(antidiag_axes) == 0:
374
+ break
375
+
376
+ return antidiag_axes
377
+
378
+
379
+ def find_antidiag_axes(x, atol=1e-12):
380
+ """Try and find a pair of axes of ``x`` in which it is anti-diagonal.
381
+
382
+ Parameters
383
+ ----------
384
+ x : array-like
385
+ The array to search.
386
+ atol : float, optional
387
+ Tolerance with which to compare to zero.
388
+
389
+ Returns
390
+ -------
391
+ tuple[int] or None
392
+ The two axes if found else None.
393
+
394
+ Examples
395
+ --------
396
+
397
+ >>> x = np.array([[[0, 1], [0, 2]],
398
+ ... [[3, 0], [4, 0]]])
399
+ >>> find_antidiag_axes(x)
400
+ (0, 2)
401
+
402
+ Which means we can reduce ``x`` without loss of information to:
403
+
404
+ >>> np.einsum('aba->ab', x[::-1, :, :])
405
+ array([[3, 4],
406
+ [1, 2]])
407
+
408
+ as long as we flip the order of dimensions on other tensors corresponding
409
+ to the the same index.
410
+ """
411
+ shape = x.shape
412
+ if len(shape) < 2:
413
+ return None
414
+
415
+ backend = infer_backend(x)
416
+
417
+ # use numba-accelerated version for numpy arrays
418
+ if backend == "numpy":
419
+ antidiag_axes = _numba_find_antidiag_axes(x, atol=atol)
420
+ if antidiag_axes:
421
+ # make it determinstic
422
+ return min(antidiag_axes)
423
+ return None
424
+
425
+ indxrs = do("indices", shape, like=backend)
426
+ zero = do("zeros", (), like=x)
427
+
428
+ for i, j in itertools.combinations(range(len(shape)), 2):
429
+ di, dj = shape[i], shape[j]
430
+ if di != dj:
431
+ continue
432
+ if do(
433
+ "allclose",
434
+ x[indxrs[i] != dj - 1 - indxrs[j]],
435
+ zero,
436
+ atol=atol,
437
+ like=backend,
438
+ ):
439
+ return (i, j)
440
+ return None
441
+
442
+
443
+ @njit
444
+ def _numba_find_columns(x, atol=1e-12): # pragma: no cover
445
+ """Numba-compiled single non-zero column axis finder.
446
+
447
+ Parameters
448
+ ----------
449
+ x : array
450
+ The array to search.
451
+ atol : float, optional
452
+ Absolute tolerance to compare to zero with.
453
+
454
+ Returns
455
+ -------
456
+ set[tuple[int]]
457
+ Set of pairs (axis, index) defining lone non-zero columns.
458
+ """
459
+
460
+ # possible pairings of axis + index
461
+ column_pairs = set()
462
+ for ax, d in enumerate(x.shape):
463
+ for i in range(d):
464
+ column_pairs.add((ax, i))
465
+
466
+ # enumerate over all array entries, invalidating potential column pairs
467
+ for index, val in numpy.ndenumerate(x):
468
+ if abs(val) > atol:
469
+ for ax, i in enumerate(index):
470
+ for pax, pi in list(column_pairs):
471
+ if ax == pax and pi != i:
472
+ column_pairs.remove((pax, pi))
473
+
474
+ # all potential pairs invalidated
475
+ if not len(column_pairs):
476
+ break
477
+
478
+ return column_pairs
479
+
480
+
481
+ def find_columns(x, atol=1e-12):
482
+ """Try and find columns of axes which are zero apart from a single index.
483
+
484
+ Parameters
485
+ ----------
486
+ x : array-like
487
+ The array to search.
488
+ atol : float, optional
489
+ Tolerance with which to compare to zero.
490
+
491
+ Returns
492
+ -------
493
+ tuple[int] or None
494
+ If found, the first integer is which axis, and the second is which
495
+ column of that axis, else None.
496
+
497
+ Examples
498
+ --------
499
+
500
+ >>> x = np.array([[[0, 1], [0, 2]],
501
+ ... [[0, 3], [0, 4]]])
502
+ >>> find_columns(x)
503
+ (2, 1)
504
+
505
+ Which means we can happily slice ``x`` without loss of information to:
506
+
507
+ >>> x[:, :, 1]
508
+ array([[1, 2],
509
+ [3, 4]])
510
+
511
+ """
512
+ shape = x.shape
513
+ if len(shape) < 1:
514
+ return None
515
+
516
+ backend = infer_backend(x)
517
+
518
+ # use numba-accelerated version for numpy arrays
519
+ if backend == "numpy":
520
+ columns_pairs = _numba_find_columns(x, atol)
521
+ if columns_pairs:
522
+ return min(columns_pairs)
523
+ return None
524
+
525
+ indxrs = do("indices", shape, like=backend)
526
+ zero = do("zeros", (), like=x)
527
+
528
+ for i in range(len(shape)):
529
+ for j in range(shape[i]):
530
+ if do(
531
+ "allclose", x[indxrs[i] != j], zero, atol=atol, like=backend
532
+ ):
533
+ return (i, j)
534
+
535
+ return None
536
+
537
+
538
+ class PArray:
539
+ """Simple array-like object that lazily generates the actual array by
540
+ calling a function with a set of parameters.
541
+
542
+ Parameters
543
+ ----------
544
+ fn : callable
545
+ The function that generates the tensor data from ``params``.
546
+ params : sequence of numbers
547
+ The initial parameters supplied to the generating function like
548
+ ``fn(params)``.
549
+
550
+ See Also
551
+ --------
552
+ PTensor
553
+ """
554
+
555
+ __slots__ = ("_fn", "_params", "_data", "_shape", "_shape_fn_id")
556
+
557
+ def __init__(self, fn, params, shape=None):
558
+ self.fn = fn
559
+ self.params = params
560
+ self._shape = shape
561
+ self._shape_fn_id = id(fn)
562
+
563
+ def copy(self):
564
+ new = PArray(self.fn, self.params, self.shape)
565
+ new._data = self._data # for efficiency
566
+ return new
567
+
568
+ @property
569
+ def fn(self):
570
+ return self._fn
571
+
572
+ @fn.setter
573
+ def fn(self, x):
574
+ self._fn = x
575
+ self._data = None
576
+
577
+ @property
578
+ def params(self):
579
+ return self._params
580
+
581
+ @params.setter
582
+ def params(self, x):
583
+ self._params = asarray(x)
584
+ self._data = None
585
+
586
+ @property
587
+ def data(self):
588
+ if self._data is None:
589
+ self._data = self._fn(self._params)
590
+ return self._data
591
+
592
+ @property
593
+ def shape(self):
594
+ # if we haven't calculated shape or have updated function, get shape
595
+ _shape_fn_id = id(self.fn)
596
+ if (self._shape is None) or (self._shape_fn_id != _shape_fn_id):
597
+ self._shape = self.data.shape
598
+ self._shape_fn_id = _shape_fn_id
599
+ return self._shape
600
+
601
+ @property
602
+ def ndim(self):
603
+ return len(self.shape)
604
+
605
+ def add_function(self, g):
606
+ """Chain the new function ``g`` on top of current function ``f`` like
607
+ ``g(f(params))``.
608
+ """
609
+ f = self.fn
610
+ self.fn = fn_compose(g, f)