sip-python 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sip_python
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Python bindings for the SIP solver.
5
5
  Author-email: João Sousa-Pinto <joaospinto@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/joaospinto/sip_python
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "sip_python"
7
- version = "0.0.2"
7
+ version = "0.0.3"
8
8
  description = "Python bindings for the SIP solver."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -102,10 +102,11 @@ class BuildBazelExtension(build_ext.build_ext):
102
102
  "run",
103
103
  ext.bazel_target,
104
104
  f"--symlink_prefix={temp_path / 'bazel-'}",
105
- f"--compilation_mode={'dbg' if self.debug else 'opt'}",
106
105
  f"--target_python_version={python_version}",
107
106
  ]
108
107
 
108
+ if self.debug:
109
+ bazel_argv += ["--config=debug"]
109
110
  if ext.py_limited_api:
110
111
  bazel_argv += ["--py_limited_api=cp312"]
111
112
  if ext.free_threaded:
@@ -0,0 +1,7 @@
1
+ from .sip_python_ext import *
2
+ from .helpers import (
3
+ get_K,
4
+ get_kkt_perm_inv,
5
+ get_kkt_and_L_nnzs,
6
+ get_kkt_perm_inv_and_nnzs,
7
+ )
@@ -0,0 +1,125 @@
1
+ import ctypes
2
+ import os
3
+ import warnings
4
+
5
+ import numpy as np
6
+ from scipy import sparse as spa
7
+ from scipy.sparse.csgraph import reverse_cuthill_mckee
8
+
9
+ from .sip_python_ext import getLnnz
10
+
11
+ # Try to load libamd directly for fast AMD ordering.
12
+ _libamd = None
13
+ try:
14
+ _dylibs_dir = os.path.join(
15
+ os.path.dirname(__import__("cvxopt").__file__), ".dylibs"
16
+ )
17
+ for f in os.listdir(_dylibs_dir):
18
+ if f.startswith("libamd"):
19
+ _libamd = ctypes.CDLL(os.path.join(_dylibs_dir, f))
20
+ _libamd.amd_l_order.restype = ctypes.c_int
21
+ _libamd.amd_l_order.argtypes = [
22
+ ctypes.c_long,
23
+ ctypes.POINTER(ctypes.c_long),
24
+ ctypes.POINTER(ctypes.c_long),
25
+ ctypes.POINTER(ctypes.c_long),
26
+ ctypes.c_void_p,
27
+ ctypes.c_void_p,
28
+ ]
29
+ break
30
+ except Exception:
31
+ pass
32
+
33
+
34
+ def _amd_order(K_csc):
35
+ """Compute AMD ordering on a CSC matrix via libamd."""
36
+ K_sym = (K_csc + K_csc.T).tocsc()
37
+ K_sym.sort_indices()
38
+ n = K_sym.shape[0]
39
+ Ap = K_sym.indptr.astype(np.int64)
40
+ Ai = K_sym.indices.astype(np.int64)
41
+ perm = np.empty(n, dtype=np.int64)
42
+ ret = _libamd.amd_l_order(
43
+ n,
44
+ Ap.ctypes.data_as(ctypes.POINTER(ctypes.c_long)),
45
+ Ai.ctypes.data_as(ctypes.POINTER(ctypes.c_long)),
46
+ perm.ctypes.data_as(ctypes.POINTER(ctypes.c_long)),
47
+ None,
48
+ None,
49
+ )
50
+ if ret != 0:
51
+ raise RuntimeError(f"amd_l_order failed with return code {ret}")
52
+ return perm.astype(np.intp)
53
+
54
+
55
+ def get_K(P, A, G):
56
+ # K = [ P + r1 I_x A.T G.T ]
57
+ # [ A -r2 * I_y 0 ]
58
+ # [ G 0 -r3 I_z ]
59
+
60
+ if isinstance(P, np.ndarray):
61
+ P = spa.csc_matrix(P)
62
+
63
+ if isinstance(A, np.ndarray):
64
+ A = spa.csr_matrix(A)
65
+
66
+ if isinstance(G, np.ndarray):
67
+ G = spa.csr_matrix(G)
68
+
69
+ x_dim = P.shape[0]
70
+ s_dim = G.shape[0]
71
+ y_dim = A.shape[0]
72
+
73
+ mod_P = spa.csc_matrix.copy(P)
74
+ mod_P.data[:] = 1.0
75
+
76
+ Z = spa.csc_matrix((y_dim, s_dim))
77
+
78
+ K = spa.block_array(
79
+ blocks=[
80
+ [mod_P + spa.eye(x_dim), A.T, G.T],
81
+ [A, -spa.eye(y_dim), Z],
82
+ [G, Z.T, -spa.eye(s_dim)],
83
+ ],
84
+ format="coo",
85
+ )
86
+
87
+ return K
88
+
89
+
90
+ def _get_kkt_perm(K, verbose):
91
+ K_csc = spa.csc_matrix(K)
92
+ if _libamd is not None:
93
+ return _amd_order(K_csc)
94
+ if verbose:
95
+ warnings.warn(
96
+ "cvxopt not installed; using reverse Cuthill-McKee (RCM) "
97
+ "instead of approximate minimum degree (AMD)."
98
+ )
99
+ return reverse_cuthill_mckee(K_csc)
100
+
101
+
102
+ def get_kkt_perm_inv(K, verbose=True):
103
+ perm = _get_kkt_perm(K, verbose)
104
+
105
+ perm_inv = np.empty_like(perm)
106
+ perm_inv[perm] = np.arange(perm_inv.shape[0])
107
+
108
+ return perm_inv
109
+
110
+
111
+ def get_kkt_and_L_nnzs(K, perm_inv):
112
+ permuted_K = spa.coo_matrix.copy(K)
113
+ permuted_K.row = perm_inv[permuted_K.row]
114
+ permuted_K.col = perm_inv[permuted_K.col]
115
+
116
+ kkt_L_nnz = getLnnz(spa.triu(permuted_K))
117
+
118
+ return K.nnz, kkt_L_nnz
119
+
120
+
121
+ def get_kkt_perm_inv_and_nnzs(P, A, G, verbose=True):
122
+ K = get_K(P, A, G)
123
+ perm_inv = get_kkt_perm_inv(K, verbose)
124
+ K_nnz, kkt_L_nnz = get_kkt_and_L_nnzs(K, perm_inv)
125
+ return perm_inv, K_nnz, kkt_L_nnz
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sip_python
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Python bindings for the SIP solver.
5
5
  Author-email: João Sousa-Pinto <joaospinto@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/joaospinto/sip_python
