Trajectree 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +9 -9
  5. trajectree/fock_optics/outputs.py +10 -6
  6. trajectree/fock_optics/utils.py +9 -6
  7. trajectree/sequence/swap.py +5 -4
  8. trajectree/trajectory.py +5 -4
  9. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/METADATA +2 -3
  10. trajectree-0.0.3.dist-info/RECORD +16 -0
  11. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  12. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  13. trajectree/quimb/docs/conf.py +0 -158
  14. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  15. trajectree/quimb/quimb/__init__.py +0 -507
  16. trajectree/quimb/quimb/calc.py +0 -1491
  17. trajectree/quimb/quimb/core.py +0 -2279
  18. trajectree/quimb/quimb/evo.py +0 -712
  19. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  20. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  21. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  22. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  23. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  24. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  25. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  26. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  27. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  28. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  29. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  30. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  31. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  32. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  33. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  34. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  35. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  36. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  37. trajectree/quimb/quimb/gates.py +0 -36
  38. trajectree/quimb/quimb/gen/__init__.py +0 -2
  39. trajectree/quimb/quimb/gen/operators.py +0 -1167
  40. trajectree/quimb/quimb/gen/rand.py +0 -713
  41. trajectree/quimb/quimb/gen/states.py +0 -479
  42. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  43. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  44. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  45. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  46. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  47. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  48. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  49. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  50. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  51. trajectree/quimb/quimb/schematic.py +0 -1518
  52. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  53. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  54. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  55. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  56. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  57. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  58. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  59. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  60. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  61. trajectree/quimb/quimb/tensor/interface.py +0 -114
  62. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  63. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  64. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  65. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  66. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  67. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  68. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  69. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  70. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  71. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  72. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  74. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  75. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  76. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  77. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  78. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  79. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  80. trajectree/quimb/quimb/utils.py +0 -892
  81. trajectree/quimb/tests/__init__.py +0 -0
  82. trajectree/quimb/tests/test_accel.py +0 -501
  83. trajectree/quimb/tests/test_calc.py +0 -788
  84. trajectree/quimb/tests/test_core.py +0 -847
  85. trajectree/quimb/tests/test_evo.py +0 -565
  86. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  87. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  88. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  89. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  90. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  91. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  92. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  93. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  94. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  95. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  96. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  97. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  103. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  104. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  105. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  106. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  107. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  108. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  109. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  110. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  111. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  112. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  113. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  114. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  115. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  116. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  117. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  118. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  119. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  120. trajectree/quimb/tests/test_utils.py +0 -85
  121. trajectree-0.0.1.dist-info/RECORD +0 -126
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/WHEEL +0 -0
  123. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/licenses/LICENSE +0 -0
  124. {trajectree-0.0.1.dist-info → trajectree-0.0.3.dist-info}/top_level.txt +0 -0
