numo-tiny_linalg 0.0.3 → 0.0.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +167 -0
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +167 -0
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +193 -0
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +158 -0
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +158 -0
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +192 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +18 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +44 -0
- metadata +8 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: '023843b368ce62924768c3341587977e57edb6aca0e6d93b64279d5e7965bf0b'
|
4
|
+
data.tar.gz: fe00ae7b435a94a6ff2790320f89bfef078efafed110e253f8929553480b67d6
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 36d907834c73b633ae3872ba5fa73f315dd01d79e99487e7fb4f6f2e3f37a7fbcbfa1b25d078b4d44cd66213da545a665c35817d6929d22a8a05fa023791bd2f
|
7
|
+
data.tar.gz: eaab6dbb2835d60ad3432525a830a1a5f19625c7b87d1010db81efeb480892c0807612d8d13cdc23c60abde1d7304594fc4562099102a6738b8cee1aeccb55a6
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,11 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.3...v0.0.4)] - 2023-08-06
|
4
|
+
- Add dsygv, ssygv, zhegv, and chegv module functions to TinyLinalg::Lapack.
|
5
|
+
- Add dsygvd, ssygvd, zhegvd, and chegvd module functions to TinyLinalg::Lapack.
|
6
|
+
- Add dsygvx, ssygvx, zhegvx, and chegvx module functions to TinyLinalg::Lapack.
|
7
|
+
- Add eigh module function to TinyLinalg.
|
8
|
+
|
3
9
|
## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
|
4
10
|
- Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
|
5
11
|
- Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
|
@@ -0,0 +1,167 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeGv {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
5
|
+
char uplo, lapack_int n, lapack_complex_double* a,
|
6
|
+
lapack_int lda, lapack_complex_double* b,
|
7
|
+
lapack_int ldb, double* w) {
|
8
|
+
return LAPACKE_zhegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct CHeGv {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
14
|
+
char uplo, lapack_int n, float* a, lapack_int lda,
|
15
|
+
float* b, lapack_int ldb, float* w) {
|
16
|
+
return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
17
|
+
}
|
18
|
+
|
19
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
20
|
+
char uplo, lapack_int n, lapack_complex_float* a,
|
21
|
+
lapack_int lda, lapack_complex_float* b,
|
22
|
+
lapack_int ldb, float* w) {
|
23
|
+
return LAPACKE_chegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
|
28
|
+
class HeGv {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegv), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct hegv_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
lapack_int itype;
|
38
|
+
char jobz;
|
39
|
+
char uplo;
|
40
|
+
};
|
41
|
+
|
42
|
+
static void iter_hegv(na_loop_t* const lp) {
|
43
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
44
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
45
|
+
RType* w = (RType*)NDL_PTR(lp, 2);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
47
|
+
hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
49
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
50
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
51
|
+
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
52
|
+
*info = static_cast<int>(i);
|
53
|
+
}
|
54
|
+
|
55
|
+
static VALUE tiny_linalg_hegv(int argc, VALUE* argv, VALUE self) {
|
56
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
57
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
58
|
+
|
59
|
+
VALUE a_vnary = Qnil;
|
60
|
+
VALUE b_vnary = Qnil;
|
61
|
+
VALUE kw_args = Qnil;
|
62
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
63
|
+
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
64
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
65
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
66
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
67
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
68
|
+
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
69
|
+
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
70
|
+
|
71
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
72
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
73
|
+
}
|
74
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
75
|
+
a_vnary = nary_dup(a_vnary);
|
76
|
+
}
|
77
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
78
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
79
|
+
}
|
80
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
81
|
+
b_vnary = nary_dup(b_vnary);
|
82
|
+
}
|
83
|
+
|
84
|
+
narray_t* a_nary = nullptr;
|
85
|
+
GetNArray(a_vnary, a_nary);
|
86
|
+
if (NA_NDIM(a_nary) != 2) {
|
87
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
88
|
+
return Qnil;
|
89
|
+
}
|
90
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
91
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
92
|
+
return Qnil;
|
93
|
+
}
|
94
|
+
narray_t* b_nary = nullptr;
|
95
|
+
GetNArray(a_vnary, b_nary);
|
96
|
+
if (NA_NDIM(b_nary) != 2) {
|
97
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
98
|
+
return Qnil;
|
99
|
+
}
|
100
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
101
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
102
|
+
return Qnil;
|
103
|
+
}
|
104
|
+
|
105
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
106
|
+
size_t shape[1] = { n };
|
107
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
108
|
+
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
|
109
|
+
ndfunc_t ndf = { iter_hegv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
110
|
+
hegv_opt opt = { matrix_layout, itype, jobz, uplo };
|
111
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
112
|
+
VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
113
|
+
|
114
|
+
RB_GC_GUARD(a_vnary);
|
115
|
+
RB_GC_GUARD(b_vnary);
|
116
|
+
|
117
|
+
return ret;
|
118
|
+
}
|
119
|
+
|
120
|
+
static lapack_int get_itype(VALUE val) {
|
121
|
+
const lapack_int itype = NUM2INT(val);
|
122
|
+
|
123
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
124
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
125
|
+
}
|
126
|
+
|
127
|
+
return itype;
|
128
|
+
}
|
129
|
+
|
130
|
+
static char get_jobz(VALUE val) {
|
131
|
+
const char jobz = NUM2CHR(val);
|
132
|
+
|
133
|
+
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
134
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
135
|
+
}
|
136
|
+
|
137
|
+
return jobz;
|
138
|
+
}
|
139
|
+
|
140
|
+
static char get_uplo(VALUE val) {
|
141
|
+
const char uplo = NUM2CHR(val);
|
142
|
+
|
143
|
+
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
144
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
145
|
+
}
|
146
|
+
|
147
|
+
return uplo;
|
148
|
+
}
|
149
|
+
|
150
|
+
static int get_matrix_layout(VALUE val) {
|
151
|
+
const char option = NUM2CHR(val);
|
152
|
+
|
153
|
+
switch (option) {
|
154
|
+
case 'r':
|
155
|
+
case 'R':
|
156
|
+
break;
|
157
|
+
case 'c':
|
158
|
+
case 'C':
|
159
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
|
160
|
+
break;
|
161
|
+
}
|
162
|
+
|
163
|
+
return LAPACK_ROW_MAJOR;
|
164
|
+
}
|
165
|
+
};
|
166
|
+
|
167
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,167 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeGvd {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
5
|
+
char uplo, lapack_int n, lapack_complex_double* a,
|
6
|
+
lapack_int lda, lapack_complex_double* b,
|
7
|
+
lapack_int ldb, double* w) {
|
8
|
+
return LAPACKE_zhegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct CHeGvd {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
14
|
+
char uplo, lapack_int n, float* a, lapack_int lda,
|
15
|
+
float* b, lapack_int ldb, float* w) {
|
16
|
+
return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
17
|
+
}
|
18
|
+
|
19
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
20
|
+
char uplo, lapack_int n, lapack_complex_float* a,
|
21
|
+
lapack_int lda, lapack_complex_float* b,
|
22
|
+
lapack_int ldb, float* w) {
|
23
|
+
return LAPACKE_chegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
|
28
|
+
class HeGvd {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvd), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct hegvd_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
lapack_int itype;
|
38
|
+
char jobz;
|
39
|
+
char uplo;
|
40
|
+
};
|
41
|
+
|
42
|
+
static void iter_hegvd(na_loop_t* const lp) {
|
43
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
44
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
45
|
+
RType* w = (RType*)NDL_PTR(lp, 2);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
47
|
+
hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
49
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
50
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
51
|
+
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
52
|
+
*info = static_cast<int>(i);
|
53
|
+
}
|
54
|
+
|
55
|
+
static VALUE tiny_linalg_hegvd(int argc, VALUE* argv, VALUE self) {
|
56
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
57
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
58
|
+
|
59
|
+
VALUE a_vnary = Qnil;
|
60
|
+
VALUE b_vnary = Qnil;
|
61
|
+
VALUE kw_args = Qnil;
|
62
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
63
|
+
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
64
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
65
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
66
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
67
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
68
|
+
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
69
|
+
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
70
|
+
|
71
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
72
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
73
|
+
}
|
74
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
75
|
+
a_vnary = nary_dup(a_vnary);
|
76
|
+
}
|
77
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
78
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
79
|
+
}
|
80
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
81
|
+
b_vnary = nary_dup(b_vnary);
|
82
|
+
}
|
83
|
+
|
84
|
+
narray_t* a_nary = nullptr;
|
85
|
+
GetNArray(a_vnary, a_nary);
|
86
|
+
if (NA_NDIM(a_nary) != 2) {
|
87
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
88
|
+
return Qnil;
|
89
|
+
}
|
90
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
91
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
92
|
+
return Qnil;
|
93
|
+
}
|
94
|
+
narray_t* b_nary = nullptr;
|
95
|
+
GetNArray(a_vnary, b_nary);
|
96
|
+
if (NA_NDIM(b_nary) != 2) {
|
97
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
98
|
+
return Qnil;
|
99
|
+
}
|
100
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
101
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
102
|
+
return Qnil;
|
103
|
+
}
|
104
|
+
|
105
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
106
|
+
size_t shape[1] = { n };
|
107
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
108
|
+
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
|
109
|
+
ndfunc_t ndf = { iter_hegvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
110
|
+
hegvd_opt opt = { matrix_layout, itype, jobz, uplo };
|
111
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
112
|
+
VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
113
|
+
|
114
|
+
RB_GC_GUARD(a_vnary);
|
115
|
+
RB_GC_GUARD(b_vnary);
|
116
|
+
|
117
|
+
return ret;
|
118
|
+
}
|
119
|
+
|
120
|
+
static lapack_int get_itype(VALUE val) {
|
121
|
+
const lapack_int itype = NUM2INT(val);
|
122
|
+
|
123
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
124
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
125
|
+
}
|
126
|
+
|
127
|
+
return itype;
|
128
|
+
}
|
129
|
+
|
130
|
+
static char get_jobz(VALUE val) {
|
131
|
+
const char jobz = NUM2CHR(val);
|
132
|
+
|
133
|
+
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
134
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
135
|
+
}
|
136
|
+
|
137
|
+
return jobz;
|
138
|
+
}
|
139
|
+
|
140
|
+
static char get_uplo(VALUE val) {
|
141
|
+
const char uplo = NUM2CHR(val);
|
142
|
+
|
143
|
+
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
144
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
145
|
+
}
|
146
|
+
|
147
|
+
return uplo;
|
148
|
+
}
|
149
|
+
|
150
|
+
static int get_matrix_layout(VALUE val) {
|
151
|
+
const char option = NUM2CHR(val);
|
152
|
+
|
153
|
+
switch (option) {
|
154
|
+
case 'r':
|
155
|
+
case 'R':
|
156
|
+
break;
|
157
|
+
case 'c':
|
158
|
+
case 'C':
|
159
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
|
160
|
+
break;
|
161
|
+
}
|
162
|
+
|
163
|
+
return LAPACK_ROW_MAJOR;
|
164
|
+
}
|
165
|
+
};
|
166
|
+
|
167
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,193 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeGvx {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb,
|
6
|
+
double vl, double vu, lapack_int il, lapack_int iu,
|
7
|
+
double abstol, lapack_int* m, double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* ifail) {
|
8
|
+
return LAPACKE_zhegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct CHeGvx {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
14
|
+
lapack_int n, lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb,
|
15
|
+
float vl, float vu, lapack_int il, lapack_int iu,
|
16
|
+
float abstol, lapack_int* m, float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* ifail) {
|
17
|
+
return LAPACKE_chegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
|
22
|
+
class HeGvx {
|
23
|
+
public:
|
24
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
25
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvx), -1);
|
26
|
+
}
|
27
|
+
|
28
|
+
private:
|
29
|
+
struct hegvx_opt {
|
30
|
+
int matrix_layout;
|
31
|
+
lapack_int itype;
|
32
|
+
char jobz;
|
33
|
+
char range;
|
34
|
+
char uplo;
|
35
|
+
RType vl;
|
36
|
+
RType vu;
|
37
|
+
lapack_int il;
|
38
|
+
lapack_int iu;
|
39
|
+
};
|
40
|
+
|
41
|
+
static void iter_hegvx(na_loop_t* const lp) {
|
42
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
43
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
44
|
+
int* m = (int*)NDL_PTR(lp, 2);
|
45
|
+
RType* w = (RType*)NDL_PTR(lp, 3);
|
46
|
+
DType* z = (DType*)NDL_PTR(lp, 4);
|
47
|
+
int* ifail = (int*)NDL_PTR(lp, 5);
|
48
|
+
int* info = (int*)NDL_PTR(lp, 6);
|
49
|
+
hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
|
50
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
51
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
52
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
53
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
54
|
+
const RType abstol = 0.0;
|
55
|
+
const lapack_int i = FncType().call(
|
56
|
+
opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
|
57
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
|
58
|
+
*info = static_cast<int>(i);
|
59
|
+
}
|
60
|
+
|
61
|
+
static VALUE tiny_linalg_hegvx(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
63
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
64
|
+
|
65
|
+
VALUE a_vnary = Qnil;
|
66
|
+
VALUE b_vnary = Qnil;
|
67
|
+
VALUE kw_args = Qnil;
|
68
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
69
|
+
ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
70
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
71
|
+
VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
72
|
+
rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
|
73
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
74
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
75
|
+
const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
|
76
|
+
const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
|
77
|
+
const RType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
78
|
+
const RType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
|
79
|
+
const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
80
|
+
const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
|
81
|
+
const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
|
82
|
+
|
83
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
84
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
85
|
+
}
|
86
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
87
|
+
a_vnary = nary_dup(a_vnary);
|
88
|
+
}
|
89
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
90
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
91
|
+
}
|
92
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
93
|
+
b_vnary = nary_dup(b_vnary);
|
94
|
+
}
|
95
|
+
|
96
|
+
narray_t* a_nary = nullptr;
|
97
|
+
GetNArray(a_vnary, a_nary);
|
98
|
+
if (NA_NDIM(a_nary) != 2) {
|
99
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
100
|
+
return Qnil;
|
101
|
+
}
|
102
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
103
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
104
|
+
return Qnil;
|
105
|
+
}
|
106
|
+
narray_t* b_nary = nullptr;
|
107
|
+
GetNArray(a_vnary, b_nary);
|
108
|
+
if (NA_NDIM(b_nary) != 2) {
|
109
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
110
|
+
return Qnil;
|
111
|
+
}
|
112
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
113
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
114
|
+
return Qnil;
|
115
|
+
}
|
116
|
+
|
117
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
118
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
119
|
+
size_t w_shape[1] = { m };
|
120
|
+
size_t z_shape[2] = { n, m };
|
121
|
+
size_t ifail_shape[1] = { n };
|
122
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
123
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
|
124
|
+
ndfunc_t ndf = { iter_hegvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
|
125
|
+
hegvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
|
126
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
127
|
+
VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
128
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
129
|
+
|
130
|
+
RB_GC_GUARD(a_vnary);
|
131
|
+
RB_GC_GUARD(b_vnary);
|
132
|
+
|
133
|
+
return ret;
|
134
|
+
}
|
135
|
+
|
136
|
+
static lapack_int get_itype(VALUE val) {
|
137
|
+
const lapack_int itype = NUM2INT(val);
|
138
|
+
|
139
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
140
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
141
|
+
}
|
142
|
+
|
143
|
+
return itype;
|
144
|
+
}
|
145
|
+
|
146
|
+
static char get_jobz(VALUE val) {
|
147
|
+
const char jobz = NUM2CHR(val);
|
148
|
+
|
149
|
+
if (jobz != 'N' && jobz != 'V') {
|
150
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
151
|
+
}
|
152
|
+
|
153
|
+
return jobz;
|
154
|
+
}
|
155
|
+
|
156
|
+
static char get_range(VALUE val) {
|
157
|
+
const char range = NUM2CHR(val);
|
158
|
+
|
159
|
+
if (range != 'A' && range != 'V' && range != 'I') {
|
160
|
+
rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
|
161
|
+
}
|
162
|
+
|
163
|
+
return range;
|
164
|
+
}
|
165
|
+
|
166
|
+
static char get_uplo(VALUE val) {
|
167
|
+
const char uplo = NUM2CHR(val);
|
168
|
+
|
169
|
+
if (uplo != 'U' && uplo != 'L') {
|
170
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
171
|
+
}
|
172
|
+
|
173
|
+
return uplo;
|
174
|
+
}
|
175
|
+
|
176
|
+
static int get_matrix_layout(VALUE val) {
|
177
|
+
const char option = NUM2CHR(val);
|
178
|
+
|
179
|
+
switch (option) {
|
180
|
+
case 'r':
|
181
|
+
case 'R':
|
182
|
+
break;
|
183
|
+
case 'c':
|
184
|
+
case 'C':
|
185
|
+
rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
|
186
|
+
break;
|
187
|
+
}
|
188
|
+
|
189
|
+
return LAPACK_ROW_MAJOR;
|
190
|
+
}
|
191
|
+
};
|
192
|
+
|
193
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,158 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyGv {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
5
|
+
char uplo, lapack_int n, double* a, lapack_int lda,
|
6
|
+
double* b, lapack_int ldb, double* w) {
|
7
|
+
return LAPACKE_dsygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
8
|
+
}
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SSyGv {
|
12
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
13
|
+
char uplo, lapack_int n, float* a, lapack_int lda,
|
14
|
+
float* b, lapack_int ldb, float* w) {
|
15
|
+
return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
16
|
+
}
|
17
|
+
};
|
18
|
+
|
19
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
20
|
+
class SyGv {
|
21
|
+
public:
|
22
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
23
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygv), -1);
|
24
|
+
}
|
25
|
+
|
26
|
+
private:
|
27
|
+
struct sygv_opt {
|
28
|
+
int matrix_layout;
|
29
|
+
lapack_int itype;
|
30
|
+
char jobz;
|
31
|
+
char uplo;
|
32
|
+
};
|
33
|
+
|
34
|
+
static void iter_sygv(na_loop_t* const lp) {
|
35
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
36
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
37
|
+
DType* w = (DType*)NDL_PTR(lp, 2);
|
38
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
39
|
+
sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
|
40
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
41
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
42
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
43
|
+
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
44
|
+
*info = static_cast<int>(i);
|
45
|
+
}
|
46
|
+
|
47
|
+
static VALUE tiny_linalg_sygv(int argc, VALUE* argv, VALUE self) {
|
48
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
49
|
+
|
50
|
+
VALUE a_vnary = Qnil;
|
51
|
+
VALUE b_vnary = Qnil;
|
52
|
+
VALUE kw_args = Qnil;
|
53
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
54
|
+
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
55
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
56
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
57
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
58
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
59
|
+
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
60
|
+
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
69
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
70
|
+
}
|
71
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
72
|
+
b_vnary = nary_dup(b_vnary);
|
73
|
+
}
|
74
|
+
|
75
|
+
narray_t* a_nary = nullptr;
|
76
|
+
GetNArray(a_vnary, a_nary);
|
77
|
+
if (NA_NDIM(a_nary) != 2) {
|
78
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
79
|
+
return Qnil;
|
80
|
+
}
|
81
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
82
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
83
|
+
return Qnil;
|
84
|
+
}
|
85
|
+
narray_t* b_nary = nullptr;
|
86
|
+
GetNArray(a_vnary, b_nary);
|
87
|
+
if (NA_NDIM(b_nary) != 2) {
|
88
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
89
|
+
return Qnil;
|
90
|
+
}
|
91
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
92
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
93
|
+
return Qnil;
|
94
|
+
}
|
95
|
+
|
96
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
97
|
+
size_t shape[1] = { n };
|
98
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
99
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
100
|
+
ndfunc_t ndf = { iter_sygv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
101
|
+
sygv_opt opt = { matrix_layout, itype, jobz, uplo };
|
102
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
103
|
+
VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
104
|
+
|
105
|
+
RB_GC_GUARD(a_vnary);
|
106
|
+
RB_GC_GUARD(b_vnary);
|
107
|
+
|
108
|
+
return ret;
|
109
|
+
}
|
110
|
+
|
111
|
+
static lapack_int get_itype(VALUE val) {
|
112
|
+
const lapack_int itype = NUM2INT(val);
|
113
|
+
|
114
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
115
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
116
|
+
}
|
117
|
+
|
118
|
+
return itype;
|
119
|
+
}
|
120
|
+
|
121
|
+
static char get_jobz(VALUE val) {
|
122
|
+
const char jobz = NUM2CHR(val);
|
123
|
+
|
124
|
+
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
125
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
126
|
+
}
|
127
|
+
|
128
|
+
return jobz;
|
129
|
+
}
|
130
|
+
|
131
|
+
static char get_uplo(VALUE val) {
|
132
|
+
const char uplo = NUM2CHR(val);
|
133
|
+
|
134
|
+
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
135
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
136
|
+
}
|
137
|
+
|
138
|
+
return uplo;
|
139
|
+
}
|
140
|
+
|
141
|
+
static int get_matrix_layout(VALUE val) {
|
142
|
+
const char option = NUM2CHR(val);
|
143
|
+
|
144
|
+
switch (option) {
|
145
|
+
case 'r':
|
146
|
+
case 'R':
|
147
|
+
break;
|
148
|
+
case 'c':
|
149
|
+
case 'C':
|
150
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
|
151
|
+
break;
|
152
|
+
}
|
153
|
+
|
154
|
+
return LAPACK_ROW_MAJOR;
|
155
|
+
}
|
156
|
+
};
|
157
|
+
|
158
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,158 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyGvd {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
5
|
+
char uplo, lapack_int n, double* a, lapack_int lda,
|
6
|
+
double* b, lapack_int ldb, double* w) {
|
7
|
+
return LAPACKE_dsygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
8
|
+
}
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SSyGvd {
|
12
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz,
|
13
|
+
char uplo, lapack_int n, float* a, lapack_int lda,
|
14
|
+
float* b, lapack_int ldb, float* w) {
|
15
|
+
return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
|
16
|
+
}
|
17
|
+
};
|
18
|
+
|
19
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
20
|
+
class SyGvd {
|
21
|
+
public:
|
22
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
23
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvd), -1);
|
24
|
+
}
|
25
|
+
|
26
|
+
private:
|
27
|
+
struct sygvd_opt {
|
28
|
+
int matrix_layout;
|
29
|
+
lapack_int itype;
|
30
|
+
char jobz;
|
31
|
+
char uplo;
|
32
|
+
};
|
33
|
+
|
34
|
+
static void iter_sygvd(na_loop_t* const lp) {
|
35
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
36
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
37
|
+
DType* w = (DType*)NDL_PTR(lp, 2);
|
38
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
39
|
+
sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
|
40
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
41
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
42
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
43
|
+
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
44
|
+
*info = static_cast<int>(i);
|
45
|
+
}
|
46
|
+
|
47
|
+
static VALUE tiny_linalg_sygvd(int argc, VALUE* argv, VALUE self) {
|
48
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
49
|
+
|
50
|
+
VALUE a_vnary = Qnil;
|
51
|
+
VALUE b_vnary = Qnil;
|
52
|
+
VALUE kw_args = Qnil;
|
53
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
54
|
+
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
55
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
56
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
57
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
58
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
59
|
+
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
60
|
+
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
69
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
70
|
+
}
|
71
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
72
|
+
b_vnary = nary_dup(b_vnary);
|
73
|
+
}
|
74
|
+
|
75
|
+
narray_t* a_nary = nullptr;
|
76
|
+
GetNArray(a_vnary, a_nary);
|
77
|
+
if (NA_NDIM(a_nary) != 2) {
|
78
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
79
|
+
return Qnil;
|
80
|
+
}
|
81
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
82
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
83
|
+
return Qnil;
|
84
|
+
}
|
85
|
+
narray_t* b_nary = nullptr;
|
86
|
+
GetNArray(a_vnary, b_nary);
|
87
|
+
if (NA_NDIM(b_nary) != 2) {
|
88
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
89
|
+
return Qnil;
|
90
|
+
}
|
91
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
92
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
93
|
+
return Qnil;
|
94
|
+
}
|
95
|
+
|
96
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
97
|
+
size_t shape[1] = { n };
|
98
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
99
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
100
|
+
ndfunc_t ndf = { iter_sygvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
101
|
+
sygvd_opt opt = { matrix_layout, itype, jobz, uplo };
|
102
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
103
|
+
VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
104
|
+
|
105
|
+
RB_GC_GUARD(a_vnary);
|
106
|
+
RB_GC_GUARD(b_vnary);
|
107
|
+
|
108
|
+
return ret;
|
109
|
+
}
|
110
|
+
|
111
|
+
static lapack_int get_itype(VALUE val) {
|
112
|
+
const lapack_int itype = NUM2INT(val);
|
113
|
+
|
114
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
115
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
116
|
+
}
|
117
|
+
|
118
|
+
return itype;
|
119
|
+
}
|
120
|
+
|
121
|
+
static char get_jobz(VALUE val) {
|
122
|
+
const char jobz = NUM2CHR(val);
|
123
|
+
|
124
|
+
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
125
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
126
|
+
}
|
127
|
+
|
128
|
+
return jobz;
|
129
|
+
}
|
130
|
+
|
131
|
+
static char get_uplo(VALUE val) {
|
132
|
+
const char uplo = NUM2CHR(val);
|
133
|
+
|
134
|
+
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
135
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
136
|
+
}
|
137
|
+
|
138
|
+
return uplo;
|
139
|
+
}
|
140
|
+
|
141
|
+
static int get_matrix_layout(VALUE val) {
|
142
|
+
const char option = NUM2CHR(val);
|
143
|
+
|
144
|
+
switch (option) {
|
145
|
+
case 'r':
|
146
|
+
case 'R':
|
147
|
+
break;
|
148
|
+
case 'c':
|
149
|
+
case 'C':
|
150
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
|
151
|
+
break;
|
152
|
+
}
|
153
|
+
|
154
|
+
return LAPACK_ROW_MAJOR;
|
155
|
+
}
|
156
|
+
};
|
157
|
+
|
158
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,192 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyGvx {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, double* a, lapack_int lda, double* b, lapack_int ldb,
|
6
|
+
double vl, double vu, lapack_int il, lapack_int iu,
|
7
|
+
double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* ifail) {
|
8
|
+
return LAPACKE_dsygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct SSyGvx {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
14
|
+
lapack_int n, float* a, lapack_int lda, float* b, lapack_int ldb,
|
15
|
+
float vl, float vu, lapack_int il, lapack_int iu,
|
16
|
+
float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* ifail) {
|
17
|
+
return LAPACKE_ssygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
22
|
+
class SyGvx {
|
23
|
+
public:
|
24
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
25
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvx), -1);
|
26
|
+
}
|
27
|
+
|
28
|
+
private:
|
29
|
+
struct sygvx_opt {
|
30
|
+
int matrix_layout;
|
31
|
+
lapack_int itype;
|
32
|
+
char jobz;
|
33
|
+
char range;
|
34
|
+
char uplo;
|
35
|
+
DType vl;
|
36
|
+
DType vu;
|
37
|
+
lapack_int il;
|
38
|
+
lapack_int iu;
|
39
|
+
};
|
40
|
+
|
41
|
+
static void iter_sygvx(na_loop_t* const lp) {
|
42
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
43
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
44
|
+
int* m = (int*)NDL_PTR(lp, 2);
|
45
|
+
DType* w = (DType*)NDL_PTR(lp, 3);
|
46
|
+
DType* z = (DType*)NDL_PTR(lp, 4);
|
47
|
+
int* ifail = (int*)NDL_PTR(lp, 5);
|
48
|
+
int* info = (int*)NDL_PTR(lp, 6);
|
49
|
+
sygvx_opt* opt = (sygvx_opt*)(lp->opt_ptr);
|
50
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
51
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
52
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
53
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
54
|
+
const DType abstol = 0.0;
|
55
|
+
const lapack_int i = FncType().call(
|
56
|
+
opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
|
57
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
|
58
|
+
*info = static_cast<int>(i);
|
59
|
+
}
|
60
|
+
|
61
|
+
static VALUE tiny_linalg_sygvx(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
63
|
+
|
64
|
+
VALUE a_vnary = Qnil;
|
65
|
+
VALUE b_vnary = Qnil;
|
66
|
+
VALUE kw_args = Qnil;
|
67
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
68
|
+
ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
69
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
70
|
+
VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
71
|
+
rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
|
72
|
+
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
73
|
+
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
74
|
+
const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
|
75
|
+
const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
|
76
|
+
const DType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
77
|
+
const DType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
|
78
|
+
const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
79
|
+
const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
|
80
|
+
const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
|
81
|
+
|
82
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
83
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
84
|
+
}
|
85
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
86
|
+
a_vnary = nary_dup(a_vnary);
|
87
|
+
}
|
88
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
89
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
90
|
+
}
|
91
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
92
|
+
b_vnary = nary_dup(b_vnary);
|
93
|
+
}
|
94
|
+
|
95
|
+
narray_t* a_nary = nullptr;
|
96
|
+
GetNArray(a_vnary, a_nary);
|
97
|
+
if (NA_NDIM(a_nary) != 2) {
|
98
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
99
|
+
return Qnil;
|
100
|
+
}
|
101
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
102
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
103
|
+
return Qnil;
|
104
|
+
}
|
105
|
+
narray_t* b_nary = nullptr;
|
106
|
+
GetNArray(a_vnary, b_nary);
|
107
|
+
if (NA_NDIM(b_nary) != 2) {
|
108
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
109
|
+
return Qnil;
|
110
|
+
}
|
111
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
112
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
113
|
+
return Qnil;
|
114
|
+
}
|
115
|
+
|
116
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
117
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
118
|
+
size_t w_shape[1] = { m };
|
119
|
+
size_t z_shape[2] = { n, m };
|
120
|
+
size_t ifail_shape[1] = { n };
|
121
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
122
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
|
123
|
+
ndfunc_t ndf = { iter_sygvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
|
124
|
+
sygvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
|
125
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
126
|
+
VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
127
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
128
|
+
|
129
|
+
RB_GC_GUARD(a_vnary);
|
130
|
+
RB_GC_GUARD(b_vnary);
|
131
|
+
|
132
|
+
return ret;
|
133
|
+
}
|
134
|
+
|
135
|
+
static lapack_int get_itype(VALUE val) {
|
136
|
+
const lapack_int itype = NUM2INT(val);
|
137
|
+
|
138
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
139
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
140
|
+
}
|
141
|
+
|
142
|
+
return itype;
|
143
|
+
}
|
144
|
+
|
145
|
+
static char get_jobz(VALUE val) {
|
146
|
+
const char jobz = NUM2CHR(val);
|
147
|
+
|
148
|
+
if (jobz != 'N' && jobz != 'V') {
|
149
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
150
|
+
}
|
151
|
+
|
152
|
+
return jobz;
|
153
|
+
}
|
154
|
+
|
155
|
+
static char get_range(VALUE val) {
|
156
|
+
const char range = NUM2CHR(val);
|
157
|
+
|
158
|
+
if (range != 'A' && range != 'V' && range != 'I') {
|
159
|
+
rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
|
160
|
+
}
|
161
|
+
|
162
|
+
return range;
|
163
|
+
}
|
164
|
+
|
165
|
+
static char get_uplo(VALUE val) {
|
166
|
+
const char uplo = NUM2CHR(val);
|
167
|
+
|
168
|
+
if (uplo != 'U' && uplo != 'L') {
|
169
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
170
|
+
}
|
171
|
+
|
172
|
+
return uplo;
|
173
|
+
}
|
174
|
+
|
175
|
+
static int get_matrix_layout(VALUE val) {
|
176
|
+
const char option = NUM2CHR(val);
|
177
|
+
|
178
|
+
switch (option) {
|
179
|
+
case 'r':
|
180
|
+
case 'R':
|
181
|
+
break;
|
182
|
+
case 'c':
|
183
|
+
case 'C':
|
184
|
+
rb_warn("Numo::TinyLinalg::Lapack.sygvx does not support column major.");
|
185
|
+
break;
|
186
|
+
}
|
187
|
+
|
188
|
+
return LAPACK_ROW_MAJOR;
|
189
|
+
}
|
190
|
+
};
|
191
|
+
|
192
|
+
} // namespace TinyLinalg
|
@@ -11,7 +11,13 @@
|
|
11
11
|
#include "lapack/gesvd.hpp"
|
12
12
|
#include "lapack/getrf.hpp"
|
13
13
|
#include "lapack/getri.hpp"
|
14
|
+
#include "lapack/hegv.hpp"
|
15
|
+
#include "lapack/hegvd.hpp"
|
16
|
+
#include "lapack/hegvx.hpp"
|
14
17
|
#include "lapack/orgqr.hpp"
|
18
|
+
#include "lapack/sygv.hpp"
|
19
|
+
#include "lapack/sygvd.hpp"
|
20
|
+
#include "lapack/sygvx.hpp"
|
15
21
|
#include "lapack/ungqr.hpp"
|
16
22
|
|
17
23
|
VALUE rb_mTinyLinalg;
|
@@ -271,6 +277,18 @@ extern "C" void Init_tiny_linalg(void) {
|
|
271
277
|
TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
|
272
278
|
TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
|
273
279
|
TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
|
280
|
+
TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
|
281
|
+
TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
|
282
|
+
TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
|
283
|
+
TinyLinalg::HeGv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGv>::define_module_function(rb_mTinyLinalgLapack, "chegv");
|
284
|
+
TinyLinalg::SyGvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvd>::define_module_function(rb_mTinyLinalgLapack, "dsygvd");
|
285
|
+
TinyLinalg::SyGvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvd>::define_module_function(rb_mTinyLinalgLapack, "ssygvd");
|
286
|
+
TinyLinalg::HeGvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvd>::define_module_function(rb_mTinyLinalgLapack, "zhegvd");
|
287
|
+
TinyLinalg::HeGvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvd>::define_module_function(rb_mTinyLinalgLapack, "chegvd");
|
288
|
+
TinyLinalg::SyGvx<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvx>::define_module_function(rb_mTinyLinalgLapack, "dsygvx");
|
289
|
+
TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
|
290
|
+
TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
|
291
|
+
TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
|
274
292
|
|
275
293
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
|
276
294
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -10,6 +10,50 @@ module Numo
|
|
10
10
|
module TinyLinalg # rubocop:disable Metrics/ModuleLength
|
11
11
|
module_function
|
12
12
|
|
13
|
+
# Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
|
14
|
+
# by solving an ordinary or generalized eigenvalue problem.
|
15
|
+
#
|
16
|
+
# @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
|
17
|
+
# @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
|
18
|
+
# @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
|
19
|
+
# @param vals_range [Range/Array]
|
20
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
21
|
+
# If nil, all eigenvalues and eigenvectors are computed.
|
22
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
23
|
+
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
24
|
+
# @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
|
25
|
+
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
|
26
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
27
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
28
|
+
|
29
|
+
bchr = blas_char(a)
|
30
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
31
|
+
|
32
|
+
unless b.nil?
|
33
|
+
raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
|
34
|
+
raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
|
35
|
+
raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
|
36
|
+
end
|
37
|
+
|
38
|
+
jobz = vals_only ? 'N' : 'V'
|
39
|
+
b = a.class.eye(a.shape[0]) if b.nil?
|
40
|
+
sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
41
|
+
|
42
|
+
if vals_range.nil?
|
43
|
+
sy_he_gv << 'd' if turbo
|
44
|
+
vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
|
45
|
+
else
|
46
|
+
sy_he_gv << 'x'
|
47
|
+
il = vals_range.first + 1
|
48
|
+
iu = vals_range.last + 1
|
49
|
+
_a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
|
50
|
+
sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
51
|
+
)
|
52
|
+
end
|
53
|
+
vecs = nil if vals_only
|
54
|
+
[vals, vecs]
|
55
|
+
end
|
56
|
+
|
13
57
|
# Computes the determinant of matrix.
|
14
58
|
#
|
15
59
|
# @param a [Numo::NArray] n-by-n square matrix.
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-08-
|
11
|
+
date: 2023-08-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -52,7 +52,13 @@ files:
|
|
52
52
|
- ext/numo/tiny_linalg/lapack/gesvd.hpp
|
53
53
|
- ext/numo/tiny_linalg/lapack/getrf.hpp
|
54
54
|
- ext/numo/tiny_linalg/lapack/getri.hpp
|
55
|
+
- ext/numo/tiny_linalg/lapack/hegv.hpp
|
56
|
+
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
57
|
+
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
55
58
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
59
|
+
- ext/numo/tiny_linalg/lapack/sygv.hpp
|
60
|
+
- ext/numo/tiny_linalg/lapack/sygvd.hpp
|
61
|
+
- ext/numo/tiny_linalg/lapack/sygvx.hpp
|
56
62
|
- ext/numo/tiny_linalg/lapack/ungqr.hpp
|
57
63
|
- ext/numo/tiny_linalg/tiny_linalg.cpp
|
58
64
|
- ext/numo/tiny_linalg/tiny_linalg.hpp
|