Trajectree 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/trajectory.py +2 -2
  7. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/METADATA +2 -3
  8. trajectree-0.0.2.dist-info/RECORD +16 -0
  9. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  10. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  11. trajectree/quimb/docs/conf.py +0 -158
  12. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  13. trajectree/quimb/quimb/__init__.py +0 -507
  14. trajectree/quimb/quimb/calc.py +0 -1491
  15. trajectree/quimb/quimb/core.py +0 -2279
  16. trajectree/quimb/quimb/evo.py +0 -712
  17. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  18. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  19. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  20. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  21. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  22. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  23. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  24. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  25. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  26. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  27. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  28. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  29. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  30. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  31. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  32. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  33. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  34. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  35. trajectree/quimb/quimb/gates.py +0 -36
  36. trajectree/quimb/quimb/gen/__init__.py +0 -2
  37. trajectree/quimb/quimb/gen/operators.py +0 -1167
  38. trajectree/quimb/quimb/gen/rand.py +0 -713
  39. trajectree/quimb/quimb/gen/states.py +0 -479
  40. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  41. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  42. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  43. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  44. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  45. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  46. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  47. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  48. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  49. trajectree/quimb/quimb/schematic.py +0 -1518
  50. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  51. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  52. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  53. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  54. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  55. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  56. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  57. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  58. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  59. trajectree/quimb/quimb/tensor/interface.py +0 -114
  60. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  61. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  62. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  63. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  64. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  65. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  66. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  67. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  68. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  69. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  70. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  71. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  72. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  74. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  75. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  76. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  77. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  78. trajectree/quimb/quimb/utils.py +0 -892
  79. trajectree/quimb/tests/__init__.py +0 -0
  80. trajectree/quimb/tests/test_accel.py +0 -501
  81. trajectree/quimb/tests/test_calc.py +0 -788
  82. trajectree/quimb/tests/test_core.py +0 -847
  83. trajectree/quimb/tests/test_evo.py +0 -565
  84. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  85. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  86. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  87. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  88. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  89. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  90. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  91. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  92. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  93. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  94. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  95. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  103. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  104. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  105. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  106. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  107. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  108. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  109. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  110. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  111. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  112. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  113. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  114. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  115. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  116. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  117. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  118. trajectree/quimb/tests/test_utils.py +0 -85
  119. trajectree-0.0.1.dist-info/RECORD +0 -126
  120. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/top_level.txt +0 -0
