ipax 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ipax/__init__.py +66 -0
  2. ipax/_logging.py +244 -0
  3. ipax/backend/__init__.py +21 -0
  4. ipax/backend/namespace.py +139 -0
  5. ipax/backend/operators.py +429 -0
  6. ipax/backend/sparse/__init__.py +83 -0
  7. ipax/backend/sparse/_routing.py +146 -0
  8. ipax/backend/sparse/cupy.py +719 -0
  9. ipax/backend/sparse/jax.py +39 -0
  10. ipax/backend/sparse/numpy_scipy.py +529 -0
  11. ipax/backend/sparse/torch.py +39 -0
  12. ipax/ipm/__init__.py +21 -0
  13. ipax/ipm/barrier.py +54 -0
  14. ipax/ipm/breedveld_ls.py +95 -0
  15. ipax/ipm/corrections.py +338 -0
  16. ipax/ipm/driver.py +1404 -0
  17. ipax/ipm/filter_ls.py +158 -0
  18. ipax/ipm/hessian.py +218 -0
  19. ipax/ipm/init.py +178 -0
  20. ipax/ipm/kkt.py +523 -0
  21. ipax/ipm/restoration.py +146 -0
  22. ipax/ipm/step.py +111 -0
  23. ipax/ipm/termination.py +186 -0
  24. ipax/linalg/__init__.py +21 -0
  25. ipax/linalg/dense.py +101 -0
  26. ipax/linalg/krylov.py +481 -0
  27. ipax/linalg/regularize.py +68 -0
  28. ipax/linalg/solver.py +91 -0
  29. ipax/linalg/sparse.py +126 -0
  30. ipax/options.py +390 -0
  31. ipax/problem/__init__.py +22 -0
  32. ipax/problem/autodiff/__init__.py +60 -0
  33. ipax/problem/autodiff/jax.py +50 -0
  34. ipax/problem/autodiff/torch.py +62 -0
  35. ipax/problem/base.py +123 -0
  36. ipax/problem/derivatives.py +254 -0
  37. ipax/problem/finitediff.py +95 -0
  38. ipax/problem/function.py +326 -0
  39. ipax/problem/scaling.py +275 -0
  40. ipax/py.typed +0 -0
  41. ipax/result.py +218 -0
  42. ipax/solve.py +387 -0
  43. ipax/testing/__init__.py +33 -0
  44. ipax/testing/backends.py +60 -0
  45. ipax/testing/problems.py +675 -0
  46. ipax/typing.py +37 -0
  47. ipax-0.1.1.dist-info/METADATA +324 -0
  48. ipax-0.1.1.dist-info/RECORD +51 -0
  49. ipax-0.1.1.dist-info/WHEEL +4 -0
  50. ipax-0.1.1.dist-info/licenses/LICENSE +201 -0
  51. ipax-0.1.1.dist-info/licenses/NOTICE +28 -0
