numo-tiny_linalg 0.0.4 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -1
- data/README.md +3 -3
- data/ext/numo/tiny_linalg/blas/gemm.hpp +3 -49
- data/ext/numo/tiny_linalg/blas/gemv.hpp +2 -48
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/gesdd.hpp +11 -11
- data/ext/numo/tiny_linalg/lapack/gesv.hpp +10 -30
- data/ext/numo/tiny_linalg/lapack/gesvd.hpp +12 -12
- data/ext/numo/tiny_linalg/lapack/getrf.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/getri.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +16 -72
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +16 -72
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +5 -25
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +56 -21
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +30 -6
- data/ext/numo/tiny_linalg/util.hpp +100 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +175 -52
- metadata +3 -2
@@ -24,7 +24,7 @@ struct CHeGv {
|
|
24
24
|
}
|
25
25
|
};
|
26
26
|
|
27
|
-
template <int nary_dtype_id, int nary_rtype_id, typename
|
27
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
28
28
|
class HeGv {
|
29
29
|
public:
|
30
30
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -40,15 +40,15 @@ private:
|
|
40
40
|
};
|
41
41
|
|
42
42
|
static void iter_hegv(na_loop_t* const lp) {
|
43
|
-
|
44
|
-
|
45
|
-
|
43
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
44
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
45
|
+
rtype* w = (rtype*)NDL_PTR(lp, 2);
|
46
46
|
int* info = (int*)NDL_PTR(lp, 3);
|
47
47
|
hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
|
48
48
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
49
49
|
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
50
50
|
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
51
|
-
const lapack_int i =
|
51
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
52
52
|
*info = static_cast<int>(i);
|
53
53
|
}
|
54
54
|
|
@@ -63,10 +63,10 @@ private:
|
|
63
63
|
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
64
64
|
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
65
65
|
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
66
|
-
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
67
|
-
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
68
|
-
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
69
|
-
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
66
|
+
const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
|
67
|
+
const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
|
68
|
+
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
|
69
|
+
const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
70
70
|
|
71
71
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
72
72
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -116,52 +116,6 @@ private:
|
|
116
116
|
|
117
117
|
return ret;
|
118
118
|
}
|
119
|
-
|
120
|
-
static lapack_int get_itype(VALUE val) {
|
121
|
-
const lapack_int itype = NUM2INT(val);
|
122
|
-
|
123
|
-
if (itype != 1 && itype != 2 && itype != 3) {
|
124
|
-
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
125
|
-
}
|
126
|
-
|
127
|
-
return itype;
|
128
|
-
}
|
129
|
-
|
130
|
-
static char get_jobz(VALUE val) {
|
131
|
-
const char jobz = NUM2CHR(val);
|
132
|
-
|
133
|
-
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
134
|
-
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
135
|
-
}
|
136
|
-
|
137
|
-
return jobz;
|
138
|
-
}
|
139
|
-
|
140
|
-
static char get_uplo(VALUE val) {
|
141
|
-
const char uplo = NUM2CHR(val);
|
142
|
-
|
143
|
-
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
144
|
-
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
145
|
-
}
|
146
|
-
|
147
|
-
return uplo;
|
148
|
-
}
|
149
|
-
|
150
|
-
static int get_matrix_layout(VALUE val) {
|
151
|
-
const char option = NUM2CHR(val);
|
152
|
-
|
153
|
-
switch (option) {
|
154
|
-
case 'r':
|
155
|
-
case 'R':
|
156
|
-
break;
|
157
|
-
case 'c':
|
158
|
-
case 'C':
|
159
|
-
rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
|
160
|
-
break;
|
161
|
-
}
|
162
|
-
|
163
|
-
return LAPACK_ROW_MAJOR;
|
164
|
-
}
|
165
119
|
};
|
166
120
|
|
167
121
|
} // namespace TinyLinalg
|
@@ -24,7 +24,7 @@ struct CHeGvd {
|
|
24
24
|
}
|
25
25
|
};
|
26
26
|
|
27
|
-
template <int nary_dtype_id, int nary_rtype_id, typename
|
27
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
28
28
|
class HeGvd {
|
29
29
|
public:
|
30
30
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -40,15 +40,15 @@ private:
|
|
40
40
|
};
|
41
41
|
|
42
42
|
static void iter_hegvd(na_loop_t* const lp) {
|
43
|
-
|
44
|
-
|
45
|
-
|
43
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
44
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
45
|
+
rtype* w = (rtype*)NDL_PTR(lp, 2);
|
46
46
|
int* info = (int*)NDL_PTR(lp, 3);
|
47
47
|
hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
|
48
48
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
49
49
|
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
50
50
|
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
51
|
-
const lapack_int i =
|
51
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
52
52
|
*info = static_cast<int>(i);
|
53
53
|
}
|
54
54
|
|
@@ -63,10 +63,10 @@ private:
|
|
63
63
|
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
64
64
|
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
65
65
|
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
66
|
-
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
67
|
-
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
68
|
-
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
69
|
-
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
66
|
+
const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
|
67
|
+
const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
|
68
|
+
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
|
69
|
+
const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
70
70
|
|
71
71
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
72
72
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -116,52 +116,6 @@ private:
|
|
116
116
|
|
117
117
|
return ret;
|
118
118
|
}
|
119
|
-
|
120
|
-
static lapack_int get_itype(VALUE val) {
|
121
|
-
const lapack_int itype = NUM2INT(val);
|
122
|
-
|
123
|
-
if (itype != 1 && itype != 2 && itype != 3) {
|
124
|
-
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
125
|
-
}
|
126
|
-
|
127
|
-
return itype;
|
128
|
-
}
|
129
|
-
|
130
|
-
static char get_jobz(VALUE val) {
|
131
|
-
const char jobz = NUM2CHR(val);
|
132
|
-
|
133
|
-
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
134
|
-
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
135
|
-
}
|
136
|
-
|
137
|
-
return jobz;
|
138
|
-
}
|
139
|
-
|
140
|
-
static char get_uplo(VALUE val) {
|
141
|
-
const char uplo = NUM2CHR(val);
|
142
|
-
|
143
|
-
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
144
|
-
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
145
|
-
}
|
146
|
-
|
147
|
-
return uplo;
|
148
|
-
}
|
149
|
-
|
150
|
-
static int get_matrix_layout(VALUE val) {
|
151
|
-
const char option = NUM2CHR(val);
|
152
|
-
|
153
|
-
switch (option) {
|
154
|
-
case 'r':
|
155
|
-
case 'R':
|
156
|
-
break;
|
157
|
-
case 'c':
|
158
|
-
case 'C':
|
159
|
-
rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
|
160
|
-
break;
|
161
|
-
}
|
162
|
-
|
163
|
-
return LAPACK_ROW_MAJOR;
|
164
|
-
}
|
165
119
|
};
|
166
120
|
|
167
121
|
} // namespace TinyLinalg
|
@@ -18,7 +18,7 @@ struct CHeGvx {
|
|
18
18
|
}
|
19
19
|
};
|
20
20
|
|
21
|
-
template <int nary_dtype_id, int nary_rtype_id, typename
|
21
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
22
22
|
class HeGvx {
|
23
23
|
public:
|
24
24
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -32,18 +32,18 @@ private:
|
|
32
32
|
char jobz;
|
33
33
|
char range;
|
34
34
|
char uplo;
|
35
|
-
|
36
|
-
|
35
|
+
rtype vl;
|
36
|
+
rtype vu;
|
37
37
|
lapack_int il;
|
38
38
|
lapack_int iu;
|
39
39
|
};
|
40
40
|
|
41
41
|
static void iter_hegvx(na_loop_t* const lp) {
|
42
|
-
|
43
|
-
|
42
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
43
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
44
44
|
int* m = (int*)NDL_PTR(lp, 2);
|
45
|
-
|
46
|
-
|
45
|
+
rtype* w = (rtype*)NDL_PTR(lp, 3);
|
46
|
+
dtype* z = (dtype*)NDL_PTR(lp, 4);
|
47
47
|
int* ifail = (int*)NDL_PTR(lp, 5);
|
48
48
|
int* info = (int*)NDL_PTR(lp, 6);
|
49
49
|
hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
|
@@ -51,8 +51,8 @@ private:
|
|
51
51
|
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
52
52
|
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
53
53
|
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
54
|
-
const
|
55
|
-
const lapack_int i =
|
54
|
+
const rtype abstol = 0.0;
|
55
|
+
const lapack_int i = LapackFn().call(
|
56
56
|
opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
|
57
57
|
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
|
58
58
|
*info = static_cast<int>(i);
|
@@ -70,15 +70,15 @@ private:
|
|
70
70
|
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
71
71
|
VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
72
72
|
rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
|
73
|
-
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
74
|
-
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
75
|
-
const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
|
76
|
-
const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
|
77
|
-
const
|
78
|
-
const
|
73
|
+
const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
|
74
|
+
const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
|
75
|
+
const char range = kw_values[2] != Qundef ? Util().get_range(kw_values[2]) : 'A';
|
76
|
+
const char uplo = kw_values[3] != Qundef ? Util().get_uplo(kw_values[3]) : 'U';
|
77
|
+
const rtype vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
78
|
+
const rtype vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
|
79
79
|
const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
80
80
|
const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
|
81
|
-
const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
|
81
|
+
const int matrix_layout = kw_values[8] != Qundef ? Util().get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
|
82
82
|
|
83
83
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
84
84
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -132,62 +132,6 @@ private:
|
|
132
132
|
|
133
133
|
return ret;
|
134
134
|
}
|
135
|
-
|
136
|
-
static lapack_int get_itype(VALUE val) {
|
137
|
-
const lapack_int itype = NUM2INT(val);
|
138
|
-
|
139
|
-
if (itype != 1 && itype != 2 && itype != 3) {
|
140
|
-
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
141
|
-
}
|
142
|
-
|
143
|
-
return itype;
|
144
|
-
}
|
145
|
-
|
146
|
-
static char get_jobz(VALUE val) {
|
147
|
-
const char jobz = NUM2CHR(val);
|
148
|
-
|
149
|
-
if (jobz != 'N' && jobz != 'V') {
|
150
|
-
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
151
|
-
}
|
152
|
-
|
153
|
-
return jobz;
|
154
|
-
}
|
155
|
-
|
156
|
-
static char get_range(VALUE val) {
|
157
|
-
const char range = NUM2CHR(val);
|
158
|
-
|
159
|
-
if (range != 'A' && range != 'V' && range != 'I') {
|
160
|
-
rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
|
161
|
-
}
|
162
|
-
|
163
|
-
return range;
|
164
|
-
}
|
165
|
-
|
166
|
-
static char get_uplo(VALUE val) {
|
167
|
-
const char uplo = NUM2CHR(val);
|
168
|
-
|
169
|
-
if (uplo != 'U' && uplo != 'L') {
|
170
|
-
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
171
|
-
}
|
172
|
-
|
173
|
-
return uplo;
|
174
|
-
}
|
175
|
-
|
176
|
-
static int get_matrix_layout(VALUE val) {
|
177
|
-
const char option = NUM2CHR(val);
|
178
|
-
|
179
|
-
switch (option) {
|
180
|
-
case 'r':
|
181
|
-
case 'R':
|
182
|
-
break;
|
183
|
-
case 'c':
|
184
|
-
case 'C':
|
185
|
-
rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
|
186
|
-
break;
|
187
|
-
}
|
188
|
-
|
189
|
-
return LAPACK_ROW_MAJOR;
|
190
|
-
}
|
191
135
|
};
|
192
136
|
|
193
137
|
} // namespace TinyLinalg
|
@@ -14,7 +14,7 @@ struct SOrgQr {
|
|
14
14
|
}
|
15
15
|
};
|
16
16
|
|
17
|
-
template <int nary_dtype_id, typename
|
17
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
18
18
|
class OrgQr {
|
19
19
|
public:
|
20
20
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -27,15 +27,15 @@ private:
|
|
27
27
|
};
|
28
28
|
|
29
29
|
static void iter_orgqr(na_loop_t* const lp) {
|
30
|
-
|
31
|
-
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
dtype* tau = (dtype*)NDL_PTR(lp, 1);
|
32
32
|
int* info = (int*)NDL_PTR(lp, 2);
|
33
33
|
orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
|
34
34
|
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
35
35
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
36
36
|
const lapack_int k = NDL_SHAPE(lp, 1)[0];
|
37
37
|
const lapack_int lda = n;
|
38
|
-
const lapack_int i =
|
38
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, k, a, lda, tau);
|
39
39
|
*info = static_cast<int>(i);
|
40
40
|
}
|
41
41
|
|
@@ -49,7 +49,7 @@ private:
|
|
49
49
|
ID kw_table[1] = { rb_intern("order") };
|
50
50
|
VALUE kw_values[1] = { Qundef };
|
51
51
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
52
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
52
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
53
53
|
|
54
54
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
55
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -90,26 +90,6 @@ private:
|
|
90
90
|
|
91
91
|
return ret;
|
92
92
|
}
|
93
|
-
|
94
|
-
static int get_matrix_layout(VALUE val) {
|
95
|
-
const char* option_str = StringValueCStr(val);
|
96
|
-
|
97
|
-
if (std::strlen(option_str) > 0) {
|
98
|
-
switch (option_str[0]) {
|
99
|
-
case 'r':
|
100
|
-
case 'R':
|
101
|
-
break;
|
102
|
-
case 'c':
|
103
|
-
case 'C':
|
104
|
-
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
105
|
-
break;
|
106
|
-
}
|
107
|
-
}
|
108
|
-
|
109
|
-
RB_GC_GUARD(val);
|
110
|
-
|
111
|
-
return LAPACK_ROW_MAJOR;
|
112
|
-
}
|
113
93
|
};
|
114
94
|
|
115
95
|
} // namespace TinyLinalg
|
@@ -16,7 +16,7 @@ struct SSyGv {
|
|
16
16
|
}
|
17
17
|
};
|
18
18
|
|
19
|
-
template <int nary_dtype_id, typename
|
19
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
20
20
|
class SyGv {
|
21
21
|
public:
|
22
22
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -32,15 +32,15 @@ private:
|
|
32
32
|
};
|
33
33
|
|
34
34
|
static void iter_sygv(na_loop_t* const lp) {
|
35
|
-
|
36
|
-
|
37
|
-
|
35
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
36
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
37
|
+
dtype* w = (dtype*)NDL_PTR(lp, 2);
|
38
38
|
int* info = (int*)NDL_PTR(lp, 3);
|
39
39
|
sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
|
40
40
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
41
41
|
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
42
42
|
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
43
|
-
const lapack_int i =
|
43
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
44
44
|
*info = static_cast<int>(i);
|
45
45
|
}
|
46
46
|
|
@@ -54,10 +54,10 @@ private:
|
|
54
54
|
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
55
55
|
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
56
56
|
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
57
|
-
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
58
|
-
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
59
|
-
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
60
|
-
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
57
|
+
const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
|
58
|
+
const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
|
59
|
+
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
|
60
|
+
const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
61
61
|
|
62
62
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
63
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -107,52 +107,6 @@ private:
|
|
107
107
|
|
108
108
|
return ret;
|
109
109
|
}
|
110
|
-
|
111
|
-
static lapack_int get_itype(VALUE val) {
|
112
|
-
const lapack_int itype = NUM2INT(val);
|
113
|
-
|
114
|
-
if (itype != 1 && itype != 2 && itype != 3) {
|
115
|
-
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
116
|
-
}
|
117
|
-
|
118
|
-
return itype;
|
119
|
-
}
|
120
|
-
|
121
|
-
static char get_jobz(VALUE val) {
|
122
|
-
const char jobz = NUM2CHR(val);
|
123
|
-
|
124
|
-
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
125
|
-
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
126
|
-
}
|
127
|
-
|
128
|
-
return jobz;
|
129
|
-
}
|
130
|
-
|
131
|
-
static char get_uplo(VALUE val) {
|
132
|
-
const char uplo = NUM2CHR(val);
|
133
|
-
|
134
|
-
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
135
|
-
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
136
|
-
}
|
137
|
-
|
138
|
-
return uplo;
|
139
|
-
}
|
140
|
-
|
141
|
-
static int get_matrix_layout(VALUE val) {
|
142
|
-
const char option = NUM2CHR(val);
|
143
|
-
|
144
|
-
switch (option) {
|
145
|
-
case 'r':
|
146
|
-
case 'R':
|
147
|
-
break;
|
148
|
-
case 'c':
|
149
|
-
case 'C':
|
150
|
-
rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
|
151
|
-
break;
|
152
|
-
}
|
153
|
-
|
154
|
-
return LAPACK_ROW_MAJOR;
|
155
|
-
}
|
156
110
|
};
|
157
111
|
|
158
112
|
} // namespace TinyLinalg
|
@@ -16,7 +16,7 @@ struct SSyGvd {
|
|
16
16
|
}
|
17
17
|
};
|
18
18
|
|
19
|
-
template <int nary_dtype_id, typename
|
19
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
20
20
|
class SyGvd {
|
21
21
|
public:
|
22
22
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -32,15 +32,15 @@ private:
|
|
32
32
|
};
|
33
33
|
|
34
34
|
static void iter_sygvd(na_loop_t* const lp) {
|
35
|
-
|
36
|
-
|
37
|
-
|
35
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
36
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
37
|
+
dtype* w = (dtype*)NDL_PTR(lp, 2);
|
38
38
|
int* info = (int*)NDL_PTR(lp, 3);
|
39
39
|
sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
|
40
40
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
41
41
|
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
42
42
|
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
43
|
-
const lapack_int i =
|
43
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
|
44
44
|
*info = static_cast<int>(i);
|
45
45
|
}
|
46
46
|
|
@@ -54,10 +54,10 @@ private:
|
|
54
54
|
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
55
55
|
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
56
56
|
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
57
|
-
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
|
58
|
-
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
|
59
|
-
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
|
60
|
-
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
57
|
+
const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
|
58
|
+
const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
|
59
|
+
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
|
60
|
+
const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
|
61
61
|
|
62
62
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
63
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -107,52 +107,6 @@ private:
|
|
107
107
|
|
108
108
|
return ret;
|
109
109
|
}
|
110
|
-
|
111
|
-
static lapack_int get_itype(VALUE val) {
|
112
|
-
const lapack_int itype = NUM2INT(val);
|
113
|
-
|
114
|
-
if (itype != 1 && itype != 2 && itype != 3) {
|
115
|
-
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
116
|
-
}
|
117
|
-
|
118
|
-
return itype;
|
119
|
-
}
|
120
|
-
|
121
|
-
static char get_jobz(VALUE val) {
|
122
|
-
const char jobz = NUM2CHR(val);
|
123
|
-
|
124
|
-
if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
|
125
|
-
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
126
|
-
}
|
127
|
-
|
128
|
-
return jobz;
|
129
|
-
}
|
130
|
-
|
131
|
-
static char get_uplo(VALUE val) {
|
132
|
-
const char uplo = NUM2CHR(val);
|
133
|
-
|
134
|
-
if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
|
135
|
-
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
136
|
-
}
|
137
|
-
|
138
|
-
return uplo;
|
139
|
-
}
|
140
|
-
|
141
|
-
static int get_matrix_layout(VALUE val) {
|
142
|
-
const char option = NUM2CHR(val);
|
143
|
-
|
144
|
-
switch (option) {
|
145
|
-
case 'r':
|
146
|
-
case 'R':
|
147
|
-
break;
|
148
|
-
case 'c':
|
149
|
-
case 'C':
|
150
|
-
rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
|
151
|
-
break;
|
152
|
-
}
|
153
|
-
|
154
|
-
return LAPACK_ROW_MAJOR;
|
155
|
-
}
|
156
110
|
};
|
157
111
|
|
158
112
|
} // namespace TinyLinalg
|