gsvd4py 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsvd4py-0.0.1/PKG-INFO +10 -0
- gsvd4py-0.0.1/README.md +104 -0
- gsvd4py-0.0.1/gsvd4py/__init__.py +5 -0
- gsvd4py-0.0.1/gsvd4py/_gsvd.py +453 -0
- gsvd4py-0.0.1/gsvd4py/_lapack.py +126 -0
- gsvd4py-0.0.1/gsvd4py.egg-info/PKG-INFO +10 -0
- gsvd4py-0.0.1/gsvd4py.egg-info/SOURCES.txt +11 -0
- gsvd4py-0.0.1/gsvd4py.egg-info/dependency_links.txt +1 -0
- gsvd4py-0.0.1/gsvd4py.egg-info/requires.txt +5 -0
- gsvd4py-0.0.1/gsvd4py.egg-info/top_level.txt +1 -0
- gsvd4py-0.0.1/pyproject.toml +17 -0
- gsvd4py-0.0.1/setup.cfg +4 -0
- gsvd4py-0.0.1/tests/test_gsvd.py +308 -0
gsvd4py-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gsvd4py
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Generalized SVD (GSVD) via LAPACK ?ggsvd3, using the same LAPACK library as SciPy
|
|
5
|
+
Author-email: Hayden Ringer <hjrrockies@gmail.com>
|
|
6
|
+
Requires-Python: >=3.9
|
|
7
|
+
Requires-Dist: scipy>=1.13
|
|
8
|
+
Requires-Dist: numpy>=2.0
|
|
9
|
+
Provides-Extra: test
|
|
10
|
+
Requires-Dist: pytest; extra == "test"
|
gsvd4py-0.0.1/README.md
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# gsvd4py
|
|
2
|
+
|
|
3
|
+
A lightweight Python wrapper for the LAPACK `?ggsvd3` routines, providing the Generalized Singular Value Decomposition (GSVD) in a style similar to `scipy.linalg`. It links to the same LAPACK library that SciPy uses on your machine — no separate LAPACK installation required.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install gsvd4py
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Requires SciPy >= 1.13 and NumPy >= 2.0.
|
|
12
|
+
|
|
13
|
+
## Background
|
|
14
|
+
|
|
15
|
+
The GSVD decomposes a pair of matrices `A` (m×p) and `B` (n×p) as:
|
|
16
|
+
|
|
17
|
+
```
|
|
18
|
+
A = U @ C @ X.conj().T
|
|
19
|
+
B = V @ S @ X.conj().T
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
where:
|
|
23
|
+
- `U` (m×m) and `V` (n×n) are unitary
|
|
24
|
+
- `C` (m×q) and `S` (n×q) are real diagonal with `C.T @ C + S.T @ S = I`
|
|
25
|
+
- `X` (p×q) is nonsingular
|
|
26
|
+
- `q = k + l` is the numerical rank of the stacked matrix `[A; B]`
|
|
27
|
+
|
|
28
|
+
The generalized singular values are the ratios `C[i,i] / S[i,i]`.
|
|
29
|
+
|
|
30
|
+
## Usage
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
import numpy as np
|
|
34
|
+
from gsvd4py import gsvd
|
|
35
|
+
|
|
36
|
+
A = np.random.randn(5, 6)
|
|
37
|
+
B = np.random.randn(4, 6)
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
### Full GSVD (default)
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
U, V, X, C, S = gsvd(A, B)
|
|
44
|
+
# U: (5,5), V: (4,4), X: (6,q), C: (5,q), S: (4,q)
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
### Economy GSVD
|
|
48
|
+
|
|
49
|
+
Truncates `U` and `V` to at most `q` columns:
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
U, V, X, C, S = gsvd(A, B, mode='econ')
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
### Raw LAPACK output
|
|
56
|
+
|
|
57
|
+
Returns the LAPACK decomposition `A = U @ D1 @ [0, R] @ Q.T` directly:
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
U, V, D1, D2, R, Q, k, l = gsvd(A, B, mode='separate')
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
### Skipping U and/or V
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
X, C, S = gsvd(A, B, compute_u=False, compute_v=False)
|
|
67
|
+
U, X, C, S = gsvd(A, B, compute_v=False)
|
|
68
|
+
V, X, C, S = gsvd(A, B, compute_u=False)
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
## API Reference
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
gsvd(a, b, mode='full', compute_u=True, compute_v=True,
|
|
75
|
+
overwrite_a=False, overwrite_b=False, lwork=None, check_finite=True)
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
| Parameter | Description |
|
|
79
|
+
|-----------|-------------|
|
|
80
|
+
| `a` | (m, p) array |
|
|
81
|
+
| `b` | (n, p) array |
|
|
82
|
+
| `mode` | `'full'` (default), `'econ'`, or `'separate'` |
|
|
83
|
+
| `compute_u` | Compute left singular vectors of `a` (default `True`) |
|
|
84
|
+
| `compute_v` | Compute left singular vectors of `b` (default `True`) |
|
|
85
|
+
| `overwrite_a` | Allow overwriting `a` to avoid a copy (default `False`) |
|
|
86
|
+
| `overwrite_b` | Allow overwriting `b` to avoid a copy (default `False`) |
|
|
87
|
+
| `lwork` | Work array size; `None` triggers an optimal workspace query |
|
|
88
|
+
| `check_finite` | Check inputs for non-finite values (default `True`) |
|
|
89
|
+
|
|
90
|
+
Supported dtypes: `float32`, `float64`, `complex64`, `complex128`. Integer inputs are upcast to `float64`.
|
|
91
|
+
|
|
92
|
+
## LAPACK backend
|
|
93
|
+
|
|
94
|
+
`gsvd4py` discovers the LAPACK library at runtime in the following order:
|
|
95
|
+
|
|
96
|
+
1. **Apple Accelerate** (macOS) — via `$NEWLAPACK` symbols
|
|
97
|
+
2. **scipy-openblas** — the OpenBLAS bundle shipped with SciPy
|
|
98
|
+
3. **System LAPACK** — `liblapack` found via `ctypes.util.find_library`
|
|
99
|
+
|
|
100
|
+
No compilation is required.
|
|
101
|
+
|
|
102
|
+
## License
|
|
103
|
+
|
|
104
|
+
MIT
|
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core implementation of gsvd() using ctypes LAPACK calls.
|
|
3
|
+
|
|
4
|
+
LAPACK vs spec notation mapping
|
|
5
|
+
--------------------------------
|
|
6
|
+
The spec uses Matlab-style dimensions: A is m×p, B is n×p.
|
|
7
|
+
LAPACK dggsvd3 uses: A is M×N, B is P×N.
|
|
8
|
+
|
|
9
|
+
spec m → LAPACK M (rows of A)
|
|
10
|
+
spec n → LAPACK P (rows of B)
|
|
11
|
+
spec p → LAPACK N (columns, shared)
|
|
12
|
+
|
|
13
|
+
LAPACK decomposition (real case)
|
|
14
|
+
----------------------------------
|
|
15
|
+
A = U * D1 * [0, R] * Q^T
|
|
16
|
+
B = V * D2 * [0, R] * Q^T
|
|
17
|
+
|
|
18
|
+
where:
|
|
19
|
+
U M×M orthogonal (spec: m×m)
|
|
20
|
+
V P×P orthogonal (spec: n×n)
|
|
21
|
+
Q N×N orthogonal (spec: p×p)
|
|
22
|
+
D1 M×q "diagonal" (spec: m×q) q = K+L
|
|
23
|
+
D2 P×q "diagonal" (spec: n×q)
|
|
24
|
+
R q×q upper-triangular, stored inside A (and B if M < q)
|
|
25
|
+
[0, R] q×N block matrix
|
|
26
|
+
|
|
27
|
+
Matlab-style X (full / econ modes)
|
|
28
|
+
------------------------------------
|
|
29
|
+
Q2 = Q[:, p-q:] last q columns of Q, shape p×q
|
|
30
|
+
X = Q2 @ conj(R).T shape p×q
|
|
31
|
+
then A = U * C * X^H, B = V * S * X^H
|
|
32
|
+
|
|
33
|
+
D1 / D2 structure (ALPHA, BETA from LAPACK)
|
|
34
|
+
--------------------------------------------
|
|
35
|
+
Case m >= q (M >= K+L):
|
|
36
|
+
ALPHA[0:k] = 1, BETA[0:k] = 0 (infinite GSVs)
|
|
37
|
+
ALPHA[k:k+l] = C, BETA[k:k+l] = S (finite GSVs, C²+S²=I)
|
|
38
|
+
ALPHA[k+l:p] = 0, BETA[k+l:p] = 0
|
|
39
|
+
|
|
40
|
+
Case m < q (M < K+L, still K <= M):
|
|
41
|
+
ALPHA[0:k] = 1, BETA[0:k] = 0
|
|
42
|
+
ALPHA[k:m] = C, BETA[k:m] = S (first M-K pairs)
|
|
43
|
+
ALPHA[m:q] = 0, BETA[m:q] = 1 (identity block in D2)
|
|
44
|
+
ALPHA[q:p] = 0, BETA[q:p] = 0
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
import ctypes
|
|
48
|
+
|
|
49
|
+
import numpy as np
|
|
50
|
+
|
|
51
|
+
from ._lapack import get_ggsvd3
|
|
52
|
+
|
|
53
|
+
# ---------------------------------------------------------------------------
|
|
54
|
+
# dtype helpers
|
|
55
|
+
# ---------------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
_DTYPE_MAP = {
|
|
58
|
+
np.dtype('float32'): ('s', np.dtype('float32'), False),
|
|
59
|
+
np.dtype('float64'): ('d', np.dtype('float64'), False),
|
|
60
|
+
np.dtype('complex64'): ('c', np.dtype('float32'), True),
|
|
61
|
+
np.dtype('complex128'): ('z', np.dtype('float64'), True),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _resolve_dtype(a, b):
|
|
66
|
+
"""Return (lapack_dtype, real_dtype, is_complex) for inputs a and b."""
|
|
67
|
+
dtype = np.result_type(a, b)
|
|
68
|
+
# Upcast integers / booleans to float64
|
|
69
|
+
if not (np.issubdtype(dtype, np.floating) or
|
|
70
|
+
np.issubdtype(dtype, np.complexfloating)):
|
|
71
|
+
dtype = np.float64
|
|
72
|
+
# Upcast float16 → float32, etc.
|
|
73
|
+
if dtype == np.float16:
|
|
74
|
+
dtype = np.float32
|
|
75
|
+
return _DTYPE_MAP[np.dtype(dtype)]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# ---------------------------------------------------------------------------
|
|
79
|
+
# ctypes helpers
|
|
80
|
+
# ---------------------------------------------------------------------------
|
|
81
|
+
|
|
82
|
+
_c_int_p = ctypes.POINTER(ctypes.c_int)
|
|
83
|
+
_c_void_p = ctypes.c_void_p
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _ptr(arr):
|
|
87
|
+
"""Return a c_void_p pointing to arr's data buffer."""
|
|
88
|
+
return arr.ctypes.data_as(_c_void_p)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _iptr(val):
|
|
92
|
+
"""Return a ctypes pointer to a c_int value."""
|
|
93
|
+
return ctypes.byref(ctypes.c_int(val))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# ---------------------------------------------------------------------------
|
|
97
|
+
# LAPACK call wrapper
|
|
98
|
+
# ---------------------------------------------------------------------------
|
|
99
|
+
|
|
100
|
+
def _call_ggsvd3(a_f, b_f, alpha, beta, u_f, v_f, q_f, iwork,
|
|
101
|
+
jobu, jobv, lwork, fn, is_complex, real_dtype,
|
|
102
|
+
uses_hidden_lengths):
|
|
103
|
+
"""Call ?ggsvd3 once (workspace query or actual computation).
|
|
104
|
+
|
|
105
|
+
All array arguments must already be Fortran-contiguous and correctly typed.
|
|
106
|
+
Returns (k, l, info).
|
|
107
|
+
"""
|
|
108
|
+
m_lap, p_lap = a_f.shape # LAPACK M, N
|
|
109
|
+
n_lap = b_f.shape[0] # LAPACK P
|
|
110
|
+
q_lap = p_lap # = LAPACK N, used for LDQ
|
|
111
|
+
|
|
112
|
+
lwork_val = lwork if lwork is not None else 1
|
|
113
|
+
work = np.zeros(max(1, lwork_val), dtype=a_f.dtype)
|
|
114
|
+
lwork_c = ctypes.c_int(-1 if lwork is None else lwork_val)
|
|
115
|
+
|
|
116
|
+
k_c = ctypes.c_int(0)
|
|
117
|
+
l_c = ctypes.c_int(0)
|
|
118
|
+
info_c = ctypes.c_int(0)
|
|
119
|
+
|
|
120
|
+
# Dummy 1×1 arrays for when U or V is not computed
|
|
121
|
+
dummy = np.zeros((1, 1), dtype=a_f.dtype, order='F')
|
|
122
|
+
|
|
123
|
+
u_ptr = _ptr(u_f) if u_f is not None else _ptr(dummy)
|
|
124
|
+
v_ptr = _ptr(v_f) if v_f is not None else _ptr(dummy)
|
|
125
|
+
ldu_val = m_lap if u_f is not None else 1
|
|
126
|
+
ldv_val = n_lap if v_f is not None else 1
|
|
127
|
+
|
|
128
|
+
# jobu / jobv chars (single byte)
|
|
129
|
+
jobu_b = jobu.encode()
|
|
130
|
+
jobv_b = jobv.encode()
|
|
131
|
+
jobq_b = b'Q'
|
|
132
|
+
|
|
133
|
+
args = [
|
|
134
|
+
jobu_b, jobv_b, jobq_b,
|
|
135
|
+
_iptr(m_lap), _iptr(p_lap), _iptr(n_lap),
|
|
136
|
+
ctypes.byref(k_c), ctypes.byref(l_c),
|
|
137
|
+
_ptr(a_f), _iptr(m_lap),
|
|
138
|
+
_ptr(b_f), _iptr(n_lap),
|
|
139
|
+
_ptr(alpha), _ptr(beta),
|
|
140
|
+
u_ptr, _iptr(ldu_val),
|
|
141
|
+
v_ptr, _iptr(ldv_val),
|
|
142
|
+
_ptr(q_f), _iptr(q_lap),
|
|
143
|
+
_ptr(work), ctypes.byref(lwork_c),
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
if is_complex:
|
|
147
|
+
rwork = np.zeros(2 * p_lap, dtype=real_dtype)
|
|
148
|
+
args += [_ptr(rwork)]
|
|
149
|
+
|
|
150
|
+
args += [_ptr(iwork), ctypes.byref(info_c)]
|
|
151
|
+
|
|
152
|
+
if uses_hidden_lengths:
|
|
153
|
+
one = ctypes.c_size_t(1)
|
|
154
|
+
args += [one, one, one]
|
|
155
|
+
|
|
156
|
+
fn(*args)
|
|
157
|
+
|
|
158
|
+
if lwork is None:
|
|
159
|
+
# workspace query: return optimal lwork from work[0]
|
|
160
|
+
return int(work[0].real)
|
|
161
|
+
|
|
162
|
+
return k_c.value, l_c.value, info_c.value
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# ---------------------------------------------------------------------------
|
|
166
|
+
# Output construction helpers
|
|
167
|
+
# ---------------------------------------------------------------------------
|
|
168
|
+
|
|
169
|
+
def _build_C_S(alpha, beta, m, n, k, l):
|
|
170
|
+
"""Build dense C (m×q) and S (n×q) from LAPACK ALPHA / BETA vectors.
|
|
171
|
+
|
|
172
|
+
Returns (C, S) as float64 arrays regardless of dtype (GSVs are real).
|
|
173
|
+
"""
|
|
174
|
+
q = k + l
|
|
175
|
+
C = np.zeros((m, q))
|
|
176
|
+
S = np.zeros((n, q))
|
|
177
|
+
|
|
178
|
+
if k > 0:
|
|
179
|
+
C[:k, :k] = np.eye(k) # identity block
|
|
180
|
+
|
|
181
|
+
if m >= q: # Case 1: M >= K+L
|
|
182
|
+
idx = np.arange(l)
|
|
183
|
+
C[k + idx, k + idx] = alpha[k:k+l]
|
|
184
|
+
S[idx, k + idx] = beta[k:k+l]
|
|
185
|
+
else: # Case 2: M < K+L (still K <= M)
|
|
186
|
+
mk = m - k # number of (cos, sin) pairs that fit in D1
|
|
187
|
+
if mk > 0:
|
|
188
|
+
idx = np.arange(mk)
|
|
189
|
+
C[k + idx, k + idx] = alpha[k:m]
|
|
190
|
+
S[idx, k + idx] = beta[k:m]
|
|
191
|
+
kl_m = q - m # K+L-M = size of identity block in D2
|
|
192
|
+
if kl_m > 0:
|
|
193
|
+
idx2 = np.arange(kl_m)
|
|
194
|
+
S[mk + idx2, m + idx2] = 1.0
|
|
195
|
+
|
|
196
|
+
return C, S
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _extract_R(a_f, b_f, m, n, p, k, l):
|
|
200
|
+
"""Extract the (k+l)×(k+l) upper-triangular R from the modified A (and B).
|
|
201
|
+
|
|
202
|
+
LAPACK stores R in A[0:k+l, p-k-l:p] (0-indexed, Fortran-order array).
|
|
203
|
+
If m < k+l, the bottom k+l-m rows come from B[0:k+l-m, p-k-l:p].
|
|
204
|
+
"""
|
|
205
|
+
q = k + l
|
|
206
|
+
R = np.zeros((q, q), dtype=a_f.dtype)
|
|
207
|
+
col_start = p - q
|
|
208
|
+
|
|
209
|
+
if m >= q:
|
|
210
|
+
R[:] = a_f[:q, col_start:]
|
|
211
|
+
else:
|
|
212
|
+
kl_m = q - m # K+L-M rows of R that overflow into B
|
|
213
|
+
R[:m, :] = a_f[:m, col_start:]
|
|
214
|
+
# LAPACK stores R[m:q, m:q] (upper-triangular block) in
|
|
215
|
+
# B(M-K+1 : L, N+M-K-L+1 : N) [Fortran 1-indexed]
|
|
216
|
+
# = b_f[m-k : l, col_start+m : p] [Python 0-indexed]
|
|
217
|
+
b_row = m - k
|
|
218
|
+
b_col = col_start + m # = p - kl_m
|
|
219
|
+
R[m:, m:] = b_f[b_row:b_row + kl_m, b_col:]
|
|
220
|
+
|
|
221
|
+
return R
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# ---------------------------------------------------------------------------
|
|
225
|
+
# Public API
|
|
226
|
+
# ---------------------------------------------------------------------------
|
|
227
|
+
|
|
228
|
+
def gsvd(a, b, mode='full', compute_u=True, compute_v=True,
|
|
229
|
+
overwrite_a=False, overwrite_b=False, lwork=None, check_finite=True):
|
|
230
|
+
"""Generalized Singular Value Decomposition.
|
|
231
|
+
|
|
232
|
+
Computes the GSVD of the matrix pair (a, b) using the LAPACK routine
|
|
233
|
+
?ggsvd3 linked via the same LAPACK library as SciPy.
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
a : (m, p) array_like
|
|
238
|
+
b : (n, p) array_like
|
|
239
|
+
mode : {'full', 'econ', 'separate'}, default 'full'
|
|
240
|
+
'full' — Full Matlab-style: U (m×m), V (n×n), X (p×q), C (m×q),
|
|
241
|
+
S (n×q), where q = k+l is the numerical rank of [a; b].
|
|
242
|
+
'econ' — Economy Matlab-style: U (m×min(m,q)), V (n×min(n,q)),
|
|
243
|
+
X (p×q), C (min(m,q)×q), S (min(n,q)×q).
|
|
244
|
+
'separate' — Raw LAPACK output (no rank truncation): U, V, D1, D2,
|
|
245
|
+
R, Q, k, l.
|
|
246
|
+
compute_u : bool, default True
|
|
247
|
+
Compute left singular vectors of a.
|
|
248
|
+
compute_v : bool, default True
|
|
249
|
+
Compute left singular vectors of b.
|
|
250
|
+
overwrite_a : bool, default False
|
|
251
|
+
Allow overwriting a (avoids a copy if True and a is already
|
|
252
|
+
Fortran-contiguous with the correct dtype).
|
|
253
|
+
overwrite_b : bool, default False
|
|
254
|
+
Allow overwriting b (same as overwrite_a).
|
|
255
|
+
lwork : int or None, default None
|
|
256
|
+
LAPACK work array size. None (or -1) triggers an optimal query.
|
|
257
|
+
check_finite : bool, default True
|
|
258
|
+
Check that a and b contain only finite values.
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
mode='full' or 'econ':
|
|
263
|
+
If compute_u and compute_v: U, V, X, C, S
|
|
264
|
+
If compute_u and not compute_v: U, X, C, S
|
|
265
|
+
If not compute_u and compute_v: V, X, C, S
|
|
266
|
+
If not compute_u and compute_v: X, C, S
|
|
267
|
+
|
|
268
|
+
mode='separate':
|
|
269
|
+
If compute_u and compute_v: U, V, D1, D2, R, Q, k, l
|
|
270
|
+
If compute_u and not compute_v: U, D1, D2, R, Q, k, l
|
|
271
|
+
If not compute_u and compute_v: V, D1, D2, R, Q, k, l
|
|
272
|
+
If not compute_u and compute_v: D1, D2, R, Q, k, l
|
|
273
|
+
"""
|
|
274
|
+
# ------------------------------------------------------------------
|
|
275
|
+
# Input validation
|
|
276
|
+
# ------------------------------------------------------------------
|
|
277
|
+
if mode not in ('full', 'econ', 'separate'):
|
|
278
|
+
raise ValueError(f"mode must be 'full', 'econ', or 'separate', got {mode!r}")
|
|
279
|
+
|
|
280
|
+
a = np.asarray(a)
|
|
281
|
+
b = np.asarray(b)
|
|
282
|
+
|
|
283
|
+
if a.ndim != 2:
|
|
284
|
+
raise ValueError(f"a must be 2-D, got shape {a.shape}")
|
|
285
|
+
if b.ndim != 2:
|
|
286
|
+
raise ValueError(f"b must be 2-D, got shape {b.shape}")
|
|
287
|
+
|
|
288
|
+
m, p = a.shape
|
|
289
|
+
n = b.shape[0]
|
|
290
|
+
if b.shape[1] != p:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
f"a and b must have the same number of columns: "
|
|
293
|
+
f"{p} != {b.shape[1]}"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if check_finite:
|
|
297
|
+
if not np.all(np.isfinite(a)):
|
|
298
|
+
raise ValueError("Array a contains non-finite values.")
|
|
299
|
+
if not np.all(np.isfinite(b)):
|
|
300
|
+
raise ValueError("Array b contains non-finite values.")
|
|
301
|
+
|
|
302
|
+
# ------------------------------------------------------------------
|
|
303
|
+
# Dtype resolution + array preparation
|
|
304
|
+
# ------------------------------------------------------------------
|
|
305
|
+
dtype_char, real_dtype, is_complex = _resolve_dtype(a, b)
|
|
306
|
+
dtype = np.dtype('complex64' if dtype_char == 'c'
|
|
307
|
+
else 'complex128' if dtype_char == 'z'
|
|
308
|
+
else 'float32' if dtype_char == 's'
|
|
309
|
+
else 'float64')
|
|
310
|
+
|
|
311
|
+
def _prep(arr, overwrite):
|
|
312
|
+
if overwrite and arr.dtype == dtype and np.isfortran(arr):
|
|
313
|
+
return arr
|
|
314
|
+
return np.array(arr, dtype=dtype, order='F', copy=True)
|
|
315
|
+
|
|
316
|
+
a_f = _prep(a, overwrite_a)
|
|
317
|
+
b_f = _prep(b, overwrite_b)
|
|
318
|
+
|
|
319
|
+
# ------------------------------------------------------------------
|
|
320
|
+
# Load LAPACK function
|
|
321
|
+
# ------------------------------------------------------------------
|
|
322
|
+
fn, uses_hidden_lengths = get_ggsvd3(dtype_char)
|
|
323
|
+
|
|
324
|
+
# ------------------------------------------------------------------
|
|
325
|
+
# Allocate output arrays
|
|
326
|
+
# ------------------------------------------------------------------
|
|
327
|
+
jobu_char = 'U' if compute_u else 'N'
|
|
328
|
+
jobv_char = 'V' if compute_v else 'N'
|
|
329
|
+
|
|
330
|
+
alpha = np.zeros(p, dtype=real_dtype)
|
|
331
|
+
beta = np.zeros(p, dtype=real_dtype)
|
|
332
|
+
iwork = np.zeros(p, dtype=np.int32)
|
|
333
|
+
q_f = np.zeros((p, p), dtype=dtype, order='F')
|
|
334
|
+
u_f = np.zeros((m, m), dtype=dtype, order='F') if compute_u else None
|
|
335
|
+
v_f = np.zeros((n, n), dtype=dtype, order='F') if compute_v else None
|
|
336
|
+
|
|
337
|
+
# ------------------------------------------------------------------
|
|
338
|
+
# Workspace query
|
|
339
|
+
# ------------------------------------------------------------------
|
|
340
|
+
if lwork is None or lwork == -1:
|
|
341
|
+
opt = _call_ggsvd3(
|
|
342
|
+
a_f, b_f, alpha, beta, u_f, v_f, q_f, iwork,
|
|
343
|
+
jobu_char, jobv_char,
|
|
344
|
+
lwork=None, fn=fn, is_complex=is_complex,
|
|
345
|
+
real_dtype=real_dtype, uses_hidden_lengths=uses_hidden_lengths,
|
|
346
|
+
)
|
|
347
|
+
lwork_use = max(opt, 1)
|
|
348
|
+
else:
|
|
349
|
+
lwork_use = lwork
|
|
350
|
+
|
|
351
|
+
# ------------------------------------------------------------------
|
|
352
|
+
# Actual LAPACK call
|
|
353
|
+
# ------------------------------------------------------------------
|
|
354
|
+
k, l, info = _call_ggsvd3(
|
|
355
|
+
a_f, b_f, alpha, beta, u_f, v_f, q_f, iwork,
|
|
356
|
+
jobu_char, jobv_char,
|
|
357
|
+
lwork=lwork_use, fn=fn, is_complex=is_complex,
|
|
358
|
+
real_dtype=real_dtype, uses_hidden_lengths=uses_hidden_lengths,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if info < 0:
|
|
362
|
+
raise ValueError(f"Illegal argument #{-info} passed to dggsvd3.")
|
|
363
|
+
if info > 0:
|
|
364
|
+
raise np.linalg.LinAlgError(
|
|
365
|
+
f"LAPACK ?ggsvd3 failed to converge (info={info})."
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
q_rank = k + l # effective numerical rank
|
|
369
|
+
|
|
370
|
+
# ------------------------------------------------------------------
|
|
371
|
+
# Post-processing
|
|
372
|
+
# ------------------------------------------------------------------
|
|
373
|
+
if mode == 'separate':
|
|
374
|
+
return _build_separate(
|
|
375
|
+
a_f, b_f, alpha, beta, u_f, v_f, q_f, iwork,
|
|
376
|
+
m, n, p, k, l, q_rank,
|
|
377
|
+
compute_u, compute_v, real_dtype,
|
|
378
|
+
)
|
|
379
|
+
else:
|
|
380
|
+
return _build_matlab_style(
|
|
381
|
+
a_f, b_f, alpha, beta, u_f, v_f, q_f,
|
|
382
|
+
m, n, p, k, l, q_rank,
|
|
383
|
+
mode, compute_u, compute_v,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# ---------------------------------------------------------------------------
|
|
388
|
+
# Post-processing: separate mode
|
|
389
|
+
# ---------------------------------------------------------------------------
|
|
390
|
+
|
|
391
|
+
def _build_separate(a_f, b_f, alpha, beta, u_f, v_f, q_f, iwork,
|
|
392
|
+
m, n, p, k, l, q_rank,
|
|
393
|
+
compute_u, compute_v, real_dtype):
|
|
394
|
+
R = _extract_R(a_f, b_f, m, n, p, k, l)
|
|
395
|
+
D1, D2 = _build_C_S(alpha, beta, m, n, k, l)
|
|
396
|
+
|
|
397
|
+
# Convert to C-order for return
|
|
398
|
+
R = np.ascontiguousarray(R)
|
|
399
|
+
D1 = np.ascontiguousarray(D1)
|
|
400
|
+
D2 = np.ascontiguousarray(D2)
|
|
401
|
+
Q = np.ascontiguousarray(q_f)
|
|
402
|
+
|
|
403
|
+
result = []
|
|
404
|
+
if compute_u:
|
|
405
|
+
result.append(np.ascontiguousarray(u_f))
|
|
406
|
+
if compute_v:
|
|
407
|
+
result.append(np.ascontiguousarray(v_f))
|
|
408
|
+
result += [D1, D2, R, Q, k, l]
|
|
409
|
+
return tuple(result)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
# ---------------------------------------------------------------------------
|
|
413
|
+
# Post-processing: full / econ modes
|
|
414
|
+
# ---------------------------------------------------------------------------
|
|
415
|
+
|
|
416
|
+
def _build_matlab_style(a_f, b_f, alpha, beta, u_f, v_f, q_f,
|
|
417
|
+
m, n, p, k, l, q_rank,
|
|
418
|
+
mode, compute_u, compute_v):
|
|
419
|
+
# Build C and S (real-valued diagonal matrices)
|
|
420
|
+
C_full, S_full = _build_C_S(alpha, beta, m, n, k, l)
|
|
421
|
+
|
|
422
|
+
# Extract R then build X = Q2 @ conj(R).T
|
|
423
|
+
R = _extract_R(a_f, b_f, m, n, p, k, l)
|
|
424
|
+
Q2 = np.asarray(q_f)[:, p - q_rank:] # p×q_rank
|
|
425
|
+
X = Q2 @ np.conj(R).T # p×q_rank
|
|
426
|
+
|
|
427
|
+
# Full mode: U is m×m, V is n×n, C is m×q, S is n×q
|
|
428
|
+
# Econ mode: truncate U to m×r, V to n×r, C to r×q, S to r×q
|
|
429
|
+
# where r = min(m, q_rank) for U/C and min(n, q_rank) for V/S
|
|
430
|
+
if mode == 'full':
|
|
431
|
+
C = C_full
|
|
432
|
+
S = S_full
|
|
433
|
+
U_out = np.ascontiguousarray(u_f) if compute_u else None
|
|
434
|
+
V_out = np.ascontiguousarray(v_f) if compute_v else None
|
|
435
|
+
else: # 'econ'
|
|
436
|
+
ru = min(m, q_rank)
|
|
437
|
+
rv = min(n, q_rank)
|
|
438
|
+
C = np.ascontiguousarray(C_full[:ru, :])
|
|
439
|
+
S = np.ascontiguousarray(S_full[:rv, :])
|
|
440
|
+
U_out = np.ascontiguousarray(u_f[:, :ru]) if compute_u else None
|
|
441
|
+
V_out = np.ascontiguousarray(v_f[:, :rv]) if compute_v else None
|
|
442
|
+
|
|
443
|
+
X = np.ascontiguousarray(X)
|
|
444
|
+
C = np.ascontiguousarray(C)
|
|
445
|
+
S = np.ascontiguousarray(S)
|
|
446
|
+
|
|
447
|
+
result = []
|
|
448
|
+
if compute_u:
|
|
449
|
+
result.append(U_out)
|
|
450
|
+
if compute_v:
|
|
451
|
+
result.append(V_out)
|
|
452
|
+
result += [X, C, S]
|
|
453
|
+
return tuple(result)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LAPACK library discovery for gsvd4py.
|
|
3
|
+
|
|
4
|
+
Tries, in order:
|
|
5
|
+
1. Apple Accelerate (macOS) — symbols named ?ggsvd3$NEWLAPACK
|
|
6
|
+
2. scipy_openblas32 — symbols named scipy_?ggsvd3_
|
|
7
|
+
3. scipy_openblas64 — symbols named scipy_?ggsvd3_
|
|
8
|
+
4. CDLL(None) — all loaded symbols (works on Linux)
|
|
9
|
+
5. ctypes.util.find_library — system LAPACK / OpenBLAS
|
|
10
|
+
|
|
11
|
+
Calling conventions differ:
|
|
12
|
+
- Accelerate: pure C interface, no hidden Fortran char-length args
|
|
13
|
+
- gfortran LAPACK: three hidden size_t args (len_jobu, len_jobv, len_jobq)
|
|
14
|
+
appended after `info`
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import ctypes
|
|
18
|
+
import ctypes.util
|
|
19
|
+
import glob
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
|
|
23
|
+
# Module-level cache
|
|
24
|
+
_lib = None
|
|
25
|
+
_lib_type = None # 'accelerate' | 'scipy_openblas' | 'system'
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _load_lib():
|
|
29
|
+
global _lib, _lib_type
|
|
30
|
+
|
|
31
|
+
if _lib is not None:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
# --- Strategy 1: Apple Accelerate (macOS) ---
|
|
35
|
+
if sys.platform == 'darwin':
|
|
36
|
+
try:
|
|
37
|
+
lib = ctypes.CDLL(
|
|
38
|
+
'/System/Library/Frameworks/Accelerate.framework/Accelerate'
|
|
39
|
+
)
|
|
40
|
+
lib['dggsvd3$NEWLAPACK'] # raises KeyError if absent
|
|
41
|
+
_lib = lib
|
|
42
|
+
_lib_type = 'accelerate'
|
|
43
|
+
return
|
|
44
|
+
except (OSError, KeyError):
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
# --- Strategy 2 & 3: scipy_openblas32 / scipy_openblas64 ---
|
|
48
|
+
for _pkg in ('scipy_openblas32', 'scipy_openblas64'):
|
|
49
|
+
try:
|
|
50
|
+
pkg = __import__(_pkg)
|
|
51
|
+
lib_dir = pkg.get_lib_dir()
|
|
52
|
+
pattern = '*.dylib' if sys.platform == 'darwin' else '*.so*'
|
|
53
|
+
for dylib in glob.glob(os.path.join(lib_dir, pattern)):
|
|
54
|
+
try:
|
|
55
|
+
lib = ctypes.CDLL(dylib)
|
|
56
|
+
getattr(lib, f'scipy_dggsvd3_')
|
|
57
|
+
_lib = lib
|
|
58
|
+
_lib_type = 'scipy_openblas'
|
|
59
|
+
return
|
|
60
|
+
except (OSError, AttributeError):
|
|
61
|
+
pass
|
|
62
|
+
except ImportError:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
# --- Strategy 4: CDLL(None) — all loaded symbols (Linux) ---
|
|
66
|
+
lib = ctypes.CDLL(None)
|
|
67
|
+
try:
|
|
68
|
+
getattr(lib, 'dggsvd3_')
|
|
69
|
+
_lib = lib
|
|
70
|
+
_lib_type = 'system'
|
|
71
|
+
return
|
|
72
|
+
except AttributeError:
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
# --- Strategy 5: find_library ---
|
|
76
|
+
for name in ('lapack', 'openblas', 'flexiblas'):
|
|
77
|
+
path = ctypes.util.find_library(name)
|
|
78
|
+
if not path:
|
|
79
|
+
continue
|
|
80
|
+
try:
|
|
81
|
+
lib = ctypes.CDLL(path)
|
|
82
|
+
getattr(lib, 'dggsvd3_')
|
|
83
|
+
_lib = lib
|
|
84
|
+
_lib_type = 'system'
|
|
85
|
+
return
|
|
86
|
+
except (OSError, AttributeError):
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
raise ImportError(
|
|
90
|
+
"gsvd4py: Could not find a LAPACK library providing dggsvd3. "
|
|
91
|
+
"Ensure scipy is installed (pip install scipy), or install "
|
|
92
|
+
"scipy-openblas32 (pip install scipy-openblas32)."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_ggsvd3(dtype_char):
|
|
97
|
+
"""Return the ctypes function handle for ?ggsvd3.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
dtype_char : str
|
|
102
|
+
One of 'd', 's', 'z', 'c'.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
fn : ctypes function object (restype already set to None)
|
|
107
|
+
uses_hidden_lengths : bool
|
|
108
|
+
True when the function uses the gfortran hidden char-length ABI.
|
|
109
|
+
"""
|
|
110
|
+
_load_lib()
|
|
111
|
+
|
|
112
|
+
if _lib_type == 'accelerate':
|
|
113
|
+
sym = f'{dtype_char}ggsvd3$NEWLAPACK'
|
|
114
|
+
fn = _lib[sym]
|
|
115
|
+
uses_hidden_lengths = False
|
|
116
|
+
elif _lib_type == 'scipy_openblas':
|
|
117
|
+
sym = f'scipy_{dtype_char}ggsvd3_'
|
|
118
|
+
fn = getattr(_lib, sym)
|
|
119
|
+
uses_hidden_lengths = True
|
|
120
|
+
else: # 'system'
|
|
121
|
+
sym = f'{dtype_char}ggsvd3_'
|
|
122
|
+
fn = getattr(_lib, sym)
|
|
123
|
+
uses_hidden_lengths = True
|
|
124
|
+
|
|
125
|
+
fn.restype = None
|
|
126
|
+
return fn, uses_hidden_lengths
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gsvd4py
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Generalized SVD (GSVD) via LAPACK ?ggsvd3, using the same LAPACK library as SciPy
|
|
5
|
+
Author-email: Hayden Ringer <hjrrockies@gmail.com>
|
|
6
|
+
Requires-Python: >=3.9
|
|
7
|
+
Requires-Dist: scipy>=1.13
|
|
8
|
+
Requires-Dist: numpy>=2.0
|
|
9
|
+
Provides-Extra: test
|
|
10
|
+
Requires-Dist: pytest; extra == "test"
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
gsvd4py/__init__.py
|
|
4
|
+
gsvd4py/_gsvd.py
|
|
5
|
+
gsvd4py/_lapack.py
|
|
6
|
+
gsvd4py.egg-info/PKG-INFO
|
|
7
|
+
gsvd4py.egg-info/SOURCES.txt
|
|
8
|
+
gsvd4py.egg-info/dependency_links.txt
|
|
9
|
+
gsvd4py.egg-info/requires.txt
|
|
10
|
+
gsvd4py.egg-info/top_level.txt
|
|
11
|
+
tests/test_gsvd.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
gsvd4py
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "gsvd4py"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "Generalized SVD (GSVD) via LAPACK ?ggsvd3, using the same LAPACK library as SciPy"
|
|
9
|
+
requires-python = ">=3.9"
|
|
10
|
+
authors = [{name = "Hayden Ringer", email = "hjrrockies@gmail.com"}]
|
|
11
|
+
dependencies = [
|
|
12
|
+
"scipy >= 1.13",
|
|
13
|
+
"numpy >= 2.0",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[project.optional-dependencies]
|
|
17
|
+
test = ["pytest"]
|
gsvd4py-0.0.1/setup.cfg
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for gsvd4py.
|
|
3
|
+
|
|
4
|
+
Validates:
|
|
5
|
+
- LAPACK library is found and loaded
|
|
6
|
+
- Reconstruction accuracy: A ≈ U @ C @ X.conj().T, B ≈ V @ S @ X.conj().T
|
|
7
|
+
- Unitarity of U and V
|
|
8
|
+
- All modes (full, econ, separate)
|
|
9
|
+
- All four dtypes (float32, float64, complex64, complex128)
|
|
10
|
+
- compute_u=False / compute_v=False short-tuple returns
|
|
11
|
+
- Various shapes (square, tall, wide, rank-deficient)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pytest
|
|
16
|
+
from numpy.testing import assert_allclose
|
|
17
|
+
|
|
18
|
+
from gsvd4py import gsvd
|
|
19
|
+
import gsvd4py._lapack as _lapack_mod
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ---------------------------------------------------------------------------
|
|
23
|
+
# Tolerances
|
|
24
|
+
# ---------------------------------------------------------------------------
|
|
25
|
+
_RTOL = {
|
|
26
|
+
np.float32: 1e-5,
|
|
27
|
+
np.float64: 1e-12,
|
|
28
|
+
np.complex64: 1e-5,
|
|
29
|
+
np.complex128: 1e-12,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ---------------------------------------------------------------------------
|
|
34
|
+
# Test: library loading
|
|
35
|
+
# ---------------------------------------------------------------------------
|
|
36
|
+
|
|
37
|
+
class TestLibraryLoading:
|
|
38
|
+
def test_loads_without_error(self):
|
|
39
|
+
_lapack_mod._load_lib()
|
|
40
|
+
assert _lapack_mod._lib_type in ('accelerate', 'scipy_openblas', 'system')
|
|
41
|
+
|
|
42
|
+
def test_lib_type_is_string(self):
|
|
43
|
+
_lapack_mod._load_lib()
|
|
44
|
+
assert isinstance(_lapack_mod._lib_type, str)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# ---------------------------------------------------------------------------
|
|
48
|
+
# Helpers
|
|
49
|
+
# ---------------------------------------------------------------------------
|
|
50
|
+
|
|
51
|
+
def _random_matrix(rng, m, n, dtype):
|
|
52
|
+
"""Generate a random matrix of the given real or complex dtype."""
|
|
53
|
+
if np.issubdtype(dtype, np.complexfloating):
|
|
54
|
+
rdtype = np.float32 if dtype == np.complex64 else np.float64
|
|
55
|
+
return (rng.standard_normal((m, n)).astype(rdtype) +
|
|
56
|
+
1j * rng.standard_normal((m, n)).astype(rdtype)).astype(dtype)
|
|
57
|
+
return rng.standard_normal((m, n)).astype(dtype)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _check_reconstruction(U, V, X, C, S, A, B, rtol):
|
|
61
|
+
"""Check A ≈ U @ C @ X.conj().T and B ≈ V @ S @ X.conj().T."""
|
|
62
|
+
XH = X.conj().T
|
|
63
|
+
assert_allclose(U @ C @ XH, A, rtol=rtol, atol=rtol * np.linalg.norm(A))
|
|
64
|
+
assert_allclose(V @ S @ XH, B, rtol=rtol, atol=rtol * np.linalg.norm(B))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _check_unitary(M, rtol):
|
|
68
|
+
"""Check M @ M.conj().T ≈ I."""
|
|
69
|
+
n = M.shape[1]
|
|
70
|
+
assert_allclose(M.conj().T @ M, np.eye(n), atol=rtol * 10)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# ---------------------------------------------------------------------------
|
|
74
|
+
# Test: mode='full'
|
|
75
|
+
# ---------------------------------------------------------------------------
|
|
76
|
+
|
|
77
|
+
class TestFullMode:
|
|
78
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64,
|
|
79
|
+
np.complex64, np.complex128])
|
|
80
|
+
@pytest.mark.parametrize("shape", [
|
|
81
|
+
(5, 4, 6), # m=5, n=4, p=6 (tall A, tall B)
|
|
82
|
+
(4, 4, 4), # square
|
|
83
|
+
(3, 5, 6), # m < p, n < p (wide)
|
|
84
|
+
(6, 3, 4), # m > p, n < p
|
|
85
|
+
])
|
|
86
|
+
def test_reconstruction(self, dtype, shape):
|
|
87
|
+
m, n, p = shape
|
|
88
|
+
rng = np.random.default_rng(42)
|
|
89
|
+
A = _random_matrix(rng, m, p, dtype)
|
|
90
|
+
B = _random_matrix(rng, n, p, dtype)
|
|
91
|
+
rtol = _RTOL[dtype]
|
|
92
|
+
|
|
93
|
+
U, V, X, C, S = gsvd(A, B, mode='full')
|
|
94
|
+
|
|
95
|
+
assert U.shape == (m, m)
|
|
96
|
+
assert V.shape == (n, n)
|
|
97
|
+
assert X.shape[0] == p
|
|
98
|
+
assert C.shape[0] == m
|
|
99
|
+
assert S.shape[0] == n
|
|
100
|
+
assert C.shape[1] == X.shape[1] == S.shape[1] # same q
|
|
101
|
+
|
|
102
|
+
_check_reconstruction(U, V, X, C, S, A, B, rtol)
|
|
103
|
+
_check_unitary(U, rtol)
|
|
104
|
+
_check_unitary(V, rtol)
|
|
105
|
+
|
|
106
|
+
def test_no_u_no_v(self):
|
|
107
|
+
rng = np.random.default_rng(0)
|
|
108
|
+
A = rng.standard_normal((4, 5))
|
|
109
|
+
B = rng.standard_normal((3, 5))
|
|
110
|
+
result = gsvd(A, B, compute_u=False, compute_v=False)
|
|
111
|
+
assert len(result) == 3
|
|
112
|
+
X, C, S = result
|
|
113
|
+
assert X.shape[0] == 5
|
|
114
|
+
|
|
115
|
+
def test_no_u_with_v(self):
|
|
116
|
+
rng = np.random.default_rng(0)
|
|
117
|
+
A = rng.standard_normal((4, 5))
|
|
118
|
+
B = rng.standard_normal((3, 5))
|
|
119
|
+
result = gsvd(A, B, compute_u=False, compute_v=True)
|
|
120
|
+
assert len(result) == 4
|
|
121
|
+
V, X, C, S = result
|
|
122
|
+
assert V.shape == (3, 3)
|
|
123
|
+
|
|
124
|
+
def test_with_u_no_v(self):
|
|
125
|
+
rng = np.random.default_rng(0)
|
|
126
|
+
A = rng.standard_normal((4, 5))
|
|
127
|
+
B = rng.standard_normal((3, 5))
|
|
128
|
+
result = gsvd(A, B, compute_u=True, compute_v=False)
|
|
129
|
+
assert len(result) == 4
|
|
130
|
+
U, X, C, S = result
|
|
131
|
+
assert U.shape == (4, 4)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# ---------------------------------------------------------------------------
|
|
135
|
+
# Test: mode='econ'
|
|
136
|
+
# ---------------------------------------------------------------------------
|
|
137
|
+
|
|
138
|
+
class TestEconMode:
|
|
139
|
+
@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
|
|
140
|
+
def test_reconstruction(self, dtype):
|
|
141
|
+
rng = np.random.default_rng(7)
|
|
142
|
+
m, n, p = 6, 4, 5
|
|
143
|
+
A = _random_matrix(rng, m, p, dtype)
|
|
144
|
+
B = _random_matrix(rng, n, p, dtype)
|
|
145
|
+
rtol = _RTOL[dtype]
|
|
146
|
+
|
|
147
|
+
U, V, X, C, S = gsvd(A, B, mode='econ')
|
|
148
|
+
q = X.shape[1]
|
|
149
|
+
|
|
150
|
+
assert U.shape == (m, min(m, q))
|
|
151
|
+
assert V.shape == (n, min(n, q))
|
|
152
|
+
assert C.shape == (min(m, q), q)
|
|
153
|
+
assert S.shape == (min(n, q), q)
|
|
154
|
+
|
|
155
|
+
_check_reconstruction(U, V, X, C, S, A, B, rtol)
|
|
156
|
+
|
|
157
|
+
def test_econ_smaller_than_full(self):
|
|
158
|
+
rng = np.random.default_rng(3)
|
|
159
|
+
A = rng.standard_normal((8, 5))
|
|
160
|
+
B = rng.standard_normal((6, 5))
|
|
161
|
+
U_f, V_f, X_f, C_f, S_f = gsvd(A, B, mode='full')
|
|
162
|
+
U_e, V_e, X_e, C_e, S_e = gsvd(A, B, mode='econ')
|
|
163
|
+
q = X_e.shape[1]
|
|
164
|
+
# Economy U/V should be the first q columns of the full U/V
|
|
165
|
+
assert U_e.shape[1] <= U_f.shape[1]
|
|
166
|
+
assert V_e.shape[1] <= V_f.shape[1]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# ---------------------------------------------------------------------------
|
|
170
|
+
# Test: mode='separate'
|
|
171
|
+
# ---------------------------------------------------------------------------
|
|
172
|
+
|
|
173
|
+
class TestSeparateMode:
|
|
174
|
+
def test_full_return(self):
|
|
175
|
+
rng = np.random.default_rng(11)
|
|
176
|
+
A = rng.standard_normal((5, 6))
|
|
177
|
+
B = rng.standard_normal((4, 6))
|
|
178
|
+
result = gsvd(A, B, mode='separate')
|
|
179
|
+
assert len(result) == 8
|
|
180
|
+
U, V, D1, D2, R, Q, k, l = result
|
|
181
|
+
assert U.shape == (5, 5)
|
|
182
|
+
assert V.shape == (4, 4)
|
|
183
|
+
assert Q.shape == (6, 6)
|
|
184
|
+
assert R.shape == (k + l, k + l)
|
|
185
|
+
assert D1.shape == (5, k + l)
|
|
186
|
+
assert D2.shape == (4, k + l)
|
|
187
|
+
|
|
188
|
+
def test_no_u(self):
|
|
189
|
+
rng = np.random.default_rng(12)
|
|
190
|
+
A = rng.standard_normal((4, 5))
|
|
191
|
+
B = rng.standard_normal((3, 5))
|
|
192
|
+
result = gsvd(A, B, mode='separate', compute_u=False)
|
|
193
|
+
assert len(result) == 7
|
|
194
|
+
V = result[0]
|
|
195
|
+
assert V.shape == (3, 3)
|
|
196
|
+
|
|
197
|
+
def test_no_u_no_v(self):
|
|
198
|
+
rng = np.random.default_rng(13)
|
|
199
|
+
A = rng.standard_normal((4, 5))
|
|
200
|
+
B = rng.standard_normal((3, 5))
|
|
201
|
+
result = gsvd(A, B, mode='separate', compute_u=False, compute_v=False)
|
|
202
|
+
assert len(result) == 6
|
|
203
|
+
D1, D2, R, Q, k, l = result
|
|
204
|
+
|
|
205
|
+
def test_reconstruction_via_lapack_form(self):
|
|
206
|
+
"""Verify A ≈ U @ D1 @ np.hstack([zeros, R]) @ Q.T."""
|
|
207
|
+
rng = np.random.default_rng(20)
|
|
208
|
+
A = rng.standard_normal((5, 6))
|
|
209
|
+
B = rng.standard_normal((4, 6))
|
|
210
|
+
U, V, D1, D2, R, Q, k, l = gsvd(A, B, mode='separate')
|
|
211
|
+
|
|
212
|
+
q = k + l
|
|
213
|
+
p = A.shape[1]
|
|
214
|
+
zero_block = np.zeros((q, p - q))
|
|
215
|
+
RQ_block = np.hstack([zero_block, R]) @ Q.T # q×p
|
|
216
|
+
|
|
217
|
+
A_rec = U @ D1 @ RQ_block
|
|
218
|
+
B_rec = V @ D2 @ RQ_block
|
|
219
|
+
assert_allclose(A_rec, A, rtol=1e-10, atol=1e-10 * np.linalg.norm(A))
|
|
220
|
+
assert_allclose(B_rec, B, rtol=1e-10, atol=1e-10 * np.linalg.norm(B))
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# ---------------------------------------------------------------------------
|
|
224
|
+
# Test: overwrite and lwork options
|
|
225
|
+
# ---------------------------------------------------------------------------
|
|
226
|
+
|
|
227
|
+
class TestOptions:
|
|
228
|
+
def test_overwrite_a_b(self):
|
|
229
|
+
rng = np.random.default_rng(99)
|
|
230
|
+
A = np.asfortranarray(rng.standard_normal((4, 5)))
|
|
231
|
+
B = np.asfortranarray(rng.standard_normal((3, 5)))
|
|
232
|
+
# Should not raise
|
|
233
|
+
gsvd(A, B, overwrite_a=True, overwrite_b=True)
|
|
234
|
+
|
|
235
|
+
def test_explicit_lwork(self):
|
|
236
|
+
rng = np.random.default_rng(100)
|
|
237
|
+
A = rng.standard_normal((4, 5))
|
|
238
|
+
B = rng.standard_normal((3, 5))
|
|
239
|
+
U1, V1, X1, C1, S1 = gsvd(A, B)
|
|
240
|
+
U2, V2, X2, C2, S2 = gsvd(A, B, lwork=500)
|
|
241
|
+
assert_allclose(C1, C2, rtol=1e-12)
|
|
242
|
+
|
|
243
|
+
def test_check_finite_raises(self):
|
|
244
|
+
A = np.array([[1.0, np.nan], [2.0, 3.0]])
|
|
245
|
+
B = np.array([[1.0, 2.0]])
|
|
246
|
+
with pytest.raises(ValueError, match="non-finite"):
|
|
247
|
+
gsvd(A, B)
|
|
248
|
+
|
|
249
|
+
def test_check_finite_skip(self):
|
|
250
|
+
# With check_finite=False, no error even with nan (behaviour is
|
|
251
|
+
# undefined, but the call should not raise a Python-level error
|
|
252
|
+
# for this test — we just check the flag is respected)
|
|
253
|
+
A = np.array([[1.0, 2.0], [3.0, 4.0]])
|
|
254
|
+
B = np.array([[1.0, 2.0]])
|
|
255
|
+
gsvd(A, B, check_finite=False) # should not raise
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
# ---------------------------------------------------------------------------
|
|
259
|
+
# Test: input validation
|
|
260
|
+
# ---------------------------------------------------------------------------
|
|
261
|
+
|
|
262
|
+
class TestValidation:
|
|
263
|
+
def test_bad_mode(self):
|
|
264
|
+
A = np.eye(3)
|
|
265
|
+
B = np.eye(3)
|
|
266
|
+
with pytest.raises(ValueError, match="mode"):
|
|
267
|
+
gsvd(A, B, mode='bad')
|
|
268
|
+
|
|
269
|
+
def test_mismatched_columns(self):
|
|
270
|
+
A = np.ones((3, 4))
|
|
271
|
+
B = np.ones((2, 5))
|
|
272
|
+
with pytest.raises(ValueError, match="columns"):
|
|
273
|
+
gsvd(A, B)
|
|
274
|
+
|
|
275
|
+
def test_1d_input(self):
|
|
276
|
+
with pytest.raises(ValueError, match="2-D"):
|
|
277
|
+
gsvd(np.ones(3), np.ones((2, 3)))
|
|
278
|
+
|
|
279
|
+
def test_integer_input_upcasted(self):
|
|
280
|
+
A = np.array([[1, 2, 3], [4, 5, 6]])
|
|
281
|
+
B = np.array([[1, 2, 3]])
|
|
282
|
+
# Should not raise; integers are upcast to float64
|
|
283
|
+
U, V, X, C, S = gsvd(A, B)
|
|
284
|
+
assert U.dtype in (np.float64, np.complex128)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# ---------------------------------------------------------------------------
|
|
288
|
+
# Test: dtype handling
|
|
289
|
+
# ---------------------------------------------------------------------------
|
|
290
|
+
|
|
291
|
+
class TestDtypes:
|
|
292
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64,
|
|
293
|
+
np.complex64, np.complex128])
|
|
294
|
+
def test_dtype_preserved(self, dtype):
|
|
295
|
+
rng = np.random.default_rng(55)
|
|
296
|
+
A = _random_matrix(rng, 4, 5, dtype)
|
|
297
|
+
B = _random_matrix(rng, 3, 5, dtype)
|
|
298
|
+
U, V, X, C, S = gsvd(A, B)
|
|
299
|
+
assert U.dtype == dtype
|
|
300
|
+
assert V.dtype == dtype
|
|
301
|
+
assert X.dtype == dtype
|
|
302
|
+
|
|
303
|
+
def test_mixed_real_float32_float64(self):
|
|
304
|
+
rng = np.random.default_rng(56)
|
|
305
|
+
A = rng.standard_normal((4, 5)).astype(np.float32)
|
|
306
|
+
B = rng.standard_normal((3, 5)).astype(np.float64)
|
|
307
|
+
U, V, X, C, S = gsvd(A, B)
|
|
308
|
+
assert U.dtype == np.float64 # result_type promotes to float64
|