spacecore 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacecore/__init__.py +46 -0
- spacecore/_contextual/__init__.py +10 -0
- spacecore/_contextual/bound.py +32 -0
- spacecore/_contextual/contextual.py +354 -0
- spacecore/_contextual/manager.py +43 -0
- spacecore/backend/__init__.py +15 -0
- spacecore/backend/_context.py +50 -0
- spacecore/backend/_family.py +6 -0
- spacecore/backend/_ops.py +354 -0
- spacecore/backend/jax/__init__.py +2 -0
- spacecore/backend/jax/_ops.py +679 -0
- spacecore/backend/jax/_pytree.py +20 -0
- spacecore/backend/numpy/__init__.py +1 -0
- spacecore/backend/numpy/_ops.py +956 -0
- spacecore/linop/__init__.py +14 -0
- spacecore/linop/_base.py +72 -0
- spacecore/linop/_dense.py +100 -0
- spacecore/linop/_sparse.py +98 -0
- spacecore/linop/product/__init__.py +11 -0
- spacecore/linop/product/_base.py +61 -0
- spacecore/linop/product/_block.py +57 -0
- spacecore/linop/product/_from_single.py +63 -0
- spacecore/linop/product/_to_single.py +63 -0
- spacecore/space/__init__.py +11 -0
- spacecore/space/_base.py +91 -0
- spacecore/space/_herm.py +110 -0
- spacecore/space/_product.py +142 -0
- spacecore/space/_vector.py +66 -0
- spacecore/types/__init__.py +16 -0
- spacecore/types/_array.py +60 -0
- spacecore/types/_dtype.py +3 -0
- spacecore/types/_misc.py +14 -0
- spacecore-0.1.1.dist-info/METADATA +121 -0
- spacecore-0.1.1.dist-info/RECORD +37 -0
- spacecore-0.1.1.dist-info/WHEEL +5 -0
- spacecore-0.1.1.dist-info/licenses/LICENSE +201 -0
- spacecore-0.1.1.dist-info/top_level.txt +1 -0
spacecore/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
__version__ = "0.1.1"
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from .backend import Context, BackendOps, JaxOps, NumpyOps, jax_pytree_class
|
|
5
|
+
from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp
|
|
6
|
+
from .space import VectorSpace, HermitianSpace, Space, ProductSpace
|
|
7
|
+
from .types import DenseArray, SparseArray, ArrayLike
|
|
8
|
+
|
|
9
|
+
from ._contextual.manager import (
|
|
10
|
+
set_context, get_context,
|
|
11
|
+
register_ops,
|
|
12
|
+
set_resolution_policy, set_dtype_resolution_policy,
|
|
13
|
+
get_resolution_policy, get_dtype_resolution_policy
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"Context",
|
|
18
|
+
|
|
19
|
+
"BackendOps",
|
|
20
|
+
"JaxOps",
|
|
21
|
+
"jax_pytree_class",
|
|
22
|
+
"NumpyOps",
|
|
23
|
+
|
|
24
|
+
"DenseLinOp",
|
|
25
|
+
"SparseLinOp",
|
|
26
|
+
"BlockDiagonalLinOp",
|
|
27
|
+
"SumToSingleLinOp",
|
|
28
|
+
"StackedLinOp",
|
|
29
|
+
|
|
30
|
+
"VectorSpace",
|
|
31
|
+
"HermitianSpace",
|
|
32
|
+
"ProductSpace",
|
|
33
|
+
"Space",
|
|
34
|
+
|
|
35
|
+
"DenseArray",
|
|
36
|
+
"SparseArray",
|
|
37
|
+
"ArrayLike",
|
|
38
|
+
|
|
39
|
+
"set_context",
|
|
40
|
+
"get_context",
|
|
41
|
+
"register_ops",
|
|
42
|
+
"set_resolution_policy",
|
|
43
|
+
"set_dtype_resolution_policy",
|
|
44
|
+
"get_resolution_policy",
|
|
45
|
+
"get_dtype_resolution_policy",
|
|
46
|
+
]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .bound import ContextBound as ContextBound
|
|
2
|
+
from .manager import (
|
|
3
|
+
ctx_manager as ctx_manager,
|
|
4
|
+
set_context as set_context,
|
|
5
|
+
register_ops as register_ops,
|
|
6
|
+
set_resolution_policy as set_resolution_policy,
|
|
7
|
+
set_dtype_resolution_policy as set_dtype_resolution_policy,
|
|
8
|
+
get_resolution_policy as get_resolution_policy,
|
|
9
|
+
get_dtype_resolution_policy as get_dtype_resolution_policy,
|
|
10
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import Self
|
|
5
|
+
|
|
6
|
+
from ..backend import Context, BackendOps, BackendFamily
|
|
7
|
+
from ..types import DType
|
|
8
|
+
from .manager import ctx_manager
|
|
9
|
+
|
|
10
|
+
class ContextBound(ABC):
|
|
11
|
+
def __init__(self, ctx: Context | str | None = None):
|
|
12
|
+
ctx = ctx_manager.normalize_context(ctx)
|
|
13
|
+
self._ctx = ctx
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def ops(self) -> BackendOps:
|
|
17
|
+
return self.ctx.ops
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def dtype(self) -> DType:
|
|
21
|
+
return self.ctx.dtype
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def ctx(self) -> Context:
|
|
25
|
+
return self._ctx
|
|
26
|
+
|
|
27
|
+
def _convert(self, new_ctx: Context) -> Self:
|
|
28
|
+
raise NotImplementedError()
|
|
29
|
+
|
|
30
|
+
def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self:
|
|
31
|
+
_, new_ctx = ctx_manager.enforce_convert_policy(self, new_ctx)
|
|
32
|
+
return self._convert(new_ctx)
|
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Any, Iterable, Tuple
|
|
4
|
+
from enum import StrEnum, auto
|
|
5
|
+
from warnings import warn
|
|
6
|
+
|
|
7
|
+
from ..types import DType
|
|
8
|
+
from ..backend import Context, NumpyOps, JaxOps, BackendFamily, BackendOps
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContextPolicy(StrEnum):
|
|
12
|
+
warning = auto()
|
|
13
|
+
error = auto()
|
|
14
|
+
silent = auto()
|
|
15
|
+
|
|
16
|
+
class DtypePreservePolicy(StrEnum):
|
|
17
|
+
keep_native = auto()
|
|
18
|
+
convert = auto()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ContextError(RuntimeError):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ContextInferenceError(ContextError):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ContextConflictError(ContextError):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UnknownBackendError(ContextError):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
class ContextConversionError(ContextError):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Contextual:
|
|
41
|
+
"""
|
|
42
|
+
Backend resolver.
|
|
43
|
+
"""
|
|
44
|
+
_default_ctx: Context
|
|
45
|
+
_available_ops: Dict[str, type[BackendOps]]
|
|
46
|
+
_resolution_policy: ContextPolicy
|
|
47
|
+
|
|
48
|
+
_default_policy: ContextPolicy = ContextPolicy.warning
|
|
49
|
+
_default_dtype_resolution_policy: DtypePreservePolicy = DtypePreservePolicy.keep_native
|
|
50
|
+
_default_dtype: DType | None = None
|
|
51
|
+
_default_enable_checks: bool = False
|
|
52
|
+
|
|
53
|
+
def __init__(self,
|
|
54
|
+
resolution_policy: str | ContextPolicy | None = None,
|
|
55
|
+
dtype_resolution_policy: str | DtypePreservePolicy | None = None
|
|
56
|
+
) -> None:
|
|
57
|
+
ops = NumpyOps()
|
|
58
|
+
self.default_ctx = Context(
|
|
59
|
+
ops=ops,
|
|
60
|
+
dtype=ops.sanitize_dtype(self._default_dtype),
|
|
61
|
+
enable_checks=self._default_enable_checks
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self._available_ops = {
|
|
65
|
+
self._backend_key(NumpyOps): NumpyOps,
|
|
66
|
+
self._backend_key(JaxOps): JaxOps,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
self.resolution_policy = resolution_policy
|
|
70
|
+
self.dtype_resolution_policy = dtype_resolution_policy
|
|
71
|
+
|
|
72
|
+
def normalize_context(self,
|
|
73
|
+
ctx: Context | BackendFamily | str | None = None,
|
|
74
|
+
dtype: Any = None,
|
|
75
|
+
enable_checks: bool | None = None
|
|
76
|
+
) -> Context:
|
|
77
|
+
if ctx is None:
|
|
78
|
+
if dtype is not None or enable_checks is not None:
|
|
79
|
+
warn(
|
|
80
|
+
'Provided context is None, dtype and enable_checks parameters are ignored.',
|
|
81
|
+
UserWarning,
|
|
82
|
+
)
|
|
83
|
+
return self.default_ctx
|
|
84
|
+
if isinstance(ctx, Context):
|
|
85
|
+
if dtype is not None or enable_checks is not None:
|
|
86
|
+
warn(
|
|
87
|
+
'Provided concrete context, dtype and enable_checks parameters are ignored.',
|
|
88
|
+
UserWarning,
|
|
89
|
+
)
|
|
90
|
+
return Context(
|
|
91
|
+
ops=ctx.ops,
|
|
92
|
+
dtype=ctx.ops.sanitize_dtype(ctx.dtype),
|
|
93
|
+
enable_checks=ctx.enable_checks
|
|
94
|
+
)
|
|
95
|
+
if isinstance(ctx, (str, BackendFamily)):
|
|
96
|
+
ctx = self._backend_key(ctx)
|
|
97
|
+
ops = self.get_ops(ctx)
|
|
98
|
+
return self.ctx_from_ops(ops, dtype=dtype, enable_checks=enable_checks)
|
|
99
|
+
else:
|
|
100
|
+
raise TypeError(f'Expected Context, BackendFamily, str, or None, got {type(ctx)}.')
|
|
101
|
+
|
|
102
|
+
def ctx_like(self, base: Context | None, ctx: Context) -> Context:
|
|
103
|
+
if isinstance(base, Context):
|
|
104
|
+
return Context(
|
|
105
|
+
ops=ctx.ops,
|
|
106
|
+
dtype=ctx.ops.sanitize_dtype(base.dtype),
|
|
107
|
+
enable_checks=ctx.enable_checks,
|
|
108
|
+
)
|
|
109
|
+
return self.default_ctx
|
|
110
|
+
|
|
111
|
+
def normalize_context_like(self, base: Context | None, ctx: Context | BackendFamily | str | None = None) -> Context:
|
|
112
|
+
if self.dtype_resolution_policy is DtypePreservePolicy.keep_native and isinstance(base, Context):
|
|
113
|
+
if ctx is None:
|
|
114
|
+
return self.ctx_like(base, self.default_ctx)
|
|
115
|
+
if isinstance(ctx, Context):
|
|
116
|
+
return self.ctx_like(base, ctx)
|
|
117
|
+
if isinstance(ctx, (str, BackendFamily)):
|
|
118
|
+
ctx = self._backend_key(ctx)
|
|
119
|
+
ops = self.get_ops(ctx)
|
|
120
|
+
return self.ctx_from_ops(
|
|
121
|
+
ops=ops,
|
|
122
|
+
dtype=base.dtype,
|
|
123
|
+
enable_checks=self._default_enable_checks,
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
raise TypeError(f'Expected Context, BackendFamily, str, or None, got {type(ctx)}.')
|
|
127
|
+
else:
|
|
128
|
+
return self.normalize_context(ctx)
|
|
129
|
+
|
|
130
|
+
def ctx_from_ops(self, ops: BackendOps, dtype: DType | None = None, enable_checks: bool | None = None) -> Context:
|
|
131
|
+
dtype = ops.sanitize_dtype(dtype)
|
|
132
|
+
if enable_checks is None:
|
|
133
|
+
enable_checks = self._default_enable_checks
|
|
134
|
+
return Context(ops=ops,
|
|
135
|
+
dtype=dtype,
|
|
136
|
+
enable_checks=enable_checks)
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def default_ctx(self) -> Context:
|
|
140
|
+
return self._default_ctx
|
|
141
|
+
|
|
142
|
+
@default_ctx.setter
|
|
143
|
+
def default_ctx(self, ctx: Context | BackendFamily | str | None = None) -> None:
|
|
144
|
+
ctx = self.normalize_context(ctx)
|
|
145
|
+
self._default_ctx = ctx
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def resolution_policy(self) -> ContextPolicy:
|
|
149
|
+
return self._resolution_policy
|
|
150
|
+
|
|
151
|
+
@resolution_policy.setter
|
|
152
|
+
def resolution_policy(self, policy: str | None = ContextPolicy.warning.value) -> None:
|
|
153
|
+
if policy is None:
|
|
154
|
+
self._resolution_policy = self._default_policy
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
self._resolution_policy = (
|
|
159
|
+
policy
|
|
160
|
+
if isinstance(policy, ContextPolicy)
|
|
161
|
+
else ContextPolicy(policy)
|
|
162
|
+
)
|
|
163
|
+
except ValueError as e:
|
|
164
|
+
allowed = ", ".join(p.value for p in ContextPolicy)
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"Unknown resolution_policy={policy!r}. "
|
|
167
|
+
f"Expected one of: {allowed}"
|
|
168
|
+
) from e
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def dtype_resolution_policy(self) -> DtypePreservePolicy:
|
|
172
|
+
return self._dtype_resolution_policy
|
|
173
|
+
|
|
174
|
+
@dtype_resolution_policy.setter
|
|
175
|
+
def dtype_resolution_policy(self, policy: str | None = DtypePreservePolicy.keep_native.value) -> None:
|
|
176
|
+
if policy is None:
|
|
177
|
+
self._dtype_resolution_policy = self._default_dtype_resolution_policy
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
self._dtype_resolution_policy = (
|
|
182
|
+
policy
|
|
183
|
+
if isinstance(policy, DtypePreservePolicy)
|
|
184
|
+
else DtypePreservePolicy(policy)
|
|
185
|
+
)
|
|
186
|
+
except ValueError as e:
|
|
187
|
+
allowed = ", ".join(p.value for p in DtypePreservePolicy)
|
|
188
|
+
raise ValueError(
|
|
189
|
+
f"Unknown dtype_resolution_policy={policy!r}. "
|
|
190
|
+
f"Expected one of: {allowed}"
|
|
191
|
+
) from e
|
|
192
|
+
|
|
193
|
+
def get_ops(self, name: str | BackendFamily | BackendOps | type[BackendOps] | Context) -> BackendOps:
|
|
194
|
+
name = self._backend_key(name)
|
|
195
|
+
if name not in self.available_ops:
|
|
196
|
+
allowed = ", ".join(k for k in self.available_ops.keys())
|
|
197
|
+
raise UnknownBackendError(
|
|
198
|
+
f"Unknown backend: {name!r}. "
|
|
199
|
+
f"Expected one of: {allowed}"
|
|
200
|
+
)
|
|
201
|
+
return self.available_ops[name]()
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def available_ops(self) -> Dict[str, type[BackendOps]]:
|
|
205
|
+
return self._available_ops
|
|
206
|
+
|
|
207
|
+
def register_ops(self, ops: type[BackendOps]) -> type[BackendOps]:
|
|
208
|
+
if not isinstance(ops, type) or not issubclass(ops, BackendOps):
|
|
209
|
+
raise TypeError(f"Expected type[BackendOps], got {type(ops)!r}")
|
|
210
|
+
else:
|
|
211
|
+
family = self._backend_key(ops)
|
|
212
|
+
if family in self.available_ops.keys():
|
|
213
|
+
raise ContextConflictError(f'BackendOps {family} is already registered.')
|
|
214
|
+
self._available_ops[family] = ops
|
|
215
|
+
return ops
|
|
216
|
+
|
|
217
|
+
def infer_context(self, x: Any, enable_checks: bool | None = None) -> Context | None:
|
|
218
|
+
"""
|
|
219
|
+
Infer a context from a single value.
|
|
220
|
+
|
|
221
|
+
Intended precedence:
|
|
222
|
+
1. objects carrying `.ctx`
|
|
223
|
+
2. backend-native arrays recognized by registered backends
|
|
224
|
+
"""
|
|
225
|
+
if isinstance(x, Context):
|
|
226
|
+
return x
|
|
227
|
+
|
|
228
|
+
ctx = getattr(x, "ctx", None)
|
|
229
|
+
if isinstance(ctx, Context):
|
|
230
|
+
return ctx
|
|
231
|
+
|
|
232
|
+
matched: list[BackendOps] = []
|
|
233
|
+
for name, ops in self.available_ops.items():
|
|
234
|
+
try:
|
|
235
|
+
ops = ops()
|
|
236
|
+
if ops.is_array(x):
|
|
237
|
+
matched.append(ops)
|
|
238
|
+
except Exception:
|
|
239
|
+
# Keep inference conservative.
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
if not matched:
|
|
243
|
+
return None
|
|
244
|
+
if len(matched) > 1:
|
|
245
|
+
raise ContextInferenceError(
|
|
246
|
+
f"Ambiguous backend inference for object of type {type(x)!r}: {matched!r}."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
ops = matched[0]
|
|
250
|
+
try:
|
|
251
|
+
dtype = ops.get_dtype(x)
|
|
252
|
+
except Exception:
|
|
253
|
+
dtype = getattr(x, "dtype", self.default_ctx.dtype)
|
|
254
|
+
|
|
255
|
+
return self.ctx_from_ops(ops, dtype, enable_checks)
|
|
256
|
+
|
|
257
|
+
def infer_contexts(self, values: Iterable[Any]) -> Tuple[Context, ...]:
|
|
258
|
+
out: list[Context] = []
|
|
259
|
+
for x in values:
|
|
260
|
+
ctx = self.infer_context(x)
|
|
261
|
+
if ctx is not None:
|
|
262
|
+
out.append(ctx)
|
|
263
|
+
return tuple(out)
|
|
264
|
+
|
|
265
|
+
def are_compatible_contexts(self, *ctxs: Context) -> bool:
|
|
266
|
+
if len(ctxs) < 2:
|
|
267
|
+
return True
|
|
268
|
+
first = ctxs[0]
|
|
269
|
+
return all(ctx.ops.family == first.ops.family for ctx in ctxs[1:])
|
|
270
|
+
|
|
271
|
+
def are_compatible_values(self, *values: Any) -> bool:
|
|
272
|
+
return self.are_compatible_contexts(*self.infer_contexts(values))
|
|
273
|
+
|
|
274
|
+
def are_compatible_ops(self, *ops: BackendOps) -> bool:
|
|
275
|
+
if not ops:
|
|
276
|
+
return True
|
|
277
|
+
first = ops[0]
|
|
278
|
+
return all(op.family == first.family for op in ops)
|
|
279
|
+
|
|
280
|
+
def enforce_convert_policy(self, x: Any, to: Context | BackendFamily | str | None = None) -> Tuple[Any, Context]:
|
|
281
|
+
native_ctx = self.infer_context(x)
|
|
282
|
+
ctx = self.normalize_context_like(native_ctx, to)
|
|
283
|
+
if self.resolution_policy is not ContextPolicy.silent:
|
|
284
|
+
if native_ctx is not None and not self.are_compatible_contexts(native_ctx, ctx):
|
|
285
|
+
if self.resolution_policy is ContextPolicy.warning:
|
|
286
|
+
warn(
|
|
287
|
+
f"Converting from {native_ctx!r} to {ctx!r}.",
|
|
288
|
+
UserWarning,
|
|
289
|
+
)
|
|
290
|
+
else:
|
|
291
|
+
raise ContextConversionError(
|
|
292
|
+
f"Conversion from {native_ctx!r} to {ctx!r} is forbidden by policy {self.resolution_policy.value!r}."
|
|
293
|
+
)
|
|
294
|
+
return x, ctx
|
|
295
|
+
|
|
296
|
+
def _backend_key(self, x: str | BackendFamily | BackendOps | type[BackendOps] | Context) -> str:
|
|
297
|
+
if isinstance(x, Context):
|
|
298
|
+
return self._backend_key(x.ops)
|
|
299
|
+
if isinstance(x, BackendOps):
|
|
300
|
+
return self._backend_key(x.family)
|
|
301
|
+
if isinstance(x, type) and issubclass(x, BackendOps):
|
|
302
|
+
return self._backend_key(x._family)
|
|
303
|
+
if isinstance(x, BackendFamily):
|
|
304
|
+
return x.value.lower()
|
|
305
|
+
if isinstance(x, str):
|
|
306
|
+
return x.lower()
|
|
307
|
+
raise TypeError(f"Unsupported backend key source: {type(x)!r}")
|
|
308
|
+
|
|
309
|
+
def resolve_context_priority(
|
|
310
|
+
self,
|
|
311
|
+
priority_ctx: Context | BackendFamily | str | None = None,
|
|
312
|
+
*other_ctx: object,
|
|
313
|
+
) -> Context:
|
|
314
|
+
"""
|
|
315
|
+
Resolve context with priority:
|
|
316
|
+
|
|
317
|
+
1. If priority_ctx is not None, return its normalized form.
|
|
318
|
+
2. Otherwise, infer contexts from other_ctx and return the best inferred one.
|
|
319
|
+
3. If inference fails, return default_ctx.
|
|
320
|
+
|
|
321
|
+
Policy:
|
|
322
|
+
- explicit context always wins
|
|
323
|
+
- inferred contexts must be family-compatible
|
|
324
|
+
- if inferred dtypes differ, take the most general of them
|
|
325
|
+
"""
|
|
326
|
+
if priority_ctx is not None:
|
|
327
|
+
return self.normalize_context(priority_ctx)
|
|
328
|
+
|
|
329
|
+
inferred = self.infer_contexts(other_ctx)
|
|
330
|
+
if not inferred:
|
|
331
|
+
return self.default_ctx
|
|
332
|
+
|
|
333
|
+
if not self.are_compatible_contexts(*inferred):
|
|
334
|
+
fams = tuple(ctx.ops.family for ctx in inferred)
|
|
335
|
+
raise ValueError(f"Incompatible inferred contexts: {fams!r}")
|
|
336
|
+
|
|
337
|
+
first = inferred[0]
|
|
338
|
+
ops = type(first.ops)()
|
|
339
|
+
dtype = self._join_dtypes(ops, *(ctx.dtype for ctx in inferred))
|
|
340
|
+
|
|
341
|
+
return self.ctx_from_ops(
|
|
342
|
+
ops=ops,
|
|
343
|
+
dtype=dtype,
|
|
344
|
+
enable_checks=all(ctx.enable_checks for ctx in inferred),
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def _join_dtypes(self, ops: BackendOps, *dtypes: DType | None) -> DType | None:
|
|
348
|
+
clean = [ops.sanitize_dtype(dt) for dt in dtypes if dt is not None]
|
|
349
|
+
if not clean:
|
|
350
|
+
return ops.sanitize_dtype(None)
|
|
351
|
+
|
|
352
|
+
np_ops = NumpyOps()
|
|
353
|
+
joined = np_ops.np.result_type(*clean)
|
|
354
|
+
return ops.sanitize_dtype(joined)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from ..backend import Context, BackendOps
|
|
4
|
+
from .contextual import Contextual, ContextPolicy, DtypePreservePolicy
|
|
5
|
+
from ..backend import BackendFamily
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
ctx_manager = Contextual()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def set_context(
|
|
12
|
+
ctx: Context | BackendFamily | str | None = None,
|
|
13
|
+
dtype: Any = None,
|
|
14
|
+
enable_checks: bool | None = None
|
|
15
|
+
) -> None:
|
|
16
|
+
ctx = ctx_manager.normalize_context(ctx, dtype=dtype, enable_checks=enable_checks)
|
|
17
|
+
ctx_manager.default_ctx = ctx
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_context():
|
|
21
|
+
return ctx_manager.default_ctx
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def register_ops(ops: type[BackendOps]) -> type[BackendOps]:
|
|
25
|
+
return ctx_manager.register_ops(ops)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None:
|
|
29
|
+
ctx_manager.resolution_policy = policy
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_resolution_policy() -> str:
|
|
33
|
+
return ctx_manager.resolution_policy.value
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def set_dtype_resolution_policy(
|
|
37
|
+
policy: DtypePreservePolicy | str | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
ctx_manager.dtype_resolution_policy = policy
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_dtype_resolution_policy() -> str:
|
|
43
|
+
return ctx_manager.dtype_resolution_policy.value
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from ._context import Context
|
|
2
|
+
from ._ops import BackendOps
|
|
3
|
+
from ._family import BackendFamily
|
|
4
|
+
from .jax import JaxOps, jax_pytree_class
|
|
5
|
+
from .numpy import NumpyOps
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Context",
|
|
10
|
+
"BackendFamily",
|
|
11
|
+
"BackendOps",
|
|
12
|
+
"JaxOps",
|
|
13
|
+
"jax_pytree_class",
|
|
14
|
+
"NumpyOps",
|
|
15
|
+
]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from ._ops import BackendOps
|
|
5
|
+
from ..types import DenseArray, SparseArray, DType, ArrayLike
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True, slots=True)
|
|
9
|
+
class Context:
|
|
10
|
+
ops: BackendOps
|
|
11
|
+
dtype: DType | None = None
|
|
12
|
+
enable_checks: bool = True
|
|
13
|
+
|
|
14
|
+
def __post_init__(self):
|
|
15
|
+
if not isinstance(self.ops, BackendOps):
|
|
16
|
+
raise TypeError("ops must be a BackendOps")
|
|
17
|
+
|
|
18
|
+
sanitized = self.ops.sanitize_dtype(self.dtype)
|
|
19
|
+
object.__setattr__(self, "dtype", sanitized)
|
|
20
|
+
|
|
21
|
+
def assert_dense(self, x: Any) -> DenseArray:
|
|
22
|
+
if not self.ops.is_dense(x):
|
|
23
|
+
raise TypeError(f"Expected dense array for {self.ops.family}, got {type(x).__name__}")
|
|
24
|
+
return x
|
|
25
|
+
|
|
26
|
+
def assert_sparse(self, x: Any) -> SparseArray:
|
|
27
|
+
if not self.ops.allow_sparse:
|
|
28
|
+
raise TypeError("Sparse objects are disallowed by this backend.")
|
|
29
|
+
if not self.ops.is_sparse(x):
|
|
30
|
+
raise TypeError(f"Expected sparse array for {self.ops.family}, got {type(x).__name__}")
|
|
31
|
+
return x
|
|
32
|
+
|
|
33
|
+
def asarray(self, x: Any) -> DenseArray:
|
|
34
|
+
return self.ops.asarray(x, dtype=self.dtype)
|
|
35
|
+
|
|
36
|
+
def assparse(self, x: Any) -> SparseArray:
|
|
37
|
+
return self.ops.assparse(x, dtype=self.dtype)
|
|
38
|
+
|
|
39
|
+
def convert(self, x: Any) -> ArrayLike:
|
|
40
|
+
if self.ops.is_dense(x):
|
|
41
|
+
return self.asarray(x)
|
|
42
|
+
elif self.ops.is_sparse(x):
|
|
43
|
+
return self.assparse(x)
|
|
44
|
+
else:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
def __eq__(self, other: Any) -> bool:
|
|
48
|
+
if isinstance(other, Context):
|
|
49
|
+
return self.ops == other.ops and self.enable_checks == other.enable_checks
|
|
50
|
+
return False
|