@@ -9,5 +9,7 @@ src/sip_python.egg-info/SOURCES.txt
9
9
  src/sip_python.egg-info/dependency_links.txt
10
10
  src/sip_python.egg-info/requires.txt
11
11
  src/sip_python.egg-info/top_level.txt
12
+ tests/test_simple_constrained_lqr.py
13
+ tests/test_simple_lqr.py
12
14
  tests/test_simple_nlp.py
13
15
  tests/test_simple_qp.py
@@ -0,0 +1,195 @@
1
+ import pytest
2
+ import jax
3
+ from jax import numpy as jnp
4
+
5
+ import numpy as np
6
+ from scipy import sparse as sp
7
+
8
+ from sip_python import (
9
+ get_kkt_perm_inv_and_nnzs,
10
+ ModelCallbackInput,
11
+ ModelCallbackOutput,
12
+ ProblemDimensions,
13
+ QDLDLSettings,
14
+ Settings,
15
+ Solver,
16
+ Status,
17
+ Variables,
18
+ )
19
+
20
+ jax.config.update("jax_enable_x64", True)
21
+
22
+
23
+ def test_simple_constrained_lqr():
24
+ ss = Settings()
25
+ ss.max_kkt_violation = 1e-6
26
+ ss.enable_elastics = True
27
+ ss.elastic_var_cost_coeff = 1e6
28
+ ss.assert_checks_pass = True
29
+ ss.penalty_parameter_increase_factor = 2.0
30
+ ss.mu_update_factor = 0.9
31
+ ss.max_iterations = 200
32
+ ss.max_ls_iterations = 1000
33
+ ss.print_logs = False
34
+ ss.print_line_search_logs = False
35
+ ss.print_search_direction_logs = False
36
+ ss.print_derivative_check_logs = False
37
+
38
+ x_dim = 2
39
+ u_dim = 1
40
+ g_dim = 2
41
+ c_dim = 1
42
+ T = 100
43
+
44
+ dt = 0.1
45
+
46
+ @jax.jit
47
+ def split(x):
48
+ x = jnp.concatenate([x, jnp.zeros(u_dim)])
49
+ x = x.reshape([T + 1, x_dim + u_dim])
50
+ X = x[:, :x_dim]
51
+ U = x[:T, x_dim:]
52
+ return X, U
53
+
54
+ @jax.jit
55
+ def f(x):
56
+ X, U = split(x)
57
+ stagewise_costs = jax.vmap(
58
+ lambda i: 0.5 * X[i, 0] ** 2
59
+ + 0.5 * 0.1 * X[i, 1] ** 2
60
+ + 0.5 * 0.1 * U[i, 0] ** 2
61
+ )(jnp.arange(T))
62
+ terminal_cost = 0.5 * X[T, 0] ** 2 + 0.5 * 0.1 * X[T, 1] ** 2
63
+ return stagewise_costs.sum() + terminal_cost
64
+
65
+ @jax.jit
66
+ def c(x):
67
+ x_0 = jnp.array([0.0, 10.0])
68
+ A = jnp.array([[1.0, dt], [0.0, 1.0]])
69
+ B = jnp.array([[dt**2 / 2], [dt]])
70
+ X, U = split(x)
71
+
72
+ out = jnp.concatenate(
73
+ [
74
+ x_0 - X[0],
75
+ jnp.array(
76
+ [
77
+ 0.0,
78
+ ]
79
+ ),
80
+ jax.vmap(
81
+ lambda i: jnp.concatenate(
82
+ [
83
+ (A @ X[i] + B @ U[i] - X[i + 1]),
84
+ jnp.array(
85
+ [
86
+ jnp.where(i + 1 == T, X[i + 1, 1], 0.0),
87
+ ]
88
+ ),
89
+ ]
90
+ )
91
+ )(jnp.arange(T)).flatten(),
92
+ ]
93
+ )
94
+ return out
95
+
96
+ @jax.jit
97
+ def g(x):
98
+ _X, U = split(x)
99
+ return jnp.concatenate(
100
+ [
101
+ jax.vmap(lambda i: jnp.array([U[i, 0] - 2.0, -U[i, 0] - 2.0]))(
102
+ jnp.arange(T)
103
+ ).flatten(),
104
+ jnp.zeros(g_dim),
105
+ ]
106
+ )
107
+
108
+ @jax.jit
109
+ def grad_f(x):
110
+ return jax.grad(f)(x)
111
+
112
+ @jax.jit
113
+ def approx_upp_hess_f(x):
114
+ def proj_psd(Q, delta=1e-6):
115
+ S, _V = jnp.linalg.eigh(Q)
116
+ k = -jnp.minimum(jnp.min(S), 0.0) + delta
117
+ return Q + k * jnp.eye(Q.shape[0])
118
+
119
+ return jnp.triu(proj_psd(jax.hessian(f)(x)))
120
+
121
+ @jax.jit
122
+ def jac_c(x):
123
+ return jax.jacfwd(c)(x)
124
+
125
+ @jax.jit
126
+ def jac_g(x):
127
+ return jax.jacfwd(g)(x)
128
+
129
+ pd = ProblemDimensions()
130
+ pd.x_dim = T * (x_dim + u_dim) + x_dim
131
+ pd.s_dim = (T + 1) * g_dim
132
+ pd.y_dim = (T + 1) * (x_dim + c_dim)
133
+
134
+ mock_x = jnp.ones(
135
+ [
136
+ pd.x_dim,
137
+ ]
138
+ )
139
+ jac_c_nnz_pattern = np.array(jac_c(mock_x))
140
+ jac_g_nnz_pattern = np.array(jac_g(mock_x))
141
+ upper_L_hess_nnz_pattern = np.array(approx_upp_hess_f(mock_x))
142
+
143
+ jac_c_nnz_pattern_sp = sp.csr_matrix(jac_c_nnz_pattern)
144
+ jac_g_nnz_pattern_sp = sp.csr_matrix(jac_g_nnz_pattern)
145
+ upper_L_hess_nnz_pattern_sp = sp.csc_matrix(upper_L_hess_nnz_pattern)
146
+
147
+ qs = QDLDLSettings()
148
+ qs.permute_kkt_system = True
149
+ qs.kkt_pinv, pd.kkt_nnz, pd.kkt_L_nnz = get_kkt_perm_inv_and_nnzs(
150
+ P=upper_L_hess_nnz_pattern_sp,
151
+ A=jac_c_nnz_pattern_sp,
152
+ G=jac_g_nnz_pattern_sp,
153
+ )
154
+
155
+ pd.upper_hessian_lagrangian_nnz = upper_L_hess_nnz_pattern_sp.nnz
156
+ pd.jacobian_c_nnz = jac_c_nnz_pattern_sp.nnz
157
+ pd.jacobian_g_nnz = jac_g_nnz_pattern_sp.nnz
158
+ pd.is_jacobian_c_transposed = True
159
+ pd.is_jacobian_g_transposed = True
160
+
161
+ def mc(mci: ModelCallbackInput) -> ModelCallbackOutput:
162
+ mco = ModelCallbackOutput()
163
+
164
+ mco.f = f(mci.x)
165
+ mco.c = np.array(c(mci.x))
166
+ mco.g = np.array(g(mci.x))
167
+
168
+ mco.gradient_f = np.array(grad_f(mci.x))
169
+
170
+ C = np.array(jac_c(mci.x))
171
+ jac_c_nnz_pattern_sp.data = C[jac_c_nnz_pattern != 0.0]
172
+ mco.jacobian_c = jac_c_nnz_pattern_sp
173
+
174
+ G = np.array(jac_g(mci.x))
175
+ jac_g_nnz_pattern_sp.data = G[jac_g_nnz_pattern != 0.0]
176
+ mco.jacobian_g = jac_g_nnz_pattern_sp
177
+
178
+ upp_hess_L = np.array(approx_upp_hess_f(mci.x))
179
+ upper_L_hess_nnz_pattern_sp.data = upp_hess_L[upper_L_hess_nnz_pattern != 0.0]
180
+ mco.upper_hessian_lagrangian = upper_L_hess_nnz_pattern_sp
181
+
182
+ return mco
183
+
184
+ solver = Solver(ss, qs, pd, mc)
185
+
186
+ vars = Variables(pd)
187
+ vars.x[:] = 0.0
188
+ vars.s[:] = 1.0
189
+ vars.y[:] = 0.0
190
+ vars.e[:] = 0.0
191
+ vars.z[:] = 1.0
192
+
193
+ output = solver.solve(vars)
194
+
195
+ assert output.exit_status == Status.SOLVED
@@ -0,0 +1,171 @@
1
+ import pytest
2
+ import jax
3
+ from jax import numpy as jnp
4
+
5
+ import numpy as np
6
+ from scipy import sparse as sp
7
+
8
+ from sip_python import (
9
+ get_kkt_perm_inv_and_nnzs,
10
+ ModelCallbackInput,
11
+ ModelCallbackOutput,
12
+ ProblemDimensions,
13
+ QDLDLSettings,
14
+ Settings,
15
+ Solver,
16
+ Status,
17
+ Variables,
18
+ )
19
+
20
+ jax.config.update("jax_enable_x64", True)
21
+
22
+
23
+ def test_simple_lqr():
24
+ ss = Settings()
25
+ ss.max_kkt_violation = 1e-6
26
+ ss.enable_elastics = True
27
+ ss.elastic_var_cost_coeff = 1e6
28
+ ss.assert_checks_pass = True
29
+ ss.penalty_parameter_increase_factor = 2.0
30
+ ss.mu_update_factor = 0.9
31
+ ss.print_logs = False
32
+ ss.print_line_search_logs = False
33
+ ss.print_search_direction_logs = False
34
+ ss.print_derivative_check_logs = False
35
+
36
+ x_dim = 2
37
+ u_dim = 1
38
+ g_dim = 0
39
+ c_dim = 0
40
+ T = 100
41
+
42
+ dt = 0.1
43
+
44
+ @jax.jit
45
+ def split(x):
46
+ x = jnp.concatenate([x, jnp.zeros(u_dim)])
47
+ x = x.reshape([T + 1, x_dim + u_dim])
48
+ X = x[:, :x_dim]
49
+ U = x[:T, x_dim:]
50
+ return X, U
51
+
52
+ @jax.jit
53
+ def f(x):
54
+ X, U = split(x)
55
+ stagewise_costs = jax.vmap(
56
+ lambda i: 0.5 * X[i, 0] ** 2
57
+ + 0.5 * 0.1 * X[i, 1] ** 2
58
+ + 0.5 * 0.1 * U[i, 0] ** 2
59
+ )(jnp.arange(T))
60
+ terminal_cost = 0.5 * X[T, 0] ** 2 + 0.5 * 0.1 * X[T, 1] ** 2
61
+ return stagewise_costs.sum() + terminal_cost
62
+
63
+ @jax.jit
64
+ def c(x):
65
+ x_0 = jnp.array([0.0, 10.0])
66
+ A = jnp.array([[1.0, dt], [0.0, 1.0]])
67
+ B = jnp.array([[dt**2 / 2], [dt]])
68
+ X, U = split(x)
69
+
70
+ out = jnp.concatenate(
71
+ [
72
+ x_0 - X[0],
73
+ jax.vmap(lambda i: A @ X[i] + B @ U[i] - X[i + 1])(
74
+ jnp.arange(T)
75
+ ).flatten(),
76
+ ]
77
+ )
78
+ return out
79
+
80
+ @jax.jit
81
+ def g(x):
82
+ return jnp.array([])
83
+
84
+ @jax.jit
85
+ def grad_f(x):
86
+ return jax.grad(f)(x)
87
+
88
+ @jax.jit
89
+ def approx_upp_hess_f(x):
90
+ def proj_psd(Q, delta=1e-6):
91
+ S, _V = jnp.linalg.eigh(Q)
92
+ k = -jnp.minimum(jnp.min(S), 0.0) + delta
93
+ return Q + k * jnp.eye(Q.shape[0])
94
+
95
+ return jnp.triu(proj_psd(jax.hessian(f)(x)))
96
+
97
+ @jax.jit
98
+ def jac_c(x):
99
+ return jax.jacfwd(c)(x)
100
+
101
+ @jax.jit
102
+ def jac_g(x):
103
+ return jax.jacfwd(g)(x)
104
+
105
+ pd = ProblemDimensions()
106
+ pd.x_dim = T * (x_dim + u_dim) + x_dim
107
+ pd.s_dim = (T + 1) * g_dim
108
+ pd.y_dim = (T + 1) * (x_dim + c_dim)
109
+
110
+ mock_x = jnp.ones(
111
+ [
112
+ pd.x_dim,
113
+ ]
114
+ )
115
+ jac_c_nnz_pattern = np.array(jac_c(mock_x))
116
+ jac_g_nnz_pattern = np.array(jac_g(mock_x))
117
+ upper_L_hess_nnz_pattern = np.array(approx_upp_hess_f(mock_x))
118
+
119
+ jac_c_nnz_pattern_sp = sp.csr_matrix(jac_c_nnz_pattern)
120
+ jac_g_nnz_pattern_sp = sp.csr_matrix(jac_g_nnz_pattern)
121
+ upper_L_hess_nnz_pattern_sp = sp.csc_matrix(upper_L_hess_nnz_pattern)
122
+
123
+ qs = QDLDLSettings()
124
+ qs.permute_kkt_system = True
125
+ qs.kkt_pinv, pd.kkt_nnz, pd.kkt_L_nnz = get_kkt_perm_inv_and_nnzs(
126
+ P=upper_L_hess_nnz_pattern_sp,
127
+ A=jac_c_nnz_pattern_sp,
128
+ G=jac_g_nnz_pattern_sp,
129
+ )
130
+
131
+ pd.upper_hessian_lagrangian_nnz = upper_L_hess_nnz_pattern_sp.nnz
132
+ pd.jacobian_c_nnz = jac_c_nnz_pattern_sp.nnz
133
+ pd.jacobian_g_nnz = jac_g_nnz_pattern_sp.nnz
134
+ pd.is_jacobian_c_transposed = True
135
+ pd.is_jacobian_g_transposed = True
136
+
137
+ def mc(mci: ModelCallbackInput) -> ModelCallbackOutput:
138
+ mco = ModelCallbackOutput()
139
+
140
+ mco.f = f(mci.x)
141
+ mco.c = np.array(c(mci.x))
142
+ mco.g = np.array(g(mci.x))
143
+
144
+ mco.gradient_f = np.array(grad_f(mci.x))
145
+
146
+ C = np.array(jac_c(mci.x))
147
+ jac_c_nnz_pattern_sp.data = C[jac_c_nnz_pattern != 0.0]
148
+ mco.jacobian_c = jac_c_nnz_pattern_sp
149
+
150
+ G = np.array(jac_g(mci.x))
151
+ jac_g_nnz_pattern_sp.data = G[jac_g_nnz_pattern != 0.0]
152
+ mco.jacobian_g = jac_g_nnz_pattern_sp
153
+
154
+ upp_hess_L = np.array(approx_upp_hess_f(mci.x))
155
+ upper_L_hess_nnz_pattern_sp.data = upp_hess_L[upper_L_hess_nnz_pattern != 0.0]
156
+ mco.upper_hessian_lagrangian = upper_L_hess_nnz_pattern_sp
157
+
158
+ return mco
159
+
160
+ solver = Solver(ss, qs, pd, mc)
161
+
162
+ vars = Variables(pd)
163
+ vars.x[:] = 0.0
164
+ vars.s[:] = 1.0
165
+ vars.y[:] = 0.0
166
+ vars.e[:] = 0.0
167
+ vars.z[:] = 1.0
168
+
169
+ output = solver.solve(vars)
170
+
171
+ assert output.exit_status == Status.SOLVED
@@ -1,6 +1,12 @@
1
+ import pytest
2
+ import jax
3
+ from jax import numpy as jnp
4
+
5
+ import numpy as np
6
+ from scipy import sparse as sp
7
+
1
8
  from sip_python import (
2
- get_kkt_and_L_nnzs,
3
- get_kkt_perm_inv,
9
+ get_kkt_perm_inv_and_nnzs,
4
10
  ModelCallbackInput,
5
11
  ModelCallbackOutput,
6
12
  ProblemDimensions,
@@ -11,17 +17,8 @@ from sip_python import (
11
17
  Variables,
12
18
  )
13
19
 
14
- import pytest
15
-
16
- import jax
17
- from jax import numpy as jnp
18
-
19
20
  jax.config.update("jax_enable_x64", True)
20
21
 
21
- import numpy as np
22
-
23
- from scipy import sparse as sp
24
-
25
22
 
26
23
  def test_simple_qp():
27
24
  ss = Settings()
@@ -29,6 +26,10 @@ def test_simple_qp():
29
26
  ss.enable_elastics = True
30
27
  ss.elastic_var_cost_coeff = 1e6
31
28
  ss.assert_checks_pass = True
29
+ ss.print_logs = False
30
+ ss.print_line_search_logs = False
31
+ ss.print_search_direction_logs = False
32
+ ss.print_derivative_check_logs = False
32
33
 
33
34
  @jax.jit
34
35
  def f(x):
@@ -83,27 +84,23 @@ def test_simple_qp():
83
84
  jac_g_nnz_pattern_sp = sp.csr_matrix(jac_g_nnz_pattern)
84
85
  upper_L_hess_nnz_pattern_sp = sp.csc_matrix(upper_L_hess_nnz_pattern)
85
86
 
87
+ pd = ProblemDimensions()
88
+ pd.x_dim = x_dim
89
+ pd.s_dim = jac_g_nnz_pattern_sp.shape[0]
90
+ pd.y_dim = jac_c_nnz_pattern_sp.shape[0]
91
+
86
92
  qs = QDLDLSettings()
87
93
  qs.permute_kkt_system = True
88
- qs.kkt_pinv = get_kkt_perm_inv(
94
+ qs.kkt_pinv, pd.kkt_nnz, pd.kkt_L_nnz = get_kkt_perm_inv_and_nnzs(
89
95
  P=upper_L_hess_nnz_pattern_sp,
90
96
  A=jac_c_nnz_pattern_sp,
91
97
  G=jac_g_nnz_pattern_sp,
92
98
  )
93
99
 
94
- pd = ProblemDimensions()
95
- pd.x_dim = x_dim
96
- pd.s_dim = 2
97
- pd.y_dim = 0
98
100
  pd.upper_hessian_lagrangian_nnz = upper_L_hess_nnz_pattern_sp.nnz
99
101
  pd.jacobian_c_nnz = jac_c_nnz_pattern_sp.nnz
100
102
  pd.jacobian_g_nnz = jac_g_nnz_pattern_sp.nnz
101
- pd.kkt_nnz, pd.kkt_L_nnz = get_kkt_and_L_nnzs(
102
- P=upper_L_hess_nnz_pattern_sp,
103
- A=jac_c_nnz_pattern_sp,
104
- G=jac_g_nnz_pattern_sp,
105
- perm_inv=qs.kkt_pinv,
106
- )
103
+
107
104
  pd.is_jacobian_c_transposed = True
108
105
  pd.is_jacobian_g_transposed = True
109
106
 
@@ -1,6 +1,12 @@
1
+ import pytest
2
+ import jax
3
+ from jax import numpy as jnp
4
+
5
+ import numpy as np
6
+ from scipy import sparse as sp
7
+
1
8
  from sip_python import (
2
- get_kkt_and_L_nnzs,
3
- get_kkt_perm_inv,
9
+ get_kkt_perm_inv_and_nnzs,
4
10
  ModelCallbackInput,
5
11
  ModelCallbackOutput,
6
12
  ProblemDimensions,
@@ -11,17 +17,8 @@ from sip_python import (
11
17
  Variables,
12
18
  )
13
19
 
14
- import pytest
15
-
16
- import jax
17
- from jax import numpy as jnp
18
-
19
20
  jax.config.update("jax_enable_x64", True)
20
21
 
21
- import numpy as np
22
-
23
- from scipy import sparse as sp
24
-
25
22
 
26
23
  def test_simple_qp():
27
24
  ss = Settings()
@@ -29,6 +26,10 @@ def test_simple_qp():
29
26
  ss.enable_elastics = True
30
27
  ss.elastic_var_cost_coeff = 1e6
31
28
  ss.assert_checks_pass = True
29
+ ss.print_logs = False
30
+ ss.print_line_search_logs = False
31
+ ss.print_search_direction_logs = False
32
+ ss.print_derivative_check_logs = False
32
33
 
33
34
  @jax.jit
34
35
  def f(x):
@@ -82,27 +83,23 @@ def test_simple_qp():
82
83
  jac_g_nnz_pattern_sp = sp.csr_matrix(jac_g_nnz_pattern)
83
84
  upper_L_hess_nnz_pattern_sp = sp.csc_matrix(upper_L_hess_nnz_pattern)
84
85
 
86
+ pd = ProblemDimensions()
87
+ pd.x_dim = x_dim
88
+ pd.s_dim = jac_g_nnz_pattern_sp.shape[0]
89
+ pd.y_dim = jac_c_nnz_pattern_sp.shape[0]
90
+
85
91
  qs = QDLDLSettings()
86
92
  qs.permute_kkt_system = True
87
- qs.kkt_pinv = get_kkt_perm_inv(
93
+ qs.kkt_pinv, pd.kkt_nnz, pd.kkt_L_nnz = get_kkt_perm_inv_and_nnzs(
88
94
  P=upper_L_hess_nnz_pattern_sp,
89
95
  A=jac_c_nnz_pattern_sp,
90
96
  G=jac_g_nnz_pattern_sp,
91
97
  )
92
98
 
93
- pd = ProblemDimensions()
94
- pd.x_dim = x_dim
95
- pd.s_dim = 4
96
- pd.y_dim = 1
97
99
  pd.upper_hessian_lagrangian_nnz = upper_L_hess_nnz_pattern_sp.nnz
98
100
  pd.jacobian_c_nnz = jac_c_nnz_pattern_sp.nnz
99
101
  pd.jacobian_g_nnz = jac_g_nnz_pattern_sp.nnz
100
- pd.kkt_nnz, pd.kkt_L_nnz = get_kkt_and_L_nnzs(
101
- P=upper_L_hess_nnz_pattern_sp,
102
- A=jac_c_nnz_pattern_sp,
103
- G=jac_g_nnz_pattern_sp,
104
- perm_inv=qs.kkt_pinv,
105
- )
102
+
106
103
  pd.is_jacobian_c_transposed = True
107
104
  pd.is_jacobian_g_transposed = True
108
105
 
@@ -1,2 +0,0 @@
1
- from .sip_python_ext import *
2
- from .helpers import get_kkt_perm_inv, get_kkt_and_L_nnzs
@@ -1,89 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
- from scipy import sparse as spa
5
-
6
- from .sip_python_ext import getLnnz
7
-
8
-
9
- _cvxopt_available = False
10
- try:
11
- from cvxopt import amd, spmatrix
12
-
13
- _cvxopt_available = True
14
- except ImportError:
15
- from scipy.sparse.csgraph import reverse_cuthill_mckee
16
-
17
-
18
- def _get_K(P, A, G):
19
- # K = [ P + r1 I_x A.T G.T ]
20
- # [ A -r2 * I_y 0 ]
21
- # [ G 0 -r3 I_z ]
22
-
23
- if isinstance(P, np.ndarray):
24
- P = spa.csc_matrix(P)
25
-
26
- if isinstance(A, np.ndarray):
27
- A = spa.csr_matrix(A)
28
-
29
- if isinstance(G, np.ndarray):
30
- G = spa.csr_matrix(G)
31
-
32
- x_dim = P.shape[0]
33
- s_dim = G.shape[0]
34
- y_dim = A.shape[0]
35
-
36
- mod_P = spa.csc_matrix.copy(P)
37
- mod_P.data[:] = 1.0
38
-
39
- Z = spa.csc_matrix((y_dim, s_dim))
40
-
41
- K = spa.block_array(
42
- blocks=[
43
- [mod_P + spa.eye(x_dim), A.T, G.T],
44
- [A, -spa.eye(y_dim), Z],
45
- [G, Z.T, -spa.eye(s_dim)],
46
- ],
47
- format="coo",
48
- )
49
-
50
- return K
51
-
52
-
53
- def _get_kkt_perm(P, A, G, verbose):
54
- K = _get_K(P=P, A=A, G=G)
55
-
56
- if _cvxopt_available:
57
- K_cvxopt = spmatrix(
58
- I=K.row,
59
- J=K.col,
60
- V=K.data,
61
- )
62
- return np.array(list(amd.order(K_cvxopt)))
63
- if verbose:
64
- warnings.warn(
65
- "cvxopt not installed; using reverse Cuthill-McKee (RCM) "
66
- "instead of approximate minimum degree (AMD)."
67
- )
68
- return reverse_cuthill_mckee(spa.csc_matrix(K))
69
-
70
-
71
- def get_kkt_perm_inv(P, A, G, verbose=True):
72
- perm = _get_kkt_perm(P=P, A=A, G=G, verbose=verbose)
73
-
74
- perm_inv = np.empty_like(perm)
75
- perm_inv[perm] = np.arange(perm_inv.shape[0])
76
-
77
- return perm_inv
78
-
79
-
80
- def get_kkt_and_L_nnzs(P, A, G, perm_inv):
81
- K = _get_K(P=P, A=A, G=G)
82
-
83
- permuted_K = spa.coo_matrix.copy(K)
84
- permuted_K.row = perm_inv[permuted_K.row]
85
- permuted_K.col = perm_inv[permuted_K.col]
86
-
87
- kkt_L_nnz = getLnnz(spa.triu(permuted_K))
88
-
89
- return K.nnz, kkt_L_nnz
File without changes
File without changes
File without changes