@@ -1,610 +0,0 @@
1
- """Backend agnostic array operations."""
2
-
3
- import functools
4
- import itertools
5
-
6
- import numpy
7
- from autoray import (
8
- compose,
9
- do,
10
- get_dtype_name,
11
- get_lib_fn,
12
- infer_backend,
13
- reshape,
14
- )
15
-
16
- from ..core import njit, qarray
17
- from ..linalg.base_linalg import norm_fro_dense
18
- from ..utils import compose as fn_compose
19
-
20
-
21
- def asarray(array):
22
- """Maybe convert data for a tensor to use. If ``array`` already has a
23
- ``.shape`` attribute, i.e. looks like an array, it is left as-is. Else the
24
- elements are inspected to see which libraries' array constructor should be
25
- used, defaulting to ``numpy`` if everything is builtin or numpy numbers.
26
- """
27
- if isinstance(array, (numpy.matrix, qarray)):
28
- # if numpy make sure array not subclass
29
- return numpy.asarray(array)
30
-
31
- if hasattr(array, "shape"):
32
- # otherwise don't touch things which are already array like
33
- return array
34
-
35
- # else we some kind of possibly nested python iterable -> inspect items
36
- backends = set()
37
-
38
- def _nd_py_iter(x):
39
- if isinstance(x, str):
40
- # handle recursion error
41
- return x
42
-
43
- backend = infer_backend(x)
44
- if backend != "builtins":
45
- # don't iterate any non-builtin containers
46
- backends.add(backend)
47
- return x
48
-
49
- # is some kind of python container or element -> iterate or return
50
- try:
51
- return list(_nd_py_iter(sub) for sub in x)
52
- except TypeError:
53
- return x
54
-
55
- nested_tup = _nd_py_iter(array)
56
-
57
- # numpy and builtin elements treat as basic
58
- backends -= {"builtins", "numpy"}
59
- if not backends:
60
- backend = "numpy"
61
- else:
62
- (backend,) = backends
63
-
64
- return do("array", nested_tup, like=backend)
65
-
66
-
67
- _blocksparselookup = {}
68
-
69
-
70
- def isblocksparse(x):
71
- """Check if `x` is a block-sparse array. Cached on class for speed."""
72
- try:
73
- return _blocksparselookup[x.__class__]
74
- except KeyError:
75
- # XXX: make this a more established interface
76
- isbs = hasattr(x, "align_axes")
77
- _blocksparselookup[x.__class__] = isbs
78
- return isbs
79
-
80
-
81
- _fermioniclookup = {}
82
-
83
-
84
- def isfermionic(x):
85
- """Check if `x` is a fermionic array. Cached on class for speed."""
86
- try:
87
- return _fermioniclookup[x.__class__]
88
- except KeyError:
89
- # XXX: make this a more established interface
90
- isf = hasattr(x, "phase_flip")
91
- _fermioniclookup[x.__class__] = isf
92
- return isf
93
-
94
-
95
- @functools.lru_cache(2**14)
96
- def calc_fuse_perm_and_shape(shape, axes_groups):
97
- ndim = len(shape)
98
-
99
- # which group does each axis appear in, if any
100
- num_groups = len(axes_groups)
101
- ax2group = {ax: g for g, axes in enumerate(axes_groups) for ax in axes}
102
-
103
- # the permutation will be the same for every block: precalculate
104
- # n.b. all new groups will be inserted at the *first fused axis*
105
- position = min(g for gax in axes_groups for g in gax)
106
- axes_before = tuple(
107
- ax for ax in range(position) if ax2group.setdefault(ax, None) is None
108
- )
109
- axes_after = tuple(
110
- ax
111
- for ax in range(position, ndim)
112
- if ax2group.setdefault(ax, None) is None
113
- )
114
- perm = (*axes_before, *(ax for g in axes_groups for ax in g), *axes_after)
115
-
116
- # track where each axis will be in the new array
117
- new_axes = {ax: ax for ax in axes_before}
118
- for i, g in enumerate(axes_groups):
119
- for ax in g:
120
- new_axes[ax] = position + i
121
- for i, ax in enumerate(axes_after):
122
- new_axes[ax] = position + num_groups + i
123
- new_ndim = len(axes_before) + num_groups + len(axes_after)
124
-
125
- new_shape = [1] * new_ndim
126
- for i, d in enumerate(shape):
127
- g = ax2group[i]
128
- new_ax = new_axes[i]
129
- if g is None:
130
- # not fusing, new value is just copied
131
- new_shape[new_ax] = d
132
- else:
133
- # fusing: need to accumulate
134
- new_shape[new_ax] *= d
135
-
136
- if all(i == ax for i, ax in enumerate(perm)):
137
- # no need to transpose
138
- perm = None
139
-
140
- new_shape = tuple(new_shape)
141
- if shape == new_shape:
142
- # no need to reshape
143
- new_shape = None
144
-
145
- return perm, new_shape
146
-
147
-
148
- @compose
149
- def fuse(x, *axes_groups, backend=None):
150
- """Fuse the give group or groups of axes. The new fused axes will be
151
- inserted at the minimum index of any fused axis (even if it is not in
152
- the first group). For example, ``fuse(x, [5, 3], [7, 2, 6])`` will
153
- produce an array with axes like::
154
-
155
- groups inserted at axis 2, removed beyond that.
156
- ......<--
157
- (0, 1, g0, g1, 4, 8, ...)
158
- | |
159
- | g1=(7, 2, 6)
160
- g0=(5, 3)
161
-
162
- Parameters
163
- ----------
164
- axes_groups : sequence of sequences of int
165
- The axes to fuse. Each group of axes will be fused into a single
166
- axis.
167
- """
168
- if backend is None:
169
- backend = infer_backend(x)
170
- _transpose = get_lib_fn(backend, "transpose")
171
- _reshape = get_lib_fn(backend, "reshape")
172
-
173
- axes_groups = tuple(map(tuple, axes_groups))
174
- if not any(axes_groups):
175
- return x
176
-
177
- shape = tuple(map(int, x.shape))
178
- perm, new_shape = calc_fuse_perm_and_shape(shape, axes_groups)
179
-
180
- if perm is not None:
181
- x = _transpose(x, perm)
182
- if new_shape is not None:
183
- x = _reshape(x, new_shape)
184
-
185
- return x
186
-
187
-
188
- def ndim(array):
189
- """The number of dimensions of an array."""
190
- try:
191
- return array.ndim
192
- except AttributeError:
193
- return len(array.shape)
194
-
195
-
196
- @compose
197
- def multiply_diagonal(x, v, axis, backend=None):
198
- """Multiply v into x as if contracting in a diagonal matrix."""
199
- newshape = tuple((-1 if i == axis else 1) for i in range(ndim(x)))
200
- v_broadcast = do("reshape", v, newshape, like=backend)
201
- return x * v_broadcast
202
-
203
-
204
- @compose
205
- def align_axes(*arrays, axes, backend=None):
206
- """Prepare a set of arrays that should be contractible along ``axes``.
207
-
208
- For example, block symmetric arrays need to have aligned sectors prior to
209
- fusing.
210
- """
211
- # default implementation is nothing
212
- return arrays
213
-
214
-
215
- # ------------- miscelleneous other backend agnostic functions -------------- #
216
-
217
-
218
- def iscomplex(x):
219
- """Does ``x`` have a complex dtype?"""
220
- if infer_backend(x) == "builtins":
221
- return isinstance(x, complex)
222
- return "complex" in get_dtype_name(x)
223
-
224
-
225
- @compose
226
- def norm_fro(x):
227
- """The frobenius norm of an array."""
228
- try:
229
- return do("linalg.norm", reshape(x, (-1,)))
230
- except AttributeError:
231
- return do("sum", do("abs", x) ** 2) ** 0.5
232
-
233
-
234
- norm_fro.register("numpy", norm_fro_dense)
235
-
236
-
237
- def sensibly_scale(x):
238
- """Take an array and scale it *very* roughly such that random tensor
239
- networks consisting of such arrays do not have gigantic norms.
240
- """
241
- return x / norm_fro(x) ** (1.5 / ndim(x))
242
-
243
-
244
- @njit
245
- def _numba_find_diag_axes(x, atol=1e-12): # pragma: no cover
246
- """Numba-compiled array diagonal axis finder.
247
-
248
- Parameters
249
- ----------
250
- x : numpy.ndarray
251
- The array to search for diagonal axes.
252
- atol : float
253
- The tolerance with which to compare to zero.
254
-
255
- Returns
256
- -------
257
- diag_axes : set[tuple[int]]
258
- The set of pairs of axes which are diagonal.
259
- """
260
-
261
- # create the set of pairs of matching size axes
262
- diag_axes = set()
263
- for d1 in range(x.ndim - 1):
264
- for d2 in range(d1 + 1, x.ndim):
265
- if x.shape[d1] == x.shape[d2]:
266
- diag_axes.add((d1, d2))
267
-
268
- # enumerate through every array entry, eagerly invalidating axis pairs
269
- for index, val in numpy.ndenumerate(x):
270
- for d1, d2 in list(diag_axes):
271
- if (index[d1] != index[d2]) and (abs(val) > atol):
272
- diag_axes.remove((d1, d2))
273
-
274
- # all pairs invalid, nothing left to do
275
- if len(diag_axes) == 0:
276
- break
277
-
278
- return diag_axes
279
-
280
-
281
- def find_diag_axes(x, atol=1e-12):
282
- """Try and find a pair of axes of ``x`` in which it is diagonal.
283
-
284
- Parameters
285
- ----------
286
- x : array-like
287
- The array to search.
288
- atol : float, optional
289
- Tolerance with which to compare to zero.
290
-
291
- Returns
292
- -------
293
- tuple[int] or None
294
- The two axes if found else None.
295
-
296
- Examples
297
- --------
298
-
299
- >>> x = np.array([[[1, 0], [0, 2]],
300
- ... [[3, 0], [0, 4]]])
301
- >>> find_diag_axes(x)
302
- (1, 2)
303
-
304
- Which means we can reduce ``x`` without loss of information to:
305
-
306
- >>> np.einsum('abb->ab', x)
307
- array([[1, 2],
308
- [3, 4]])
309
-
310
- """
311
- shape = x.shape
312
- if len(shape) < 2:
313
- return None
314
-
315
- backend = infer_backend(x)
316
- zero = do("zeros", (), like=x)
317
-
318
- # use numba-accelerated version for numpy arrays
319
- if backend == "numpy":
320
- diag_axes = _numba_find_diag_axes(x, atol=atol)
321
- if diag_axes:
322
- # make it determinstic
323
- return min(diag_axes)
324
- return None
325
- indxrs = do("indices", shape, like=backend)
326
-
327
- for i, j in itertools.combinations(range(len(shape)), 2):
328
- if shape[i] != shape[j]:
329
- continue
330
- if do(
331
- "allclose",
332
- x[indxrs[i] != indxrs[j]],
333
- zero,
334
- atol=atol,
335
- like=backend,
336
- ):
337
- return (i, j)
338
- return None
339
-
340
-
341
- @njit
342
- def _numba_find_antidiag_axes(x, atol=1e-12): # pragma: no cover
343
- """Numba-compiled array antidiagonal axis finder.
344
-
345
- Parameters
346
- ----------
347
- x : numpy.ndarray
348
- The array to search for anti-diagonal axes.
349
- atol : float
350
- The tolerance with which to compare to zero.
351
-
352
- Returns
353
- -------
354
- antidiag_axes : set[tuple[int]]
355
- The set of pairs of axes which are anti-diagonal.
356
- """
357
-
358
- # create the set of pairs of matching size axes
359
- antidiag_axes = set()
360
- for i in range(x.ndim - 1):
361
- for j in range(i + 1, x.ndim):
362
- if x.shape[i] == x.shape[j]:
363
- antidiag_axes.add((i, j))
364
-
365
- # enumerate through every array entry, eagerly invalidating axis pairs
366
- for index, val in numpy.ndenumerate(x):
367
- for i, j in list(antidiag_axes):
368
- d = x.shape[i]
369
- if (index[i] != d - 1 - index[j]) and (abs(val) > atol):
370
- antidiag_axes.remove((i, j))
371
-
372
- # all pairs invalid, nothing left to do
373
- if len(antidiag_axes) == 0:
374
- break
375
-
376
- return antidiag_axes
377
-
378
-
379
- def find_antidiag_axes(x, atol=1e-12):
380
- """Try and find a pair of axes of ``x`` in which it is anti-diagonal.
381
-
382
- Parameters
383
- ----------
384
- x : array-like
385
- The array to search.
386
- atol : float, optional
387
- Tolerance with which to compare to zero.
388
-
389
- Returns
390
- -------
391
- tuple[int] or None
392
- The two axes if found else None.
393
-
394
- Examples
395
- --------
396
-
397
- >>> x = np.array([[[0, 1], [0, 2]],
398
- ... [[3, 0], [4, 0]]])
399
- >>> find_antidiag_axes(x)
400
- (0, 2)
401
-
402
- Which means we can reduce ``x`` without loss of information to:
403
-
404
- >>> np.einsum('aba->ab', x[::-1, :, :])
405
- array([[3, 4],
406
- [1, 2]])
407
-
408
- as long as we flip the order of dimensions on other tensors corresponding
409
- to the the same index.
410
- """
411
- shape = x.shape
412
- if len(shape) < 2:
413
- return None
414
-
415
- backend = infer_backend(x)
416
-
417
- # use numba-accelerated version for numpy arrays
418
- if backend == "numpy":
419
- antidiag_axes = _numba_find_antidiag_axes(x, atol=atol)
420
- if antidiag_axes:
421
- # make it determinstic
422
- return min(antidiag_axes)
423
- return None
424
-
425
- indxrs = do("indices", shape, like=backend)
426
- zero = do("zeros", (), like=x)
427
-
428
- for i, j in itertools.combinations(range(len(shape)), 2):
429
- di, dj = shape[i], shape[j]
430
- if di != dj:
431
- continue
432
- if do(
433
- "allclose",
434
- x[indxrs[i] != dj - 1 - indxrs[j]],
435
- zero,
436
- atol=atol,
437
- like=backend,
438
- ):
439
- return (i, j)
440
- return None
441
-
442
-
443
- @njit
444
- def _numba_find_columns(x, atol=1e-12): # pragma: no cover
445
- """Numba-compiled single non-zero column axis finder.
446
-
447
- Parameters
448
- ----------
449
- x : array
450
- The array to search.
451
- atol : float, optional
452
- Absolute tolerance to compare to zero with.
453
-
454
- Returns
455
- -------
456
- set[tuple[int]]
457
- Set of pairs (axis, index) defining lone non-zero columns.
458
- """
459
-
460
- # possible pairings of axis + index
461
- column_pairs = set()
462
- for ax, d in enumerate(x.shape):
463
- for i in range(d):
464
- column_pairs.add((ax, i))
465
-
466
- # enumerate over all array entries, invalidating potential column pairs
467
- for index, val in numpy.ndenumerate(x):
468
- if abs(val) > atol:
469
- for ax, i in enumerate(index):
470
- for pax, pi in list(column_pairs):
471
- if ax == pax and pi != i:
472
- column_pairs.remove((pax, pi))
473
-
474
- # all potential pairs invalidated
475
- if not len(column_pairs):
476
- break
477
-
478
- return column_pairs
479
-
480
-
481
- def find_columns(x, atol=1e-12):
482
- """Try and find columns of axes which are zero apart from a single index.
483
-
484
- Parameters
485
- ----------
486
- x : array-like
487
- The array to search.
488
- atol : float, optional
489
- Tolerance with which to compare to zero.
490
-
491
- Returns
492
- -------
493
- tuple[int] or None
494
- If found, the first integer is which axis, and the second is which
495
- column of that axis, else None.
496
-
497
- Examples
498
- --------
499
-
500
- >>> x = np.array([[[0, 1], [0, 2]],
501
- ... [[0, 3], [0, 4]]])
502
- >>> find_columns(x)
503
- (2, 1)
504
-
505
- Which means we can happily slice ``x`` without loss of information to:
506
-
507
- >>> x[:, :, 1]
508
- array([[1, 2],
509
- [3, 4]])
510
-
511
- """
512
- shape = x.shape
513
- if len(shape) < 1:
514
- return None
515
-
516
- backend = infer_backend(x)
517
-
518
- # use numba-accelerated version for numpy arrays
519
- if backend == "numpy":
520
- columns_pairs = _numba_find_columns(x, atol)
521
- if columns_pairs:
522
- return min(columns_pairs)
523
- return None
524
-
525
- indxrs = do("indices", shape, like=backend)
526
- zero = do("zeros", (), like=x)
527
-
528
- for i in range(len(shape)):
529
- for j in range(shape[i]):
530
- if do(
531
- "allclose", x[indxrs[i] != j], zero, atol=atol, like=backend
532
- ):
533
- return (i, j)
534
-
535
- return None
536
-
537
-
538
- class PArray:
539
- """Simple array-like object that lazily generates the actual array by
540
- calling a function with a set of parameters.
541
-
542
- Parameters
543
- ----------
544
- fn : callable
545
- The function that generates the tensor data from ``params``.
546
- params : sequence of numbers
547
- The initial parameters supplied to the generating function like
548
- ``fn(params)``.
549
-
550
- See Also
551
- --------
552
- PTensor
553
- """
554
-
555
- __slots__ = ("_fn", "_params", "_data", "_shape", "_shape_fn_id")
556
-
557
- def __init__(self, fn, params, shape=None):
558
- self.fn = fn
559
- self.params = params
560
- self._shape = shape
561
- self._shape_fn_id = id(fn)
562
-
563
- def copy(self):
564
- new = PArray(self.fn, self.params, self.shape)
565
- new._data = self._data # for efficiency
566
- return new
567
-
568
- @property
569
- def fn(self):
570
- return self._fn
571
-
572
- @fn.setter
573
- def fn(self, x):
574
- self._fn = x
575
- self._data = None
576
-
577
- @property
578
- def params(self):
579
- return self._params
580
-
581
- @params.setter
582
- def params(self, x):
583
- self._params = asarray(x)
584
- self._data = None
585
-
586
- @property
587
- def data(self):
588
- if self._data is None:
589
- self._data = self._fn(self._params)
590
- return self._data
591
-
592
- @property
593
- def shape(self):
594
- # if we haven't calculated shape or have updated function, get shape
595
- _shape_fn_id = id(self.fn)
596
- if (self._shape is None) or (self._shape_fn_id != _shape_fn_id):
597
- self._shape = self.data.shape
598
- self._shape_fn_id = _shape_fn_id
599
- return self._shape
600
-
601
- @property
602
- def ndim(self):
603
- return len(self.shape)
604
-
605
- def add_function(self, g):
606
- """Chain the new function ``g`` on top of current function ``f`` like
607
- ``g(f(params))``.
608
- """
609
- f = self.fn
610
- self.fn = fn_compose(g, f)