spacecore 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacecore/__init__.py ADDED
@@ -0,0 +1,46 @@
1
+ __version__ = "0.1.1"
2
+
3
+
4
+ from .backend import Context, BackendOps, JaxOps, NumpyOps, jax_pytree_class
5
+ from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp
6
+ from .space import VectorSpace, HermitianSpace, Space, ProductSpace
7
+ from .types import DenseArray, SparseArray, ArrayLike
8
+
9
+ from ._contextual.manager import (
10
+ set_context, get_context,
11
+ register_ops,
12
+ set_resolution_policy, set_dtype_resolution_policy,
13
+ get_resolution_policy, get_dtype_resolution_policy
14
+ )
15
+
16
+ __all__ = [
17
+ "Context",
18
+
19
+ "BackendOps",
20
+ "JaxOps",
21
+ "jax_pytree_class",
22
+ "NumpyOps",
23
+
24
+ "DenseLinOp",
25
+ "SparseLinOp",
26
+ "BlockDiagonalLinOp",
27
+ "SumToSingleLinOp",
28
+ "StackedLinOp",
29
+
30
+ "VectorSpace",
31
+ "HermitianSpace",
32
+ "ProductSpace",
33
+ "Space",
34
+
35
+ "DenseArray",
36
+ "SparseArray",
37
+ "ArrayLike",
38
+
39
+ "set_context",
40
+ "get_context",
41
+ "register_ops",
42
+ "set_resolution_policy",
43
+ "set_dtype_resolution_policy",
44
+ "get_resolution_policy",
45
+ "get_dtype_resolution_policy",
46
+ ]
@@ -0,0 +1,10 @@
1
+ from .bound import ContextBound as ContextBound
2
+ from .manager import (
3
+ ctx_manager as ctx_manager,
4
+ set_context as set_context,
5
+ register_ops as register_ops,
6
+ set_resolution_policy as set_resolution_policy,
7
+ set_dtype_resolution_policy as set_dtype_resolution_policy,
8
+ get_resolution_policy as get_resolution_policy,
9
+ get_dtype_resolution_policy as get_dtype_resolution_policy,
10
+ )
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from typing import Self
5
+
6
+ from ..backend import Context, BackendOps, BackendFamily
7
+ from ..types import DType
8
+ from .manager import ctx_manager
9
+
10
+ class ContextBound(ABC):
11
+ def __init__(self, ctx: Context | str | None = None):
12
+ ctx = ctx_manager.normalize_context(ctx)
13
+ self._ctx = ctx
14
+
15
+ @property
16
+ def ops(self) -> BackendOps:
17
+ return self.ctx.ops
18
+
19
+ @property
20
+ def dtype(self) -> DType:
21
+ return self.ctx.dtype
22
+
23
+ @property
24
+ def ctx(self) -> Context:
25
+ return self._ctx
26
+
27
+ def _convert(self, new_ctx: Context) -> Self:
28
+ raise NotImplementedError()
29
+
30
+ def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self:
31
+ _, new_ctx = ctx_manager.enforce_convert_policy(self, new_ctx)
32
+ return self._convert(new_ctx)
@@ -0,0 +1,354 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Any, Iterable, Tuple
4
+ from enum import StrEnum, auto
5
+ from warnings import warn
6
+
7
+ from ..types import DType
8
+ from ..backend import Context, NumpyOps, JaxOps, BackendFamily, BackendOps
9
+
10
+
11
+ class ContextPolicy(StrEnum):
12
+ warning = auto()
13
+ error = auto()
14
+ silent = auto()
15
+
16
+ class DtypePreservePolicy(StrEnum):
17
+ keep_native = auto()
18
+ convert = auto()
19
+
20
+
21
+ class ContextError(RuntimeError):
22
+ pass
23
+
24
+
25
+ class ContextInferenceError(ContextError):
26
+ pass
27
+
28
+
29
+ class ContextConflictError(ContextError):
30
+ pass
31
+
32
+
33
+ class UnknownBackendError(ContextError):
34
+ pass
35
+
36
+ class ContextConversionError(ContextError):
37
+ pass
38
+
39
+
40
+ class Contextual:
41
+ """
42
+ Backend resolver.
43
+ """
44
+ _default_ctx: Context
45
+ _available_ops: Dict[str, type[BackendOps]]
46
+ _resolution_policy: ContextPolicy
47
+
48
+ _default_policy: ContextPolicy = ContextPolicy.warning
49
+ _default_dtype_resolution_policy: DtypePreservePolicy = DtypePreservePolicy.keep_native
50
+ _default_dtype: DType | None = None
51
+ _default_enable_checks: bool = False
52
+
53
+ def __init__(self,
54
+ resolution_policy: str | ContextPolicy | None = None,
55
+ dtype_resolution_policy: str | DtypePreservePolicy | None = None
56
+ ) -> None:
57
+ ops = NumpyOps()
58
+ self.default_ctx = Context(
59
+ ops=ops,
60
+ dtype=ops.sanitize_dtype(self._default_dtype),
61
+ enable_checks=self._default_enable_checks
62
+ )
63
+
64
+ self._available_ops = {
65
+ self._backend_key(NumpyOps): NumpyOps,
66
+ self._backend_key(JaxOps): JaxOps,
67
+ }
68
+
69
+ self.resolution_policy = resolution_policy
70
+ self.dtype_resolution_policy = dtype_resolution_policy
71
+
72
+ def normalize_context(self,
73
+ ctx: Context | BackendFamily | str | None = None,
74
+ dtype: Any = None,
75
+ enable_checks: bool | None = None
76
+ ) -> Context:
77
+ if ctx is None:
78
+ if dtype is not None or enable_checks is not None:
79
+ warn(
80
+ 'Provided context is None, dtype and enable_checks parameters are ignored.',
81
+ UserWarning,
82
+ )
83
+ return self.default_ctx
84
+ if isinstance(ctx, Context):
85
+ if dtype is not None or enable_checks is not None:
86
+ warn(
87
+ 'Provided concrete context, dtype and enable_checks parameters are ignored.',
88
+ UserWarning,
89
+ )
90
+ return Context(
91
+ ops=ctx.ops,
92
+ dtype=ctx.ops.sanitize_dtype(ctx.dtype),
93
+ enable_checks=ctx.enable_checks
94
+ )
95
+ if isinstance(ctx, (str, BackendFamily)):
96
+ ctx = self._backend_key(ctx)
97
+ ops = self.get_ops(ctx)
98
+ return self.ctx_from_ops(ops, dtype=dtype, enable_checks=enable_checks)
99
+ else:
100
+ raise TypeError(f'Expected Context, BackendFamily, str, or None, got {type(ctx)}.')
101
+
102
+ def ctx_like(self, base: Context | None, ctx: Context) -> Context:
103
+ if isinstance(base, Context):
104
+ return Context(
105
+ ops=ctx.ops,
106
+ dtype=ctx.ops.sanitize_dtype(base.dtype),
107
+ enable_checks=ctx.enable_checks,
108
+ )
109
+ return self.default_ctx
110
+
111
+ def normalize_context_like(self, base: Context | None, ctx: Context | BackendFamily | str | None = None) -> Context:
112
+ if self.dtype_resolution_policy is DtypePreservePolicy.keep_native and isinstance(base, Context):
113
+ if ctx is None:
114
+ return self.ctx_like(base, self.default_ctx)
115
+ if isinstance(ctx, Context):
116
+ return self.ctx_like(base, ctx)
117
+ if isinstance(ctx, (str, BackendFamily)):
118
+ ctx = self._backend_key(ctx)
119
+ ops = self.get_ops(ctx)
120
+ return self.ctx_from_ops(
121
+ ops=ops,
122
+ dtype=base.dtype,
123
+ enable_checks=self._default_enable_checks,
124
+ )
125
+ else:
126
+ raise TypeError(f'Expected Context, BackendFamily, str, or None, got {type(ctx)}.')
127
+ else:
128
+ return self.normalize_context(ctx)
129
+
130
+ def ctx_from_ops(self, ops: BackendOps, dtype: DType | None = None, enable_checks: bool | None = None) -> Context:
131
+ dtype = ops.sanitize_dtype(dtype)
132
+ if enable_checks is None:
133
+ enable_checks = self._default_enable_checks
134
+ return Context(ops=ops,
135
+ dtype=dtype,
136
+ enable_checks=enable_checks)
137
+
138
+ @property
139
+ def default_ctx(self) -> Context:
140
+ return self._default_ctx
141
+
142
+ @default_ctx.setter
143
+ def default_ctx(self, ctx: Context | BackendFamily | str | None = None) -> None:
144
+ ctx = self.normalize_context(ctx)
145
+ self._default_ctx = ctx
146
+
147
+ @property
148
+ def resolution_policy(self) -> ContextPolicy:
149
+ return self._resolution_policy
150
+
151
+ @resolution_policy.setter
152
+ def resolution_policy(self, policy: str | None = ContextPolicy.warning.value) -> None:
153
+ if policy is None:
154
+ self._resolution_policy = self._default_policy
155
+ return
156
+
157
+ try:
158
+ self._resolution_policy = (
159
+ policy
160
+ if isinstance(policy, ContextPolicy)
161
+ else ContextPolicy(policy)
162
+ )
163
+ except ValueError as e:
164
+ allowed = ", ".join(p.value for p in ContextPolicy)
165
+ raise ValueError(
166
+ f"Unknown resolution_policy={policy!r}. "
167
+ f"Expected one of: {allowed}"
168
+ ) from e
169
+
170
+ @property
171
+ def dtype_resolution_policy(self) -> DtypePreservePolicy:
172
+ return self._dtype_resolution_policy
173
+
174
+ @dtype_resolution_policy.setter
175
+ def dtype_resolution_policy(self, policy: str | None = DtypePreservePolicy.keep_native.value) -> None:
176
+ if policy is None:
177
+ self._dtype_resolution_policy = self._default_dtype_resolution_policy
178
+ return
179
+
180
+ try:
181
+ self._dtype_resolution_policy = (
182
+ policy
183
+ if isinstance(policy, DtypePreservePolicy)
184
+ else DtypePreservePolicy(policy)
185
+ )
186
+ except ValueError as e:
187
+ allowed = ", ".join(p.value for p in DtypePreservePolicy)
188
+ raise ValueError(
189
+ f"Unknown dtype_resolution_policy={policy!r}. "
190
+ f"Expected one of: {allowed}"
191
+ ) from e
192
+
193
+ def get_ops(self, name: str | BackendFamily | BackendOps | type[BackendOps] | Context) -> BackendOps:
194
+ name = self._backend_key(name)
195
+ if name not in self.available_ops:
196
+ allowed = ", ".join(k for k in self.available_ops.keys())
197
+ raise UnknownBackendError(
198
+ f"Unknown backend: {name!r}. "
199
+ f"Expected one of: {allowed}"
200
+ )
201
+ return self.available_ops[name]()
202
+
203
+ @property
204
+ def available_ops(self) -> Dict[str, type[BackendOps]]:
205
+ return self._available_ops
206
+
207
+ def register_ops(self, ops: type[BackendOps]) -> type[BackendOps]:
208
+ if not isinstance(ops, type) or not issubclass(ops, BackendOps):
209
+ raise TypeError(f"Expected type[BackendOps], got {type(ops)!r}")
210
+ else:
211
+ family = self._backend_key(ops)
212
+ if family in self.available_ops.keys():
213
+ raise ContextConflictError(f'BackendOps {family} is already registered.')
214
+ self._available_ops[family] = ops
215
+ return ops
216
+
217
+ def infer_context(self, x: Any, enable_checks: bool | None = None) -> Context | None:
218
+ """
219
+ Infer a context from a single value.
220
+
221
+ Intended precedence:
222
+ 1. objects carrying `.ctx`
223
+ 2. backend-native arrays recognized by registered backends
224
+ """
225
+ if isinstance(x, Context):
226
+ return x
227
+
228
+ ctx = getattr(x, "ctx", None)
229
+ if isinstance(ctx, Context):
230
+ return ctx
231
+
232
+ matched: list[BackendOps] = []
233
+ for name, ops in self.available_ops.items():
234
+ try:
235
+ ops = ops()
236
+ if ops.is_array(x):
237
+ matched.append(ops)
238
+ except Exception:
239
+ # Keep inference conservative.
240
+ continue
241
+
242
+ if not matched:
243
+ return None
244
+ if len(matched) > 1:
245
+ raise ContextInferenceError(
246
+ f"Ambiguous backend inference for object of type {type(x)!r}: {matched!r}."
247
+ )
248
+
249
+ ops = matched[0]
250
+ try:
251
+ dtype = ops.get_dtype(x)
252
+ except Exception:
253
+ dtype = getattr(x, "dtype", self.default_ctx.dtype)
254
+
255
+ return self.ctx_from_ops(ops, dtype, enable_checks)
256
+
257
+ def infer_contexts(self, values: Iterable[Any]) -> Tuple[Context, ...]:
258
+ out: list[Context] = []
259
+ for x in values:
260
+ ctx = self.infer_context(x)
261
+ if ctx is not None:
262
+ out.append(ctx)
263
+ return tuple(out)
264
+
265
+ def are_compatible_contexts(self, *ctxs: Context) -> bool:
266
+ if len(ctxs) < 2:
267
+ return True
268
+ first = ctxs[0]
269
+ return all(ctx.ops.family == first.ops.family for ctx in ctxs[1:])
270
+
271
+ def are_compatible_values(self, *values: Any) -> bool:
272
+ return self.are_compatible_contexts(*self.infer_contexts(values))
273
+
274
+ def are_compatible_ops(self, *ops: BackendOps) -> bool:
275
+ if not ops:
276
+ return True
277
+ first = ops[0]
278
+ return all(op.family == first.family for op in ops)
279
+
280
+ def enforce_convert_policy(self, x: Any, to: Context | BackendFamily | str | None = None) -> Tuple[Any, Context]:
281
+ native_ctx = self.infer_context(x)
282
+ ctx = self.normalize_context_like(native_ctx, to)
283
+ if self.resolution_policy is not ContextPolicy.silent:
284
+ if native_ctx is not None and not self.are_compatible_contexts(native_ctx, ctx):
285
+ if self.resolution_policy is ContextPolicy.warning:
286
+ warn(
287
+ f"Converting from {native_ctx!r} to {ctx!r}.",
288
+ UserWarning,
289
+ )
290
+ else:
291
+ raise ContextConversionError(
292
+ f"Conversion from {native_ctx!r} to {ctx!r} is forbidden by policy {self.resolution_policy.value!r}."
293
+ )
294
+ return x, ctx
295
+
296
+ def _backend_key(self, x: str | BackendFamily | BackendOps | type[BackendOps] | Context) -> str:
297
+ if isinstance(x, Context):
298
+ return self._backend_key(x.ops)
299
+ if isinstance(x, BackendOps):
300
+ return self._backend_key(x.family)
301
+ if isinstance(x, type) and issubclass(x, BackendOps):
302
+ return self._backend_key(x._family)
303
+ if isinstance(x, BackendFamily):
304
+ return x.value.lower()
305
+ if isinstance(x, str):
306
+ return x.lower()
307
+ raise TypeError(f"Unsupported backend key source: {type(x)!r}")
308
+
309
+ def resolve_context_priority(
310
+ self,
311
+ priority_ctx: Context | BackendFamily | str | None = None,
312
+ *other_ctx: object,
313
+ ) -> Context:
314
+ """
315
+ Resolve context with priority:
316
+
317
+ 1. If priority_ctx is not None, return its normalized form.
318
+ 2. Otherwise, infer contexts from other_ctx and return the best inferred one.
319
+ 3. If inference fails, return default_ctx.
320
+
321
+ Policy:
322
+ - explicit context always wins
323
+ - inferred contexts must be family-compatible
324
+ - if inferred dtypes differ, take the most general of them
325
+ """
326
+ if priority_ctx is not None:
327
+ return self.normalize_context(priority_ctx)
328
+
329
+ inferred = self.infer_contexts(other_ctx)
330
+ if not inferred:
331
+ return self.default_ctx
332
+
333
+ if not self.are_compatible_contexts(*inferred):
334
+ fams = tuple(ctx.ops.family for ctx in inferred)
335
+ raise ValueError(f"Incompatible inferred contexts: {fams!r}")
336
+
337
+ first = inferred[0]
338
+ ops = type(first.ops)()
339
+ dtype = self._join_dtypes(ops, *(ctx.dtype for ctx in inferred))
340
+
341
+ return self.ctx_from_ops(
342
+ ops=ops,
343
+ dtype=dtype,
344
+ enable_checks=all(ctx.enable_checks for ctx in inferred),
345
+ )
346
+
347
+ def _join_dtypes(self, ops: BackendOps, *dtypes: DType | None) -> DType | None:
348
+ clean = [ops.sanitize_dtype(dt) for dt in dtypes if dt is not None]
349
+ if not clean:
350
+ return ops.sanitize_dtype(None)
351
+
352
+ np_ops = NumpyOps()
353
+ joined = np_ops.np.result_type(*clean)
354
+ return ops.sanitize_dtype(joined)
@@ -0,0 +1,43 @@
1
+ from typing import Any
2
+
3
+ from ..backend import Context, BackendOps
4
+ from .contextual import Contextual, ContextPolicy, DtypePreservePolicy
5
+ from ..backend import BackendFamily
6
+
7
+
8
+ ctx_manager = Contextual()
9
+
10
+
11
+ def set_context(
12
+ ctx: Context | BackendFamily | str | None = None,
13
+ dtype: Any = None,
14
+ enable_checks: bool | None = None
15
+ ) -> None:
16
+ ctx = ctx_manager.normalize_context(ctx, dtype=dtype, enable_checks=enable_checks)
17
+ ctx_manager.default_ctx = ctx
18
+
19
+
20
+ def get_context():
21
+ return ctx_manager.default_ctx
22
+
23
+
24
+ def register_ops(ops: type[BackendOps]) -> type[BackendOps]:
25
+ return ctx_manager.register_ops(ops)
26
+
27
+
28
+ def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None:
29
+ ctx_manager.resolution_policy = policy
30
+
31
+
32
+ def get_resolution_policy() -> str:
33
+ return ctx_manager.resolution_policy.value
34
+
35
+
36
+ def set_dtype_resolution_policy(
37
+ policy: DtypePreservePolicy | str | None = None,
38
+ ) -> None:
39
+ ctx_manager.dtype_resolution_policy = policy
40
+
41
+
42
+ def get_dtype_resolution_policy() -> str:
43
+ return ctx_manager.dtype_resolution_policy.value
@@ -0,0 +1,15 @@
1
+ from ._context import Context
2
+ from ._ops import BackendOps
3
+ from ._family import BackendFamily
4
+ from .jax import JaxOps, jax_pytree_class
5
+ from .numpy import NumpyOps
6
+
7
+
8
+ __all__ = [
9
+ "Context",
10
+ "BackendFamily",
11
+ "BackendOps",
12
+ "JaxOps",
13
+ "jax_pytree_class",
14
+ "NumpyOps",
15
+ ]
@@ -0,0 +1,50 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ from ._ops import BackendOps
5
+ from ..types import DenseArray, SparseArray, DType, ArrayLike
6
+
7
+
8
+ @dataclass(frozen=True, slots=True)
9
+ class Context:
10
+ ops: BackendOps
11
+ dtype: DType | None = None
12
+ enable_checks: bool = True
13
+
14
+ def __post_init__(self):
15
+ if not isinstance(self.ops, BackendOps):
16
+ raise TypeError("ops must be a BackendOps")
17
+
18
+ sanitized = self.ops.sanitize_dtype(self.dtype)
19
+ object.__setattr__(self, "dtype", sanitized)
20
+
21
+ def assert_dense(self, x: Any) -> DenseArray:
22
+ if not self.ops.is_dense(x):
23
+ raise TypeError(f"Expected dense array for {self.ops.family}, got {type(x).__name__}")
24
+ return x
25
+
26
+ def assert_sparse(self, x: Any) -> SparseArray:
27
+ if not self.ops.allow_sparse:
28
+ raise TypeError("Sparse objects are disallowed by this backend.")
29
+ if not self.ops.is_sparse(x):
30
+ raise TypeError(f"Expected sparse array for {self.ops.family}, got {type(x).__name__}")
31
+ return x
32
+
33
+ def asarray(self, x: Any) -> DenseArray:
34
+ return self.ops.asarray(x, dtype=self.dtype)
35
+
36
+ def assparse(self, x: Any) -> SparseArray:
37
+ return self.ops.assparse(x, dtype=self.dtype)
38
+
39
+ def convert(self, x: Any) -> ArrayLike:
40
+ if self.ops.is_dense(x):
41
+ return self.asarray(x)
42
+ elif self.ops.is_sparse(x):
43
+ return self.assparse(x)
44
+ else:
45
+ raise NotImplementedError
46
+
47
+ def __eq__(self, other: Any) -> bool:
48
+ if isinstance(other, Context):
49
+ return self.ops == other.ops and self.enable_checks == other.enable_checks
50
+ return False
@@ -0,0 +1,6 @@
1
+ from enum import StrEnum, auto
2
+
3
+
4
+ class BackendFamily(StrEnum):
5
+ numpy = auto()
6
+ jax = auto()