ipax 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ipax/__init__.py +66 -0
- ipax/_logging.py +244 -0
- ipax/backend/__init__.py +21 -0
- ipax/backend/namespace.py +139 -0
- ipax/backend/operators.py +429 -0
- ipax/backend/sparse/__init__.py +83 -0
- ipax/backend/sparse/_routing.py +146 -0
- ipax/backend/sparse/cupy.py +719 -0
- ipax/backend/sparse/jax.py +39 -0
- ipax/backend/sparse/numpy_scipy.py +529 -0
- ipax/backend/sparse/torch.py +39 -0
- ipax/ipm/__init__.py +21 -0
- ipax/ipm/barrier.py +54 -0
- ipax/ipm/breedveld_ls.py +95 -0
- ipax/ipm/corrections.py +338 -0
- ipax/ipm/driver.py +1404 -0
- ipax/ipm/filter_ls.py +158 -0
- ipax/ipm/hessian.py +218 -0
- ipax/ipm/init.py +178 -0
- ipax/ipm/kkt.py +523 -0
- ipax/ipm/restoration.py +146 -0
- ipax/ipm/step.py +111 -0
- ipax/ipm/termination.py +186 -0
- ipax/linalg/__init__.py +21 -0
- ipax/linalg/dense.py +101 -0
- ipax/linalg/krylov.py +481 -0
- ipax/linalg/regularize.py +68 -0
- ipax/linalg/solver.py +91 -0
- ipax/linalg/sparse.py +126 -0
- ipax/options.py +390 -0
- ipax/problem/__init__.py +22 -0
- ipax/problem/autodiff/__init__.py +60 -0
- ipax/problem/autodiff/jax.py +50 -0
- ipax/problem/autodiff/torch.py +62 -0
- ipax/problem/base.py +123 -0
- ipax/problem/derivatives.py +254 -0
- ipax/problem/finitediff.py +95 -0
- ipax/problem/function.py +326 -0
- ipax/problem/scaling.py +275 -0
- ipax/py.typed +0 -0
- ipax/result.py +218 -0
- ipax/solve.py +387 -0
- ipax/testing/__init__.py +33 -0
- ipax/testing/backends.py +60 -0
- ipax/testing/problems.py +675 -0
- ipax/typing.py +37 -0
- ipax-0.1.1.dist-info/METADATA +324 -0
- ipax-0.1.1.dist-info/RECORD +51 -0
- ipax-0.1.1.dist-info/WHEEL +4 -0
- ipax-0.1.1.dist-info/licenses/LICENSE +201 -0
- ipax-0.1.1.dist-info/licenses/NOTICE +28 -0
ipax/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2026 Niklas Wahl
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""ipax — Array-API interior-point solver for nonlinear constrained optimization.
|
|
16
|
+
|
|
17
|
+
Public API. The package never imports a concrete array library at the top level
|
|
18
|
+
(invariant #1); the backend is inferred from the arrays a :class:`Problem`
|
|
19
|
+
returns.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
from ipax.options import (
|
|
25
|
+
AcceptableStoppingOptions,
|
|
26
|
+
CorrectionsOptions,
|
|
27
|
+
OptimalityConditionOptions,
|
|
28
|
+
Options,
|
|
29
|
+
ScalingOptions,
|
|
30
|
+
)
|
|
31
|
+
from ipax.problem.base import Problem
|
|
32
|
+
from ipax.problem.function import FunctionProblem, LinearProblem, QuadraticProblem
|
|
33
|
+
from ipax.result import (
|
|
34
|
+
DerivativeSources,
|
|
35
|
+
IterationCallback,
|
|
36
|
+
IterationInfo,
|
|
37
|
+
IterationRecord,
|
|
38
|
+
KKTResiduals,
|
|
39
|
+
Result,
|
|
40
|
+
Status,
|
|
41
|
+
WarmStart,
|
|
42
|
+
)
|
|
43
|
+
from ipax.solve import solve
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
"AcceptableStoppingOptions",
|
|
47
|
+
"CorrectionsOptions",
|
|
48
|
+
"DerivativeSources",
|
|
49
|
+
"FunctionProblem",
|
|
50
|
+
"IterationCallback",
|
|
51
|
+
"IterationInfo",
|
|
52
|
+
"IterationRecord",
|
|
53
|
+
"KKTResiduals",
|
|
54
|
+
"LinearProblem",
|
|
55
|
+
"OptimalityConditionOptions",
|
|
56
|
+
"Options",
|
|
57
|
+
"Problem",
|
|
58
|
+
"QuadraticProblem",
|
|
59
|
+
"Result",
|
|
60
|
+
"ScalingOptions",
|
|
61
|
+
"Status",
|
|
62
|
+
"WarmStart",
|
|
63
|
+
"solve",
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
__version__ = "0.1.1"
|
ipax/_logging.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
# Copyright 2026 Niklas Wahl
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Layered diagnostics for the solver.
|
|
16
|
+
|
|
17
|
+
``ipax`` logs through a package logger carrying a :class:`logging.NullHandler`,
|
|
18
|
+
so importing the package never emits output on its own. ``Options.verbose`` opts
|
|
19
|
+
in to a console handler via :func:`configure_verbosity`; applications that
|
|
20
|
+
configure their own handlers on the ``"ipax"`` logger keep full control.
|
|
21
|
+
|
|
22
|
+
The verbosity ladder is expressed as **custom numeric logging levels**, one per
|
|
23
|
+
content tier, so a *single* handler/logger threshold selects what is shown and
|
|
24
|
+
downstream handlers (or :func:`caplog`) still receive every record regardless of
|
|
25
|
+
``verbose``:
|
|
26
|
+
|
|
27
|
+
====== ===================================== ==================
|
|
28
|
+
level content numeric log level
|
|
29
|
+
====== ===================================== ==================
|
|
30
|
+
0 (silent — only warnings/errors) —
|
|
31
|
+
1 result summary ``RESULT`` (25)
|
|
32
|
+
2 + per-iteration table & timing split ``ITERATION`` (22)
|
|
33
|
+
3 + problem structure ``PROBLEM`` (19)
|
|
34
|
+
4 + resolved solver setup ``SOLVER`` (16)
|
|
35
|
+
5 + every sub-option ``OPTIONS`` (13)
|
|
36
|
+
≥6 + debug diagnostics ``logging.DEBUG`` (10)
|
|
37
|
+
====== ===================================== ==================
|
|
38
|
+
|
|
39
|
+
Higher verbosity → lower threshold → more tiers pass. Emitting code guards the
|
|
40
|
+
formatting cost with ``logger.isEnabledFor(LEVEL)``.
|
|
41
|
+
|
|
42
|
+
This module holds no solver state (invariant #5) — the logger is a process-wide
|
|
43
|
+
sink, not mutable algorithm state.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
from __future__ import annotations
|
|
47
|
+
|
|
48
|
+
import dataclasses
|
|
49
|
+
import logging
|
|
50
|
+
from typing import TYPE_CHECKING
|
|
51
|
+
|
|
52
|
+
if TYPE_CHECKING:
|
|
53
|
+
from ipax.options import Options
|
|
54
|
+
from ipax.result import IterationRecord, Result
|
|
55
|
+
|
|
56
|
+
LOGGER_NAME = "ipax"
|
|
57
|
+
|
|
58
|
+
# Content tiers as custom levels, ordered so lower ``verbose`` shows only the
|
|
59
|
+
# headline tiers (a higher level passes a higher threshold). Spaced by 3 to stay
|
|
60
|
+
# clear of the stdlib levels (DEBUG=10, INFO=20, WARNING=30).
|
|
61
|
+
RESULT = 25
|
|
62
|
+
ITERATION = 22
|
|
63
|
+
PROBLEM = 19
|
|
64
|
+
SOLVER = 16
|
|
65
|
+
OPTIONS = 13
|
|
66
|
+
|
|
67
|
+
for _level, _name in (
|
|
68
|
+
(RESULT, "RESULT"),
|
|
69
|
+
(ITERATION, "ITER"),
|
|
70
|
+
(PROBLEM, "PROBLEM"),
|
|
71
|
+
(SOLVER, "SOLVER"),
|
|
72
|
+
(OPTIONS, "OPTIONS"),
|
|
73
|
+
):
|
|
74
|
+
logging.addLevelName(_level, _name)
|
|
75
|
+
|
|
76
|
+
logger = logging.getLogger(LOGGER_NAME)
|
|
77
|
+
logger.addHandler(logging.NullHandler())
|
|
78
|
+
|
|
79
|
+
# Marks the handler this module owns so repeated ``configure_verbosity`` calls
|
|
80
|
+
# (e.g. nested ``solve`` invocations) reuse it instead of stacking duplicates.
|
|
81
|
+
_VERBOSE_HANDLER_ATTR = "_ipax_verbose_handler"
|
|
82
|
+
|
|
83
|
+
_HEADER = (
|
|
84
|
+
f"{'iter':>4} {'objective':>15} {'infeas':>10} {'kkt':>10} "
|
|
85
|
+
f"{'mu':>10} {'alpha_pr':>9} {'alpha_du':>9} {'reg':>9} "
|
|
86
|
+
f"{'prob_s':>9} {'step_s':>9}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def verbosity_threshold(verbose: int) -> int:
|
|
91
|
+
"""Map ``Options.verbose`` to the logger/handler threshold level.
|
|
92
|
+
|
|
93
|
+
``verbose <= 0`` silences ipax's own output (warnings/errors still pass);
|
|
94
|
+
``1..5`` select the content tiers ``RESULT..OPTIONS``; ``>= 6`` drops to
|
|
95
|
+
``DEBUG`` so the scattered diagnostic traces appear.
|
|
96
|
+
"""
|
|
97
|
+
if verbose <= 0:
|
|
98
|
+
return RESULT + 1 # above every content tier; warnings (30) still pass
|
|
99
|
+
if verbose >= 6:
|
|
100
|
+
return logging.DEBUG
|
|
101
|
+
return RESULT - 3 * (verbose - 1) # 1→25, 2→22, 3→19, 4→16, 5→13
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def configure_verbosity(verbose: int) -> None:
|
|
105
|
+
"""Attach (or update) a console handler driven by ``Options.verbose``.
|
|
106
|
+
|
|
107
|
+
``verbose`` 0 leaves logging untouched (silent unless the application has
|
|
108
|
+
configured its own handlers). Higher values lower the threshold so more
|
|
109
|
+
content tiers reach the console. Idempotent: the module owns a single tagged
|
|
110
|
+
handler, so repeated calls only adjust its level rather than duplicating
|
|
111
|
+
output, and the logger level is only ever lowered (never raised) so an
|
|
112
|
+
application's own configuration is respected.
|
|
113
|
+
"""
|
|
114
|
+
if verbose <= 0:
|
|
115
|
+
return
|
|
116
|
+
threshold = verbosity_threshold(verbose)
|
|
117
|
+
for handler in logger.handlers:
|
|
118
|
+
if getattr(handler, _VERBOSE_HANDLER_ATTR, False):
|
|
119
|
+
break
|
|
120
|
+
else:
|
|
121
|
+
handler = logging.StreamHandler()
|
|
122
|
+
handler.setFormatter(logging.Formatter("%(message)s"))
|
|
123
|
+
setattr(handler, _VERBOSE_HANDLER_ATTR, True)
|
|
124
|
+
logger.addHandler(handler)
|
|
125
|
+
handler.setLevel(threshold)
|
|
126
|
+
if logger.level == logging.NOTSET or logger.level > threshold:
|
|
127
|
+
logger.setLevel(threshold)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def format_header() -> str:
|
|
131
|
+
"""Column header for the per-iteration table (IPOPT-style)."""
|
|
132
|
+
return _HEADER
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def format_record(record: IterationRecord) -> str:
|
|
136
|
+
"""One iteration-table row matching :func:`format_header`."""
|
|
137
|
+
return (
|
|
138
|
+
f"{record.iteration:>4d} {record.objective:>15.7e} "
|
|
139
|
+
f"{record.theta:>10.3e} {record.kkt_error:>10.3e} {record.mu:>10.3e} "
|
|
140
|
+
f"{record.alpha_primal:>9.2e} {record.alpha_dual:>9.2e} "
|
|
141
|
+
f"{record.regularization:>9.2e} "
|
|
142
|
+
f"{record.problem_time:>9.2e} {record.step_solve_time:>9.2e}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def format_problem(
|
|
147
|
+
*,
|
|
148
|
+
n_vars: int,
|
|
149
|
+
n_ineq: int,
|
|
150
|
+
n_eq_nonlinear: int,
|
|
151
|
+
n_eq_linear: int,
|
|
152
|
+
n_lower: int,
|
|
153
|
+
n_upper: int,
|
|
154
|
+
) -> str:
|
|
155
|
+
"""Problem structure block (verbosity tier 3)."""
|
|
156
|
+
return (
|
|
157
|
+
"problem structure:\n"
|
|
158
|
+
f" variables = {n_vars}\n"
|
|
159
|
+
f" inequalities = {n_ineq}\n"
|
|
160
|
+
f" equalities = {n_eq_nonlinear} nonlinear + {n_eq_linear} linear\n"
|
|
161
|
+
f" bounded vars = {n_lower} lower, {n_upper} upper"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def format_solver(opts: Options, solver_name: str) -> str:
|
|
166
|
+
"""Resolved solver setup block (verbosity tier 4)."""
|
|
167
|
+
scaling = opts.scaling
|
|
168
|
+
method = getattr(scaling, "method", scaling)
|
|
169
|
+
corrections = getattr(opts.corrections, "method", opts.corrections)
|
|
170
|
+
lines = [
|
|
171
|
+
"solver setup:",
|
|
172
|
+
f" hessian = {opts.hessian}",
|
|
173
|
+
f" linear solver = {opts.linsolve} ({solver_name})",
|
|
174
|
+
f" globalization = {opts.globalization}",
|
|
175
|
+
f" mu schedule = {opts.mu_schedule}",
|
|
176
|
+
f" scaling = {method}",
|
|
177
|
+
f" corrections = {corrections}",
|
|
178
|
+
]
|
|
179
|
+
if opts.linsolve in ("krylov", "auto"):
|
|
180
|
+
lines.append(f" krylov method = {opts.krylov.method}")
|
|
181
|
+
return "\n".join(lines)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def format_options(opts: Options) -> str:
|
|
185
|
+
"""Full option dump including every sub-group (verbosity tier 5)."""
|
|
186
|
+
lines = ["options:"]
|
|
187
|
+
for f in dataclasses.fields(opts):
|
|
188
|
+
value = getattr(opts, f.name)
|
|
189
|
+
if dataclasses.is_dataclass(value) and not isinstance(value, type):
|
|
190
|
+
lines.append(f" {f.name}:")
|
|
191
|
+
for sub in dataclasses.fields(value):
|
|
192
|
+
lines.append(f" {sub.name} = {getattr(value, sub.name)}")
|
|
193
|
+
else:
|
|
194
|
+
lines.append(f" {f.name} = {value}")
|
|
195
|
+
return "\n".join(lines)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def format_result(result: Result) -> str:
|
|
199
|
+
"""Final result summary block (verbosity tier 1)."""
|
|
200
|
+
src = result.derivative_sources
|
|
201
|
+
return (
|
|
202
|
+
f"result: {result.status.value} - {result.message}\n"
|
|
203
|
+
f" objective = {result.objective:.8e}\n"
|
|
204
|
+
f" iterations = {result.n_iter}\n"
|
|
205
|
+
f" kkt error = {result.kkt_error:.3e}\n"
|
|
206
|
+
f" kkt components= dual:{result.dual_infeasibility:.3e} "
|
|
207
|
+
f"primal:{result.primal_infeasibility:.3e} "
|
|
208
|
+
f"compl:{result.complementarity:.3e}\n"
|
|
209
|
+
f" infeasibility = {result.constraint_violation:.3e}\n"
|
|
210
|
+
f" solve time = {result.solve_time:.3e}s\n"
|
|
211
|
+
f" linear solver = {result.linear_solver}\n"
|
|
212
|
+
f" derivatives = grad:{src.gradient} eq_jac:{src.eq_jacobian} "
|
|
213
|
+
f"ineq_jac:{src.ineq_jacobian} hess:{src.hessian}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def format_timing(history: tuple[IterationRecord, ...]) -> str:
|
|
218
|
+
"""Aggregate problem-callback vs inner-solve time split (verbosity tier 2)."""
|
|
219
|
+
problem_total = sum(record.problem_time for record in history)
|
|
220
|
+
step_total = sum(record.step_solve_time for record in history)
|
|
221
|
+
return (
|
|
222
|
+
f"timing: problem-callbacks = {problem_total:.3e}s, "
|
|
223
|
+
f"inner-solve = {step_total:.3e}s"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
__all__ = [
|
|
228
|
+
"ITERATION",
|
|
229
|
+
"LOGGER_NAME",
|
|
230
|
+
"OPTIONS",
|
|
231
|
+
"PROBLEM",
|
|
232
|
+
"RESULT",
|
|
233
|
+
"SOLVER",
|
|
234
|
+
"configure_verbosity",
|
|
235
|
+
"format_header",
|
|
236
|
+
"format_options",
|
|
237
|
+
"format_problem",
|
|
238
|
+
"format_record",
|
|
239
|
+
"format_result",
|
|
240
|
+
"format_solver",
|
|
241
|
+
"format_timing",
|
|
242
|
+
"logger",
|
|
243
|
+
"verbosity_threshold",
|
|
244
|
+
]
|
ipax/backend/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright 2026 Niklas Wahl
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Backend layer: namespace resolution, operators, and sparse adapters."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from ipax.backend.namespace import array_namespace, capabilities
|
|
20
|
+
|
|
21
|
+
__all__ = ["array_namespace", "capabilities"]
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright 2026 Niklas Wahl
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Array-API namespace resolution and capability detection (§5.4).
|
|
16
|
+
|
|
17
|
+
This is the single place the core asks "what backend am I on, and what can it
|
|
18
|
+
do?". It does **not** import any concrete array library — it uses
|
|
19
|
+
``array-api-compat`` to resolve the namespace from the input arrays and probes
|
|
20
|
+
the resolved namespace for optional features.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from typing import TYPE_CHECKING
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from ipax.typing import Array, Namespace
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def array_namespace(*arrays: Array) -> Namespace:
|
|
33
|
+
"""Return the common Array-API namespace for ``arrays``.
|
|
34
|
+
|
|
35
|
+
Thin wrapper over ``array_api_compat.array_namespace`` so the rest of the
|
|
36
|
+
core has a single import site. Raises if the arrays disagree on backend.
|
|
37
|
+
"""
|
|
38
|
+
# Deferred to keep the top-level import free of the compat shim until used.
|
|
39
|
+
from array_api_compat import array_namespace as _ns
|
|
40
|
+
|
|
41
|
+
return _ns(*arrays)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
_ARRAY_API_LINALG_FUNCTIONS = frozenset(
|
|
45
|
+
{
|
|
46
|
+
"cholesky",
|
|
47
|
+
"cross",
|
|
48
|
+
"det",
|
|
49
|
+
"diagonal",
|
|
50
|
+
"eigh",
|
|
51
|
+
"eigvalsh",
|
|
52
|
+
"inv",
|
|
53
|
+
"matmul",
|
|
54
|
+
"matrix_norm",
|
|
55
|
+
"matrix_power",
|
|
56
|
+
"matrix_rank",
|
|
57
|
+
"matrix_transpose",
|
|
58
|
+
"outer",
|
|
59
|
+
"pinv",
|
|
60
|
+
"qr",
|
|
61
|
+
"slogdet",
|
|
62
|
+
"solve",
|
|
63
|
+
"svd",
|
|
64
|
+
"svdvals",
|
|
65
|
+
"tensordot",
|
|
66
|
+
"trace",
|
|
67
|
+
"vecdot",
|
|
68
|
+
"vector_norm",
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
_SPARSE_ADAPTER_BACKENDS = frozenset({"numpy", "torch", "cupy", "jax"})
|
|
73
|
+
_AUTODIFF_BACKENDS = frozenset({"torch", "jax"})
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _namespace_name(xp: Namespace) -> str:
|
|
77
|
+
"""Canonical backend name from a namespace module's ``__name__``.
|
|
78
|
+
|
|
79
|
+
The backend is identified by the *leading* package, not the trailing one:
|
|
80
|
+
JAX resolves to the ``jax.numpy`` module (``__name__ == "jax.numpy"``), so a
|
|
81
|
+
trailing-segment rule would mislabel it as ``"numpy"`` and skip its autodiff
|
|
82
|
+
adapter. We strip an optional ``array_api_compat.`` wrapper prefix and take
|
|
83
|
+
the first component: ``jax.numpy`` → ``jax``, ``array_api_compat.torch`` →
|
|
84
|
+
``torch``, ``array_api_strict`` → ``array_api_strict``.
|
|
85
|
+
"""
|
|
86
|
+
module_name = getattr(xp, "__name__", "")
|
|
87
|
+
prefix = "array_api_compat."
|
|
88
|
+
if module_name.startswith(prefix):
|
|
89
|
+
module_name = module_name[len(prefix) :]
|
|
90
|
+
return module_name.split(".", 1)[0] or "unknown"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass(frozen=True, slots=True)
|
|
94
|
+
class Capabilities:
|
|
95
|
+
"""What the current namespace/device supports (§5.4)."""
|
|
96
|
+
|
|
97
|
+
name: str # "numpy", "torch", "array_api_strict", ...
|
|
98
|
+
has_linalg: bool
|
|
99
|
+
linalg_functions: frozenset[str]
|
|
100
|
+
has_sparse_adapter: bool
|
|
101
|
+
supports_autodiff: bool
|
|
102
|
+
devices: tuple[str, ...]
|
|
103
|
+
default_float: str # "float64" preferred; read from inputs in practice
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def capabilities(xp: Namespace) -> Capabilities:
|
|
107
|
+
"""Probe ``xp`` for optional Array-API features and adapter availability.
|
|
108
|
+
|
|
109
|
+
Records presence of ``xp.linalg`` and which functions exist, whether a
|
|
110
|
+
sparse adapter is registered for this namespace, device info, and autodiff
|
|
111
|
+
support. Missing standard pieces (triangular solve, ``lstsq``) are filled by
|
|
112
|
+
labeled helpers elsewhere in ``backend``/``linalg``.
|
|
113
|
+
"""
|
|
114
|
+
linalg = getattr(xp, "linalg", None)
|
|
115
|
+
linalg_functions = frozenset(
|
|
116
|
+
name
|
|
117
|
+
for name in _ARRAY_API_LINALG_FUNCTIONS
|
|
118
|
+
if linalg is not None and hasattr(linalg, name)
|
|
119
|
+
)
|
|
120
|
+
name = _namespace_name(xp)
|
|
121
|
+
if hasattr(xp, "float64"):
|
|
122
|
+
default_float = "float64"
|
|
123
|
+
elif hasattr(xp, "float32"):
|
|
124
|
+
default_float = "float32"
|
|
125
|
+
else:
|
|
126
|
+
default_float = "unknown"
|
|
127
|
+
|
|
128
|
+
return Capabilities(
|
|
129
|
+
name=name,
|
|
130
|
+
has_linalg=linalg is not None,
|
|
131
|
+
linalg_functions=linalg_functions,
|
|
132
|
+
has_sparse_adapter=name in _SPARSE_ADAPTER_BACKENDS,
|
|
133
|
+
supports_autodiff=name in _AUTODIFF_BACKENDS,
|
|
134
|
+
devices=("cpu",),
|
|
135
|
+
default_float=default_float,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
__all__ = ["Capabilities", "array_namespace", "capabilities"]
|