numo-tiny_linalg 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DOrgQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ double* a, lapack_int lda, const double* tau) {
6
+ return LAPACKE_dorgqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SOrgQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ float* a, lapack_int lda, const float* tau) {
13
+ return LAPACKE_sorgqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class OrgQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_orgqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct orgqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_orgqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_orgqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_orgqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ orgqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -0,0 +1,158 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGv {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename DType, typename FncType>
20
+ class SyGv {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygv), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygv_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygv(na_loop_t* const lp) {
35
+ DType* a = (DType*)NDL_PTR(lp, 0);
36
+ DType* b = (DType*)NDL_PTR(lp, 1);
37
+ DType* w = (DType*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygv(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygv_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+
111
+ static lapack_int get_itype(VALUE val) {
112
+ const lapack_int itype = NUM2INT(val);
113
+
114
+ if (itype != 1 && itype != 2 && itype != 3) {
115
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
+ }
117
+
118
+ return itype;
119
+ }
120
+
121
+ static char get_jobz(VALUE val) {
122
+ const char jobz = NUM2CHR(val);
123
+
124
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
+ }
127
+
128
+ return jobz;
129
+ }
130
+
131
+ static char get_uplo(VALUE val) {
132
+ const char uplo = NUM2CHR(val);
133
+
134
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
+ }
137
+
138
+ return uplo;
139
+ }
140
+
141
+ static int get_matrix_layout(VALUE val) {
142
+ const char option = NUM2CHR(val);
143
+
144
+ switch (option) {
145
+ case 'r':
146
+ case 'R':
147
+ break;
148
+ case 'c':
149
+ case 'C':
150
+ rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
151
+ break;
152
+ }
153
+
154
+ return LAPACK_ROW_MAJOR;
155
+ }
156
+ };
157
+
158
+ } // namespace TinyLinalg
@@ -0,0 +1,158 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGvd {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename DType, typename FncType>
20
+ class SyGvd {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvd), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygvd_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygvd(na_loop_t* const lp) {
35
+ DType* a = (DType*)NDL_PTR(lp, 0);
36
+ DType* b = (DType*)NDL_PTR(lp, 1);
37
+ DType* w = (DType*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygvd(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygvd_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+
111
+ static lapack_int get_itype(VALUE val) {
112
+ const lapack_int itype = NUM2INT(val);
113
+
114
+ if (itype != 1 && itype != 2 && itype != 3) {
115
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
+ }
117
+
118
+ return itype;
119
+ }
120
+
121
+ static char get_jobz(VALUE val) {
122
+ const char jobz = NUM2CHR(val);
123
+
124
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
+ }
127
+
128
+ return jobz;
129
+ }
130
+
131
+ static char get_uplo(VALUE val) {
132
+ const char uplo = NUM2CHR(val);
133
+
134
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
+ }
137
+
138
+ return uplo;
139
+ }
140
+
141
+ static int get_matrix_layout(VALUE val) {
142
+ const char option = NUM2CHR(val);
143
+
144
+ switch (option) {
145
+ case 'r':
146
+ case 'R':
147
+ break;
148
+ case 'c':
149
+ case 'C':
150
+ rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
151
+ break;
152
+ }
153
+
154
+ return LAPACK_ROW_MAJOR;
155
+ }
156
+ };
157
+
158
+ } // namespace TinyLinalg
@@ -0,0 +1,192 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, double* a, lapack_int lda, double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_dsygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct SSyGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, float* a, lapack_int lda, float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_ssygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, typename DType, typename FncType>
22
+ class SyGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct sygvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ DType vl;
36
+ DType vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_sygvx(na_loop_t* const lp) {
42
+ DType* a = (DType*)NDL_PTR(lp, 0);
43
+ DType* b = (DType*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ DType* w = (DType*)NDL_PTR(lp, 3);
46
+ DType* z = (DType*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ sygvx_opt* opt = (sygvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const DType abstol = 0.0;
55
+ const lapack_int i = FncType().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_sygvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+
64
+ VALUE a_vnary = Qnil;
65
+ VALUE b_vnary = Qnil;
66
+ VALUE kw_args = Qnil;
67
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
68
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
69
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
70
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
71
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
72
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
73
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
74
+ const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
75
+ const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
76
+ const DType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
77
+ const DType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
78
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
79
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
80
+ const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
81
+
82
+ if (CLASS_OF(a_vnary) != nary_dtype) {
83
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
84
+ }
85
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
86
+ a_vnary = nary_dup(a_vnary);
87
+ }
88
+ if (CLASS_OF(b_vnary) != nary_dtype) {
89
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
90
+ }
91
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
92
+ b_vnary = nary_dup(b_vnary);
93
+ }
94
+
95
+ narray_t* a_nary = nullptr;
96
+ GetNArray(a_vnary, a_nary);
97
+ if (NA_NDIM(a_nary) != 2) {
98
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
99
+ return Qnil;
100
+ }
101
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
102
+ rb_raise(rb_eArgError, "input array a must be square");
103
+ return Qnil;
104
+ }
105
+ narray_t* b_nary = nullptr;
106
+ GetNArray(a_vnary, b_nary);
107
+ if (NA_NDIM(b_nary) != 2) {
108
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
109
+ return Qnil;
110
+ }
111
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
112
+ rb_raise(rb_eArgError, "input array b must be square");
113
+ return Qnil;
114
+ }
115
+
116
+ const size_t n = NA_SHAPE(a_nary)[1];
117
+ size_t m = range != 'I' ? n : iu - il + 1;
118
+ size_t w_shape[1] = { m };
119
+ size_t z_shape[2] = { n, m };
120
+ size_t ifail_shape[1] = { n };
121
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
122
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
123
+ ndfunc_t ndf = { iter_sygvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
124
+ sygvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
125
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
126
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
127
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
128
+
129
+ RB_GC_GUARD(a_vnary);
130
+ RB_GC_GUARD(b_vnary);
131
+
132
+ return ret;
133
+ }
134
+
135
+ static lapack_int get_itype(VALUE val) {
136
+ const lapack_int itype = NUM2INT(val);
137
+
138
+ if (itype != 1 && itype != 2 && itype != 3) {
139
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
140
+ }
141
+
142
+ return itype;
143
+ }
144
+
145
+ static char get_jobz(VALUE val) {
146
+ const char jobz = NUM2CHR(val);
147
+
148
+ if (jobz != 'N' && jobz != 'V') {
149
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
150
+ }
151
+
152
+ return jobz;
153
+ }
154
+
155
+ static char get_range(VALUE val) {
156
+ const char range = NUM2CHR(val);
157
+
158
+ if (range != 'A' && range != 'V' && range != 'I') {
159
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
160
+ }
161
+
162
+ return range;
163
+ }
164
+
165
+ static char get_uplo(VALUE val) {
166
+ const char uplo = NUM2CHR(val);
167
+
168
+ if (uplo != 'U' && uplo != 'L') {
169
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
170
+ }
171
+
172
+ return uplo;
173
+ }
174
+
175
+ static int get_matrix_layout(VALUE val) {
176
+ const char option = NUM2CHR(val);
177
+
178
+ switch (option) {
179
+ case 'r':
180
+ case 'R':
181
+ break;
182
+ case 'c':
183
+ case 'C':
184
+ rb_warn("Numo::TinyLinalg::Lapack.sygvx does not support column major.");
185
+ break;
186
+ }
187
+
188
+ return LAPACK_ROW_MAJOR;
189
+ }
190
+ };
191
+
192
+ } // namespace TinyLinalg