spacecore 0.2.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {spacecore-0.2.0 → spacecore-0.3.0}/PKG-INFO +7 -7
- {spacecore-0.2.0 → spacecore-0.3.0}/README.md +6 -6
- {spacecore-0.2.0 → spacecore-0.3.0}/pyproject.toml +4 -1
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/__init__.py +40 -15
- spacecore-0.3.0/spacecore/_batching.py +18 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/_checks.py +19 -2
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/_contextual/__init__.py +1 -15
- spacecore-0.3.0/spacecore/_contextual/_bound.py +57 -0
- spacecore-0.3.0/spacecore/_contextual/_policies.py +16 -0
- spacecore-0.3.0/spacecore/_contextual/_state.py +388 -0
- spacecore-0.3.0/spacecore/_tree.py +26 -0
- spacecore-0.3.0/spacecore/_version.py +3 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/__init__.py +7 -2
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/_context.py +10 -6
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/_ops.py +5 -0
- spacecore-0.3.0/spacecore/backend/jax/__init__.py +14 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/jax/_ops.py +5 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/torch/_ops.py +5 -0
- spacecore-0.3.0/spacecore/functional/_base.py +138 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/functional/_linear.py +55 -11
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/functional/_quadratic.py +37 -14
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linalg/_cg.py +24 -13
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linalg/_expm.py +12 -5
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linalg/_lanczos.py +20 -12
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linalg/_lsqr.py +56 -16
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linalg/_power.py +79 -37
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linop/_algebra.py +261 -92
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linop/_base.py +56 -69
- spacecore-0.3.0/spacecore/linop/_dense.py +310 -0
- spacecore-0.3.0/spacecore/linop/_diagonal.py +248 -0
- spacecore-0.3.0/spacecore/linop/_metric.py +117 -0
- spacecore-0.3.0/spacecore/linop/_sparse.py +345 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linop/product/_base.py +8 -2
- spacecore-0.3.0/spacecore/linop/product/_block.py +125 -0
- spacecore-0.3.0/spacecore/linop/product/_from_single.py +179 -0
- spacecore-0.3.0/spacecore/linop/product/_to_single.py +179 -0
- spacecore-0.3.0/spacecore/space/__init__.py +69 -0
- spacecore-0.3.0/spacecore/space/_structure.py +118 -0
- spacecore-0.3.0/spacecore/space/base/__init__.py +24 -0
- spacecore-0.3.0/spacecore/space/base/_coordinate.py +72 -0
- spacecore-0.3.0/spacecore/space/base/_inner_product.py +154 -0
- spacecore-0.3.0/spacecore/space/base/_jordan.py +51 -0
- spacecore-0.3.0/spacecore/space/base/_space.py +49 -0
- spacecore-0.3.0/spacecore/space/base/_star.py +21 -0
- spacecore-0.3.0/spacecore/space/base/_vector.py +33 -0
- {spacecore-0.2.0/spacecore/space → spacecore-0.3.0/spacecore/space/checks}/__init__.py +3 -13
- spacecore-0.2.0/spacecore/space/_checks.py → spacecore-0.3.0/spacecore/space/checks/_base.py +119 -18
- spacecore-0.3.0/spacecore/space/checks/_coordinate.py +3 -0
- spacecore-0.3.0/spacecore/space/checks/_product.py +3 -0
- spacecore-0.3.0/spacecore/space/concrete/__init__.py +16 -0
- spacecore-0.3.0/spacecore/space/concrete/_dense_coordinate.py +102 -0
- spacecore-0.3.0/spacecore/space/concrete/_dense_vector.py +226 -0
- spacecore-0.2.0/spacecore/space/_herm.py → spacecore-0.3.0/spacecore/space/concrete/_hermitian.py +59 -14
- spacecore-0.3.0/spacecore/space/concrete/_product.py +672 -0
- spacecore-0.3.0/spacecore/space/concrete/_stacked.py +391 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore.egg-info/PKG-INFO +7 -7
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore.egg-info/SOURCES.txt +22 -7
- spacecore-0.2.0/spacecore/_contextual/_bound.py +0 -201
- spacecore-0.2.0/spacecore/_contextual/_manager.py +0 -266
- spacecore-0.2.0/spacecore/_contextual/_policies.py +0 -63
- spacecore-0.2.0/spacecore/_contextual/_state.py +0 -512
- spacecore-0.2.0/spacecore/backend/jax/__init__.py +0 -4
- spacecore-0.2.0/spacecore/functional/_base.py +0 -168
- spacecore-0.2.0/spacecore/linop/_dense.py +0 -280
- spacecore-0.2.0/spacecore/linop/_diagonal.py +0 -163
- spacecore-0.2.0/spacecore/linop/_sparse.py +0 -274
- spacecore-0.2.0/spacecore/linop/product/_block.py +0 -108
- spacecore-0.2.0/spacecore/linop/product/_from_single.py +0 -118
- spacecore-0.2.0/spacecore/linop/product/_to_single.py +0 -118
- spacecore-0.2.0/spacecore/space/_base.py +0 -145
- spacecore-0.2.0/spacecore/space/_batch.py +0 -209
- spacecore-0.2.0/spacecore/space/_product.py +0 -280
- spacecore-0.2.0/spacecore/space/_vector.py +0 -151
- {spacecore-0.2.0 → spacecore-0.3.0}/LICENSE +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/setup.cfg +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/_family.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/cupy/__init__.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/cupy/_ops.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/jax/_pytree.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/numpy/__init__.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/numpy/_ops.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/backend/torch/__init__.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/functional/__init__.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/functional/_composed.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linalg/__init__.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linalg/_utils.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linop/__init__.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/linop/product/__init__.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/types/__init__.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/types/_array.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/types/_dtype.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore/types/_misc.py +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore.egg-info/dependency_links.txt +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore.egg-info/requires.txt +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/spacecore.egg-info/top_level.txt +0 -0
- {spacecore-0.2.0 → spacecore-0.3.0}/tests/test_backend_ops_complex.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: spacecore
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Backend-agnostic vector spaces and linear operators.
|
|
5
5
|
Author: Pavlo Pelikh
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -56,7 +56,7 @@ import numpy as np
|
|
|
56
56
|
|
|
57
57
|
# Define a space, a linear operator, and solve Ax = b
|
|
58
58
|
ctx = sc.Context(sc.NumpyOps(), dtype=np.float64)
|
|
59
|
-
X = sc.
|
|
59
|
+
X = sc.DenseCoordinateSpace((100,), ctx)
|
|
60
60
|
A = sc.DenseLinOp(np.random.randn(100, 100) @ np.random.randn(100, 100).T + np.eye(100), X, X, ctx)
|
|
61
61
|
b = ctx.asarray(np.random.randn(100))
|
|
62
62
|
|
|
@@ -116,10 +116,10 @@ result = sc.lanczos_smallest(A, initial_vector, max_iter=50)
|
|
|
116
116
|
print(f"E_0 = {result.eigenvalue}, converged={result.converged}")
|
|
117
117
|
```
|
|
118
118
|
|
|
119
|
-
**3. Custom Hilbert spaces with non-Euclidean geometry.** Subclass `
|
|
119
|
+
**3. Custom Hilbert spaces with non-Euclidean geometry.** Subclass `DenseCoordinateSpace`, override `inner`, and every solver respects your geometry:
|
|
120
120
|
|
|
121
121
|
```python
|
|
122
|
-
class WeightedL2(sc.
|
|
122
|
+
class WeightedL2(sc.DenseCoordinateSpace):
|
|
123
123
|
def __init__(self, shape, weights, ctx=None):
|
|
124
124
|
super().__init__(shape, ctx)
|
|
125
125
|
self.weights = self.ctx.asarray(weights)
|
|
@@ -140,7 +140,7 @@ This is the basis for RKHS spaces, truncated Fock spaces (quantum many-body), fu
|
|
|
140
140
|
import spacecore as sc
|
|
141
141
|
|
|
142
142
|
ctx = sc.Context(sc.NumpyOps(), dtype=np.float64)
|
|
143
|
-
X = sc.
|
|
143
|
+
X = sc.DenseCoordinateSpace((1000,), ctx)
|
|
144
144
|
A = sc.DenseLinOp(make_spd_matrix(), X, X, ctx)
|
|
145
145
|
b = ctx.asarray(rhs)
|
|
146
146
|
|
|
@@ -203,7 +203,7 @@ This operator works on NumPy, JAX, and PyTorch backends without modification.
|
|
|
203
203
|
|
|
204
204
|
## Features at a glance
|
|
205
205
|
|
|
206
|
-
**Spaces.** `
|
|
206
|
+
**Spaces.** `DenseCoordinateSpace`, `DenseVectorSpace`, `ElementwiseJordanSpace`, `EuclideanElementwiseJordanSpace`, `HermitianSpace`, `ProductSpace`, and `StackedSpace`. Generic dense spaces can use custom inner products; `DenseVectorSpace` has no Jordan capability by default; real Euclidean elementwise spaces get the Euclidean-Jordan capability.
|
|
207
207
|
|
|
208
208
|
**Linear operators.** `DenseLinOp`, `SparseLinOp`, `DiagonalLinOp`, `MatrixFreeLinOp`, plus operator algebra (`A @ B`, `A + B`, `2 * A`, `A.H`, `IdentityLinOp`, `ZeroLinOp`).
|
|
209
209
|
|
|
@@ -215,7 +215,7 @@ This operator works on NumPy, JAX, and PyTorch backends without modification.
|
|
|
215
215
|
|
|
216
216
|
## Project status
|
|
217
217
|
|
|
218
|
-
**v0.
|
|
218
|
+
**v0.3 alpha.** API may still change in minor ways. Core abstractions are stable. Suitable for research code; not yet recommended for production deployment.
|
|
219
219
|
|
|
220
220
|
The library is being developed in the open and is looking for early users and feedback. If you try it on your problem, please open an issue with what worked and what didn't — that's the single most valuable contribution right now.
|
|
221
221
|
|
|
@@ -15,7 +15,7 @@ import numpy as np
|
|
|
15
15
|
|
|
16
16
|
# Define a space, a linear operator, and solve Ax = b
|
|
17
17
|
ctx = sc.Context(sc.NumpyOps(), dtype=np.float64)
|
|
18
|
-
X = sc.
|
|
18
|
+
X = sc.DenseCoordinateSpace((100,), ctx)
|
|
19
19
|
A = sc.DenseLinOp(np.random.randn(100, 100) @ np.random.randn(100, 100).T + np.eye(100), X, X, ctx)
|
|
20
20
|
b = ctx.asarray(np.random.randn(100))
|
|
21
21
|
|
|
@@ -75,10 +75,10 @@ result = sc.lanczos_smallest(A, initial_vector, max_iter=50)
|
|
|
75
75
|
print(f"E_0 = {result.eigenvalue}, converged={result.converged}")
|
|
76
76
|
```
|
|
77
77
|
|
|
78
|
-
**3. Custom Hilbert spaces with non-Euclidean geometry.** Subclass `
|
|
78
|
+
**3. Custom Hilbert spaces with non-Euclidean geometry.** Subclass `DenseCoordinateSpace`, override `inner`, and every solver respects your geometry:
|
|
79
79
|
|
|
80
80
|
```python
|
|
81
|
-
class WeightedL2(sc.
|
|
81
|
+
class WeightedL2(sc.DenseCoordinateSpace):
|
|
82
82
|
def __init__(self, shape, weights, ctx=None):
|
|
83
83
|
super().__init__(shape, ctx)
|
|
84
84
|
self.weights = self.ctx.asarray(weights)
|
|
@@ -99,7 +99,7 @@ This is the basis for RKHS spaces, truncated Fock spaces (quantum many-body), fu
|
|
|
99
99
|
import spacecore as sc
|
|
100
100
|
|
|
101
101
|
ctx = sc.Context(sc.NumpyOps(), dtype=np.float64)
|
|
102
|
-
X = sc.
|
|
102
|
+
X = sc.DenseCoordinateSpace((1000,), ctx)
|
|
103
103
|
A = sc.DenseLinOp(make_spd_matrix(), X, X, ctx)
|
|
104
104
|
b = ctx.asarray(rhs)
|
|
105
105
|
|
|
@@ -162,7 +162,7 @@ This operator works on NumPy, JAX, and PyTorch backends without modification.
|
|
|
162
162
|
|
|
163
163
|
## Features at a glance
|
|
164
164
|
|
|
165
|
-
**Spaces.** `
|
|
165
|
+
**Spaces.** `DenseCoordinateSpace`, `DenseVectorSpace`, `ElementwiseJordanSpace`, `EuclideanElementwiseJordanSpace`, `HermitianSpace`, `ProductSpace`, and `StackedSpace`. Generic dense spaces can use custom inner products; `DenseVectorSpace` has no Jordan capability by default; real Euclidean elementwise spaces get the Euclidean-Jordan capability.
|
|
166
166
|
|
|
167
167
|
**Linear operators.** `DenseLinOp`, `SparseLinOp`, `DiagonalLinOp`, `MatrixFreeLinOp`, plus operator algebra (`A @ B`, `A + B`, `2 * A`, `A.H`, `IdentityLinOp`, `ZeroLinOp`).
|
|
168
168
|
|
|
@@ -174,7 +174,7 @@ This operator works on NumPy, JAX, and PyTorch backends without modification.
|
|
|
174
174
|
|
|
175
175
|
## Project status
|
|
176
176
|
|
|
177
|
-
**v0.
|
|
177
|
+
**v0.3 alpha.** API may still change in minor ways. Core abstractions are stable. Suitable for research code; not yet recommended for production deployment.
|
|
178
178
|
|
|
179
179
|
The library is being developed in the open and is looking for early users and feedback. If you try it on your problem, please open an issue with what worked and what didn't — that's the single most valuable contribution right now.
|
|
180
180
|
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "spacecore"
|
|
7
|
-
|
|
7
|
+
dynamic = ["version"]
|
|
8
8
|
description = "Backend-agnostic vector spaces and linear operators."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.11"
|
|
@@ -68,6 +68,9 @@ dev = [
|
|
|
68
68
|
[tool.setuptools]
|
|
69
69
|
include-package-data = true
|
|
70
70
|
|
|
71
|
+
[tool.setuptools.dynamic]
|
|
72
|
+
version = {attr = "spacecore._version.__version__"}
|
|
73
|
+
|
|
71
74
|
[tool.setuptools.packages.find]
|
|
72
75
|
where = ["."]
|
|
73
76
|
include = ["spacecore*"]
|
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
"""Backend-agnostic vector spaces, linear operators, and solvers."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from ._version import __version__
|
|
4
4
|
|
|
5
|
+
from .backend import Context, BackendOps, NumpyOps, jax_pytree_class
|
|
5
6
|
try:
|
|
6
|
-
|
|
7
|
-
except
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from .backend import Context, BackendOps, JaxOps, NumpyOps, jax_pytree_class
|
|
7
|
+
from .backend import JaxOps as JaxOps
|
|
8
|
+
except ImportError:
|
|
9
|
+
pass
|
|
12
10
|
try:
|
|
13
11
|
from .backend import CuPyOps as CuPyOps
|
|
14
12
|
except ImportError:
|
|
@@ -58,18 +56,34 @@ from .linalg import (
|
|
|
58
56
|
power_iteration,
|
|
59
57
|
)
|
|
60
58
|
from .space import (
|
|
61
|
-
BatchSpace,
|
|
62
59
|
BackendCheck,
|
|
63
60
|
DTypeCheck,
|
|
64
61
|
HermitianCheck,
|
|
65
62
|
ProductComponentCheck,
|
|
63
|
+
CoordinateSpace,
|
|
64
|
+
DenseCoordinateSpace,
|
|
65
|
+
DenseVectorSpace,
|
|
66
|
+
ElementwiseJordanSpace,
|
|
67
|
+
EuclideanElementwiseJordanSpace,
|
|
68
|
+
EuclideanJordanAlgebraSpace,
|
|
69
|
+
InnerProductSpace,
|
|
70
|
+
JordanAlgebraSpace,
|
|
66
71
|
ProductSpace,
|
|
72
|
+
ProductSpectralDecomposition,
|
|
73
|
+
ProductStructure,
|
|
67
74
|
ProductStructureCheck,
|
|
75
|
+
PytreeStructure,
|
|
76
|
+
StackedSpace,
|
|
68
77
|
ShapeCheck,
|
|
78
|
+
InnerProduct,
|
|
79
|
+
EuclideanInnerProduct,
|
|
80
|
+
WeightedInnerProduct,
|
|
69
81
|
Space,
|
|
82
|
+
StarSpace,
|
|
70
83
|
SpaceCheck,
|
|
71
84
|
SpaceValidationError,
|
|
72
85
|
SquareMatrixCheck,
|
|
86
|
+
TupleStructure,
|
|
73
87
|
VectorSpace,
|
|
74
88
|
HermitianSpace,
|
|
75
89
|
)
|
|
@@ -81,12 +95,11 @@ from ._contextual import (
|
|
|
81
95
|
set_context, get_context,
|
|
82
96
|
resolve_context_priority,
|
|
83
97
|
register_ops,
|
|
84
|
-
set_resolution_policy, set_dtype_resolution_policy,
|
|
85
|
-
get_resolution_policy, get_dtype_resolution_policy,
|
|
86
98
|
normalize_ops, normalize_context,
|
|
87
99
|
)
|
|
88
100
|
|
|
89
101
|
__all__ = [
|
|
102
|
+
"__version__",
|
|
90
103
|
"Context",
|
|
91
104
|
|
|
92
105
|
"BackendOps",
|
|
@@ -137,10 +150,26 @@ __all__ = [
|
|
|
137
150
|
"ProductComponentCheck",
|
|
138
151
|
"ProductStructureCheck",
|
|
139
152
|
"ShapeCheck",
|
|
140
|
-
"
|
|
153
|
+
"InnerProduct",
|
|
154
|
+
"EuclideanInnerProduct",
|
|
155
|
+
"WeightedInnerProduct",
|
|
141
156
|
"VectorSpace",
|
|
142
157
|
"HermitianSpace",
|
|
143
158
|
"ProductSpace",
|
|
159
|
+
"ProductStructure",
|
|
160
|
+
"TupleStructure",
|
|
161
|
+
"PytreeStructure",
|
|
162
|
+
"StackedSpace",
|
|
163
|
+
"CoordinateSpace",
|
|
164
|
+
"InnerProductSpace",
|
|
165
|
+
"StarSpace",
|
|
166
|
+
"JordanAlgebraSpace",
|
|
167
|
+
"EuclideanJordanAlgebraSpace",
|
|
168
|
+
"DenseCoordinateSpace",
|
|
169
|
+
"DenseVectorSpace",
|
|
170
|
+
"ElementwiseJordanSpace",
|
|
171
|
+
"EuclideanElementwiseJordanSpace",
|
|
172
|
+
"ProductSpectralDecomposition",
|
|
144
173
|
"Space",
|
|
145
174
|
"SpaceCheck",
|
|
146
175
|
"SpaceValidationError",
|
|
@@ -156,10 +185,6 @@ __all__ = [
|
|
|
156
185
|
"get_context",
|
|
157
186
|
"resolve_context_priority",
|
|
158
187
|
"register_ops",
|
|
159
|
-
"set_resolution_policy",
|
|
160
|
-
"set_dtype_resolution_policy",
|
|
161
|
-
"get_resolution_policy",
|
|
162
|
-
"get_dtype_resolution_policy",
|
|
163
188
|
"normalize_ops",
|
|
164
189
|
"normalize_context",
|
|
165
190
|
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .space.checks import _run_checks
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _check_batched(space: Any, xs: Any) -> None:
|
|
9
|
+
"""Raise if ``xs`` does not have ``space.shape`` as trailing dimensions."""
|
|
10
|
+
_run_checks(space, xs, allow_leading=True)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _batched_inner(space: Any, xs: Any, ys: Any) -> Any:
|
|
14
|
+
"""Return ``space.inner(xs[i], ys[i])`` for a leading-axis batch."""
|
|
15
|
+
xs_flat = space.flatten_batch(xs)
|
|
16
|
+
ys_dual = ys if space.is_euclidean else space.riesz(ys)
|
|
17
|
+
ys_flat = space.flatten_batch(ys_dual)
|
|
18
|
+
return space.ops.sum(space.ops.conj(xs_flat) * ys_flat, axis=1)
|
|
@@ -29,6 +29,8 @@ def checked_method(
|
|
|
29
29
|
out_space: str | None = None,
|
|
30
30
|
arg_pos: int | None = None,
|
|
31
31
|
arg_positions: int | tuple[int, ...] | None = None,
|
|
32
|
+
in_batched: bool = False,
|
|
33
|
+
out_batched: bool = False,
|
|
32
34
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
33
35
|
"""
|
|
34
36
|
Build a decorator that validates method inputs and outputs against spaces.
|
|
@@ -48,6 +50,10 @@ def checked_method(
|
|
|
48
50
|
arg_positions : int, tuple of int, or None, optional
|
|
49
51
|
Zero-based positions in ``*args`` of input values that should be checked
|
|
50
52
|
against ``in_space``. Defaults to ``(0,)``.
|
|
53
|
+
in_batched : bool, optional
|
|
54
|
+
Validate inputs as leading-axis batches instead of single elements.
|
|
55
|
+
out_batched : bool, optional
|
|
56
|
+
Validate outputs as leading-axis batches instead of single elements.
|
|
51
57
|
|
|
52
58
|
Returns
|
|
53
59
|
-------
|
|
@@ -64,12 +70,23 @@ def checked_method(
|
|
|
64
70
|
if self._enable_checks and in_space is not None:
|
|
65
71
|
check_target = _space_target(self, in_space)
|
|
66
72
|
for pos in positions:
|
|
67
|
-
|
|
73
|
+
if in_batched:
|
|
74
|
+
from ._batching import _check_batched
|
|
75
|
+
|
|
76
|
+
_check_batched(check_target, args[pos])
|
|
77
|
+
else:
|
|
78
|
+
check_target._check_member(args[pos])
|
|
68
79
|
|
|
69
80
|
y = method(self, *args, **kwargs)
|
|
70
81
|
|
|
71
82
|
if self._enable_checks and out_space is not None:
|
|
72
|
-
_space_target(self, out_space)
|
|
83
|
+
check_target = _space_target(self, out_space)
|
|
84
|
+
if out_batched:
|
|
85
|
+
from ._batching import _check_batched
|
|
86
|
+
|
|
87
|
+
_check_batched(check_target, y)
|
|
88
|
+
else:
|
|
89
|
+
check_target._check_member(y)
|
|
73
90
|
|
|
74
91
|
return y
|
|
75
92
|
|
|
@@ -1,45 +1,31 @@
|
|
|
1
1
|
from ._bound import ContextBound as ContextBound
|
|
2
|
-
from .
|
|
2
|
+
from ._state import (
|
|
3
3
|
enforce_convert_policy as enforce_convert_policy,
|
|
4
4
|
get_context as get_context,
|
|
5
|
-
get_dtype_resolution_policy as get_dtype_resolution_policy,
|
|
6
|
-
get_resolution_policy as get_resolution_policy,
|
|
7
5
|
normalize_context as normalize_context,
|
|
8
6
|
normalize_ops as normalize_ops,
|
|
9
7
|
register_ops as register_ops,
|
|
10
8
|
resolve_context_priority as resolve_context_priority,
|
|
11
9
|
set_context as set_context,
|
|
12
|
-
set_dtype_resolution_policy as set_dtype_resolution_policy,
|
|
13
|
-
set_resolution_policy as set_resolution_policy,
|
|
14
10
|
)
|
|
15
11
|
from ._policies import (
|
|
16
12
|
ContextConflictError as ContextConflictError,
|
|
17
|
-
ContextConversionError as ContextConversionError,
|
|
18
13
|
ContextError as ContextError,
|
|
19
14
|
ContextInferenceError as ContextInferenceError,
|
|
20
|
-
ContextPolicy as ContextPolicy,
|
|
21
|
-
DtypePreservePolicy as DtypePreservePolicy,
|
|
22
15
|
UnknownBackendError as UnknownBackendError,
|
|
23
16
|
)
|
|
24
17
|
|
|
25
18
|
__all__ = [
|
|
26
19
|
"ContextBound",
|
|
27
20
|
"ContextConflictError",
|
|
28
|
-
"ContextConversionError",
|
|
29
21
|
"ContextError",
|
|
30
22
|
"ContextInferenceError",
|
|
31
|
-
"ContextPolicy",
|
|
32
|
-
"DtypePreservePolicy",
|
|
33
23
|
"UnknownBackendError",
|
|
34
24
|
"enforce_convert_policy",
|
|
35
25
|
"get_context",
|
|
36
|
-
"get_dtype_resolution_policy",
|
|
37
|
-
"get_resolution_policy",
|
|
38
26
|
"normalize_context",
|
|
39
27
|
"normalize_ops",
|
|
40
28
|
"register_ops",
|
|
41
29
|
"resolve_context_priority",
|
|
42
30
|
"set_context",
|
|
43
|
-
"set_dtype_resolution_policy",
|
|
44
|
-
"set_resolution_policy",
|
|
45
31
|
]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import TYPE_CHECKING, Self
|
|
5
|
+
|
|
6
|
+
from ..types import DType
|
|
7
|
+
from ._state import enforce_convert_policy, normalize_context
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from ..backend import BackendFamily, BackendOps, Context
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _same_math_context(left: Context, right: Context) -> bool:
|
|
14
|
+
"""Return whether contexts match for algebra, ignoring validation checks."""
|
|
15
|
+
return left.ops == right.ops and left.dtype == right.dtype
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ContextBound(ABC):
|
|
19
|
+
"""
|
|
20
|
+
Base class for objects bound to a SpaceCore execution context.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
ctx : Context, str, or None, optional
|
|
25
|
+
Context specification used to resolve backend operations, dtype, and
|
|
26
|
+
validation policy.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, ctx: Context | str | None = None):
|
|
30
|
+
ctx = normalize_context(ctx)
|
|
31
|
+
self._ctx = ctx
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def ops(self) -> BackendOps:
|
|
35
|
+
"""Return backend operations associated with this object's context."""
|
|
36
|
+
return self.ctx.ops
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def dtype(self) -> DType:
|
|
40
|
+
"""Return the default dtype associated with this object's context."""
|
|
41
|
+
return self.ctx.dtype
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def ctx(self) -> Context:
|
|
45
|
+
"""Return the execution context bound to this object."""
|
|
46
|
+
return self._ctx
|
|
47
|
+
|
|
48
|
+
def _convert(self, new_ctx: Context) -> Self:
|
|
49
|
+
"""Rebuild this object in ``new_ctx``."""
|
|
50
|
+
raise NotImplementedError()
|
|
51
|
+
|
|
52
|
+
def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self:
|
|
53
|
+
"""Return this object represented in ``new_ctx``."""
|
|
54
|
+
_, new_ctx = enforce_convert_policy(self, new_ctx)
|
|
55
|
+
if self.ctx == new_ctx:
|
|
56
|
+
return self
|
|
57
|
+
return self._convert(new_ctx)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
class ContextError(RuntimeError):
|
|
4
|
+
pass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ContextInferenceError(ContextError):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContextConflictError(ContextError):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UnknownBackendError(ContextError):
|
|
16
|
+
pass
|