Trajectree 0.0.0__py3-none-any.whl → 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trajectree/__init__.py +3 -0
- trajectree/fock_optics/devices.py +1 -1
- trajectree/fock_optics/light_sources.py +2 -2
- trajectree/fock_optics/measurement.py +3 -3
- trajectree/fock_optics/utils.py +6 -6
- trajectree/quimb/docs/_pygments/_pygments_dark.py +118 -0
- trajectree/quimb/docs/_pygments/_pygments_light.py +118 -0
- trajectree/quimb/docs/conf.py +158 -0
- trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +62 -0
- trajectree/quimb/quimb/__init__.py +507 -0
- trajectree/quimb/quimb/calc.py +1491 -0
- trajectree/quimb/quimb/core.py +2279 -0
- trajectree/quimb/quimb/evo.py +712 -0
- trajectree/quimb/quimb/experimental/__init__.py +0 -0
- trajectree/quimb/quimb/experimental/autojittn.py +129 -0
- trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +109 -0
- trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +397 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +653 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +571 -0
- trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +775 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +316 -0
- trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +537 -0
- trajectree/quimb/quimb/experimental/belief_propagation/regions.py +194 -0
- trajectree/quimb/quimb/experimental/cluster_update.py +286 -0
- trajectree/quimb/quimb/experimental/merabuilder.py +865 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +15 -0
- trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +1631 -0
- trajectree/quimb/quimb/experimental/schematic.py +7 -0
- trajectree/quimb/quimb/experimental/tn_marginals.py +130 -0
- trajectree/quimb/quimb/experimental/tnvmc.py +1483 -0
- trajectree/quimb/quimb/gates.py +36 -0
- trajectree/quimb/quimb/gen/__init__.py +2 -0
- trajectree/quimb/quimb/gen/operators.py +1167 -0
- trajectree/quimb/quimb/gen/rand.py +713 -0
- trajectree/quimb/quimb/gen/states.py +479 -0
- trajectree/quimb/quimb/linalg/__init__.py +6 -0
- trajectree/quimb/quimb/linalg/approx_spectral.py +1109 -0
- trajectree/quimb/quimb/linalg/autoblock.py +258 -0
- trajectree/quimb/quimb/linalg/base_linalg.py +719 -0
- trajectree/quimb/quimb/linalg/mpi_launcher.py +397 -0
- trajectree/quimb/quimb/linalg/numpy_linalg.py +244 -0
- trajectree/quimb/quimb/linalg/rand_linalg.py +514 -0
- trajectree/quimb/quimb/linalg/scipy_linalg.py +293 -0
- trajectree/quimb/quimb/linalg/slepc_linalg.py +892 -0
- trajectree/quimb/quimb/schematic.py +1518 -0
- trajectree/quimb/quimb/tensor/__init__.py +401 -0
- trajectree/quimb/quimb/tensor/array_ops.py +610 -0
- trajectree/quimb/quimb/tensor/circuit.py +4824 -0
- trajectree/quimb/quimb/tensor/circuit_gen.py +411 -0
- trajectree/quimb/quimb/tensor/contraction.py +336 -0
- trajectree/quimb/quimb/tensor/decomp.py +1255 -0
- trajectree/quimb/quimb/tensor/drawing.py +1646 -0
- trajectree/quimb/quimb/tensor/fitting.py +385 -0
- trajectree/quimb/quimb/tensor/geometry.py +583 -0
- trajectree/quimb/quimb/tensor/interface.py +114 -0
- trajectree/quimb/quimb/tensor/networking.py +1058 -0
- trajectree/quimb/quimb/tensor/optimize.py +1818 -0
- trajectree/quimb/quimb/tensor/tensor_1d.py +4778 -0
- trajectree/quimb/quimb/tensor/tensor_1d_compress.py +1854 -0
- trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +662 -0
- trajectree/quimb/quimb/tensor/tensor_2d.py +5954 -0
- trajectree/quimb/quimb/tensor/tensor_2d_compress.py +96 -0
- trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +1230 -0
- trajectree/quimb/quimb/tensor/tensor_3d.py +2869 -0
- trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +46 -0
- trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +60 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom.py +3237 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +565 -0
- trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +1138 -0
- trajectree/quimb/quimb/tensor/tensor_builder.py +5411 -0
- trajectree/quimb/quimb/tensor/tensor_core.py +11179 -0
- trajectree/quimb/quimb/tensor/tensor_dmrg.py +1472 -0
- trajectree/quimb/quimb/tensor/tensor_mera.py +204 -0
- trajectree/quimb/quimb/utils.py +892 -0
- trajectree/quimb/tests/__init__.py +0 -0
- trajectree/quimb/tests/test_accel.py +501 -0
- trajectree/quimb/tests/test_calc.py +788 -0
- trajectree/quimb/tests/test_core.py +847 -0
- trajectree/quimb/tests/test_evo.py +565 -0
- trajectree/quimb/tests/test_gen/__init__.py +0 -0
- trajectree/quimb/tests/test_gen/test_operators.py +361 -0
- trajectree/quimb/tests/test_gen/test_rand.py +296 -0
- trajectree/quimb/tests/test_gen/test_states.py +261 -0
- trajectree/quimb/tests/test_linalg/__init__.py +0 -0
- trajectree/quimb/tests/test_linalg/test_approx_spectral.py +368 -0
- trajectree/quimb/tests/test_linalg/test_base_linalg.py +351 -0
- trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +127 -0
- trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +84 -0
- trajectree/quimb/tests/test_linalg/test_rand_linalg.py +134 -0
- trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +283 -0
- trajectree/quimb/tests/test_tensor/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +39 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +67 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +64 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +51 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +142 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +101 -0
- trajectree/quimb/tests/test_tensor/test_circuit.py +816 -0
- trajectree/quimb/tests/test_tensor/test_contract.py +67 -0
- trajectree/quimb/tests/test_tensor/test_decomp.py +40 -0
- trajectree/quimb/tests/test_tensor/test_mera.py +52 -0
- trajectree/quimb/tests/test_tensor/test_optimizers.py +488 -0
- trajectree/quimb/tests/test_tensor/test_tensor_1d.py +1171 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d.py +606 -0
- trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +144 -0
- trajectree/quimb/tests/test_tensor/test_tensor_3d.py +123 -0
- trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +226 -0
- trajectree/quimb/tests/test_tensor/test_tensor_builder.py +441 -0
- trajectree/quimb/tests/test_tensor/test_tensor_core.py +2066 -0
- trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +388 -0
- trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +63 -0
- trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +270 -0
- trajectree/quimb/tests/test_utils.py +85 -0
- trajectree/trajectory.py +2 -2
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/METADATA +2 -2
- trajectree-0.0.1.dist-info/RECORD +126 -0
- trajectree-0.0.0.dist-info/RECORD +0 -16
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/WHEEL +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/licenses/LICENSE +0 -0
- {trajectree-0.0.0.dist-info → trajectree-0.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
"""Backend agnostic array operations."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import itertools
|
|
5
|
+
|
|
6
|
+
import numpy
|
|
7
|
+
from autoray import (
|
|
8
|
+
compose,
|
|
9
|
+
do,
|
|
10
|
+
get_dtype_name,
|
|
11
|
+
get_lib_fn,
|
|
12
|
+
infer_backend,
|
|
13
|
+
reshape,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from ..core import njit, qarray
|
|
17
|
+
from ..linalg.base_linalg import norm_fro_dense
|
|
18
|
+
from ..utils import compose as fn_compose
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def asarray(array):
|
|
22
|
+
"""Maybe convert data for a tensor to use. If ``array`` already has a
|
|
23
|
+
``.shape`` attribute, i.e. looks like an array, it is left as-is. Else the
|
|
24
|
+
elements are inspected to see which libraries' array constructor should be
|
|
25
|
+
used, defaulting to ``numpy`` if everything is builtin or numpy numbers.
|
|
26
|
+
"""
|
|
27
|
+
if isinstance(array, (numpy.matrix, qarray)):
|
|
28
|
+
# if numpy make sure array not subclass
|
|
29
|
+
return numpy.asarray(array)
|
|
30
|
+
|
|
31
|
+
if hasattr(array, "shape"):
|
|
32
|
+
# otherwise don't touch things which are already array like
|
|
33
|
+
return array
|
|
34
|
+
|
|
35
|
+
# else we some kind of possibly nested python iterable -> inspect items
|
|
36
|
+
backends = set()
|
|
37
|
+
|
|
38
|
+
def _nd_py_iter(x):
|
|
39
|
+
if isinstance(x, str):
|
|
40
|
+
# handle recursion error
|
|
41
|
+
return x
|
|
42
|
+
|
|
43
|
+
backend = infer_backend(x)
|
|
44
|
+
if backend != "builtins":
|
|
45
|
+
# don't iterate any non-builtin containers
|
|
46
|
+
backends.add(backend)
|
|
47
|
+
return x
|
|
48
|
+
|
|
49
|
+
# is some kind of python container or element -> iterate or return
|
|
50
|
+
try:
|
|
51
|
+
return list(_nd_py_iter(sub) for sub in x)
|
|
52
|
+
except TypeError:
|
|
53
|
+
return x
|
|
54
|
+
|
|
55
|
+
nested_tup = _nd_py_iter(array)
|
|
56
|
+
|
|
57
|
+
# numpy and builtin elements treat as basic
|
|
58
|
+
backends -= {"builtins", "numpy"}
|
|
59
|
+
if not backends:
|
|
60
|
+
backend = "numpy"
|
|
61
|
+
else:
|
|
62
|
+
(backend,) = backends
|
|
63
|
+
|
|
64
|
+
return do("array", nested_tup, like=backend)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
_blocksparselookup = {}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def isblocksparse(x):
|
|
71
|
+
"""Check if `x` is a block-sparse array. Cached on class for speed."""
|
|
72
|
+
try:
|
|
73
|
+
return _blocksparselookup[x.__class__]
|
|
74
|
+
except KeyError:
|
|
75
|
+
# XXX: make this a more established interface
|
|
76
|
+
isbs = hasattr(x, "align_axes")
|
|
77
|
+
_blocksparselookup[x.__class__] = isbs
|
|
78
|
+
return isbs
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
_fermioniclookup = {}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def isfermionic(x):
|
|
85
|
+
"""Check if `x` is a fermionic array. Cached on class for speed."""
|
|
86
|
+
try:
|
|
87
|
+
return _fermioniclookup[x.__class__]
|
|
88
|
+
except KeyError:
|
|
89
|
+
# XXX: make this a more established interface
|
|
90
|
+
isf = hasattr(x, "phase_flip")
|
|
91
|
+
_fermioniclookup[x.__class__] = isf
|
|
92
|
+
return isf
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@functools.lru_cache(2**14)
|
|
96
|
+
def calc_fuse_perm_and_shape(shape, axes_groups):
|
|
97
|
+
ndim = len(shape)
|
|
98
|
+
|
|
99
|
+
# which group does each axis appear in, if any
|
|
100
|
+
num_groups = len(axes_groups)
|
|
101
|
+
ax2group = {ax: g for g, axes in enumerate(axes_groups) for ax in axes}
|
|
102
|
+
|
|
103
|
+
# the permutation will be the same for every block: precalculate
|
|
104
|
+
# n.b. all new groups will be inserted at the *first fused axis*
|
|
105
|
+
position = min(g for gax in axes_groups for g in gax)
|
|
106
|
+
axes_before = tuple(
|
|
107
|
+
ax for ax in range(position) if ax2group.setdefault(ax, None) is None
|
|
108
|
+
)
|
|
109
|
+
axes_after = tuple(
|
|
110
|
+
ax
|
|
111
|
+
for ax in range(position, ndim)
|
|
112
|
+
if ax2group.setdefault(ax, None) is None
|
|
113
|
+
)
|
|
114
|
+
perm = (*axes_before, *(ax for g in axes_groups for ax in g), *axes_after)
|
|
115
|
+
|
|
116
|
+
# track where each axis will be in the new array
|
|
117
|
+
new_axes = {ax: ax for ax in axes_before}
|
|
118
|
+
for i, g in enumerate(axes_groups):
|
|
119
|
+
for ax in g:
|
|
120
|
+
new_axes[ax] = position + i
|
|
121
|
+
for i, ax in enumerate(axes_after):
|
|
122
|
+
new_axes[ax] = position + num_groups + i
|
|
123
|
+
new_ndim = len(axes_before) + num_groups + len(axes_after)
|
|
124
|
+
|
|
125
|
+
new_shape = [1] * new_ndim
|
|
126
|
+
for i, d in enumerate(shape):
|
|
127
|
+
g = ax2group[i]
|
|
128
|
+
new_ax = new_axes[i]
|
|
129
|
+
if g is None:
|
|
130
|
+
# not fusing, new value is just copied
|
|
131
|
+
new_shape[new_ax] = d
|
|
132
|
+
else:
|
|
133
|
+
# fusing: need to accumulate
|
|
134
|
+
new_shape[new_ax] *= d
|
|
135
|
+
|
|
136
|
+
if all(i == ax for i, ax in enumerate(perm)):
|
|
137
|
+
# no need to transpose
|
|
138
|
+
perm = None
|
|
139
|
+
|
|
140
|
+
new_shape = tuple(new_shape)
|
|
141
|
+
if shape == new_shape:
|
|
142
|
+
# no need to reshape
|
|
143
|
+
new_shape = None
|
|
144
|
+
|
|
145
|
+
return perm, new_shape
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@compose
|
|
149
|
+
def fuse(x, *axes_groups, backend=None):
|
|
150
|
+
"""Fuse the give group or groups of axes. The new fused axes will be
|
|
151
|
+
inserted at the minimum index of any fused axis (even if it is not in
|
|
152
|
+
the first group). For example, ``fuse(x, [5, 3], [7, 2, 6])`` will
|
|
153
|
+
produce an array with axes like::
|
|
154
|
+
|
|
155
|
+
groups inserted at axis 2, removed beyond that.
|
|
156
|
+
......<--
|
|
157
|
+
(0, 1, g0, g1, 4, 8, ...)
|
|
158
|
+
| |
|
|
159
|
+
| g1=(7, 2, 6)
|
|
160
|
+
g0=(5, 3)
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
axes_groups : sequence of sequences of int
|
|
165
|
+
The axes to fuse. Each group of axes will be fused into a single
|
|
166
|
+
axis.
|
|
167
|
+
"""
|
|
168
|
+
if backend is None:
|
|
169
|
+
backend = infer_backend(x)
|
|
170
|
+
_transpose = get_lib_fn(backend, "transpose")
|
|
171
|
+
_reshape = get_lib_fn(backend, "reshape")
|
|
172
|
+
|
|
173
|
+
axes_groups = tuple(map(tuple, axes_groups))
|
|
174
|
+
if not any(axes_groups):
|
|
175
|
+
return x
|
|
176
|
+
|
|
177
|
+
shape = tuple(map(int, x.shape))
|
|
178
|
+
perm, new_shape = calc_fuse_perm_and_shape(shape, axes_groups)
|
|
179
|
+
|
|
180
|
+
if perm is not None:
|
|
181
|
+
x = _transpose(x, perm)
|
|
182
|
+
if new_shape is not None:
|
|
183
|
+
x = _reshape(x, new_shape)
|
|
184
|
+
|
|
185
|
+
return x
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def ndim(array):
|
|
189
|
+
"""The number of dimensions of an array."""
|
|
190
|
+
try:
|
|
191
|
+
return array.ndim
|
|
192
|
+
except AttributeError:
|
|
193
|
+
return len(array.shape)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@compose
|
|
197
|
+
def multiply_diagonal(x, v, axis, backend=None):
|
|
198
|
+
"""Multiply v into x as if contracting in a diagonal matrix."""
|
|
199
|
+
newshape = tuple((-1 if i == axis else 1) for i in range(ndim(x)))
|
|
200
|
+
v_broadcast = do("reshape", v, newshape, like=backend)
|
|
201
|
+
return x * v_broadcast
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@compose
|
|
205
|
+
def align_axes(*arrays, axes, backend=None):
|
|
206
|
+
"""Prepare a set of arrays that should be contractible along ``axes``.
|
|
207
|
+
|
|
208
|
+
For example, block symmetric arrays need to have aligned sectors prior to
|
|
209
|
+
fusing.
|
|
210
|
+
"""
|
|
211
|
+
# default implementation is nothing
|
|
212
|
+
return arrays
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ------------- miscelleneous other backend agnostic functions -------------- #
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def iscomplex(x):
|
|
219
|
+
"""Does ``x`` have a complex dtype?"""
|
|
220
|
+
if infer_backend(x) == "builtins":
|
|
221
|
+
return isinstance(x, complex)
|
|
222
|
+
return "complex" in get_dtype_name(x)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@compose
|
|
226
|
+
def norm_fro(x):
|
|
227
|
+
"""The frobenius norm of an array."""
|
|
228
|
+
try:
|
|
229
|
+
return do("linalg.norm", reshape(x, (-1,)))
|
|
230
|
+
except AttributeError:
|
|
231
|
+
return do("sum", do("abs", x) ** 2) ** 0.5
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
norm_fro.register("numpy", norm_fro_dense)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def sensibly_scale(x):
|
|
238
|
+
"""Take an array and scale it *very* roughly such that random tensor
|
|
239
|
+
networks consisting of such arrays do not have gigantic norms.
|
|
240
|
+
"""
|
|
241
|
+
return x / norm_fro(x) ** (1.5 / ndim(x))
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@njit
|
|
245
|
+
def _numba_find_diag_axes(x, atol=1e-12): # pragma: no cover
|
|
246
|
+
"""Numba-compiled array diagonal axis finder.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
x : numpy.ndarray
|
|
251
|
+
The array to search for diagonal axes.
|
|
252
|
+
atol : float
|
|
253
|
+
The tolerance with which to compare to zero.
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
diag_axes : set[tuple[int]]
|
|
258
|
+
The set of pairs of axes which are diagonal.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
# create the set of pairs of matching size axes
|
|
262
|
+
diag_axes = set()
|
|
263
|
+
for d1 in range(x.ndim - 1):
|
|
264
|
+
for d2 in range(d1 + 1, x.ndim):
|
|
265
|
+
if x.shape[d1] == x.shape[d2]:
|
|
266
|
+
diag_axes.add((d1, d2))
|
|
267
|
+
|
|
268
|
+
# enumerate through every array entry, eagerly invalidating axis pairs
|
|
269
|
+
for index, val in numpy.ndenumerate(x):
|
|
270
|
+
for d1, d2 in list(diag_axes):
|
|
271
|
+
if (index[d1] != index[d2]) and (abs(val) > atol):
|
|
272
|
+
diag_axes.remove((d1, d2))
|
|
273
|
+
|
|
274
|
+
# all pairs invalid, nothing left to do
|
|
275
|
+
if len(diag_axes) == 0:
|
|
276
|
+
break
|
|
277
|
+
|
|
278
|
+
return diag_axes
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def find_diag_axes(x, atol=1e-12):
|
|
282
|
+
"""Try and find a pair of axes of ``x`` in which it is diagonal.
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
x : array-like
|
|
287
|
+
The array to search.
|
|
288
|
+
atol : float, optional
|
|
289
|
+
Tolerance with which to compare to zero.
|
|
290
|
+
|
|
291
|
+
Returns
|
|
292
|
+
-------
|
|
293
|
+
tuple[int] or None
|
|
294
|
+
The two axes if found else None.
|
|
295
|
+
|
|
296
|
+
Examples
|
|
297
|
+
--------
|
|
298
|
+
|
|
299
|
+
>>> x = np.array([[[1, 0], [0, 2]],
|
|
300
|
+
... [[3, 0], [0, 4]]])
|
|
301
|
+
>>> find_diag_axes(x)
|
|
302
|
+
(1, 2)
|
|
303
|
+
|
|
304
|
+
Which means we can reduce ``x`` without loss of information to:
|
|
305
|
+
|
|
306
|
+
>>> np.einsum('abb->ab', x)
|
|
307
|
+
array([[1, 2],
|
|
308
|
+
[3, 4]])
|
|
309
|
+
|
|
310
|
+
"""
|
|
311
|
+
shape = x.shape
|
|
312
|
+
if len(shape) < 2:
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
backend = infer_backend(x)
|
|
316
|
+
zero = do("zeros", (), like=x)
|
|
317
|
+
|
|
318
|
+
# use numba-accelerated version for numpy arrays
|
|
319
|
+
if backend == "numpy":
|
|
320
|
+
diag_axes = _numba_find_diag_axes(x, atol=atol)
|
|
321
|
+
if diag_axes:
|
|
322
|
+
# make it determinstic
|
|
323
|
+
return min(diag_axes)
|
|
324
|
+
return None
|
|
325
|
+
indxrs = do("indices", shape, like=backend)
|
|
326
|
+
|
|
327
|
+
for i, j in itertools.combinations(range(len(shape)), 2):
|
|
328
|
+
if shape[i] != shape[j]:
|
|
329
|
+
continue
|
|
330
|
+
if do(
|
|
331
|
+
"allclose",
|
|
332
|
+
x[indxrs[i] != indxrs[j]],
|
|
333
|
+
zero,
|
|
334
|
+
atol=atol,
|
|
335
|
+
like=backend,
|
|
336
|
+
):
|
|
337
|
+
return (i, j)
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
@njit
|
|
342
|
+
def _numba_find_antidiag_axes(x, atol=1e-12): # pragma: no cover
|
|
343
|
+
"""Numba-compiled array antidiagonal axis finder.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
x : numpy.ndarray
|
|
348
|
+
The array to search for anti-diagonal axes.
|
|
349
|
+
atol : float
|
|
350
|
+
The tolerance with which to compare to zero.
|
|
351
|
+
|
|
352
|
+
Returns
|
|
353
|
+
-------
|
|
354
|
+
antidiag_axes : set[tuple[int]]
|
|
355
|
+
The set of pairs of axes which are anti-diagonal.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
# create the set of pairs of matching size axes
|
|
359
|
+
antidiag_axes = set()
|
|
360
|
+
for i in range(x.ndim - 1):
|
|
361
|
+
for j in range(i + 1, x.ndim):
|
|
362
|
+
if x.shape[i] == x.shape[j]:
|
|
363
|
+
antidiag_axes.add((i, j))
|
|
364
|
+
|
|
365
|
+
# enumerate through every array entry, eagerly invalidating axis pairs
|
|
366
|
+
for index, val in numpy.ndenumerate(x):
|
|
367
|
+
for i, j in list(antidiag_axes):
|
|
368
|
+
d = x.shape[i]
|
|
369
|
+
if (index[i] != d - 1 - index[j]) and (abs(val) > atol):
|
|
370
|
+
antidiag_axes.remove((i, j))
|
|
371
|
+
|
|
372
|
+
# all pairs invalid, nothing left to do
|
|
373
|
+
if len(antidiag_axes) == 0:
|
|
374
|
+
break
|
|
375
|
+
|
|
376
|
+
return antidiag_axes
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def find_antidiag_axes(x, atol=1e-12):
|
|
380
|
+
"""Try and find a pair of axes of ``x`` in which it is anti-diagonal.
|
|
381
|
+
|
|
382
|
+
Parameters
|
|
383
|
+
----------
|
|
384
|
+
x : array-like
|
|
385
|
+
The array to search.
|
|
386
|
+
atol : float, optional
|
|
387
|
+
Tolerance with which to compare to zero.
|
|
388
|
+
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
tuple[int] or None
|
|
392
|
+
The two axes if found else None.
|
|
393
|
+
|
|
394
|
+
Examples
|
|
395
|
+
--------
|
|
396
|
+
|
|
397
|
+
>>> x = np.array([[[0, 1], [0, 2]],
|
|
398
|
+
... [[3, 0], [4, 0]]])
|
|
399
|
+
>>> find_antidiag_axes(x)
|
|
400
|
+
(0, 2)
|
|
401
|
+
|
|
402
|
+
Which means we can reduce ``x`` without loss of information to:
|
|
403
|
+
|
|
404
|
+
>>> np.einsum('aba->ab', x[::-1, :, :])
|
|
405
|
+
array([[3, 4],
|
|
406
|
+
[1, 2]])
|
|
407
|
+
|
|
408
|
+
as long as we flip the order of dimensions on other tensors corresponding
|
|
409
|
+
to the the same index.
|
|
410
|
+
"""
|
|
411
|
+
shape = x.shape
|
|
412
|
+
if len(shape) < 2:
|
|
413
|
+
return None
|
|
414
|
+
|
|
415
|
+
backend = infer_backend(x)
|
|
416
|
+
|
|
417
|
+
# use numba-accelerated version for numpy arrays
|
|
418
|
+
if backend == "numpy":
|
|
419
|
+
antidiag_axes = _numba_find_antidiag_axes(x, atol=atol)
|
|
420
|
+
if antidiag_axes:
|
|
421
|
+
# make it determinstic
|
|
422
|
+
return min(antidiag_axes)
|
|
423
|
+
return None
|
|
424
|
+
|
|
425
|
+
indxrs = do("indices", shape, like=backend)
|
|
426
|
+
zero = do("zeros", (), like=x)
|
|
427
|
+
|
|
428
|
+
for i, j in itertools.combinations(range(len(shape)), 2):
|
|
429
|
+
di, dj = shape[i], shape[j]
|
|
430
|
+
if di != dj:
|
|
431
|
+
continue
|
|
432
|
+
if do(
|
|
433
|
+
"allclose",
|
|
434
|
+
x[indxrs[i] != dj - 1 - indxrs[j]],
|
|
435
|
+
zero,
|
|
436
|
+
atol=atol,
|
|
437
|
+
like=backend,
|
|
438
|
+
):
|
|
439
|
+
return (i, j)
|
|
440
|
+
return None
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@njit
|
|
444
|
+
def _numba_find_columns(x, atol=1e-12): # pragma: no cover
|
|
445
|
+
"""Numba-compiled single non-zero column axis finder.
|
|
446
|
+
|
|
447
|
+
Parameters
|
|
448
|
+
----------
|
|
449
|
+
x : array
|
|
450
|
+
The array to search.
|
|
451
|
+
atol : float, optional
|
|
452
|
+
Absolute tolerance to compare to zero with.
|
|
453
|
+
|
|
454
|
+
Returns
|
|
455
|
+
-------
|
|
456
|
+
set[tuple[int]]
|
|
457
|
+
Set of pairs (axis, index) defining lone non-zero columns.
|
|
458
|
+
"""
|
|
459
|
+
|
|
460
|
+
# possible pairings of axis + index
|
|
461
|
+
column_pairs = set()
|
|
462
|
+
for ax, d in enumerate(x.shape):
|
|
463
|
+
for i in range(d):
|
|
464
|
+
column_pairs.add((ax, i))
|
|
465
|
+
|
|
466
|
+
# enumerate over all array entries, invalidating potential column pairs
|
|
467
|
+
for index, val in numpy.ndenumerate(x):
|
|
468
|
+
if abs(val) > atol:
|
|
469
|
+
for ax, i in enumerate(index):
|
|
470
|
+
for pax, pi in list(column_pairs):
|
|
471
|
+
if ax == pax and pi != i:
|
|
472
|
+
column_pairs.remove((pax, pi))
|
|
473
|
+
|
|
474
|
+
# all potential pairs invalidated
|
|
475
|
+
if not len(column_pairs):
|
|
476
|
+
break
|
|
477
|
+
|
|
478
|
+
return column_pairs
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def find_columns(x, atol=1e-12):
|
|
482
|
+
"""Try and find columns of axes which are zero apart from a single index.
|
|
483
|
+
|
|
484
|
+
Parameters
|
|
485
|
+
----------
|
|
486
|
+
x : array-like
|
|
487
|
+
The array to search.
|
|
488
|
+
atol : float, optional
|
|
489
|
+
Tolerance with which to compare to zero.
|
|
490
|
+
|
|
491
|
+
Returns
|
|
492
|
+
-------
|
|
493
|
+
tuple[int] or None
|
|
494
|
+
If found, the first integer is which axis, and the second is which
|
|
495
|
+
column of that axis, else None.
|
|
496
|
+
|
|
497
|
+
Examples
|
|
498
|
+
--------
|
|
499
|
+
|
|
500
|
+
>>> x = np.array([[[0, 1], [0, 2]],
|
|
501
|
+
... [[0, 3], [0, 4]]])
|
|
502
|
+
>>> find_columns(x)
|
|
503
|
+
(2, 1)
|
|
504
|
+
|
|
505
|
+
Which means we can happily slice ``x`` without loss of information to:
|
|
506
|
+
|
|
507
|
+
>>> x[:, :, 1]
|
|
508
|
+
array([[1, 2],
|
|
509
|
+
[3, 4]])
|
|
510
|
+
|
|
511
|
+
"""
|
|
512
|
+
shape = x.shape
|
|
513
|
+
if len(shape) < 1:
|
|
514
|
+
return None
|
|
515
|
+
|
|
516
|
+
backend = infer_backend(x)
|
|
517
|
+
|
|
518
|
+
# use numba-accelerated version for numpy arrays
|
|
519
|
+
if backend == "numpy":
|
|
520
|
+
columns_pairs = _numba_find_columns(x, atol)
|
|
521
|
+
if columns_pairs:
|
|
522
|
+
return min(columns_pairs)
|
|
523
|
+
return None
|
|
524
|
+
|
|
525
|
+
indxrs = do("indices", shape, like=backend)
|
|
526
|
+
zero = do("zeros", (), like=x)
|
|
527
|
+
|
|
528
|
+
for i in range(len(shape)):
|
|
529
|
+
for j in range(shape[i]):
|
|
530
|
+
if do(
|
|
531
|
+
"allclose", x[indxrs[i] != j], zero, atol=atol, like=backend
|
|
532
|
+
):
|
|
533
|
+
return (i, j)
|
|
534
|
+
|
|
535
|
+
return None
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class PArray:
|
|
539
|
+
"""Simple array-like object that lazily generates the actual array by
|
|
540
|
+
calling a function with a set of parameters.
|
|
541
|
+
|
|
542
|
+
Parameters
|
|
543
|
+
----------
|
|
544
|
+
fn : callable
|
|
545
|
+
The function that generates the tensor data from ``params``.
|
|
546
|
+
params : sequence of numbers
|
|
547
|
+
The initial parameters supplied to the generating function like
|
|
548
|
+
``fn(params)``.
|
|
549
|
+
|
|
550
|
+
See Also
|
|
551
|
+
--------
|
|
552
|
+
PTensor
|
|
553
|
+
"""
|
|
554
|
+
|
|
555
|
+
__slots__ = ("_fn", "_params", "_data", "_shape", "_shape_fn_id")
|
|
556
|
+
|
|
557
|
+
def __init__(self, fn, params, shape=None):
|
|
558
|
+
self.fn = fn
|
|
559
|
+
self.params = params
|
|
560
|
+
self._shape = shape
|
|
561
|
+
self._shape_fn_id = id(fn)
|
|
562
|
+
|
|
563
|
+
def copy(self):
|
|
564
|
+
new = PArray(self.fn, self.params, self.shape)
|
|
565
|
+
new._data = self._data # for efficiency
|
|
566
|
+
return new
|
|
567
|
+
|
|
568
|
+
@property
|
|
569
|
+
def fn(self):
|
|
570
|
+
return self._fn
|
|
571
|
+
|
|
572
|
+
@fn.setter
|
|
573
|
+
def fn(self, x):
|
|
574
|
+
self._fn = x
|
|
575
|
+
self._data = None
|
|
576
|
+
|
|
577
|
+
@property
|
|
578
|
+
def params(self):
|
|
579
|
+
return self._params
|
|
580
|
+
|
|
581
|
+
@params.setter
|
|
582
|
+
def params(self, x):
|
|
583
|
+
self._params = asarray(x)
|
|
584
|
+
self._data = None
|
|
585
|
+
|
|
586
|
+
@property
|
|
587
|
+
def data(self):
|
|
588
|
+
if self._data is None:
|
|
589
|
+
self._data = self._fn(self._params)
|
|
590
|
+
return self._data
|
|
591
|
+
|
|
592
|
+
@property
|
|
593
|
+
def shape(self):
|
|
594
|
+
# if we haven't calculated shape or have updated function, get shape
|
|
595
|
+
_shape_fn_id = id(self.fn)
|
|
596
|
+
if (self._shape is None) or (self._shape_fn_id != _shape_fn_id):
|
|
597
|
+
self._shape = self.data.shape
|
|
598
|
+
self._shape_fn_id = _shape_fn_id
|
|
599
|
+
return self._shape
|
|
600
|
+
|
|
601
|
+
@property
|
|
602
|
+
def ndim(self):
|
|
603
|
+
return len(self.shape)
|
|
604
|
+
|
|
605
|
+
def add_function(self, g):
|
|
606
|
+
"""Chain the new function ``g`` on top of current function ``f`` like
|
|
607
|
+
``g(f(params))``.
|
|
608
|
+
"""
|
|
609
|
+
f = self.fn
|
|
610
|
+
self.fn = fn_compose(g, f)
|