tnfr 6.0.0__py3-none-any.whl → 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tnfr might be problematic. Click here for more details.
- tnfr/__init__.py +50 -5
- tnfr/__init__.pyi +0 -7
- tnfr/_compat.py +0 -1
- tnfr/_generated_version.py +34 -0
- tnfr/_version.py +44 -2
- tnfr/alias.py +14 -13
- tnfr/alias.pyi +5 -37
- tnfr/cache.py +9 -729
- tnfr/cache.pyi +8 -224
- tnfr/callback_utils.py +16 -31
- tnfr/callback_utils.pyi +3 -29
- tnfr/cli/__init__.py +17 -11
- tnfr/cli/__init__.pyi +0 -21
- tnfr/cli/arguments.py +175 -14
- tnfr/cli/arguments.pyi +5 -11
- tnfr/cli/execution.py +434 -48
- tnfr/cli/execution.pyi +14 -24
- tnfr/cli/utils.py +20 -3
- tnfr/cli/utils.pyi +5 -5
- tnfr/config/__init__.py +2 -1
- tnfr/config/__init__.pyi +2 -0
- tnfr/config/feature_flags.py +83 -0
- tnfr/config/init.py +1 -1
- tnfr/config/operator_names.py +1 -14
- tnfr/config/presets.py +6 -26
- tnfr/constants/__init__.py +10 -13
- tnfr/constants/__init__.pyi +10 -22
- tnfr/constants/aliases.py +31 -0
- tnfr/constants/core.py +4 -3
- tnfr/constants/init.py +1 -1
- tnfr/constants/metric.py +3 -3
- tnfr/dynamics/__init__.py +64 -10
- tnfr/dynamics/__init__.pyi +3 -4
- tnfr/dynamics/adaptation.py +79 -13
- tnfr/dynamics/aliases.py +10 -9
- tnfr/dynamics/coordination.py +77 -35
- tnfr/dynamics/dnfr.py +575 -274
- tnfr/dynamics/dnfr.pyi +1 -10
- tnfr/dynamics/integrators.py +47 -33
- tnfr/dynamics/integrators.pyi +0 -1
- tnfr/dynamics/runtime.py +489 -129
- tnfr/dynamics/sampling.py +2 -0
- tnfr/dynamics/selectors.py +101 -62
- tnfr/execution.py +15 -8
- tnfr/execution.pyi +5 -25
- tnfr/flatten.py +7 -3
- tnfr/flatten.pyi +1 -8
- tnfr/gamma.py +22 -26
- tnfr/gamma.pyi +0 -6
- tnfr/glyph_history.py +37 -26
- tnfr/glyph_history.pyi +1 -19
- tnfr/glyph_runtime.py +16 -0
- tnfr/glyph_runtime.pyi +9 -0
- tnfr/immutable.py +20 -15
- tnfr/immutable.pyi +4 -7
- tnfr/initialization.py +5 -7
- tnfr/initialization.pyi +1 -9
- tnfr/io.py +6 -305
- tnfr/io.pyi +13 -8
- tnfr/mathematics/__init__.py +81 -0
- tnfr/mathematics/backend.py +426 -0
- tnfr/mathematics/dynamics.py +398 -0
- tnfr/mathematics/epi.py +254 -0
- tnfr/mathematics/generators.py +222 -0
- tnfr/mathematics/metrics.py +119 -0
- tnfr/mathematics/operators.py +233 -0
- tnfr/mathematics/operators_factory.py +71 -0
- tnfr/mathematics/projection.py +78 -0
- tnfr/mathematics/runtime.py +173 -0
- tnfr/mathematics/spaces.py +247 -0
- tnfr/mathematics/transforms.py +292 -0
- tnfr/metrics/__init__.py +10 -10
- tnfr/metrics/coherence.py +123 -94
- tnfr/metrics/common.py +22 -13
- tnfr/metrics/common.pyi +42 -11
- tnfr/metrics/core.py +72 -14
- tnfr/metrics/diagnosis.py +48 -57
- tnfr/metrics/diagnosis.pyi +3 -7
- tnfr/metrics/export.py +3 -5
- tnfr/metrics/glyph_timing.py +41 -31
- tnfr/metrics/reporting.py +13 -6
- tnfr/metrics/sense_index.py +884 -114
- tnfr/metrics/trig.py +167 -11
- tnfr/metrics/trig.pyi +1 -0
- tnfr/metrics/trig_cache.py +112 -15
- tnfr/node.py +400 -17
- tnfr/node.pyi +55 -38
- tnfr/observers.py +111 -8
- tnfr/observers.pyi +0 -15
- tnfr/ontosim.py +9 -6
- tnfr/ontosim.pyi +0 -5
- tnfr/operators/__init__.py +529 -42
- tnfr/operators/__init__.pyi +14 -0
- tnfr/operators/definitions.py +350 -18
- tnfr/operators/definitions.pyi +0 -14
- tnfr/operators/grammar.py +760 -0
- tnfr/operators/jitter.py +28 -22
- tnfr/operators/registry.py +7 -12
- tnfr/operators/registry.pyi +0 -2
- tnfr/operators/remesh.py +38 -61
- tnfr/rng.py +17 -300
- tnfr/schemas/__init__.py +8 -0
- tnfr/schemas/grammar.json +94 -0
- tnfr/selector.py +3 -4
- tnfr/selector.pyi +1 -1
- tnfr/sense.py +22 -24
- tnfr/sense.pyi +0 -7
- tnfr/structural.py +504 -21
- tnfr/structural.pyi +41 -18
- tnfr/telemetry/__init__.py +23 -1
- tnfr/telemetry/cache_metrics.py +226 -0
- tnfr/telemetry/nu_f.py +423 -0
- tnfr/telemetry/nu_f.pyi +123 -0
- tnfr/tokens.py +1 -4
- tnfr/tokens.pyi +1 -6
- tnfr/trace.py +20 -53
- tnfr/trace.pyi +9 -37
- tnfr/types.py +244 -15
- tnfr/types.pyi +200 -14
- tnfr/units.py +69 -0
- tnfr/units.pyi +16 -0
- tnfr/utils/__init__.py +107 -48
- tnfr/utils/__init__.pyi +80 -11
- tnfr/utils/cache.py +1705 -65
- tnfr/utils/cache.pyi +370 -58
- tnfr/utils/chunks.py +104 -0
- tnfr/utils/chunks.pyi +21 -0
- tnfr/utils/data.py +95 -5
- tnfr/utils/data.pyi +8 -17
- tnfr/utils/graph.py +2 -4
- tnfr/utils/init.py +31 -7
- tnfr/utils/init.pyi +4 -11
- tnfr/utils/io.py +313 -14
- tnfr/{helpers → utils}/numeric.py +50 -24
- tnfr/utils/numeric.pyi +21 -0
- tnfr/validation/__init__.py +92 -4
- tnfr/validation/__init__.pyi +77 -17
- tnfr/validation/compatibility.py +79 -43
- tnfr/validation/compatibility.pyi +4 -6
- tnfr/validation/grammar.py +55 -133
- tnfr/validation/grammar.pyi +37 -8
- tnfr/validation/graph.py +138 -0
- tnfr/validation/graph.pyi +17 -0
- tnfr/validation/rules.py +161 -74
- tnfr/validation/rules.pyi +55 -18
- tnfr/validation/runtime.py +263 -0
- tnfr/validation/runtime.pyi +31 -0
- tnfr/validation/soft_filters.py +170 -0
- tnfr/validation/soft_filters.pyi +37 -0
- tnfr/validation/spectral.py +159 -0
- tnfr/validation/spectral.pyi +46 -0
- tnfr/validation/syntax.py +28 -139
- tnfr/validation/syntax.pyi +7 -4
- tnfr/validation/window.py +39 -0
- tnfr/validation/window.pyi +1 -0
- tnfr/viz/__init__.py +9 -0
- tnfr/viz/matplotlib.py +246 -0
- {tnfr-6.0.0.dist-info → tnfr-7.0.0.dist-info}/METADATA +63 -19
- tnfr-7.0.0.dist-info/RECORD +185 -0
- {tnfr-6.0.0.dist-info → tnfr-7.0.0.dist-info}/licenses/LICENSE.md +1 -1
- tnfr/constants_glyphs.py +0 -16
- tnfr/constants_glyphs.pyi +0 -12
- tnfr/grammar.py +0 -25
- tnfr/grammar.pyi +0 -13
- tnfr/helpers/__init__.py +0 -151
- tnfr/helpers/__init__.pyi +0 -66
- tnfr/helpers/numeric.pyi +0 -12
- tnfr/presets.py +0 -15
- tnfr/presets.pyi +0 -7
- tnfr/utils/io.pyi +0 -10
- tnfr/utils/validators.py +0 -130
- tnfr/utils/validators.pyi +0 -19
- tnfr-6.0.0.dist-info/RECORD +0 -157
- {tnfr-6.0.0.dist-info → tnfr-7.0.0.dist-info}/WHEEL +0 -0
- {tnfr-6.0.0.dist-info → tnfr-7.0.0.dist-info}/entry_points.txt +0 -0
- {tnfr-6.0.0.dist-info → tnfr-7.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
"""Backend abstraction for TNFR mathematical kernels.
|
|
2
|
+
|
|
3
|
+
This module introduces a unified interface that maps core linear algebra
|
|
4
|
+
operations to concrete numerical libraries. Keeping this layer small and
|
|
5
|
+
canonical guarantees we can switch implementations without diluting the
|
|
6
|
+
structural semantics required by TNFR (coherence, phase, νf, ΔNFR, etc.).
|
|
7
|
+
|
|
8
|
+
The canonical entry point is :func:`get_backend`, which honours three lookup
|
|
9
|
+
mechanisms in order of precedence:
|
|
10
|
+
|
|
11
|
+
1. Explicit ``name`` argument.
|
|
12
|
+
2. ``TNFR_MATH_BACKEND`` environment variable.
|
|
13
|
+
3. ``tnfr.config.get_flags().math_backend``.
|
|
14
|
+
|
|
15
|
+
If none of these provide a value we default to the NumPy backend. Optional
|
|
16
|
+
backends are registered lazily so downstream environments without JAX or
|
|
17
|
+
PyTorch remain functional.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
import os
|
|
24
|
+
from typing import Any, Callable, ClassVar, Iterable, Mapping, MutableMapping, Protocol
|
|
25
|
+
|
|
26
|
+
from ..utils import cached_import, get_logger
|
|
27
|
+
|
|
28
|
+
logger = get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BackendUnavailableError(RuntimeError):
|
|
32
|
+
"""Raised when a registered backend cannot be constructed."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MathematicsBackend(Protocol):
|
|
36
|
+
"""Structural numerical backend interface."""
|
|
37
|
+
|
|
38
|
+
name: str
|
|
39
|
+
supports_autodiff: bool
|
|
40
|
+
|
|
41
|
+
def as_array(self, value: Any, *, dtype: Any | None = None) -> Any:
|
|
42
|
+
"""Convert ``value`` into a backend-native dense array."""
|
|
43
|
+
|
|
44
|
+
def eig(self, matrix: Any) -> tuple[Any, Any]:
|
|
45
|
+
"""Return eigenvalues and eigenvectors for a general matrix."""
|
|
46
|
+
|
|
47
|
+
def eigh(self, matrix: Any) -> tuple[Any, Any]:
|
|
48
|
+
"""Return eigenpairs for a Hermitian/symmetric matrix."""
|
|
49
|
+
|
|
50
|
+
def matrix_exp(self, matrix: Any) -> Any:
|
|
51
|
+
"""Compute the matrix exponential of ``matrix``."""
|
|
52
|
+
|
|
53
|
+
def norm(self, value: Any, *, ord: Any | None = None, axis: Any | None = None) -> Any:
|
|
54
|
+
"""Return the matrix or vector norm according to ``ord``."""
|
|
55
|
+
|
|
56
|
+
def einsum(self, pattern: str, *operands: Any, **kwargs: Any) -> Any:
|
|
57
|
+
"""Evaluate an Einstein summation expression."""
|
|
58
|
+
|
|
59
|
+
def matmul(self, a: Any, b: Any) -> Any:
|
|
60
|
+
"""Matrix multiplication that respects backend broadcasting rules."""
|
|
61
|
+
|
|
62
|
+
def conjugate_transpose(self, matrix: Any) -> Any:
|
|
63
|
+
"""Hermitian conjugate of ``matrix`` († operator)."""
|
|
64
|
+
|
|
65
|
+
def stack(self, arrays: Iterable[Any], *, axis: int = 0) -> Any:
|
|
66
|
+
"""Stack arrays along a new ``axis``."""
|
|
67
|
+
|
|
68
|
+
def to_numpy(self, value: Any) -> Any:
|
|
69
|
+
"""Convert ``value`` to a ``numpy.ndarray`` when possible."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
BackendFactory = Callable[[], MathematicsBackend]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass(slots=True)
|
|
76
|
+
class _NumpyBackend:
|
|
77
|
+
"""NumPy backed implementation."""
|
|
78
|
+
|
|
79
|
+
_np: Any
|
|
80
|
+
_scipy_linalg: Any | None
|
|
81
|
+
|
|
82
|
+
name: ClassVar[str] = "numpy"
|
|
83
|
+
supports_autodiff: ClassVar[bool] = False
|
|
84
|
+
|
|
85
|
+
def as_array(self, value: Any, *, dtype: Any | None = None) -> Any:
|
|
86
|
+
return self._np.asarray(value, dtype=dtype)
|
|
87
|
+
|
|
88
|
+
def eig(self, matrix: Any) -> tuple[Any, Any]:
|
|
89
|
+
return self._np.linalg.eig(matrix)
|
|
90
|
+
|
|
91
|
+
def eigh(self, matrix: Any) -> tuple[Any, Any]:
|
|
92
|
+
return self._np.linalg.eigh(matrix)
|
|
93
|
+
|
|
94
|
+
def matrix_exp(self, matrix: Any) -> Any:
|
|
95
|
+
if self._scipy_linalg is not None:
|
|
96
|
+
return self._scipy_linalg.expm(matrix)
|
|
97
|
+
eigvals, eigvecs = self._np.linalg.eig(matrix)
|
|
98
|
+
inv = self._np.linalg.inv(eigvecs)
|
|
99
|
+
exp_vals = self._np.exp(eigvals)
|
|
100
|
+
return eigvecs @ self._np.diag(exp_vals) @ inv
|
|
101
|
+
|
|
102
|
+
def norm(self, value: Any, *, ord: Any | None = None, axis: Any | None = None) -> Any:
|
|
103
|
+
return self._np.linalg.norm(value, ord=ord, axis=axis)
|
|
104
|
+
|
|
105
|
+
def einsum(self, pattern: str, *operands: Any, **kwargs: Any) -> Any:
|
|
106
|
+
return self._np.einsum(pattern, *operands, **kwargs)
|
|
107
|
+
|
|
108
|
+
def matmul(self, a: Any, b: Any) -> Any:
|
|
109
|
+
return self._np.matmul(a, b)
|
|
110
|
+
|
|
111
|
+
def conjugate_transpose(self, matrix: Any) -> Any:
|
|
112
|
+
return self._np.conjugate(matrix).T
|
|
113
|
+
|
|
114
|
+
def stack(self, arrays: Iterable[Any], *, axis: int = 0) -> Any:
|
|
115
|
+
return self._np.stack(tuple(arrays), axis=axis)
|
|
116
|
+
|
|
117
|
+
def to_numpy(self, value: Any) -> Any:
|
|
118
|
+
return self._np.asarray(value)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@dataclass(slots=True)
|
|
122
|
+
class _JaxBackend:
|
|
123
|
+
"""JAX backed implementation."""
|
|
124
|
+
|
|
125
|
+
_jnp: Any
|
|
126
|
+
_jax_linalg: Any
|
|
127
|
+
_jax: Any
|
|
128
|
+
|
|
129
|
+
name: ClassVar[str] = "jax"
|
|
130
|
+
supports_autodiff: ClassVar[bool] = True
|
|
131
|
+
|
|
132
|
+
def as_array(self, value: Any, *, dtype: Any | None = None) -> Any:
|
|
133
|
+
return self._jnp.asarray(value, dtype=dtype)
|
|
134
|
+
|
|
135
|
+
def eig(self, matrix: Any) -> tuple[Any, Any]:
|
|
136
|
+
return self._jnp.linalg.eig(matrix)
|
|
137
|
+
|
|
138
|
+
def eigh(self, matrix: Any) -> tuple[Any, Any]:
|
|
139
|
+
return self._jnp.linalg.eigh(matrix)
|
|
140
|
+
|
|
141
|
+
def matrix_exp(self, matrix: Any) -> Any:
|
|
142
|
+
return self._jax_linalg.expm(matrix)
|
|
143
|
+
|
|
144
|
+
def norm(self, value: Any, *, ord: Any | None = None, axis: Any | None = None) -> Any:
|
|
145
|
+
return self._jnp.linalg.norm(value, ord=ord, axis=axis)
|
|
146
|
+
|
|
147
|
+
def einsum(self, pattern: str, *operands: Any, **kwargs: Any) -> Any:
|
|
148
|
+
return self._jnp.einsum(pattern, *operands, **kwargs)
|
|
149
|
+
|
|
150
|
+
def matmul(self, a: Any, b: Any) -> Any:
|
|
151
|
+
return self._jnp.matmul(a, b)
|
|
152
|
+
|
|
153
|
+
def conjugate_transpose(self, matrix: Any) -> Any:
|
|
154
|
+
return self._jnp.conjugate(matrix).T
|
|
155
|
+
|
|
156
|
+
def stack(self, arrays: Iterable[Any], *, axis: int = 0) -> Any:
|
|
157
|
+
return self._jnp.stack(tuple(arrays), axis=axis)
|
|
158
|
+
|
|
159
|
+
def to_numpy(self, value: Any) -> Any:
|
|
160
|
+
np_mod = cached_import("numpy")
|
|
161
|
+
if np_mod is None:
|
|
162
|
+
raise BackendUnavailableError("NumPy is required to export JAX arrays")
|
|
163
|
+
return np_mod.asarray(self._jax.device_get(value))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(slots=True)
|
|
167
|
+
class _TorchBackend:
|
|
168
|
+
"""PyTorch backed implementation."""
|
|
169
|
+
|
|
170
|
+
_torch: Any
|
|
171
|
+
_torch_linalg: Any
|
|
172
|
+
|
|
173
|
+
name: ClassVar[str] = "torch"
|
|
174
|
+
supports_autodiff: ClassVar[bool] = True
|
|
175
|
+
|
|
176
|
+
def as_array(self, value: Any, *, dtype: Any | None = None) -> Any:
|
|
177
|
+
tensor = self._torch.as_tensor(value)
|
|
178
|
+
if dtype is None:
|
|
179
|
+
return tensor
|
|
180
|
+
|
|
181
|
+
target_dtype = self._normalise_dtype(dtype)
|
|
182
|
+
if target_dtype is None:
|
|
183
|
+
return tensor.to(dtype=dtype)
|
|
184
|
+
|
|
185
|
+
if tensor.dtype == target_dtype:
|
|
186
|
+
return tensor
|
|
187
|
+
|
|
188
|
+
return tensor.to(dtype=target_dtype)
|
|
189
|
+
|
|
190
|
+
def _normalise_dtype(self, dtype: Any) -> Any | None:
|
|
191
|
+
"""Return a ``torch.dtype`` equivalent for ``dtype`` when available."""
|
|
192
|
+
|
|
193
|
+
if isinstance(dtype, self._torch.dtype):
|
|
194
|
+
return dtype
|
|
195
|
+
|
|
196
|
+
np_mod = cached_import("numpy")
|
|
197
|
+
if np_mod is None:
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
np_dtype = np_mod.dtype(dtype)
|
|
202
|
+
except TypeError:
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
numpy_name = np_dtype.name
|
|
206
|
+
numpy_to_torch = {
|
|
207
|
+
"bool": self._torch.bool,
|
|
208
|
+
"uint8": self._torch.uint8,
|
|
209
|
+
"int8": self._torch.int8,
|
|
210
|
+
"int16": self._torch.int16,
|
|
211
|
+
"int32": self._torch.int32,
|
|
212
|
+
"int64": self._torch.int64,
|
|
213
|
+
"float16": self._torch.float16,
|
|
214
|
+
"float32": self._torch.float32,
|
|
215
|
+
"float64": self._torch.float64,
|
|
216
|
+
"complex64": getattr(self._torch, "complex64", None),
|
|
217
|
+
"complex128": getattr(self._torch, "complex128", None),
|
|
218
|
+
"bfloat16": getattr(self._torch, "bfloat16", None),
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
torch_dtype = numpy_to_torch.get(numpy_name)
|
|
222
|
+
return torch_dtype
|
|
223
|
+
|
|
224
|
+
def eig(self, matrix: Any) -> tuple[Any, Any]:
|
|
225
|
+
eigenvalues, eigenvectors = self._torch.linalg.eig(matrix)
|
|
226
|
+
return eigenvalues, eigenvectors
|
|
227
|
+
|
|
228
|
+
def eigh(self, matrix: Any) -> tuple[Any, Any]:
|
|
229
|
+
eigenvalues, eigenvectors = self._torch.linalg.eigh(matrix)
|
|
230
|
+
return eigenvalues, eigenvectors
|
|
231
|
+
|
|
232
|
+
def matrix_exp(self, matrix: Any) -> Any:
|
|
233
|
+
return self._torch_linalg.matrix_exp(matrix)
|
|
234
|
+
|
|
235
|
+
def norm(self, value: Any, *, ord: Any | None = None, axis: Any | None = None) -> Any:
|
|
236
|
+
if axis is None:
|
|
237
|
+
return self._torch.linalg.norm(value, ord=ord)
|
|
238
|
+
return self._torch.linalg.norm(value, ord=ord, dim=axis)
|
|
239
|
+
|
|
240
|
+
def einsum(self, pattern: str, *operands: Any, **kwargs: Any) -> Any:
|
|
241
|
+
return self._torch.einsum(pattern, *operands, **kwargs)
|
|
242
|
+
|
|
243
|
+
def matmul(self, a: Any, b: Any) -> Any:
|
|
244
|
+
return self._torch.matmul(a, b)
|
|
245
|
+
|
|
246
|
+
def conjugate_transpose(self, matrix: Any) -> Any:
|
|
247
|
+
return matrix.mH if hasattr(matrix, "mH") else matrix.conj().transpose(-2, -1)
|
|
248
|
+
|
|
249
|
+
def stack(self, arrays: Iterable[Any], *, axis: int = 0) -> Any:
|
|
250
|
+
return self._torch.stack(tuple(arrays), dim=axis)
|
|
251
|
+
|
|
252
|
+
def to_numpy(self, value: Any) -> Any:
|
|
253
|
+
np_mod = cached_import("numpy")
|
|
254
|
+
if np_mod is None:
|
|
255
|
+
raise BackendUnavailableError("NumPy is required to export Torch tensors")
|
|
256
|
+
if hasattr(value, "detach"):
|
|
257
|
+
return value.detach().cpu().numpy()
|
|
258
|
+
return np_mod.asarray(value)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _normalise_name(name: str) -> str:
|
|
262
|
+
return name.strip().lower()
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
_BACKEND_FACTORIES: MutableMapping[str, BackendFactory] = {}
|
|
266
|
+
_BACKEND_ALIASES: MutableMapping[str, str] = {}
|
|
267
|
+
_BACKEND_CACHE: MutableMapping[str, MathematicsBackend] = {}
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def ensure_array(
|
|
271
|
+
value: Any,
|
|
272
|
+
*,
|
|
273
|
+
dtype: Any | None = None,
|
|
274
|
+
backend: MathematicsBackend | None = None,
|
|
275
|
+
) -> Any:
|
|
276
|
+
"""Return ``value`` as a backend-native dense array."""
|
|
277
|
+
|
|
278
|
+
resolved = backend or get_backend()
|
|
279
|
+
return resolved.as_array(value, dtype=dtype)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def ensure_numpy(value: Any, *, backend: MathematicsBackend | None = None) -> Any:
|
|
283
|
+
"""Export ``value`` from the backend into :class:`numpy.ndarray`."""
|
|
284
|
+
|
|
285
|
+
resolved = backend or get_backend()
|
|
286
|
+
return resolved.to_numpy(value)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def register_backend(
|
|
290
|
+
name: str,
|
|
291
|
+
factory: BackendFactory,
|
|
292
|
+
*,
|
|
293
|
+
aliases: Iterable[str] | None = None,
|
|
294
|
+
override: bool = False,
|
|
295
|
+
) -> None:
|
|
296
|
+
"""Register a backend factory under ``name``.
|
|
297
|
+
|
|
298
|
+
Parameters
|
|
299
|
+
----------
|
|
300
|
+
name:
|
|
301
|
+
Canonical backend identifier.
|
|
302
|
+
factory:
|
|
303
|
+
Callable that returns a :class:`MathematicsBackend` instance.
|
|
304
|
+
aliases:
|
|
305
|
+
Optional alternative identifiers that will resolve to ``name``.
|
|
306
|
+
override:
|
|
307
|
+
When ``True`` replaces existing registrations.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
key = _normalise_name(name)
|
|
311
|
+
if not override and key in _BACKEND_FACTORIES:
|
|
312
|
+
raise ValueError(f"Backend '{name}' already registered")
|
|
313
|
+
_BACKEND_FACTORIES[key] = factory
|
|
314
|
+
if aliases:
|
|
315
|
+
for alias in aliases:
|
|
316
|
+
alias_key = _normalise_name(alias)
|
|
317
|
+
if not override and alias_key in _BACKEND_ALIASES:
|
|
318
|
+
raise ValueError(f"Backend alias '{alias}' already registered")
|
|
319
|
+
_BACKEND_ALIASES[alias_key] = key
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _resolve_backend_name(name: str | None) -> str:
|
|
323
|
+
if name:
|
|
324
|
+
return _normalise_name(name)
|
|
325
|
+
|
|
326
|
+
env_choice = os.getenv("TNFR_MATH_BACKEND")
|
|
327
|
+
if env_choice:
|
|
328
|
+
return _normalise_name(env_choice)
|
|
329
|
+
|
|
330
|
+
backend_from_flags: str | None = None
|
|
331
|
+
try:
|
|
332
|
+
from ..config import get_flags # Local import avoids circular dependency
|
|
333
|
+
|
|
334
|
+
backend_from_flags = getattr(get_flags(), "math_backend", None)
|
|
335
|
+
except Exception: # pragma: no cover - defensive; config must not break selection
|
|
336
|
+
backend_from_flags = None
|
|
337
|
+
|
|
338
|
+
if backend_from_flags:
|
|
339
|
+
return _normalise_name(backend_from_flags)
|
|
340
|
+
|
|
341
|
+
return "numpy"
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _resolve_factory(name: str) -> BackendFactory:
|
|
345
|
+
canonical = _BACKEND_ALIASES.get(name, name)
|
|
346
|
+
try:
|
|
347
|
+
return _BACKEND_FACTORIES[canonical]
|
|
348
|
+
except KeyError as exc: # pragma: no cover - defensive path
|
|
349
|
+
raise LookupError(f"Unknown mathematics backend: {name}") from exc
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def get_backend(name: str | None = None) -> MathematicsBackend:
|
|
353
|
+
"""Return a backend instance using the configured resolution order."""
|
|
354
|
+
|
|
355
|
+
resolved_name = _resolve_backend_name(name)
|
|
356
|
+
canonical = _BACKEND_ALIASES.get(resolved_name, resolved_name)
|
|
357
|
+
if canonical in _BACKEND_CACHE:
|
|
358
|
+
return _BACKEND_CACHE[canonical]
|
|
359
|
+
|
|
360
|
+
factory = _resolve_factory(canonical)
|
|
361
|
+
try:
|
|
362
|
+
backend = factory()
|
|
363
|
+
except BackendUnavailableError as exc:
|
|
364
|
+
logger.warning("Backend '%s' unavailable: %s", canonical, exc)
|
|
365
|
+
if canonical != "numpy":
|
|
366
|
+
logger.warning("Falling back to NumPy backend")
|
|
367
|
+
return get_backend("numpy")
|
|
368
|
+
raise
|
|
369
|
+
|
|
370
|
+
_BACKEND_CACHE[canonical] = backend
|
|
371
|
+
return backend
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def available_backends() -> Mapping[str, BackendFactory]:
|
|
375
|
+
"""Return the registered backend factories."""
|
|
376
|
+
|
|
377
|
+
return dict(_BACKEND_FACTORIES)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _make_numpy_backend() -> MathematicsBackend:
|
|
381
|
+
np_module = cached_import("numpy")
|
|
382
|
+
if np_module is None:
|
|
383
|
+
raise BackendUnavailableError("NumPy is not installed")
|
|
384
|
+
scipy_linalg = cached_import("scipy.linalg")
|
|
385
|
+
if scipy_linalg is None:
|
|
386
|
+
logger.debug("SciPy not available; falling back to eigen decomposition for expm")
|
|
387
|
+
return _NumpyBackend(np_module, scipy_linalg)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def _make_jax_backend() -> MathematicsBackend:
|
|
391
|
+
jnp_module = cached_import("jax.numpy")
|
|
392
|
+
if jnp_module is None:
|
|
393
|
+
raise BackendUnavailableError("jax.numpy is not available")
|
|
394
|
+
jax_scipy = cached_import("jax.scipy.linalg")
|
|
395
|
+
if jax_scipy is None:
|
|
396
|
+
raise BackendUnavailableError("jax.scipy.linalg is required for matrix_exp")
|
|
397
|
+
jax_module = cached_import("jax")
|
|
398
|
+
if jax_module is None:
|
|
399
|
+
raise BackendUnavailableError("jax core module is required")
|
|
400
|
+
return _JaxBackend(jnp_module, jax_scipy, jax_module)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _make_torch_backend() -> MathematicsBackend:
|
|
404
|
+
torch_module = cached_import("torch")
|
|
405
|
+
if torch_module is None:
|
|
406
|
+
raise BackendUnavailableError("PyTorch is not installed")
|
|
407
|
+
torch_linalg = cached_import("torch.linalg")
|
|
408
|
+
if torch_linalg is None:
|
|
409
|
+
raise BackendUnavailableError("torch.linalg is required for linear algebra operations")
|
|
410
|
+
return _TorchBackend(torch_module, torch_linalg)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
register_backend("numpy", _make_numpy_backend, aliases=("np",))
|
|
414
|
+
register_backend("jax", _make_jax_backend)
|
|
415
|
+
register_backend("torch", _make_torch_backend, aliases=("pytorch",))
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
__all__ = [
|
|
419
|
+
"MathematicsBackend",
|
|
420
|
+
"BackendUnavailableError",
|
|
421
|
+
"register_backend",
|
|
422
|
+
"get_backend",
|
|
423
|
+
"available_backends",
|
|
424
|
+
"ensure_array",
|
|
425
|
+
"ensure_numpy",
|
|
426
|
+
]
|