ipax/__init__.py ADDED
@@ -0,0 +1,66 @@
1
+ # Copyright 2026 Niklas Wahl
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """ipax — Array-API interior-point solver for nonlinear constrained optimization.
16
+
17
+ Public API. The package never imports a concrete array library at the top level
18
+ (invariant #1); the backend is inferred from the arrays a :class:`Problem`
19
+ returns.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from ipax.options import (
25
+ AcceptableStoppingOptions,
26
+ CorrectionsOptions,
27
+ OptimalityConditionOptions,
28
+ Options,
29
+ ScalingOptions,
30
+ )
31
+ from ipax.problem.base import Problem
32
+ from ipax.problem.function import FunctionProblem, LinearProblem, QuadraticProblem
33
+ from ipax.result import (
34
+ DerivativeSources,
35
+ IterationCallback,
36
+ IterationInfo,
37
+ IterationRecord,
38
+ KKTResiduals,
39
+ Result,
40
+ Status,
41
+ WarmStart,
42
+ )
43
+ from ipax.solve import solve
44
+
45
+ __all__ = [
46
+ "AcceptableStoppingOptions",
47
+ "CorrectionsOptions",
48
+ "DerivativeSources",
49
+ "FunctionProblem",
50
+ "IterationCallback",
51
+ "IterationInfo",
52
+ "IterationRecord",
53
+ "KKTResiduals",
54
+ "LinearProblem",
55
+ "OptimalityConditionOptions",
56
+ "Options",
57
+ "Problem",
58
+ "QuadraticProblem",
59
+ "Result",
60
+ "ScalingOptions",
61
+ "Status",
62
+ "WarmStart",
63
+ "solve",
64
+ ]
65
+
66
+ __version__ = "0.1.1"
ipax/_logging.py ADDED
@@ -0,0 +1,244 @@
1
+ # Copyright 2026 Niklas Wahl
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Layered diagnostics for the solver.
16
+
17
+ ``ipax`` logs through a package logger carrying a :class:`logging.NullHandler`,
18
+ so importing the package never emits output on its own. ``Options.verbose`` opts
19
+ in to a console handler via :func:`configure_verbosity`; applications that
20
+ configure their own handlers on the ``"ipax"`` logger keep full control.
21
+
22
+ The verbosity ladder is expressed as **custom numeric logging levels**, one per
23
+ content tier, so a *single* handler/logger threshold selects what is shown and
24
+ downstream handlers (or :func:`caplog`) still receive every record regardless of
25
+ ``verbose``:
26
+
27
+ ====== ===================================== ==================
28
+ level content numeric log level
29
+ ====== ===================================== ==================
30
+ 0 (silent — only warnings/errors) —
31
+ 1 result summary ``RESULT`` (25)
32
+ 2 + per-iteration table & timing split ``ITERATION`` (22)
33
+ 3 + problem structure ``PROBLEM`` (19)
34
+ 4 + resolved solver setup ``SOLVER`` (16)
35
+ 5 + every sub-option ``OPTIONS`` (13)
36
+ ≥6 + debug diagnostics ``logging.DEBUG`` (10)
37
+ ====== ===================================== ==================
38
+
39
+ Higher verbosity → lower threshold → more tiers pass. Emitting code guards the
40
+ formatting cost with ``logger.isEnabledFor(LEVEL)``.
41
+
42
+ This module holds no solver state (invariant #5) — the logger is a process-wide
43
+ sink, not mutable algorithm state.
44
+ """
45
+
46
+ from __future__ import annotations
47
+
48
+ import dataclasses
49
+ import logging
50
+ from typing import TYPE_CHECKING
51
+
52
+ if TYPE_CHECKING:
53
+ from ipax.options import Options
54
+ from ipax.result import IterationRecord, Result
55
+
56
+ LOGGER_NAME = "ipax"
57
+
58
+ # Content tiers as custom levels, ordered so lower ``verbose`` shows only the
59
+ # headline tiers (a higher level passes a higher threshold). Spaced by 3 to stay
60
+ # clear of the stdlib levels (DEBUG=10, INFO=20, WARNING=30).
61
+ RESULT = 25
62
+ ITERATION = 22
63
+ PROBLEM = 19
64
+ SOLVER = 16
65
+ OPTIONS = 13
66
+
67
+ for _level, _name in (
68
+ (RESULT, "RESULT"),
69
+ (ITERATION, "ITER"),
70
+ (PROBLEM, "PROBLEM"),
71
+ (SOLVER, "SOLVER"),
72
+ (OPTIONS, "OPTIONS"),
73
+ ):
74
+ logging.addLevelName(_level, _name)
75
+
76
+ logger = logging.getLogger(LOGGER_NAME)
77
+ logger.addHandler(logging.NullHandler())
78
+
79
+ # Marks the handler this module owns so repeated ``configure_verbosity`` calls
80
+ # (e.g. nested ``solve`` invocations) reuse it instead of stacking duplicates.
81
+ _VERBOSE_HANDLER_ATTR = "_ipax_verbose_handler"
82
+
83
+ _HEADER = (
84
+ f"{'iter':>4} {'objective':>15} {'infeas':>10} {'kkt':>10} "
85
+ f"{'mu':>10} {'alpha_pr':>9} {'alpha_du':>9} {'reg':>9} "
86
+ f"{'prob_s':>9} {'step_s':>9}"
87
+ )
88
+
89
+
90
+ def verbosity_threshold(verbose: int) -> int:
91
+ """Map ``Options.verbose`` to the logger/handler threshold level.
92
+
93
+ ``verbose <= 0`` silences ipax's own output (warnings/errors still pass);
94
+ ``1..5`` select the content tiers ``RESULT..OPTIONS``; ``>= 6`` drops to
95
+ ``DEBUG`` so the scattered diagnostic traces appear.
96
+ """
97
+ if verbose <= 0:
98
+ return RESULT + 1 # above every content tier; warnings (30) still pass
99
+ if verbose >= 6:
100
+ return logging.DEBUG
101
+ return RESULT - 3 * (verbose - 1) # 1→25, 2→22, 3→19, 4→16, 5→13
102
+
103
+
104
+ def configure_verbosity(verbose: int) -> None:
105
+ """Attach (or update) a console handler driven by ``Options.verbose``.
106
+
107
+ ``verbose`` 0 leaves logging untouched (silent unless the application has
108
+ configured its own handlers). Higher values lower the threshold so more
109
+ content tiers reach the console. Idempotent: the module owns a single tagged
110
+ handler, so repeated calls only adjust its level rather than duplicating
111
+ output, and the logger level is only ever lowered (never raised) so an
112
+ application's own configuration is respected.
113
+ """
114
+ if verbose <= 0:
115
+ return
116
+ threshold = verbosity_threshold(verbose)
117
+ for handler in logger.handlers:
118
+ if getattr(handler, _VERBOSE_HANDLER_ATTR, False):
119
+ break
120
+ else:
121
+ handler = logging.StreamHandler()
122
+ handler.setFormatter(logging.Formatter("%(message)s"))
123
+ setattr(handler, _VERBOSE_HANDLER_ATTR, True)
124
+ logger.addHandler(handler)
125
+ handler.setLevel(threshold)
126
+ if logger.level == logging.NOTSET or logger.level > threshold:
127
+ logger.setLevel(threshold)
128
+
129
+
130
+ def format_header() -> str:
131
+ """Column header for the per-iteration table (IPOPT-style)."""
132
+ return _HEADER
133
+
134
+
135
+ def format_record(record: IterationRecord) -> str:
136
+ """One iteration-table row matching :func:`format_header`."""
137
+ return (
138
+ f"{record.iteration:>4d} {record.objective:>15.7e} "
139
+ f"{record.theta:>10.3e} {record.kkt_error:>10.3e} {record.mu:>10.3e} "
140
+ f"{record.alpha_primal:>9.2e} {record.alpha_dual:>9.2e} "
141
+ f"{record.regularization:>9.2e} "
142
+ f"{record.problem_time:>9.2e} {record.step_solve_time:>9.2e}"
143
+ )
144
+
145
+
146
+ def format_problem(
147
+ *,
148
+ n_vars: int,
149
+ n_ineq: int,
150
+ n_eq_nonlinear: int,
151
+ n_eq_linear: int,
152
+ n_lower: int,
153
+ n_upper: int,
154
+ ) -> str:
155
+ """Problem structure block (verbosity tier 3)."""
156
+ return (
157
+ "problem structure:\n"
158
+ f" variables = {n_vars}\n"
159
+ f" inequalities = {n_ineq}\n"
160
+ f" equalities = {n_eq_nonlinear} nonlinear + {n_eq_linear} linear\n"
161
+ f" bounded vars = {n_lower} lower, {n_upper} upper"
162
+ )
163
+
164
+
165
+ def format_solver(opts: Options, solver_name: str) -> str:
166
+ """Resolved solver setup block (verbosity tier 4)."""
167
+ scaling = opts.scaling
168
+ method = getattr(scaling, "method", scaling)
169
+ corrections = getattr(opts.corrections, "method", opts.corrections)
170
+ lines = [
171
+ "solver setup:",
172
+ f" hessian = {opts.hessian}",
173
+ f" linear solver = {opts.linsolve} ({solver_name})",
174
+ f" globalization = {opts.globalization}",
175
+ f" mu schedule = {opts.mu_schedule}",
176
+ f" scaling = {method}",
177
+ f" corrections = {corrections}",
178
+ ]
179
+ if opts.linsolve in ("krylov", "auto"):
180
+ lines.append(f" krylov method = {opts.krylov.method}")
181
+ return "\n".join(lines)
182
+
183
+
184
+ def format_options(opts: Options) -> str:
185
+ """Full option dump including every sub-group (verbosity tier 5)."""
186
+ lines = ["options:"]
187
+ for f in dataclasses.fields(opts):
188
+ value = getattr(opts, f.name)
189
+ if dataclasses.is_dataclass(value) and not isinstance(value, type):
190
+ lines.append(f" {f.name}:")
191
+ for sub in dataclasses.fields(value):
192
+ lines.append(f" {sub.name} = {getattr(value, sub.name)}")
193
+ else:
194
+ lines.append(f" {f.name} = {value}")
195
+ return "\n".join(lines)
196
+
197
+
198
+ def format_result(result: Result) -> str:
199
+ """Final result summary block (verbosity tier 1)."""
200
+ src = result.derivative_sources
201
+ return (
202
+ f"result: {result.status.value} - {result.message}\n"
203
+ f" objective = {result.objective:.8e}\n"
204
+ f" iterations = {result.n_iter}\n"
205
+ f" kkt error = {result.kkt_error:.3e}\n"
206
+ f" kkt components= dual:{result.dual_infeasibility:.3e} "
207
+ f"primal:{result.primal_infeasibility:.3e} "
208
+ f"compl:{result.complementarity:.3e}\n"
209
+ f" infeasibility = {result.constraint_violation:.3e}\n"
210
+ f" solve time = {result.solve_time:.3e}s\n"
211
+ f" linear solver = {result.linear_solver}\n"
212
+ f" derivatives = grad:{src.gradient} eq_jac:{src.eq_jacobian} "
213
+ f"ineq_jac:{src.ineq_jacobian} hess:{src.hessian}"
214
+ )
215
+
216
+
217
+ def format_timing(history: tuple[IterationRecord, ...]) -> str:
218
+ """Aggregate problem-callback vs inner-solve time split (verbosity tier 2)."""
219
+ problem_total = sum(record.problem_time for record in history)
220
+ step_total = sum(record.step_solve_time for record in history)
221
+ return (
222
+ f"timing: problem-callbacks = {problem_total:.3e}s, "
223
+ f"inner-solve = {step_total:.3e}s"
224
+ )
225
+
226
+
227
+ __all__ = [
228
+ "ITERATION",
229
+ "LOGGER_NAME",
230
+ "OPTIONS",
231
+ "PROBLEM",
232
+ "RESULT",
233
+ "SOLVER",
234
+ "configure_verbosity",
235
+ "format_header",
236
+ "format_options",
237
+ "format_problem",
238
+ "format_record",
239
+ "format_result",
240
+ "format_solver",
241
+ "format_timing",
242
+ "logger",
243
+ "verbosity_threshold",
244
+ ]
@@ -0,0 +1,21 @@
1
+ # Copyright 2026 Niklas Wahl
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Backend layer: namespace resolution, operators, and sparse adapters."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from ipax.backend.namespace import array_namespace, capabilities
20
+
21
+ __all__ = ["array_namespace", "capabilities"]
@@ -0,0 +1,139 @@
1
+ # Copyright 2026 Niklas Wahl
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Array-API namespace resolution and capability detection (§5.4).
16
+
17
+ This is the single place the core asks "what backend am I on, and what can it
18
+ do?". It does **not** import any concrete array library — it uses
19
+ ``array-api-compat`` to resolve the namespace from the input arrays and probes
20
+ the resolved namespace for optional features.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from dataclasses import dataclass
26
+ from typing import TYPE_CHECKING
27
+
28
+ if TYPE_CHECKING:
29
+ from ipax.typing import Array, Namespace
30
+
31
+
32
+ def array_namespace(*arrays: Array) -> Namespace:
33
+ """Return the common Array-API namespace for ``arrays``.
34
+
35
+ Thin wrapper over ``array_api_compat.array_namespace`` so the rest of the
36
+ core has a single import site. Raises if the arrays disagree on backend.
37
+ """
38
+ # Deferred to keep the top-level import free of the compat shim until used.
39
+ from array_api_compat import array_namespace as _ns
40
+
41
+ return _ns(*arrays)
42
+
43
+
44
+ _ARRAY_API_LINALG_FUNCTIONS = frozenset(
45
+ {
46
+ "cholesky",
47
+ "cross",
48
+ "det",
49
+ "diagonal",
50
+ "eigh",
51
+ "eigvalsh",
52
+ "inv",
53
+ "matmul",
54
+ "matrix_norm",
55
+ "matrix_power",
56
+ "matrix_rank",
57
+ "matrix_transpose",
58
+ "outer",
59
+ "pinv",
60
+ "qr",
61
+ "slogdet",
62
+ "solve",
63
+ "svd",
64
+ "svdvals",
65
+ "tensordot",
66
+ "trace",
67
+ "vecdot",
68
+ "vector_norm",
69
+ }
70
+ )
71
+
72
+ _SPARSE_ADAPTER_BACKENDS = frozenset({"numpy", "torch", "cupy", "jax"})
73
+ _AUTODIFF_BACKENDS = frozenset({"torch", "jax"})
74
+
75
+
76
+ def _namespace_name(xp: Namespace) -> str:
77
+ """Canonical backend name from a namespace module's ``__name__``.
78
+
79
+ The backend is identified by the *leading* package, not the trailing one:
80
+ JAX resolves to the ``jax.numpy`` module (``__name__ == "jax.numpy"``), so a
81
+ trailing-segment rule would mislabel it as ``"numpy"`` and skip its autodiff
82
+ adapter. We strip an optional ``array_api_compat.`` wrapper prefix and take
83
+ the first component: ``jax.numpy`` → ``jax``, ``array_api_compat.torch`` →
84
+ ``torch``, ``array_api_strict`` → ``array_api_strict``.
85
+ """
86
+ module_name = getattr(xp, "__name__", "")
87
+ prefix = "array_api_compat."
88
+ if module_name.startswith(prefix):
89
+ module_name = module_name[len(prefix) :]
90
+ return module_name.split(".", 1)[0] or "unknown"
91
+
92
+
93
+ @dataclass(frozen=True, slots=True)
94
+ class Capabilities:
95
+ """What the current namespace/device supports (§5.4)."""
96
+
97
+ name: str # "numpy", "torch", "array_api_strict", ...
98
+ has_linalg: bool
99
+ linalg_functions: frozenset[str]
100
+ has_sparse_adapter: bool
101
+ supports_autodiff: bool
102
+ devices: tuple[str, ...]
103
+ default_float: str # "float64" preferred; read from inputs in practice
104
+
105
+
106
+ def capabilities(xp: Namespace) -> Capabilities:
107
+ """Probe ``xp`` for optional Array-API features and adapter availability.
108
+
109
+ Records presence of ``xp.linalg`` and which functions exist, whether a
110
+ sparse adapter is registered for this namespace, device info, and autodiff
111
+ support. Missing standard pieces (triangular solve, ``lstsq``) are filled by
112
+ labeled helpers elsewhere in ``backend``/``linalg``.
113
+ """
114
+ linalg = getattr(xp, "linalg", None)
115
+ linalg_functions = frozenset(
116
+ name
117
+ for name in _ARRAY_API_LINALG_FUNCTIONS
118
+ if linalg is not None and hasattr(linalg, name)
119
+ )
120
+ name = _namespace_name(xp)
121
+ if hasattr(xp, "float64"):
122
+ default_float = "float64"
123
+ elif hasattr(xp, "float32"):
124
+ default_float = "float32"
125
+ else:
126
+ default_float = "unknown"
127
+
128
+ return Capabilities(
129
+ name=name,
130
+ has_linalg=linalg is not None,
131
+ linalg_functions=linalg_functions,
132
+ has_sparse_adapter=name in _SPARSE_ADAPTER_BACKENDS,
133
+ supports_autodiff=name in _AUTODIFF_BACKENDS,
134
+ devices=("cpu",),
135
+ default_float=default_float,
136
+ )
137
+
138
+
139
+ __all__ = ["Capabilities", "array_namespace", "capabilities"]