numo-tiny_linalg 0.0.3 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +167 -0
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +167 -0
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +193 -0
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +158 -0
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +158 -0
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +192 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +18 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +44 -0
- metadata +8 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: '023843b368ce62924768c3341587977e57edb6aca0e6d93b64279d5e7965bf0b'
|
4
|
+
data.tar.gz: fe00ae7b435a94a6ff2790320f89bfef078efafed110e253f8929553480b67d6
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 36d907834c73b633ae3872ba5fa73f315dd01d79e99487e7fb4f6f2e3f37a7fbcbfa1b25d078b4d44cd66213da545a665c35817d6929d22a8a05fa023791bd2f
|
7
|
+
data.tar.gz: eaab6dbb2835d60ad3432525a830a1a5f19625c7b87d1010db81efeb480892c0807612d8d13cdc23c60abde1d7304594fc4562099102a6738b8cee1aeccb55a6
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,11 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.3...v0.0.4)] - 2023-08-06
|
4
|
+
- Add dsygv, ssygv, zhegv, and chegv module functions to TinyLinalg::Lapack.
|
5
|
+
- Add dsygvd, ssygvd, zhegvd, and chegvd module functions to TinyLinalg::Lapack.
|
6
|
+
- Add dsygvx, ssygvx, zhegvx, and chegvx module functions to TinyLinalg::Lapack.
|
7
|
+
- Add eigh module function to TinyLinalg.
|
8
|
+
|
3
9
|
## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
|
4
10
|
- Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
|
5
11
|
- Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
|
@@ -0,0 +1,167 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeGv {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
5
|
+
char uplo, lapack_int n, lapack_complex_double* a,
|
6
|
+
lapack_int lda, lapack_complex_double* b,
|
7
|
+
lapack_int ldb, double* w) {
|
8
|
+
return LAPACKE_zhegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct CHeGv {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
14
|
+
char uplo, lapack_int n, float* a, lapack_int lda,
|
15
|
+
float* b, lapack_int ldb, float* w) {
|
16
|
+
return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
17
|
+
}
|
18
|
+
|
19
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
20
|
+
char uplo, lapack_int n, lapack_complex_float* a,
|
21
|
+
lapack_int lda, lapack_complex_float* b,
|
22
|
+
lapack_int ldb, float* w) {
|
23
|
+
return LAPACKE_chegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
|
28
|
+
class HeGv {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegv), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct hegv_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
lapack_int itype;
|
38
|
+
char jobz;
|
39
|
+
char uplo;
|
40
|
+
};
|
41
|
+
|
42
|
+
static void iter_hegv(na_loop_t* const lp) {
|
43
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
44
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
45
|
+
RType* w = (RType*)NDL_PTR(lp, 2);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
47
|
+
hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
49
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
50
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
51
|
+
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
52
|
+
*info = static_cast<int>(i);
|
53
|
+
}
|
54
|
+
|
55
|
+
static VALUE tiny_linalg_hegv(int argc, VALUE* argv, VALUE self) {
|
56
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
57
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
58
|
+
|
59
|
+
VALUE a_vnary = Qnil;
|
60
|
+
VALUE b_vnary = Qnil;
|
61
|
+
VALUE kw_args = Qnil;
|
62
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
63
|
+
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
64
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
65
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
66
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
67
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
68
|
+
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
69
|
+
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
70
|
+
|
71
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
72
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
73
|
+
}
|
74
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
75
|
+
a_vnary = nary_dup(a_vnary);
|
76
|
+
}
|
77
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
78
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
79
|
+
}
|
80
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
81
|
+
b_vnary = nary_dup(b_vnary);
|
82
|
+
}
|
83
|
+
|
84
|
+
narray_t* a_nary = nullptr;
|
85
|
+
GetNArray(a_vnary, a_nary);
|
86
|
+
if (NA_NDIM(a_nary) != 2) {
|
87
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
88
|
+
return Qnil;
|
89
|
+
}
|
90
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
91
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
92
|
+
return Qnil;
|
93
|
+
}
|
94
|
+
narray_t* b_nary = nullptr;
|
95
|
+
GetNArray(a_vnary, b_nary);
|
96
|
+
if (NA_NDIM(b_nary) != 2) {
|
97
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
98
|
+
return Qnil;
|
99
|
+
}
|
100
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
101
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
102
|
+
return Qnil;
|
103
|
+
}
|
104
|
+
|
105
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
106
|
+
size_t shape[1] = { n };
|
107
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
108
|
+
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
|
109
|
+
ndfunc_t ndf = { iter_hegv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
110
|
+
hegv_opt opt = { matrix_layout, itype, jobz, uplo };
|
111
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
112
|
+
VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
113
|
+
|
114
|
+
RB_GC_GUARD(a_vnary);
|
115
|
+
RB_GC_GUARD(b_vnary);
|
116
|
+
|
117
|
+
return ret;
|
118
|
+
}
|
119
|
+
|
120
|
+
static lapack_int get_itype(VALUE val) {
|
121
|
+
const lapack_int itype = NUM2INT(val);
|
122
|
+
|
123
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
124
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
125
|
+
}
|
126
|
+
|
127
|
+
return itype;
|
128
|
+
}
|
129
|
+
|
130
|
+
static char get_jobz(VALUE val) {
|
131
|
+
const char jobz = NUM2CHR(val);
|
132
|
+
|
133
|
+
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
134
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
135
|
+
}
|
136
|
+
|
137
|
+
return jobz;
|
138
|
+
}
|
139
|
+
|
140
|
+
static char get_uplo(VALUE val) {
|
141
|
+
const char uplo = NUM2CHR(val);
|
142
|
+
|
143
|
+
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
144
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
145
|
+
}
|
146
|
+
|
147
|
+
return uplo;
|
148
|
+
}
|
149
|
+
|
150
|
+
static int get_matrix_layout(VALUE val) {
|
151
|
+
const char option = NUM2CHR(val);
|
152
|
+
|
153
|
+
switch (option) {
|
154
|
+
case 'r':
|
155
|
+
case 'R':
|
156
|
+
break;
|
157
|
+
case 'c':
|
158
|
+
case 'C':
|
159
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
|
160
|
+
break;
|
161
|
+
}
|
162
|
+
|
163
|
+
return LAPACK_ROW_MAJOR;
|
164
|
+
}
|
165
|
+
};
|
166
|
+
|
167
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,167 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeGvd {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
5
|
+
char uplo, lapack_int n, lapack_complex_double* a,
|
6
|
+
lapack_int lda, lapack_complex_double* b,
|
7
|
+
lapack_int ldb, double* w) {
|
8
|
+
return LAPACKE_zhegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct CHeGvd {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
14
|
+
char uplo, lapack_int n, float* a, lapack_int lda,
|
15
|
+
float* b, lapack_int ldb, float* w) {
|
16
|
+
return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
17
|
+
}
|
18
|
+
|
19
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
20
|
+
char uplo, lapack_int n, lapack_complex_float* a,
|
21
|
+
lapack_int lda, lapack_complex_float* b,
|
22
|
+
lapack_int ldb, float* w) {
|
23
|
+
return LAPACKE_chegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
|
28
|
+
class HeGvd {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvd), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct hegvd_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
lapack_int itype;
|
38
|
+
char jobz;
|
39
|
+
char uplo;
|
40
|
+
};
|
41
|
+
|
42
|
+
static void iter_hegvd(na_loop_t* const lp) {
|
43
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
44
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
45
|
+
RType* w = (RType*)NDL_PTR(lp, 2);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
47
|
+
hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
49
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
50
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
51
|
+
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
52
|
+
*info = static_cast<int>(i);
|
53
|
+
}
|
54
|
+
|
55
|
+
static VALUE tiny_linalg_hegvd(int argc, VALUE* argv, VALUE self) {
|
56
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
57
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
58
|
+
|
59
|
+
VALUE a_vnary = Qnil;
|
60
|
+
VALUE b_vnary = Qnil;
|
61
|
+
VALUE kw_args = Qnil;
|
62
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
63
|
+
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
64
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
65
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
66
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
67
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
68
|
+
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
69
|
+
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
70
|
+
|
71
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
72
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
73
|
+
}
|
74
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
75
|
+
a_vnary = nary_dup(a_vnary);
|
76
|
+
}
|
77
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
78
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
79
|
+
}
|
80
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
81
|
+
b_vnary = nary_dup(b_vnary);
|
82
|
+
}
|
83
|
+
|
84
|
+
narray_t* a_nary = nullptr;
|
85
|
+
GetNArray(a_vnary, a_nary);
|
86
|
+
if (NA_NDIM(a_nary) != 2) {
|
87
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
88
|
+
return Qnil;
|
89
|
+
}
|
90
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
91
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
92
|
+
return Qnil;
|
93
|
+
}
|
94
|
+
narray_t* b_nary = nullptr;
|
95
|
+
GetNArray(a_vnary, b_nary);
|
96
|
+
if (NA_NDIM(b_nary) != 2) {
|
97
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
98
|
+
return Qnil;
|
99
|
+
}
|
100
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
101
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
102
|
+
return Qnil;
|
103
|
+
}
|
104
|
+
|
105
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
106
|
+
size_t shape[1] = { n };
|
107
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
108
|
+
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
|
109
|
+
ndfunc_t ndf = { iter_hegvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
110
|
+
hegvd_opt opt = { matrix_layout, itype, jobz, uplo };
|
111
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
112
|
+
VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
113
|
+
|
114
|
+
RB_GC_GUARD(a_vnary);
|
115
|
+
RB_GC_GUARD(b_vnary);
|
116
|
+
|
117
|
+
return ret;
|
118
|
+
}
|
119
|
+
|
120
|
+
static lapack_int get_itype(VALUE val) {
|
121
|
+
const lapack_int itype = NUM2INT(val);
|
122
|
+
|
123
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
124
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
125
|
+
}
|
126
|
+
|
127
|
+
return itype;
|
128
|
+
}
|
129
|
+
|
130
|
+
static char get_jobz(VALUE val) {
|
131
|
+
const char jobz = NUM2CHR(val);
|
132
|
+
|
133
|
+
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
134
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
135
|
+
}
|
136
|
+
|
137
|
+
return jobz;
|
138
|
+
}
|
139
|
+
|
140
|
+
static char get_uplo(VALUE val) {
|
141
|
+
const char uplo = NUM2CHR(val);
|
142
|
+
|
143
|
+
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
144
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
145
|
+
}
|
146
|
+
|
147
|
+
return uplo;
|
148
|
+
}
|
149
|
+
|
150
|
+
static int get_matrix_layout(VALUE val) {
|
151
|
+
const char option = NUM2CHR(val);
|
152
|
+
|
153
|
+
switch (option) {
|
154
|
+
case 'r':
|
155
|
+
case 'R':
|
156
|
+
break;
|
157
|
+
case 'c':
|
158
|
+
case 'C':
|
159
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
|
160
|
+
break;
|
161
|
+
}
|
162
|
+
|
163
|
+
return LAPACK_ROW_MAJOR;
|
164
|
+
}
|
165
|
+
};
|
166
|
+
|
167
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,193 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeGvx {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb,
|
6
|
+
double vl, double vu, lapack_int il, lapack_int iu,
|
7
|
+
double abstol, lapack_int* m, double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* ifail) {
|
8
|
+
return LAPACKE_zhegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct CHeGvx {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
14
|
+
lapack_int n, lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb,
|
15
|
+
float vl, float vu, lapack_int il, lapack_int iu,
|
16
|
+
float abstol, lapack_int* m, float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* ifail) {
|
17
|
+
return LAPACKE_chegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
|
22
|
+
class HeGvx {
|
23
|
+
public:
|
24
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
25
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvx), -1);
|
26
|
+
}
|
27
|
+
|
28
|
+
private:
|
29
|
+
struct hegvx_opt {
|
30
|
+
int matrix_layout;
|
31
|
+
lapack_int itype;
|
32
|
+
char jobz;
|
33
|
+
char range;
|
34
|
+
char uplo;
|
35
|
+
RType vl;
|
36
|
+
RType vu;
|
37
|
+
lapack_int il;
|
38
|
+
lapack_int iu;
|
39
|
+
};
|
40
|
+
|
41
|
+
static void iter_hegvx(na_loop_t* const lp) {
|
42
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
43
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
44
|
+
int* m = (int*)NDL_PTR(lp, 2);
|
45
|
+
RType* w = (RType*)NDL_PTR(lp, 3);
|
46
|
+
DType* z = (DType*)NDL_PTR(lp, 4);
|
47
|
+
int* ifail = (int*)NDL_PTR(lp, 5);
|
48
|
+
int* info = (int*)NDL_PTR(lp, 6);
|
49
|
+
hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
|
50
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
51
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
52
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
53
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
54
|
+
const RType abstol = 0.0;
|
55
|
+
const lapack_int i = FncType().call(
|
56
|
+
opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
|
57
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
|
58
|
+
*info = static_cast<int>(i);
|
59
|
+
}
|
60
|
+
|
61
|
+
static VALUE tiny_linalg_hegvx(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
63
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
64
|
+
|
65
|
+
VALUE a_vnary = Qnil;
|
66
|
+
VALUE b_vnary = Qnil;
|
67
|
+
VALUE kw_args = Qnil;
|
68
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
69
|
+
ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
70
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
71
|
+
VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
72
|
+
rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
|
73
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
74
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
75
|
+
const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
|
76
|
+
const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
|
77
|
+
const RType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
78
|
+
const RType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
|
79
|
+
const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
80
|
+
const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
|
81
|
+
const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
|
82
|
+
|
83
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
84
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
85
|
+
}
|
86
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
87
|
+
a_vnary = nary_dup(a_vnary);
|
88
|
+
}
|
89
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
90
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
91
|
+
}
|
92
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
93
|
+
b_vnary = nary_dup(b_vnary);
|
94
|
+
}
|
95
|
+
|
96
|
+
narray_t* a_nary = nullptr;
|
97
|
+
GetNArray(a_vnary, a_nary);
|
98
|
+
if (NA_NDIM(a_nary) != 2) {
|
99
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
100
|
+
return Qnil;
|
101
|
+
}
|
102
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
103
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
104
|
+
return Qnil;
|
105
|
+
}
|
106
|
+
narray_t* b_nary = nullptr;
|
107
|
+
GetNArray(a_vnary, b_nary);
|
108
|
+
if (NA_NDIM(b_nary) != 2) {
|
109
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
110
|
+
return Qnil;
|
111
|
+
}
|
112
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
113
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
114
|
+
return Qnil;
|
115
|
+
}
|
116
|
+
|
117
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
118
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
119
|
+
size_t w_shape[1] = { m };
|
120
|
+
size_t z_shape[2] = { n, m };
|
121
|
+
size_t ifail_shape[1] = { n };
|
122
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
123
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
|
124
|
+
ndfunc_t ndf = { iter_hegvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
|
125
|
+
hegvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
|
126
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
127
|
+
VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
128
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
129
|
+
|
130
|
+
RB_GC_GUARD(a_vnary);
|
131
|
+
RB_GC_GUARD(b_vnary);
|
132
|
+
|
133
|
+
return ret;
|
134
|
+
}
|
135
|
+
|
136
|
+
static lapack_int get_itype(VALUE val) {
|
137
|
+
const lapack_int itype = NUM2INT(val);
|
138
|
+
|
139
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
140
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
141
|
+
}
|
142
|
+
|
143
|
+
return itype;
|
144
|
+
}
|
145
|
+
|
146
|
+
static char get_jobz(VALUE val) {
|
147
|
+
const char jobz = NUM2CHR(val);
|
148
|
+
|
149
|
+
if (jobz != 'N' && jobz != 'V') {
|
150
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
151
|
+
}
|
152
|
+
|
153
|
+
return jobz;
|
154
|
+
}
|
155
|
+
|
156
|
+
static char get_range(VALUE val) {
|
157
|
+
const char range = NUM2CHR(val);
|
158
|
+
|
159
|
+
if (range != 'A' && range != 'V' && range != 'I') {
|
160
|
+
rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
|
161
|
+
}
|
162
|
+
|
163
|
+
return range;
|
164
|
+
}
|
165
|
+
|
166
|
+
static char get_uplo(VALUE val) {
|
167
|
+
const char uplo = NUM2CHR(val);
|
168
|
+
|
169
|
+
if (uplo != 'U' && uplo != 'L') {
|
170
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
171
|
+
}
|
172
|
+
|
173
|
+
return uplo;
|
174
|
+
}
|
175
|
+
|
176
|
+
static int get_matrix_layout(VALUE val) {
|
177
|
+
const char option = NUM2CHR(val);
|
178
|
+
|
179
|
+
switch (option) {
|
180
|
+
case 'r':
|
181
|
+
case 'R':
|
182
|
+
break;
|
183
|
+
case 'c':
|
184
|
+
case 'C':
|
185
|
+
rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
|
186
|
+
break;
|
187
|
+
}
|
188
|
+
|
189
|
+
return LAPACK_ROW_MAJOR;
|
190
|
+
}
|
191
|
+
};
|
192
|
+
|
193
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,158 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyGv {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
5
|
+
char uplo, lapack_int n, double* a, lapack_int lda,
|
6
|
+
double* b, lapack_int ldb, double* w) {
|
7
|
+
return LAPACKE_dsygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
8
|
+
}
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SSyGv {
|
12
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
13
|
+
char uplo, lapack_int n, float* a, lapack_int lda,
|
14
|
+
float* b, lapack_int ldb, float* w) {
|
15
|
+
return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
16
|
+
}
|
17
|
+
};
|
18
|
+
|
19
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
20
|
+
class SyGv {
|
21
|
+
public:
|
22
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
23
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygv), -1);
|
24
|
+
}
|
25
|
+
|
26
|
+
private:
|
27
|
+
struct sygv_opt {
|
28
|
+
int matrix_layout;
|
29
|
+
lapack_int itype;
|
30
|
+
char jobz;
|
31
|
+
char uplo;
|
32
|
+
};
|
33
|
+
|
34
|
+
static void iter_sygv(na_loop_t* const lp) {
|
35
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
36
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
37
|
+
DType* w = (DType*)NDL_PTR(lp, 2);
|
38
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
39
|
+
sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
|
40
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
41
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
42
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
43
|
+
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
44
|
+
*info = static_cast<int>(i);
|
45
|
+
}
|
46
|
+
|
47
|
+
static VALUE tiny_linalg_sygv(int argc, VALUE* argv, VALUE self) {
|
48
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
49
|
+
|
50
|
+
VALUE a_vnary = Qnil;
|
51
|
+
VALUE b_vnary = Qnil;
|
52
|
+
VALUE kw_args = Qnil;
|
53
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
54
|
+
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
55
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
56
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
57
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
58
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
59
|
+
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
60
|
+
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
69
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
70
|
+
}
|
71
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
72
|
+
b_vnary = nary_dup(b_vnary);
|
73
|
+
}
|
74
|
+
|
75
|
+
narray_t* a_nary = nullptr;
|
76
|
+
GetNArray(a_vnary, a_nary);
|
77
|
+
if (NA_NDIM(a_nary) != 2) {
|
78
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
79
|
+
return Qnil;
|
80
|
+
}
|
81
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
82
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
83
|
+
return Qnil;
|
84
|
+
}
|
85
|
+
narray_t* b_nary = nullptr;
|
86
|
+
GetNArray(a_vnary, b_nary);
|
87
|
+
if (NA_NDIM(b_nary) != 2) {
|
88
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
89
|
+
return Qnil;
|
90
|
+
}
|
91
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
92
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
93
|
+
return Qnil;
|
94
|
+
}
|
95
|
+
|
96
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
97
|
+
size_t shape[1] = { n };
|
98
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
99
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
100
|
+
ndfunc_t ndf = { iter_sygv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
101
|
+
sygv_opt opt = { matrix_layout, itype, jobz, uplo };
|
102
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
103
|
+
VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
104
|
+
|
105
|
+
RB_GC_GUARD(a_vnary);
|
106
|
+
RB_GC_GUARD(b_vnary);
|
107
|
+
|
108
|
+
return ret;
|
109
|
+
}
|
110
|
+
|
111
|
+
static lapack_int get_itype(VALUE val) {
|
112
|
+
const lapack_int itype = NUM2INT(val);
|
113
|
+
|
114
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
115
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
116
|
+
}
|
117
|
+
|
118
|
+
return itype;
|
119
|
+
}
|
120
|
+
|
121
|
+
static char get_jobz(VALUE val) {
|
122
|
+
const char jobz = NUM2CHR(val);
|
123
|
+
|
124
|
+
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
125
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
126
|
+
}
|
127
|
+
|
128
|
+
return jobz;
|
129
|
+
}
|
130
|
+
|
131
|
+
static char get_uplo(VALUE val) {
|
132
|
+
const char uplo = NUM2CHR(val);
|
133
|
+
|
134
|
+
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
135
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
136
|
+
}
|
137
|
+
|
138
|
+
return uplo;
|
139
|
+
}
|
140
|
+
|
141
|
+
static int get_matrix_layout(VALUE val) {
|
142
|
+
const char option = NUM2CHR(val);
|
143
|
+
|
144
|
+
switch (option) {
|
145
|
+
case 'r':
|
146
|
+
case 'R':
|
147
|
+
break;
|
148
|
+
case 'c':
|
149
|
+
case 'C':
|
150
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
|
151
|
+
break;
|
152
|
+
}
|
153
|
+
|
154
|
+
return LAPACK_ROW_MAJOR;
|
155
|
+
}
|
156
|
+
};
|
157
|
+
|
158
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,158 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyGvd {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
5
|
+
char uplo, lapack_int n, double* a, lapack_int lda,
|
6
|
+
double* b, lapack_int ldb, double* w) {
|
7
|
+
return LAPACKE_dsygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
8
|
+
}
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SSyGvd {
|
12
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
13
|
+
char uplo, lapack_int n, float* a, lapack_int lda,
|
14
|
+
float* b, lapack_int ldb, float* w) {
|
15
|
+
return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
16
|
+
}
|
17
|
+
};
|
18
|
+
|
19
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
20
|
+
class SyGvd {
|
21
|
+
public:
|
22
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
23
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvd), -1);
|
24
|
+
}
|
25
|
+
|
26
|
+
private:
|
27
|
+
struct sygvd_opt {
|
28
|
+
int matrix_layout;
|
29
|
+
lapack_int itype;
|
30
|
+
char jobz;
|
31
|
+
char uplo;
|
32
|
+
};
|
33
|
+
|
34
|
+
static void iter_sygvd(na_loop_t* const lp) {
|
35
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
36
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
37
|
+
DType* w = (DType*)NDL_PTR(lp, 2);
|
38
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
39
|
+
sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
|
40
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
41
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
42
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
43
|
+
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
44
|
+
*info = static_cast<int>(i);
|
45
|
+
}
|
46
|
+
|
47
|
+
static VALUE tiny_linalg_sygvd(int argc, VALUE* argv, VALUE self) {
|
48
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
49
|
+
|
50
|
+
VALUE a_vnary = Qnil;
|
51
|
+
VALUE b_vnary = Qnil;
|
52
|
+
VALUE kw_args = Qnil;
|
53
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
54
|
+
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
55
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
56
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
57
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
58
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
59
|
+
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
60
|
+
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
69
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
70
|
+
}
|
71
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
72
|
+
b_vnary = nary_dup(b_vnary);
|
73
|
+
}
|
74
|
+
|
75
|
+
narray_t* a_nary = nullptr;
|
76
|
+
GetNArray(a_vnary, a_nary);
|
77
|
+
if (NA_NDIM(a_nary) != 2) {
|
78
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
79
|
+
return Qnil;
|
80
|
+
}
|
81
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
82
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
83
|
+
return Qnil;
|
84
|
+
}
|
85
|
+
narray_t* b_nary = nullptr;
|
86
|
+
GetNArray(a_vnary, b_nary);
|
87
|
+
if (NA_NDIM(b_nary) != 2) {
|
88
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
89
|
+
return Qnil;
|
90
|
+
}
|
91
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
92
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
93
|
+
return Qnil;
|
94
|
+
}
|
95
|
+
|
96
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
97
|
+
size_t shape[1] = { n };
|
98
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
99
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
100
|
+
ndfunc_t ndf = { iter_sygvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
101
|
+
sygvd_opt opt = { matrix_layout, itype, jobz, uplo };
|
102
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
103
|
+
VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
104
|
+
|
105
|
+
RB_GC_GUARD(a_vnary);
|
106
|
+
RB_GC_GUARD(b_vnary);
|
107
|
+
|
108
|
+
return ret;
|
109
|
+
}
|
110
|
+
|
111
|
+
static lapack_int get_itype(VALUE val) {
|
112
|
+
const lapack_int itype = NUM2INT(val);
|
113
|
+
|
114
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
115
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
116
|
+
}
|
117
|
+
|
118
|
+
return itype;
|
119
|
+
}
|
120
|
+
|
121
|
+
static char get_jobz(VALUE val) {
|
122
|
+
const char jobz = NUM2CHR(val);
|
123
|
+
|
124
|
+
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
125
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
126
|
+
}
|
127
|
+
|
128
|
+
return jobz;
|
129
|
+
}
|
130
|
+
|
131
|
+
static char get_uplo(VALUE val) {
|
132
|
+
const char uplo = NUM2CHR(val);
|
133
|
+
|
134
|
+
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
135
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
136
|
+
}
|
137
|
+
|
138
|
+
return uplo;
|
139
|
+
}
|
140
|
+
|
141
|
+
static int get_matrix_layout(VALUE val) {
|
142
|
+
const char option = NUM2CHR(val);
|
143
|
+
|
144
|
+
switch (option) {
|
145
|
+
case 'r':
|
146
|
+
case 'R':
|
147
|
+
break;
|
148
|
+
case 'c':
|
149
|
+
case 'C':
|
150
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
|
151
|
+
break;
|
152
|
+
}
|
153
|
+
|
154
|
+
return LAPACK_ROW_MAJOR;
|
155
|
+
}
|
156
|
+
};
|
157
|
+
|
158
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,192 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyGvx {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, double* a, lapack_int lda, double* b, lapack_int ldb,
|
6
|
+
double vl, double vu, lapack_int il, lapack_int iu,
|
7
|
+
double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* ifail) {
|
8
|
+
return LAPACKE_dsygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct SSyGvx {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
14
|
+
lapack_int n, float* a, lapack_int lda, float* b, lapack_int ldb,
|
15
|
+
float vl, float vu, lapack_int il, lapack_int iu,
|
16
|
+
float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* ifail) {
|
17
|
+
return LAPACKE_ssygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
22
|
+
class SyGvx {
|
23
|
+
public:
|
24
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
25
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvx), -1);
|
26
|
+
}
|
27
|
+
|
28
|
+
private:
|
29
|
+
struct sygvx_opt {
|
30
|
+
int matrix_layout;
|
31
|
+
lapack_int itype;
|
32
|
+
char jobz;
|
33
|
+
char range;
|
34
|
+
char uplo;
|
35
|
+
DType vl;
|
36
|
+
DType vu;
|
37
|
+
lapack_int il;
|
38
|
+
lapack_int iu;
|
39
|
+
};
|
40
|
+
|
41
|
+
static void iter_sygvx(na_loop_t* const lp) {
|
42
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
43
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
44
|
+
int* m = (int*)NDL_PTR(lp, 2);
|
45
|
+
DType* w = (DType*)NDL_PTR(lp, 3);
|
46
|
+
DType* z = (DType*)NDL_PTR(lp, 4);
|
47
|
+
int* ifail = (int*)NDL_PTR(lp, 5);
|
48
|
+
int* info = (int*)NDL_PTR(lp, 6);
|
49
|
+
sygvx_opt* opt = (sygvx_opt*)(lp->opt_ptr);
|
50
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
51
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
52
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
53
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
54
|
+
const DType abstol = 0.0;
|
55
|
+
const lapack_int i = FncType().call(
|
56
|
+
opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
|
57
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
|
58
|
+
*info = static_cast<int>(i);
|
59
|
+
}
|
60
|
+
|
61
|
+
static VALUE tiny_linalg_sygvx(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
63
|
+
|
64
|
+
VALUE a_vnary = Qnil;
|
65
|
+
VALUE b_vnary = Qnil;
|
66
|
+
VALUE kw_args = Qnil;
|
67
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
68
|
+
ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
69
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
70
|
+
VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
71
|
+
rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
|
72
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
73
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
74
|
+
const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
|
75
|
+
const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
|
76
|
+
const DType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
77
|
+
const DType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
|
78
|
+
const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
79
|
+
const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
|
80
|
+
const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
|
81
|
+
|
82
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
83
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
84
|
+
}
|
85
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
86
|
+
a_vnary = nary_dup(a_vnary);
|
87
|
+
}
|
88
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
89
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
90
|
+
}
|
91
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
92
|
+
b_vnary = nary_dup(b_vnary);
|
93
|
+
}
|
94
|
+
|
95
|
+
narray_t* a_nary = nullptr;
|
96
|
+
GetNArray(a_vnary, a_nary);
|
97
|
+
if (NA_NDIM(a_nary) != 2) {
|
98
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
99
|
+
return Qnil;
|
100
|
+
}
|
101
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
102
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
103
|
+
return Qnil;
|
104
|
+
}
|
105
|
+
narray_t* b_nary = nullptr;
|
106
|
+
GetNArray(a_vnary, b_nary);
|
107
|
+
if (NA_NDIM(b_nary) != 2) {
|
108
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
109
|
+
return Qnil;
|
110
|
+
}
|
111
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
112
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
113
|
+
return Qnil;
|
114
|
+
}
|
115
|
+
|
116
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
117
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
118
|
+
size_t w_shape[1] = { m };
|
119
|
+
size_t z_shape[2] = { n, m };
|
120
|
+
size_t ifail_shape[1] = { n };
|
121
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
122
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
|
123
|
+
ndfunc_t ndf = { iter_sygvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
|
124
|
+
sygvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
|
125
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
126
|
+
VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
127
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
128
|
+
|
129
|
+
RB_GC_GUARD(a_vnary);
|
130
|
+
RB_GC_GUARD(b_vnary);
|
131
|
+
|
132
|
+
return ret;
|
133
|
+
}
|
134
|
+
|
135
|
+
static lapack_int get_itype(VALUE val) {
|
136
|
+
const lapack_int itype = NUM2INT(val);
|
137
|
+
|
138
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
139
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
140
|
+
}
|
141
|
+
|
142
|
+
return itype;
|
143
|
+
}
|
144
|
+
|
145
|
+
static char get_jobz(VALUE val) {
|
146
|
+
const char jobz = NUM2CHR(val);
|
147
|
+
|
148
|
+
if (jobz != 'N' && jobz != 'V') {
|
149
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
150
|
+
}
|
151
|
+
|
152
|
+
return jobz;
|
153
|
+
}
|
154
|
+
|
155
|
+
static char get_range(VALUE val) {
|
156
|
+
const char range = NUM2CHR(val);
|
157
|
+
|
158
|
+
if (range != 'A' && range != 'V' && range != 'I') {
|
159
|
+
rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
|
160
|
+
}
|
161
|
+
|
162
|
+
return range;
|
163
|
+
}
|
164
|
+
|
165
|
+
static char get_uplo(VALUE val) {
|
166
|
+
const char uplo = NUM2CHR(val);
|
167
|
+
|
168
|
+
if (uplo != 'U' && uplo != 'L') {
|
169
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
170
|
+
}
|
171
|
+
|
172
|
+
return uplo;
|
173
|
+
}
|
174
|
+
|
175
|
+
static int get_matrix_layout(VALUE val) {
|
176
|
+
const char option = NUM2CHR(val);
|
177
|
+
|
178
|
+
switch (option) {
|
179
|
+
case 'r':
|
180
|
+
case 'R':
|
181
|
+
break;
|
182
|
+
case 'c':
|
183
|
+
case 'C':
|
184
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygvx does not support column major.");
|
185
|
+
break;
|
186
|
+
}
|
187
|
+
|
188
|
+
return LAPACK_ROW_MAJOR;
|
189
|
+
}
|
190
|
+
};
|
191
|
+
|
192
|
+
} // namespace TinyLinalg
|
@@ -11,7 +11,13 @@
|
|
11
11
|
#include "lapack/gesvd.hpp"
|
12
12
|
#include "lapack/getrf.hpp"
|
13
13
|
#include "lapack/getri.hpp"
|
14
|
+
#include "lapack/hegv.hpp"
|
15
|
+
#include "lapack/hegvd.hpp"
|
16
|
+
#include "lapack/hegvx.hpp"
|
14
17
|
#include "lapack/orgqr.hpp"
|
18
|
+
#include "lapack/sygv.hpp"
|
19
|
+
#include "lapack/sygvd.hpp"
|
20
|
+
#include "lapack/sygvx.hpp"
|
15
21
|
#include "lapack/ungqr.hpp"
|
16
22
|
|
17
23
|
VALUE rb_mTinyLinalg;
|
@@ -271,6 +277,18 @@ extern "C" void Init_tiny_linalg(void) {
|
|
271
277
|
TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
|
272
278
|
TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
|
273
279
|
TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
|
280
|
+
TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
|
281
|
+
TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
|
282
|
+
TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
|
283
|
+
TinyLinalg::HeGv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGv>::define_module_function(rb_mTinyLinalgLapack, "chegv");
|
284
|
+
TinyLinalg::SyGvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvd>::define_module_function(rb_mTinyLinalgLapack, "dsygvd");
|
285
|
+
TinyLinalg::SyGvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvd>::define_module_function(rb_mTinyLinalgLapack, "ssygvd");
|
286
|
+
TinyLinalg::HeGvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvd>::define_module_function(rb_mTinyLinalgLapack, "zhegvd");
|
287
|
+
TinyLinalg::HeGvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvd>::define_module_function(rb_mTinyLinalgLapack, "chegvd");
|
288
|
+
TinyLinalg::SyGvx<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvx>::define_module_function(rb_mTinyLinalgLapack, "dsygvx");
|
289
|
+
TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
|
290
|
+
TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
|
291
|
+
TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
|
274
292
|
|
275
293
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
|
276
294
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -10,6 +10,50 @@ module Numo
|
|
10
10
|
module TinyLinalg # rubocop:disable Metrics/ModuleLength
|
11
11
|
module_function
|
12
12
|
|
13
|
+
# Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
|
14
|
+
# by solving an ordinary or generalized eigenvalue problem.
|
15
|
+
#
|
16
|
+
# @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
|
17
|
+
# @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
|
18
|
+
# @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
|
19
|
+
# @param vals_range [Range/Array]
|
20
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
21
|
+
# If nil, all eigenvalues and eigenvectors are computed.
|
22
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
23
|
+
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
24
|
+
# @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
|
25
|
+
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
|
26
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
27
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
28
|
+
|
29
|
+
bchr = blas_char(a)
|
30
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
31
|
+
|
32
|
+
unless b.nil?
|
33
|
+
raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
|
34
|
+
raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
|
35
|
+
raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
|
36
|
+
end
|
37
|
+
|
38
|
+
jobz = vals_only ? 'N' : 'V'
|
39
|
+
b = a.class.eye(a.shape[0]) if b.nil?
|
40
|
+
sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
41
|
+
|
42
|
+
if vals_range.nil?
|
43
|
+
sy_he_gv << 'd' if turbo
|
44
|
+
vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
|
45
|
+
else
|
46
|
+
sy_he_gv << 'x'
|
47
|
+
il = vals_range.first + 1
|
48
|
+
iu = vals_range.last + 1
|
49
|
+
_a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
|
50
|
+
sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
51
|
+
)
|
52
|
+
end
|
53
|
+
vecs = nil if vals_only
|
54
|
+
[vals, vecs]
|
55
|
+
end
|
56
|
+
|
13
57
|
# Computes the determinant of matrix.
|
14
58
|
#
|
15
59
|
# @param a [Numo::NArray] n-by-n square matrix.
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-08-
|
11
|
+
date: 2023-08-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -52,7 +52,13 @@ files:
|
|
52
52
|
- ext/numo/tiny_linalg/lapack/gesvd.hpp
|
53
53
|
- ext/numo/tiny_linalg/lapack/getrf.hpp
|
54
54
|
- ext/numo/tiny_linalg/lapack/getri.hpp
|
55
|
+
- ext/numo/tiny_linalg/lapack/hegv.hpp
|
56
|
+
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
57
|
+
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
55
58
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
59
|
+
- ext/numo/tiny_linalg/lapack/sygv.hpp
|
60
|
+
- ext/numo/tiny_linalg/lapack/sygvd.hpp
|
61
|
+
- ext/numo/tiny_linalg/lapack/sygvx.hpp
|
56
62
|
- ext/numo/tiny_linalg/lapack/ungqr.hpp
|
57
63
|
- ext/numo/tiny_linalg/tiny_linalg.cpp
|
58
64
|
- ext/numo/tiny_linalg/tiny_linalg.hpp
|