sip-python 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 João Sousa Pinto
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.4
2
+ Name: sip_python
3
+ Version: 0.0.1
4
+ Summary: Python bindings for the SIP solver.
5
+ Author-email: João Sousa-Pinto <joaospinto@gmail.com>
6
+ Project-URL: Homepage, https://github.com/joaospinto/sip_python
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Operating System :: OS Independent
9
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: numpy
14
+ Requires-Dist: scipy
15
+ Dynamic: license-file
16
+
17
+ # sip_python
18
+ [![pip](https://github.com/joaospinto/sip_python/actions/workflows/pip.yml/badge.svg)](https://github.com/joaospinto/sip_python/actions?query=workflow%3Apip)
19
+ [![wheels](https://github.com/joaospinto/sip_python/actions/workflows/wheels.yml/badge.svg)](https://github.com/joaospinto/sip_python/actions?query=workflow%3Awheels)
20
+
21
+ Python bindings for the [SIP](https://github.com/joaospinto/sip) solver.
@@ -0,0 +1,5 @@
1
+ # sip_python
2
+ [![pip](https://github.com/joaospinto/sip_python/actions/workflows/pip.yml/badge.svg)](https://github.com/joaospinto/sip_python/actions?query=workflow%3Apip)
3
+ [![wheels](https://github.com/joaospinto/sip_python/actions/workflows/wheels.yml/badge.svg)](https://github.com/joaospinto/sip_python/actions?query=workflow%3Awheels)
4
+
5
+ Python bindings for the [SIP](https://github.com/joaospinto/sip) solver.
@@ -0,0 +1,23 @@
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sip_python"
7
+ version = "0.0.1"
8
+ description = "Python bindings for the SIP solver."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ authors = [{ name = "João Sousa-Pinto", email = "joaospinto@gmail.com" }]
12
+ classifiers = [
13
+ "Programming Language :: Python :: 3",
14
+ "Operating System :: OS Independent",
15
+ "Topic :: Scientific/Engineering :: Mathematics",
16
+ ]
17
+ dependencies = [
18
+ "numpy",
19
+ "scipy",
20
+ ]
21
+
22
+ [project.urls]
23
+ Homepage = "https://github.com/joaospinto/sip_python"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,163 @@
1
+ import contextlib
2
+ import os
3
+ import platform
4
+ import re
5
+ import shutil
6
+ import sys
7
+ from pathlib import Path
8
+ from collections.abc import Generator
9
+
10
+ import setuptools
11
+ from setuptools.command import build_ext
12
+
13
+ # free-threaded build option, requires Python 3.13+.
14
+ # Source: https://docs.python.org/3/howto/free-threading-python.html#identifying-free-threaded-python
15
+ free_threaded = "experimental free-threading build" in sys.version
16
+ # SABI-related options. Requires that each Python interpreter
17
+ # (hermetic or not) participating is of the same major-minor version.
18
+ # Cannot be used together with free-threading.
19
+ py_limited_api = sys.version_info >= (3, 12) and not free_threaded
20
+ options = {"bdist_wheel": {"py_limited_api": "cp312"}} if py_limited_api else {}
21
+
22
+
23
+ def is_cibuildwheel() -> bool:
24
+ return os.getenv("CIBUILDWHEEL") is not None
25
+
26
+
27
+ @contextlib.contextmanager
28
+ def _maybe_patch_toolchains() -> Generator[None, None, None]:
29
+ """
30
+ Patch rules_python toolchains to ignore root user error
31
+ when run in a Docker container on Linux in cibuildwheel.
32
+ """
33
+
34
+ def fmt_toolchain_args(matchobj):
35
+ suffix = "ignore_root_user_error = True"
36
+ callargs = matchobj.group(1)
37
+ # toolchain def is broken over multiple lines
38
+ if callargs.endswith("\n"):
39
+ callargs = callargs + " " + suffix + ",\n"
40
+ # toolchain def is on one line.
41
+ else:
42
+ callargs = callargs + ", " + suffix
43
+ return "python.toolchain(" + callargs + ")"
44
+
45
+ CIBW_LINUX = is_cibuildwheel() and platform.system() == "Linux"
46
+ module_bazel = Path("MODULE.bazel")
47
+ content: str = module_bazel.read_text()
48
+ try:
49
+ if CIBW_LINUX:
50
+ module_bazel.write_text(
51
+ re.sub(
52
+ r"python.toolchain\(([\w\"\s,.=]*)\)",
53
+ fmt_toolchain_args,
54
+ content,
55
+ )
56
+ )
57
+ yield
58
+ finally:
59
+ if CIBW_LINUX:
60
+ module_bazel.write_text(content)
61
+
62
+
63
+ class BazelExtension(setuptools.Extension):
64
+ """A C/C++ extension that is defined as a Bazel BUILD target."""
65
+
66
+ def __init__(self, name: str, bazel_target: str, **kwargs):
67
+ super().__init__(name=name, sources=[], **kwargs)
68
+
69
+ self.free_threaded = free_threaded
70
+ self.bazel_target = bazel_target
71
+ stripped_target = bazel_target.split("//")[-1]
72
+ self.relpath, self.target_name = stripped_target.split(":")
73
+
74
+
75
+ class BuildBazelExtension(build_ext.build_ext):
76
+ """A command that runs Bazel to build a C/C++ extension."""
77
+
78
+ def run(self):
79
+ for ext in self.extensions:
80
+ self.bazel_build(ext)
81
+ # explicitly call `bazel shutdown` for graceful exit
82
+ self.spawn(["bazel", "shutdown"])
83
+
84
+ def copy_extensions_to_source(self):
85
+ """
86
+ Copy generated extensions into the source tree.
87
+ This is done in the ``bazel_build`` method, so it's not necessary to
88
+ do again in the `build_ext` base class.
89
+ """
90
+ pass
91
+
92
+ def bazel_build(self, ext: BazelExtension) -> None:
93
+ """Runs the bazel build to create a nanobind extension."""
94
+ temp_path = Path(self.build_temp)
95
+
96
+ # Specifying only MAJOR.MINOR makes rules_python do an internal
97
+ # lookup selecting the newest patch version.
98
+ python_version = "{0}.{1}".format(*sys.version_info[:2])
99
+
100
+ bazel_argv = [
101
+ "bazel",
102
+ "run",
103
+ ext.bazel_target,
104
+ f"--symlink_prefix={temp_path / 'bazel-'}",
105
+ f"--compilation_mode={'dbg' if self.debug else 'opt'}",
106
+ f"--target_python_version={python_version}",
107
+ ]
108
+
109
+ if ext.py_limited_api:
110
+ bazel_argv += ["--py_limited_api=cp312"]
111
+ if ext.free_threaded:
112
+ bazel_argv += ["--free_threaded=yes"]
113
+
114
+ with _maybe_patch_toolchains():
115
+ self.spawn(bazel_argv)
116
+
117
+ if platform.system() == "Windows":
118
+ suffix = ".pyd"
119
+ else:
120
+ suffix = ".abi3.so" if ext.py_limited_api else ".so"
121
+
122
+ # copy the Bazel build artifacts into setuptools' libdir,
123
+ # from where the wheel is built.
124
+ srcdir = temp_path / "bazel-bin" / "src"
125
+ libdir = Path(self.build_lib) / "sip_python"
126
+ for root, dirs, files in os.walk(srcdir, topdown=True):
127
+ # exclude runfiles directories and children.
128
+ dirs[:] = [d for d in dirs if "runfiles" not in d]
129
+
130
+ for f in files:
131
+ fp = Path(f)
132
+ should_copy = False
133
+ # we do not want the bare .so file included
134
+ # when building for ABI3, so we require a
135
+ # full and exact match on the file extension.
136
+ if "".join(fp.suffixes) == suffix:
137
+ should_copy = True
138
+ elif fp.suffix == ".pyi":
139
+ should_copy = True
140
+ elif Path(root) == srcdir and f == "py.typed":
141
+ # copy py.typed, but only at the package root.
142
+ should_copy = True
143
+
144
+ if should_copy:
145
+ dstdir = libdir / os.path.relpath(root, srcdir)
146
+ if not os.path.exists(dstdir):
147
+ os.mkdir(dstdir)
148
+ shutil.copyfile(root / fp, dstdir / fp)
149
+
150
+
151
+ setuptools.setup(
152
+ cmdclass=dict(build_ext=BuildBazelExtension),
153
+ package_data={"sip_python": ["py.typed", "*.pyi", "**/*.pyi"]},
154
+ ext_modules=[
155
+ BazelExtension(
156
+ name="sip_python.sip_python_ext",
157
+ bazel_target="//src:sip_python_ext_stubgen",
158
+ free_threaded=free_threaded,
159
+ py_limited_api=py_limited_api,
160
+ )
161
+ ],
162
+ options=options,
163
+ )
@@ -0,0 +1,2 @@
1
+ from .sip_python_ext import *
2
+ from .helpers import get_kkt_perm_inv, get_kkt_and_L_nnzs
@@ -0,0 +1,89 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+ from scipy import sparse as spa
5
+
6
+ from .sip_python_ext import getLnnz
7
+
8
+
9
+ _cvxopt_available = False
10
+ try:
11
+ from cvxopt import amd, spmatrix
12
+
13
+ _cvxopt_available = True
14
+ except ImportError:
15
+ from scipy.sparse.csgraph import reverse_cuthill_mckee
16
+
17
+
18
+ def _get_K(P, A, G):
19
+ # K = [ P + r1 I_x A.T G.T ]
20
+ # [ A -r2 * I_y 0 ]
21
+ # [ G 0 -r3 I_z ]
22
+
23
+ if isinstance(P, np.ndarray):
24
+ P = spa.csc_matrix(P)
25
+
26
+ if isinstance(A, np.ndarray):
27
+ A = spa.csr_matrix(A)
28
+
29
+ if isinstance(G, np.ndarray):
30
+ G = spa.csr_matrix(G)
31
+
32
+ x_dim = P.shape[0]
33
+ s_dim = G.shape[0]
34
+ y_dim = A.shape[0]
35
+
36
+ mod_P = spa.csc_matrix.copy(P)
37
+ mod_P.data[:] = 1.0
38
+
39
+ Z = spa.csc_matrix((y_dim, s_dim))
40
+
41
+ K = spa.block_array(
42
+ blocks=[
43
+ [mod_P + spa.eye(x_dim), A.T, G.T],
44
+ [A, -spa.eye(y_dim), Z],
45
+ [G, Z.T, -spa.eye(s_dim)],
46
+ ],
47
+ format="coo",
48
+ )
49
+
50
+ return K
51
+
52
+
53
+ def _get_kkt_perm(P, A, G, verbose):
54
+ K = _get_K(P=P, A=A, G=G)
55
+
56
+ if _cvxopt_available:
57
+ K_cvxopt = spmatrix(
58
+ I=K.row,
59
+ J=K.col,
60
+ V=K.data,
61
+ )
62
+ return np.array(list(amd.order(K_cvxopt)))
63
+ if verbose:
64
+ warnings.warn(
65
+ "cvxopt not installed; using reverse Cuthill-McKee (RCM) "
66
+ "instead of approximate minimum degree (AMD)."
67
+ )
68
+ return reverse_cuthill_mckee(spa.csc_matrix(K))
69
+
70
+
71
+ def get_kkt_perm_inv(P, A, G, verbose=True):
72
+ perm = _get_kkt_perm(P=P, A=A, G=G, verbose=verbose)
73
+
74
+ perm_inv = np.empty_like(perm)
75
+ perm_inv[perm] = np.arange(perm_inv.shape[0])
76
+
77
+ return perm_inv
78
+
79
+
80
+ def get_kkt_and_L_nnzs(P, A, G, perm_inv):
81
+ K = _get_K(P=P, A=A, G=G)
82
+
83
+ permuted_K = spa.coo_matrix.copy(K)
84
+ permuted_K.row = perm_inv[permuted_K.row]
85
+ permuted_K.col = perm_inv[permuted_K.col]
86
+
87
+ kkt_L_nnz = getLnnz(spa.triu(permuted_K))
88
+
89
+ return K.nnz, kkt_L_nnz
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.4
2
+ Name: sip_python
3
+ Version: 0.0.1
4
+ Summary: Python bindings for the SIP solver.
5
+ Author-email: João Sousa-Pinto <joaospinto@gmail.com>
6
+ Project-URL: Homepage, https://github.com/joaospinto/sip_python
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Operating System :: OS Independent
9
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: numpy
14
+ Requires-Dist: scipy
15
+ Dynamic: license-file
16
+
17
+ # sip_python
18
+ [![pip](https://github.com/joaospinto/sip_python/actions/workflows/pip.yml/badge.svg)](https://github.com/joaospinto/sip_python/actions?query=workflow%3Apip)
19
+ [![wheels](https://github.com/joaospinto/sip_python/actions/workflows/wheels.yml/badge.svg)](https://github.com/joaospinto/sip_python/actions?query=workflow%3Awheels)
20
+
21
+ Python bindings for the [SIP](https://github.com/joaospinto/sip) solver.
@@ -0,0 +1,13 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ src/sip_python/__init__.py
6
+ src/sip_python/helpers.py
7
+ src/sip_python.egg-info/PKG-INFO
8
+ src/sip_python.egg-info/SOURCES.txt
9
+ src/sip_python.egg-info/dependency_links.txt
10
+ src/sip_python.egg-info/requires.txt
11
+ src/sip_python.egg-info/top_level.txt
12
+ tests/test_simple_nlp.py
13
+ tests/test_simple_qp.py
@@ -0,0 +1,2 @@
1
+ numpy
2
+ scipy
@@ -0,0 +1 @@
1
+ sip_python
@@ -0,0 +1,148 @@
1
+ from sip_python import (
2
+ get_kkt_and_L_nnzs,
3
+ get_kkt_perm_inv,
4
+ ModelCallbackInput,
5
+ ModelCallbackOutput,
6
+ ProblemDimensions,
7
+ QDLDLSettings,
8
+ Settings,
9
+ Solver,
10
+ Status,
11
+ Variables,
12
+ )
13
+
14
+ import pytest
15
+
16
+ import jax
17
+ from jax import numpy as jnp
18
+
19
+ jax.config.update("jax_enable_x64", True)
20
+
21
+ import numpy as np
22
+
23
+ from scipy import sparse as sp
24
+
25
+
26
+ def test_simple_qp():
27
+ ss = Settings()
28
+ ss.max_aug_kkt_violation = 1e-12
29
+ ss.num_iterative_refinement_steps = 1
30
+ ss.penalty_parameter_increase_factor = 3.0
31
+ ss.enable_elastics = True
32
+ ss.elastic_var_cost_coeff = 1e6
33
+ ss.assert_checks_pass = True
34
+
35
+ @jax.jit
36
+ def f(x):
37
+ return x[1] * (5.0 + x[0])
38
+
39
+ @jax.jit
40
+ def c(x):
41
+ return jnp.array([])
42
+
43
+ @jax.jit
44
+ def g(x):
45
+ return jnp.array(
46
+ [
47
+ 5.0 - x[0] * x[1],
48
+ x[0] * x[0] + x[1] * x[1] - 20.0,
49
+ ]
50
+ )
51
+
52
+ @jax.jit
53
+ def grad_f(x):
54
+ return jax.grad(f)(x)
55
+
56
+ @jax.jit
57
+ def approx_upp_hess_f(x):
58
+ def proj_psd(Q, delta=1e-6):
59
+ S, _V = jnp.linalg.eigh(Q)
60
+ k = -jnp.minimum(jnp.min(S), 0.0) + delta
61
+ return Q + k * jnp.eye(Q.shape[0])
62
+
63
+ return jnp.triu(proj_psd(jax.hessian(f)(x)))
64
+
65
+ @jax.jit
66
+ def jac_c(x):
67
+ return jax.jacfwd(c)(x)
68
+
69
+ @jax.jit
70
+ def jac_g(x):
71
+ return jax.jacfwd(g)(x)
72
+
73
+ x_dim = 2
74
+
75
+ mock_x = jnp.ones(
76
+ [
77
+ x_dim,
78
+ ]
79
+ )
80
+ jac_c_nnz_pattern = np.array(jac_c(mock_x))
81
+ jac_g_nnz_pattern = np.array(jac_g(mock_x))
82
+ upper_L_hess_nnz_pattern = np.array(approx_upp_hess_f(mock_x))
83
+
84
+ jac_c_nnz_pattern_sp = sp.csr_matrix(jac_c_nnz_pattern)
85
+ jac_g_nnz_pattern_sp = sp.csr_matrix(jac_g_nnz_pattern)
86
+ upper_L_hess_nnz_pattern_sp = sp.csc_matrix(upper_L_hess_nnz_pattern)
87
+
88
+ qs = QDLDLSettings()
89
+ qs.permute_kkt_system = True
90
+ qs.kkt_pinv = get_kkt_perm_inv(
91
+ P=upper_L_hess_nnz_pattern_sp,
92
+ A=jac_c_nnz_pattern_sp,
93
+ G=jac_g_nnz_pattern_sp,
94
+ )
95
+
96
+ pd = ProblemDimensions()
97
+ pd.x_dim = x_dim
98
+ pd.s_dim = 2
99
+ pd.y_dim = 0
100
+ pd.upper_hessian_lagrangian_nnz = upper_L_hess_nnz_pattern_sp.nnz
101
+ pd.jacobian_c_nnz = jac_c_nnz_pattern_sp.nnz
102
+ pd.jacobian_g_nnz = jac_g_nnz_pattern_sp.nnz
103
+ pd.kkt_nnz, pd.kkt_L_nnz = get_kkt_and_L_nnzs(
104
+ P=upper_L_hess_nnz_pattern_sp,
105
+ A=jac_c_nnz_pattern_sp,
106
+ G=jac_g_nnz_pattern_sp,
107
+ perm_inv=qs.kkt_pinv,
108
+ )
109
+ pd.is_jacobian_c_transposed = True
110
+ pd.is_jacobian_g_transposed = True
111
+
112
+ def mc(mci: ModelCallbackInput) -> ModelCallbackOutput:
113
+ mco = ModelCallbackOutput()
114
+
115
+ mco.f = f(mci.x)
116
+ mco.c = np.array(c(mci.x))
117
+ mco.g = np.array(g(mci.x))
118
+
119
+ mco.gradient_f = np.array(grad_f(mci.x))
120
+
121
+ C = np.array(jac_c(mci.x))
122
+ jac_c_nnz_pattern_sp.data = C[jac_c_nnz_pattern != 0.0]
123
+ mco.jacobian_c = jac_c_nnz_pattern_sp
124
+
125
+ G = np.array(jac_g(mci.x))
126
+ jac_g_nnz_pattern_sp.data = G[jac_g_nnz_pattern != 0.0]
127
+ mco.jacobian_g = jac_g_nnz_pattern_sp
128
+
129
+ upp_hess_L = np.array(approx_upp_hess_f(mci.x))
130
+ upper_L_hess_nnz_pattern_sp.data = upp_hess_L[upper_L_hess_nnz_pattern != 0.0]
131
+ mco.upper_hessian_lagrangian = upper_L_hess_nnz_pattern_sp
132
+
133
+ return mco
134
+
135
+ solver = Solver(ss, qs, pd, mc)
136
+
137
+ vars = Variables(pd)
138
+ vars.x[:] = 0.0
139
+ vars.s[:] = 1.0
140
+ vars.y[:] = 0.0
141
+ vars.e[:] = 0.0
142
+ vars.z[:] = 1.0
143
+
144
+ output = solver.solve(vars)
145
+
146
+ assert output.exit_status == Status.SOLVED
147
+ assert vars.x[0] == pytest.approx(-1.15747396, abs=1e-6)
148
+ assert vars.x[1] == pytest.approx(-4.31975162, abs=1e-6)
@@ -0,0 +1,145 @@
1
+ from sip_python import (
2
+ get_kkt_and_L_nnzs,
3
+ get_kkt_perm_inv,
4
+ ModelCallbackInput,
5
+ ModelCallbackOutput,
6
+ ProblemDimensions,
7
+ QDLDLSettings,
8
+ Settings,
9
+ Solver,
10
+ Status,
11
+ Variables,
12
+ )
13
+
14
+ import pytest
15
+
16
+ import jax
17
+ from jax import numpy as jnp
18
+
19
+ jax.config.update("jax_enable_x64", True)
20
+
21
+ import numpy as np
22
+
23
+ from scipy import sparse as sp
24
+
25
+
26
+ def test_simple_qp():
27
+ ss = Settings()
28
+ ss.max_aug_kkt_violation = 1e-12
29
+ ss.enable_elastics = True
30
+ ss.elastic_var_cost_coeff = 1e6
31
+ ss.assert_checks_pass = True
32
+
33
+ @jax.jit
34
+ def f(x):
35
+ return (
36
+ 0.5 * (4.0 * x[0] * x[0] + 2.0 * x[0] * x[1] + 2.0 * x[1] * x[1])
37
+ + x[0]
38
+ + x[1]
39
+ )
40
+
41
+ @jax.jit
42
+ def c(x):
43
+ return jnp.array([x[0] + x[1] - 1.0])
44
+
45
+ @jax.jit
46
+ def g(x):
47
+ return jnp.array([x[0] - 0.7, -x[0] - 0.0, x[1] - 0.7, -x[1] - 0.0])
48
+
49
+ @jax.jit
50
+ def grad_f(x):
51
+ return jax.grad(f)(x)
52
+
53
+ @jax.jit
54
+ def approx_upp_hess_f(x):
55
+ def proj_psd(Q, delta=1e-6):
56
+ S, _V = jnp.linalg.eigh(Q)
57
+ k = -jnp.minimum(jnp.min(S), 0.0) + delta
58
+ return Q + k * jnp.eye(Q.shape[0])
59
+
60
+ return jnp.triu(proj_psd(jax.hessian(f)(x)))
61
+
62
+ @jax.jit
63
+ def jac_c(x):
64
+ return jax.jacfwd(c)(x)
65
+
66
+ @jax.jit
67
+ def jac_g(x):
68
+ return jax.jacfwd(g)(x)
69
+
70
+ x_dim = 2
71
+
72
+ mock_x = jnp.ones(
73
+ [
74
+ x_dim,
75
+ ]
76
+ )
77
+ jac_c_nnz_pattern = np.array(jac_c(mock_x))
78
+ jac_g_nnz_pattern = np.array(jac_g(mock_x))
79
+ upper_L_hess_nnz_pattern = np.array(approx_upp_hess_f(mock_x))
80
+
81
+ jac_c_nnz_pattern_sp = sp.csr_matrix(jac_c_nnz_pattern)
82
+ jac_g_nnz_pattern_sp = sp.csr_matrix(jac_g_nnz_pattern)
83
+ upper_L_hess_nnz_pattern_sp = sp.csc_matrix(upper_L_hess_nnz_pattern)
84
+
85
+ qs = QDLDLSettings()
86
+ qs.permute_kkt_system = True
87
+ qs.kkt_pinv = get_kkt_perm_inv(
88
+ P=upper_L_hess_nnz_pattern_sp,
89
+ A=jac_c_nnz_pattern_sp,
90
+ G=jac_g_nnz_pattern_sp,
91
+ )
92
+
93
+ pd = ProblemDimensions()
94
+ pd.x_dim = x_dim
95
+ pd.s_dim = 4
96
+ pd.y_dim = 1
97
+ pd.upper_hessian_lagrangian_nnz = upper_L_hess_nnz_pattern_sp.nnz
98
+ pd.jacobian_c_nnz = jac_c_nnz_pattern_sp.nnz
99
+ pd.jacobian_g_nnz = jac_g_nnz_pattern_sp.nnz
100
+ pd.kkt_nnz, pd.kkt_L_nnz = get_kkt_and_L_nnzs(
101
+ P=upper_L_hess_nnz_pattern_sp,
102
+ A=jac_c_nnz_pattern_sp,
103
+ G=jac_g_nnz_pattern_sp,
104
+ perm_inv=qs.kkt_pinv,
105
+ )
106
+ pd.is_jacobian_c_transposed = True
107
+ pd.is_jacobian_g_transposed = True
108
+
109
+ def mc(mci: ModelCallbackInput) -> ModelCallbackOutput:
110
+ mco = ModelCallbackOutput()
111
+
112
+ mco.f = f(mci.x)
113
+ mco.c = np.array(c(mci.x))
114
+ mco.g = np.array(g(mci.x))
115
+
116
+ mco.gradient_f = np.array(grad_f(mci.x))
117
+
118
+ C = np.array(jac_c(mci.x))
119
+ jac_c_nnz_pattern_sp.data = C[jac_c_nnz_pattern != 0.0]
120
+ mco.jacobian_c = jac_c_nnz_pattern_sp
121
+
122
+ G = np.array(jac_g(mci.x))
123
+ jac_g_nnz_pattern_sp.data = G[jac_g_nnz_pattern != 0.0]
124
+ mco.jacobian_g = jac_g_nnz_pattern_sp
125
+
126
+ upp_hess_L = np.array(approx_upp_hess_f(mci.x))
127
+ upper_L_hess_nnz_pattern_sp.data = upp_hess_L[upper_L_hess_nnz_pattern != 0.0]
128
+ mco.upper_hessian_lagrangian = upper_L_hess_nnz_pattern_sp
129
+
130
+ return mco
131
+
132
+ solver = Solver(ss, qs, pd, mc)
133
+
134
+ vars = Variables(pd)
135
+ vars.x[:] = 0.0
136
+ vars.s[:] = 1.0
137
+ vars.y[:] = 0.0
138
+ vars.e[:] = 0.0
139
+ vars.z[:] = 1.0
140
+
141
+ output = solver.solve(vars)
142
+
143
+ assert output.exit_status == Status.SOLVED
144
+ assert vars.x[0] == pytest.approx(0.3, abs=1e-5)
145
+ assert vars.x[1] == pytest.approx(0.7, abs=1e-5)