@@ -1,610 +0,0 @@
1
- """Backend agnostic array operations."""
2
-
3
- import functools
4
- import itertools
5
-
6
- import numpy
7
- from autoray import (
8
- compose,
9
- do,
10
- get_dtype_name,
11
- get_lib_fn,
12
- infer_backend,
13
- reshape,
14
- )
15
-
16
- from ..core import njit, qarray
17
- from ..linalg.base_linalg import norm_fro_dense
18
- from ..utils import compose as fn_compose
19
-
20
-
21
- def asarray(array):
22
- """Maybe convert data for a tensor to use. If ``array`` already has a
23
- ``.shape`` attribute, i.e. looks like an array, it is left as-is. Else the
24
- elements are inspected to see which libraries' array constructor should be
25
- used, defaulting to ``numpy`` if everything is builtin or numpy numbers.
26
- """
27
- if isinstance(array, (numpy.matrix, qarray)):
28
- # if numpy make sure array not subclass
29
- return numpy.asarray(array)
30
-
31
- if hasattr(array, "shape"):
32
- # otherwise don't touch things which are already array like
33
- return array
34
-
35
- # else we some kind of possibly nested python iterable -> inspect items
36
- backends = set()
37
-
38
- def _nd_py_iter(x):
39
- if isinstance(x, str):
40
- # handle recursion error
41
- return x
42
-
43
- backend = infer_backend(x)
44
- if backend != "builtins":
45
- # don't iterate any non-builtin containers
46
- backends.add(backend)
47
- return x
48
-
49
- # is some kind of python container or element -> iterate or return
50
- try:
51
- return list(_nd_py_iter(sub) for sub in x)
52
- except TypeError:
53
- return x
54
-
55
- nested_tup = _nd_py_iter(array)
56
-
57
- # numpy and builtin elements treat as basic
58
- backends -= {"builtins", "numpy"}
59
- if not backends:
60
- backend = "numpy"
61
- else:
62
- (backend,) = backends
63
-
64
- return do("array", nested_tup, like=backend)
65
-
66
-
67
- _blocksparselookup = {}
68
-
69
-
70
- def isblocksparse(x):
71
- """Check if `x` is a block-sparse array. Cached on class for speed."""
72
- try:
73
- return _blocksparselookup[x.__class__]
74
- except KeyError:
75
- # XXX: make this a more established interface
76
- isbs = hasattr(x, "align_axes")
77
- _blocksparselookup[x.__class__] = isbs
78
- return isbs
79
-
80
-
81
- _fermioniclookup = {}
82
-
83
-
84
- def isfermionic(x):
85
- """Check if `x` is a fermionic array. Cached on class for speed."""
86
- try:
87
- return _fermioniclookup[x.__class__]
88
- except KeyError:
89
- # XXX: make this a more established interface
90
- isf = hasattr(x, "phase_flip")
91
- _fermioniclookup[x.__class__] = isf
92
- return isf
93
-
94
-
95
- @functools.lru_cache(2**14)
96
- def calc_fuse_perm_and_shape(shape, axes_groups):
97
- ndim = len(shape)
98
-
99
- # which group does each axis appear in, if any
100
- num_groups = len(axes_groups)
101
- ax2group = {ax: g for g, axes in enumerate(axes_groups) for ax in axes}
102
-
103
- # the permutation will be the same for every block: precalculate
104
- # n.b. all new groups will be inserted at the *first fused axis*
105
- position = min(g for gax in axes_groups for g in gax)
106
- axes_before = tuple(
107
- ax for ax in range(position) if ax2group.setdefault(ax, None) is None
108
- )
109
- axes_after = tuple(
110
- ax
111
- for ax in range(position, ndim)
112
- if ax2group.setdefault(ax, None) is None
113
- )
114
- perm = (*axes_before, *(ax for g in axes_groups for ax in g), *axes_after)
115
-
116
- # track where each axis will be in the new array
117
- new_axes = {ax: ax for ax in axes_before}
118
- for i, g in enumerate(axes_groups):
119
- for ax in g:
120
- new_axes[ax] = position + i
121
- for i, ax in enumerate(axes_after):
122
- new_axes[ax] = position + num_groups + i
123
- new_ndim = len(axes_before) + num_groups + len(axes_after)
124
-
125
- new_shape = [1] * new_ndim
126
- for i, d in enumerate(shape):
127
- g = ax2group[i]
128
- new_ax = new_axes[i]
129
- if g is None:
130
- # not fusing, new value is just copied
131
- new_shape[new_ax] = d
132
- else:
133
- # fusing: need to accumulate
134
- new_shape[new_ax] *= d
135
-
136
- if all(i == ax for i, ax in enumerate(perm)):
137
- # no need to transpose
138
- perm = None
139
-
140
- new_shape = tuple(new_shape)
141
- if shape == new_shape:
142
- # no need to reshape
143
- new_shape = None
144
-
145
- return perm, new_shape
146
-
147
-
148
- @compose
149
- def fuse(x, *axes_groups, backend=None):
150
- """Fuse the give group or groups of axes. The new fused axes will be
151
- inserted at the minimum index of any fused axis (even if it is not in
152
- the first group). For example, ``fuse(x, [5, 3], [7, 2, 6])`` will
153
- produce an array with axes like::
154
-
155
- groups inserted at axis 2, removed beyond that.
156
- ......<--
157
- (0, 1, g0, g1, 4, 8, ...)
158
- | |
159
- | g1=(7, 2, 6)
160
- g0=(5, 3)
161
-
162
- Parameters
163
- ----------
164
- axes_groups : sequence of sequences of int
165
- The axes to fuse. Each group of axes will be fused into a single
166
- axis.
167
- """
168
- if backend is None:
169
- backend = infer_backend(x)
170
- _transpose = get_lib_fn(backend, "transpose")
171
- _reshape = get_lib_fn(backend, "reshape")
172
-
173
- axes_groups = tuple(map(tuple, axes_groups))
174
- if not any(axes_groups):
175
- return x
176
-
177
- shape = tuple(map(int, x.shape))
178
- perm, new_shape = calc_fuse_perm_and_shape(shape, axes_groups)
179
-
180
- if perm is not None:
181
- x = _transpose(x, perm)
182
- if new_shape is not None:
183
- x = _reshape(x, new_shape)
184
-
185
- return x
186
-
187
-
188
- def ndim(array):
189
- """The number of dimensions of an array."""
190
- try:
191
- return array.ndim
192
- except AttributeError:
193
- return len(array.shape)
194
-
195
-
196
- @compose
197
- def multiply_diagonal(x, v, axis, backend=None):
198
- """Multiply v into x as if contracting in a diagonal matrix."""
199
- newshape = tuple((-1 if i == axis else 1) for i in range(ndim(x)))
200
- v_broadcast = do("reshape", v, newshape, like=backend)
201
- return x * v_broadcast
202
-
203
-
204
- @compose
205
- def align_axes(*arrays, axes, backend=None):
206
- """Prepare a set of arrays that should be contractible along ``axes``.
207
-
208
- For example, block symmetric arrays need to have aligned sectors prior to
209
- fusing.
210
- """
211
- # default implementation is nothing
212
- return arrays
213
-
214
-
215
- # ------------- miscelleneous other backend agnostic functions -------------- #
216
-
217
-
218
- def iscomplex(x):
219
- """Does ``x`` have a complex dtype?"""
220
- if infer_backend(x) == "builtins":
221
- return isinstance(x, complex)
222
- return "complex" in get_dtype_name(x)
223
-
224
-
225
- @compose
226
- def norm_fro(x):
227
- """The frobenius norm of an array."""
228
- try:
229
- return do("linalg.norm", reshape(x, (-1,)))
230
- except AttributeError:
231
- return do("sum", do("abs", x) ** 2) ** 0.5
232
-
233
-
234
- norm_fro.register("numpy", norm_fro_dense)
235
-
236
-
237
- def sensibly_scale(x):
238
- """Take an array and scale it *very* roughly such that random tensor
239
- networks consisting of such arrays do not have gigantic norms.
240
- """
241
- return x / norm_fro(x) ** (1.5 / ndim(x))
242
-
243
-
244
- @njit
245
- def _numba_find_diag_axes(x, atol=1e-12): # pragma: no cover
246
- """Numba-compiled array diagonal axis finder.
247
-
248
- Parameters
249
- ----------
250
- x : numpy.ndarray
251
- The array to search for diagonal axes.
252
- atol : float
253
- The tolerance with which to compare to zero.
254
-
255
- Returns
256
- -------
257
- diag_axes : set[tuple[int]]
258
- The set of pairs of axes which are diagonal.
259
- """
260
-
261
- # create the set of pairs of matching size axes
262
- diag_axes = set()
263
- for d1 in range(x.ndim - 1):
264
- for d2 in range(d1 + 1, x.ndim):
265
- if x.shape[d1] == x.shape[d2]:
266
- diag_axes.add((d1, d2))
267
-
268
- # enumerate through every array entry, eagerly invalidating axis pairs
269
- for index, val in numpy.ndenumerate(x):
270
- for d1, d2 in list(diag_axes):
271
- if (index[d1] != index[d2]) and (abs(val) > atol):
272
- diag_axes.remove((d1, d2))
273
-
274
- # all pairs invalid, nothing left to do
275
- if len(diag_axes) == 0:
276
- break
277
-
278
- return diag_axes
279
-
280
-
281
- def find_diag_axes(x, atol=1e-12):
282
- """Try and find a pair of axes of ``x`` in which it is diagonal.
283
-
284
- Parameters
285
- ----------
286
- x : array-like
287
- The array to search.
288
- atol : float, optional
289
- Tolerance with which to compare to zero.
290
-
291
- Returns
292
- -------
293
- tuple[int] or None
294
- The two axes if found else None.
295
-
296
- Examples
297
- --------
298
-
299
- >>> x = np.array([[[1, 0], [0, 2]],
300
- ... [[3, 0], [0, 4]]])
301
- >>> find_diag_axes(x)
302
- (1, 2)
303
-
304
- Which means we can reduce ``x`` without loss of information to:
305
-
306
- >>> np.einsum('abb->ab', x)
307
- array([[1, 2],
308
- [3, 4]])
309
-
310
- """
311
- shape = x.shape
312
- if len(shape) < 2:
313
- return None
314
-
315
- backend = infer_backend(x)
316
- zero = do("zeros", (), like=x)
317
-
318
- # use numba-accelerated version for numpy arrays
319
- if backend == "numpy":
320
- diag_axes = _numba_find_diag_axes(x, atol=atol)
321
- if diag_axes:
322
- # make it determinstic
323
- return min(diag_axes)
324
- return None
325
- indxrs = do("indices", shape, like=backend)
326
-
327
- for i, j in itertools.combinations(range(len(shape)), 2):
328
- if shape[i] != shape[j]:
329
- continue
330
- if do(
331
- "allclose",
332
- x[indxrs[i] != indxrs[j]],
333
- zero,
334
- atol=atol,
335
- like=backend,
336
- ):
337
- return (i, j)
338
- return None
339
-
340
-
341
- @njit
342
- def _numba_find_antidiag_axes(x, atol=1e-12): # pragma: no cover
343
- """Numba-compiled array antidiagonal axis finder.
344
-
345
- Parameters
346
- ----------
347
- x : numpy.ndarray
348
- The array to search for anti-diagonal axes.
349
- atol : float
350
- The tolerance with which to compare to zero.
351
-
352
- Returns
353
- -------
354
- antidiag_axes : set[tuple[int]]
355
- The set of pairs of axes which are anti-diagonal.
356
- """
357
-
358
- # create the set of pairs of matching size axes
359
- antidiag_axes = set()
360
- for i in range(x.ndim - 1):
361
- for j in range(i + 1, x.ndim):
362
- if x.shape[i] == x.shape[j]:
363
- antidiag_axes.add((i, j))
364
-
365
- # enumerate through every array entry, eagerly invalidating axis pairs
366
- for index, val in numpy.ndenumerate(x):
367
- for i, j in list(antidiag_axes):
368
- d = x.shape[i]
369
- if (index[i] != d - 1 - index[j]) and (abs(val) > atol):
370
- antidiag_axes.remove((i, j))
371
-
372
- # all pairs invalid, nothing left to do
373
- if len(antidiag_axes) == 0:
374
- break
375
-
376
- return antidiag_axes
377
-
378
-
379
- def find_antidiag_axes(x, atol=1e-12):
380
- """Try and find a pair of axes of ``x`` in which it is anti-diagonal.
381
-
382
- Parameters
383
- ----------
384
- x : array-like
385
- The array to search.
386
- atol : float, optional
387
- Tolerance with which to compare to zero.
388
-
389
- Returns
390
- -------
391
- tuple[int] or None
392
- The two axes if found else None.
393
-
394
- Examples
395
- --------
396
-
397
- >>> x = np.array([[[0, 1], [0, 2]],
398
- ... [[3, 0], [4, 0]]])
399
- >>> find_antidiag_axes(x)
400
- (0, 2)
401
-
402
- Which means we can reduce ``x`` without loss of information to:
403
-
404
- >>> np.einsum('aba->ab', x[::-1, :, :])
405
- array([[3, 4],
406
- [1, 2]])
407
-
408
- as long as we flip the order of dimensions on other tensors corresponding
409
- to the the same index.
410
- """
411
- shape = x.shape
412
- if len(shape) < 2:
413
- return None
414
-
415
- backend = infer_backend(x)
416
-
417
- # use numba-accelerated version for numpy arrays
418
- if backend == "numpy":
419
- antidiag_axes = _numba_find_antidiag_axes(x, atol=atol)
420
- if antidiag_axes:
421
- # make it determinstic
422
- return min(antidiag_axes)
423
- return None
424
-
425
- indxrs = do("indices", shape, like=backend)
426
- zero = do("zeros", (), like=x)
427
-
428
- for i, j in itertools.combinations(range(len(shape)), 2):
429
- di, dj = shape[i], shape[j]
430
- if di != dj:
431
- continue
432
- if do(
433
- "allclose",
434
- x[indxrs[i] != dj - 1 - indxrs[j]],
435
- zero,
436
- atol=atol,
437
- like=backend,
438
- ):
439
- return (i, j)
440
- return None
441
-
442
-
443
- @njit
444
- def _numba_find_columns(x, atol=1e-12): # pragma: no cover
445
- """Numba-compiled single non-zero column axis finder.
446
-
447
- Parameters
448
- ----------
449
- x : array
450
- The array to search.
451
- atol : float, optional
452
- Absolute tolerance to compare to zero with.
453
-
454
- Returns
455
- -------
456
- set[tuple[int]]
457
- Set of pairs (axis, index) defining lone non-zero columns.
458
- """
459
-
460
- # possible pairings of axis + index
461
- column_pairs = set()
462
- for ax, d in enumerate(x.shape):
463
- for i in range(d):
464
- column_pairs.add((ax, i))
465
-
466
- # enumerate over all array entries, invalidating potential column pairs
467
- for index, val in numpy.ndenumerate(x):
468
- if abs(val) > atol:
469
- for ax, i in enumerate(index):
470
- for pax, pi in list(column_pairs):
471
- if ax == pax and pi != i:
472
- column_pairs.remove((pax, pi))
473
-
474
- # all potential pairs invalidated
475
- if not len(column_pairs):
476
- break
477
-
478
- return column_pairs
479
-
480
-
481
- def find_columns(x, atol=1e-12):
482
- """Try and find columns of axes which are zero apart from a single index.
483
-
484
- Parameters
485
- ----------
486
- x : array-like
487
- The array to search.
488
- atol : float, optional
489
- Tolerance with which to compare to zero.
490
-
491
- Returns
492
- -------
493
- tuple[int] or None
494
- If found, the first integer is which axis, and the second is which
495
- column of that axis, else None.
496
-
497
- Examples
498
- --------
499
-
500
- >>> x = np.array([[[0, 1], [0, 2]],
501
- ... [[0, 3], [0, 4]]])
502
- >>> find_columns(x)
503
- (2, 1)
504
-
505
- Which means we can happily slice ``x`` without loss of information to:
506
-
507
- >>> x[:, :, 1]
508
- array([[1, 2],
509
- [3, 4]])
510
-
511
- """
512
- shape = x.shape
513
- if len(shape) < 1:
514
- return None
515
-
516
- backend = infer_backend(x)
517
-
518
- # use numba-accelerated version for numpy arrays
519
- if backend == "numpy":
520
- columns_pairs = _numba_find_columns(x, atol)
521
- if columns_pairs:
522
- return min(columns_pairs)
523
- return None
524
-
525
- indxrs = do("indices", shape, like=backend)
526
- zero = do("zeros", (), like=x)
527
-
528
- for i in range(len(shape)):
529
- for j in range(shape[i]):
530
- if do(
531
- "allclose", x[indxrs[i] != j], zero, atol=atol, like=backend
532
- ):
533
- return (i, j)
534
-
535
- return None
536
-
537
-
538
- class PArray:
539
- """Simple array-like object that lazily generates the actual array by
540
- calling a function with a set of parameters.
541
-
542
- Parameters
543
- ----------
544
- fn : callable
545
- The function that generates the tensor data from ``params``.
546
- params : sequence of numbers
547
- The initial parameters supplied to the generating function like
548
- ``fn(params)``.
549
-
550
- See Also
551
- --------
552
- PTensor
553
- """
554
-
555
- __slots__ = ("_fn", "_params", "_data", "_shape", "_shape_fn_id")
556
-
557
- def __init__(self, fn, params, shape=None):
558
- self.fn = fn
559
- self.params = params
560
- self._shape = shape
561
- self._shape_fn_id = id(fn)
562
-
563
- def copy(self):
564
- new = PArray(self.fn, self.params, self.shape)
565
- new._data = self._data # for efficiency
566
- return new
567
-
568
- @property
569
- def fn(self):
570
- return self._fn
571
-
572
- @fn.setter
573
- def fn(self, x):
574
- self._fn = x
575
- self._data = None
576
-
577
- @property
578
- def params(self):
579
- return self._params
580
-
581
- @params.setter
582
- def params(self, x):
583
- self._params = asarray(x)
584
- self._data = None
585
-
586
- @property
587
- def data(self):
588
- if self._data is None:
589
- self._data = self._fn(self._params)
590
- return self._data
591
-
592
- @property
593
- def shape(self):
594
- # if we haven't calculated shape or have updated function, get shape
595
- _shape_fn_id = id(self.fn)
596
- if (self._shape is None) or (self._shape_fn_id != _shape_fn_id):
597
- self._shape = self.data.shape
598
- self._shape_fn_id = _shape_fn_id
599
- return self._shape
600
-
601
- @property
602
- def ndim(self):
603
- return len(self.shape)
604
-
605
- def add_function(self, g):
606
- """Chain the new function ``g`` on top of current function ``f`` like
607
- ``g(f(params))``.
608
- """
609
- f = self.fn
610
- self.fn = fn_compose(g, f)