numo-tiny_linalg 0.0.2 → 0.0.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DOrgQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ double* a, lapack_int lda, const double* tau) {
6
+ return LAPACKE_dorgqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SOrgQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ float* a, lapack_int lda, const float* tau) {
13
+ return LAPACKE_sorgqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class OrgQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_orgqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct orgqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_orgqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_orgqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_orgqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ orgqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -0,0 +1,158 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGv {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename DType, typename FncType>
20
+ class SyGv {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygv), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygv_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygv(na_loop_t* const lp) {
35
+ DType* a = (DType*)NDL_PTR(lp, 0);
36
+ DType* b = (DType*)NDL_PTR(lp, 1);
37
+ DType* w = (DType*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygv(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygv_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+
111
+ static lapack_int get_itype(VALUE val) {
112
+ const lapack_int itype = NUM2INT(val);
113
+
114
+ if (itype != 1 && itype != 2 && itype != 3) {
115
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
+ }
117
+
118
+ return itype;
119
+ }
120
+
121
+ static char get_jobz(VALUE val) {
122
+ const char jobz = NUM2CHR(val);
123
+
124
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
+ }
127
+
128
+ return jobz;
129
+ }
130
+
131
+ static char get_uplo(VALUE val) {
132
+ const char uplo = NUM2CHR(val);
133
+
134
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
+ }
137
+
138
+ return uplo;
139
+ }
140
+
141
+ static int get_matrix_layout(VALUE val) {
142
+ const char option = NUM2CHR(val);
143
+
144
+ switch (option) {
145
+ case 'r':
146
+ case 'R':
147
+ break;
148
+ case 'c':
149
+ case 'C':
150
+ rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
151
+ break;
152
+ }
153
+
154
+ return LAPACK_ROW_MAJOR;
155
+ }
156
+ };
157
+
158
+ } // namespace TinyLinalg
@@ -0,0 +1,158 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGvd {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename DType, typename FncType>
20
+ class SyGvd {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvd), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygvd_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygvd(na_loop_t* const lp) {
35
+ DType* a = (DType*)NDL_PTR(lp, 0);
36
+ DType* b = (DType*)NDL_PTR(lp, 1);
37
+ DType* w = (DType*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygvd(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygvd_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+
111
+ static lapack_int get_itype(VALUE val) {
112
+ const lapack_int itype = NUM2INT(val);
113
+
114
+ if (itype != 1 && itype != 2 && itype != 3) {
115
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
+ }
117
+
118
+ return itype;
119
+ }
120
+
121
+ static char get_jobz(VALUE val) {
122
+ const char jobz = NUM2CHR(val);
123
+
124
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
+ }
127
+
128
+ return jobz;
129
+ }
130
+
131
+ static char get_uplo(VALUE val) {
132
+ const char uplo = NUM2CHR(val);
133
+
134
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
+ }
137
+
138
+ return uplo;
139
+ }
140
+
141
+ static int get_matrix_layout(VALUE val) {
142
+ const char option = NUM2CHR(val);
143
+
144
+ switch (option) {
145
+ case 'r':
146
+ case 'R':
147
+ break;
148
+ case 'c':
149
+ case 'C':
150
+ rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
151
+ break;
152
+ }
153
+
154
+ return LAPACK_ROW_MAJOR;
155
+ }
156
+ };
157
+
158
+ } // namespace TinyLinalg
@@ -0,0 +1,192 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, double* a, lapack_int lda, double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_dsygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct SSyGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, float* a, lapack_int lda, float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_ssygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, typename DType, typename FncType>
22
+ class SyGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct sygvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ DType vl;
36
+ DType vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_sygvx(na_loop_t* const lp) {
42
+ DType* a = (DType*)NDL_PTR(lp, 0);
43
+ DType* b = (DType*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ DType* w = (DType*)NDL_PTR(lp, 3);
46
+ DType* z = (DType*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ sygvx_opt* opt = (sygvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const DType abstol = 0.0;
55
+ const lapack_int i = FncType().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_sygvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+
64
+ VALUE a_vnary = Qnil;
65
+ VALUE b_vnary = Qnil;
66
+ VALUE kw_args = Qnil;
67
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
68
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
69
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
70
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
71
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
72
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
73
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
74
+ const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
75
+ const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
76
+ const DType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
77
+ const DType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
78
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
79
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
80
+ const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
81
+
82
+ if (CLASS_OF(a_vnary) != nary_dtype) {
83
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
84
+ }
85
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
86
+ a_vnary = nary_dup(a_vnary);
87
+ }
88
+ if (CLASS_OF(b_vnary) != nary_dtype) {
89
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
90
+ }
91
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
92
+ b_vnary = nary_dup(b_vnary);
93
+ }
94
+
95
+ narray_t* a_nary = nullptr;
96
+ GetNArray(a_vnary, a_nary);
97
+ if (NA_NDIM(a_nary) != 2) {
98
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
99
+ return Qnil;
100
+ }
101
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
102
+ rb_raise(rb_eArgError, "input array a must be square");
103
+ return Qnil;
104
+ }
105
+ narray_t* b_nary = nullptr;
106
+ GetNArray(a_vnary, b_nary);
107
+ if (NA_NDIM(b_nary) != 2) {
108
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
109
+ return Qnil;
110
+ }
111
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
112
+ rb_raise(rb_eArgError, "input array b must be square");
113
+ return Qnil;
114
+ }
115
+
116
+ const size_t n = NA_SHAPE(a_nary)[1];
117
+ size_t m = range != 'I' ? n : iu - il + 1;
118
+ size_t w_shape[1] = { m };
119
+ size_t z_shape[2] = { n, m };
120
+ size_t ifail_shape[1] = { n };
121
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
122
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
123
+ ndfunc_t ndf = { iter_sygvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
124
+ sygvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
125
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
126
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
127
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
128
+
129
+ RB_GC_GUARD(a_vnary);
130
+ RB_GC_GUARD(b_vnary);
131
+
132
+ return ret;
133
+ }
134
+
135
+ static lapack_int get_itype(VALUE val) {
136
+ const lapack_int itype = NUM2INT(val);
137
+
138
+ if (itype != 1 && itype != 2 && itype != 3) {
139
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
140
+ }
141
+
142
+ return itype;
143
+ }
144
+
145
+ static char get_jobz(VALUE val) {
146
+ const char jobz = NUM2CHR(val);
147
+
148
+ if (jobz != 'N' && jobz != 'V') {
149
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
150
+ }
151
+
152
+ return jobz;
153
+ }
154
+
155
+ static char get_range(VALUE val) {
156
+ const char range = NUM2CHR(val);
157
+
158
+ if (range != 'A' && range != 'V' && range != 'I') {
159
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
160
+ }
161
+
162
+ return range;
163
+ }
164
+
165
+ static char get_uplo(VALUE val) {
166
+ const char uplo = NUM2CHR(val);
167
+
168
+ if (uplo != 'U' && uplo != 'L') {
169
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
170
+ }
171
+
172
+ return uplo;
173
+ }
174
+
175
+ static int get_matrix_layout(VALUE val) {
176
+ const char option = NUM2CHR(val);
177
+
178
+ switch (option) {
179
+ case 'r':
180
+ case 'R':
181
+ break;
182
+ case 'c':
183
+ case 'C':
184
+ rb_warn("Numo::TinyLinalg::Lapack.sygvx does not support column major.");
185
+ break;
186
+ }
187
+
188
+ return LAPACK_ROW_MAJOR;
189
+ }
190
+ };
191
+
192
+ } // namespace TinyLinalg