numo-tiny_linalg 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '06379198010df5ee43b42c8d6776e551ac7870e0a81bb43ecfd776868464edf9'
4
- data.tar.gz: e0e38cd51c48332a496e8b9405e219cdcd9f7dcebd83e53d977daa6b82a5b1fc
3
+ metadata.gz: '023843b368ce62924768c3341587977e57edb6aca0e6d93b64279d5e7965bf0b'
4
+ data.tar.gz: fe00ae7b435a94a6ff2790320f89bfef078efafed110e253f8929553480b67d6
5
5
  SHA512:
6
- metadata.gz: 5288ec9be4365280fc177d4d57523e7187360e8d298bb5d60819a2776274f03979a2e0319aee1081532f0abfcc23e0c3a897dbca47d1cd1bde36df5ac87db011
7
- data.tar.gz: 4f03737044a25e81aab2fda5ce2bb555c8565a092bb6edc52776c54b758c1d76147ccdea845ea3e834f510b79a6254e8e9379f264e88421a1e446057f8bac058
6
+ metadata.gz: 36d907834c73b633ae3872ba5fa73f315dd01d79e99487e7fb4f6f2e3f37a7fbcbfa1b25d078b4d44cd66213da545a665c35817d6929d22a8a05fa023791bd2f
7
+ data.tar.gz: eaab6dbb2835d60ad3432525a830a1a5f19625c7b87d1010db81efeb480892c0807612d8d13cdc23c60abde1d7304594fc4562099102a6738b8cee1aeccb55a6
data/CHANGELOG.md CHANGED
@@ -1,5 +1,18 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.3...v0.0.4)] - 2023-08-06
4
+ - Add dsygv, ssygv, zhegv, and chegv module functions to TinyLinalg::Lapack.
5
+ - Add dsygvd, ssygvd, zhegvd, and chegvd module functions to TinyLinalg::Lapack.
6
+ - Add dsygvx, ssygvx, zhegvx, and chegvx module functions to TinyLinalg::Lapack.
7
+ - Add eigh module function to TinyLinalg.
8
+
9
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
10
+ - Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
11
+ - Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
12
+ - Add det module function to TinyLinalg.
13
+ - Add pinv module function to TinyLinalg.
14
+ - Add qr module function to TinyLinalg.
15
+
3
16
  ## [[0.0.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.1...v0.0.2)] - 2023-07-26
4
17
  - Add automatic build of OpenBLAS if it is not found.
5
18
  - Add dgesv, sgesv, zgesv, and cgesv module functions to TinyLinalg::Lapack.
@@ -0,0 +1,118 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGeQrf {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* tau) {
6
+ return LAPACKE_dgeqrf(matrix_layout, m, n, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SGeQrf {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, float* tau) {
13
+ return LAPACKE_sgeqrf(matrix_layout, m, n, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ struct ZGeQrf {
18
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, lapack_complex_double* tau) {
20
+ return LAPACKE_zgeqrf(matrix_layout, m, n, a, lda, tau);
21
+ }
22
+ };
23
+
24
+ struct CGeQrf {
25
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, lapack_complex_float* tau) {
27
+ return LAPACKE_cgeqrf(matrix_layout, m, n, a, lda, tau);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename DType, typename FncType>
32
+ class GeQrf {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_geqrf), -1);
36
+ }
37
+
38
+ private:
39
+ struct geqrf_opt {
40
+ int matrix_layout;
41
+ };
42
+
43
+ static void iter_geqrf(na_loop_t* const lp) {
44
+ DType* a = (DType*)NDL_PTR(lp, 0);
45
+ DType* tau = (DType*)NDL_PTR(lp, 1);
46
+ int* info = (int*)NDL_PTR(lp, 2);
47
+ geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
48
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
49
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
50
+ const lapack_int lda = n;
51
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, tau);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_geqrf(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+
58
+ VALUE a_vnary = Qnil;
59
+ VALUE kw_args = Qnil;
60
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
61
+ ID kw_table[1] = { rb_intern("order") };
62
+ VALUE kw_values[1] = { Qundef };
63
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
64
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
65
+
66
+ if (CLASS_OF(a_vnary) != nary_dtype) {
67
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
68
+ }
69
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
70
+ a_vnary = nary_dup(a_vnary);
71
+ }
72
+
73
+ narray_t* a_nary = NULL;
74
+ GetNArray(a_vnary, a_nary);
75
+ const int n_dims = NA_NDIM(a_nary);
76
+ if (n_dims != 2) {
77
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
78
+ return Qnil;
79
+ }
80
+
81
+ size_t m = NA_SHAPE(a_nary)[0];
82
+ size_t n = NA_SHAPE(a_nary)[1];
83
+ size_t shape[1] = { m < n ? m : n };
84
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
85
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
86
+ ndfunc_t ndf = { iter_geqrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
87
+ geqrf_opt opt = { matrix_layout };
88
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
89
+
90
+ VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
91
+
92
+ RB_GC_GUARD(a_vnary);
93
+
94
+ return ret;
95
+ }
96
+
97
+ static int get_matrix_layout(VALUE val) {
98
+ const char* option_str = StringValueCStr(val);
99
+
100
+ if (std::strlen(option_str) > 0) {
101
+ switch (option_str[0]) {
102
+ case 'r':
103
+ case 'R':
104
+ break;
105
+ case 'c':
106
+ case 'C':
107
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
108
+ break;
109
+ }
110
+ }
111
+
112
+ RB_GC_GUARD(val);
113
+
114
+ return LAPACK_ROW_MAJOR;
115
+ }
116
+ };
117
+
118
+ } // namespace TinyLinalg
@@ -0,0 +1,167 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGv {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
28
+ class HeGv {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegv), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegv_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegv(na_loop_t* const lp) {
43
+ DType* a = (DType*)NDL_PTR(lp, 0);
44
+ DType* b = (DType*)NDL_PTR(lp, 1);
45
+ RType* w = (RType*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegv(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegv_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+
120
+ static lapack_int get_itype(VALUE val) {
121
+ const lapack_int itype = NUM2INT(val);
122
+
123
+ if (itype != 1 && itype != 2 && itype != 3) {
124
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
+ }
126
+
127
+ return itype;
128
+ }
129
+
130
+ static char get_jobz(VALUE val) {
131
+ const char jobz = NUM2CHR(val);
132
+
133
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
+ }
136
+
137
+ return jobz;
138
+ }
139
+
140
+ static char get_uplo(VALUE val) {
141
+ const char uplo = NUM2CHR(val);
142
+
143
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
+ }
146
+
147
+ return uplo;
148
+ }
149
+
150
+ static int get_matrix_layout(VALUE val) {
151
+ const char option = NUM2CHR(val);
152
+
153
+ switch (option) {
154
+ case 'r':
155
+ case 'R':
156
+ break;
157
+ case 'c':
158
+ case 'C':
159
+ rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
160
+ break;
161
+ }
162
+
163
+ return LAPACK_ROW_MAJOR;
164
+ }
165
+ };
166
+
167
+ } // namespace TinyLinalg
@@ -0,0 +1,167 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGvd {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
28
+ class HeGvd {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvd), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegvd_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegvd(na_loop_t* const lp) {
43
+ DType* a = (DType*)NDL_PTR(lp, 0);
44
+ DType* b = (DType*)NDL_PTR(lp, 1);
45
+ RType* w = (RType*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegvd(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegvd_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+
120
+ static lapack_int get_itype(VALUE val) {
121
+ const lapack_int itype = NUM2INT(val);
122
+
123
+ if (itype != 1 && itype != 2 && itype != 3) {
124
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
+ }
126
+
127
+ return itype;
128
+ }
129
+
130
+ static char get_jobz(VALUE val) {
131
+ const char jobz = NUM2CHR(val);
132
+
133
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
+ }
136
+
137
+ return jobz;
138
+ }
139
+
140
+ static char get_uplo(VALUE val) {
141
+ const char uplo = NUM2CHR(val);
142
+
143
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
+ }
146
+
147
+ return uplo;
148
+ }
149
+
150
+ static int get_matrix_layout(VALUE val) {
151
+ const char option = NUM2CHR(val);
152
+
153
+ switch (option) {
154
+ case 'r':
155
+ case 'R':
156
+ break;
157
+ case 'c':
158
+ case 'C':
159
+ rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
160
+ break;
161
+ }
162
+
163
+ return LAPACK_ROW_MAJOR;
164
+ }
165
+ };
166
+
167
+ } // namespace TinyLinalg
@@ -0,0 +1,193 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_zhegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct CHeGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_chegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
22
+ class HeGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct hegvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ RType vl;
36
+ RType vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_hegvx(na_loop_t* const lp) {
42
+ DType* a = (DType*)NDL_PTR(lp, 0);
43
+ DType* b = (DType*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ RType* w = (RType*)NDL_PTR(lp, 3);
46
+ DType* z = (DType*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const RType abstol = 0.0;
55
+ const lapack_int i = FncType().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_hegvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
64
+
65
+ VALUE a_vnary = Qnil;
66
+ VALUE b_vnary = Qnil;
67
+ VALUE kw_args = Qnil;
68
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
69
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
70
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
71
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
72
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
73
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
74
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
75
+ const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
76
+ const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
77
+ const RType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
+ const RType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
79
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
80
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
81
+ const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
82
+
83
+ if (CLASS_OF(a_vnary) != nary_dtype) {
84
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
85
+ }
86
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
87
+ a_vnary = nary_dup(a_vnary);
88
+ }
89
+ if (CLASS_OF(b_vnary) != nary_dtype) {
90
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
91
+ }
92
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
93
+ b_vnary = nary_dup(b_vnary);
94
+ }
95
+
96
+ narray_t* a_nary = nullptr;
97
+ GetNArray(a_vnary, a_nary);
98
+ if (NA_NDIM(a_nary) != 2) {
99
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
100
+ return Qnil;
101
+ }
102
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
103
+ rb_raise(rb_eArgError, "input array a must be square");
104
+ return Qnil;
105
+ }
106
+ narray_t* b_nary = nullptr;
107
+ GetNArray(a_vnary, b_nary);
108
+ if (NA_NDIM(b_nary) != 2) {
109
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
110
+ return Qnil;
111
+ }
112
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
113
+ rb_raise(rb_eArgError, "input array b must be square");
114
+ return Qnil;
115
+ }
116
+
117
+ const size_t n = NA_SHAPE(a_nary)[1];
118
+ size_t m = range != 'I' ? n : iu - il + 1;
119
+ size_t w_shape[1] = { m };
120
+ size_t z_shape[2] = { n, m };
121
+ size_t ifail_shape[1] = { n };
122
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
123
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
124
+ ndfunc_t ndf = { iter_hegvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
125
+ hegvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
126
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
127
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
128
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
129
+
130
+ RB_GC_GUARD(a_vnary);
131
+ RB_GC_GUARD(b_vnary);
132
+
133
+ return ret;
134
+ }
135
+
136
+ static lapack_int get_itype(VALUE val) {
137
+ const lapack_int itype = NUM2INT(val);
138
+
139
+ if (itype != 1 && itype != 2 && itype != 3) {
140
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
141
+ }
142
+
143
+ return itype;
144
+ }
145
+
146
+ static char get_jobz(VALUE val) {
147
+ const char jobz = NUM2CHR(val);
148
+
149
+ if (jobz != 'N' && jobz != 'V') {
150
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
151
+ }
152
+
153
+ return jobz;
154
+ }
155
+
156
+ static char get_range(VALUE val) {
157
+ const char range = NUM2CHR(val);
158
+
159
+ if (range != 'A' && range != 'V' && range != 'I') {
160
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
161
+ }
162
+
163
+ return range;
164
+ }
165
+
166
+ static char get_uplo(VALUE val) {
167
+ const char uplo = NUM2CHR(val);
168
+
169
+ if (uplo != 'U' && uplo != 'L') {
170
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
171
+ }
172
+
173
+ return uplo;
174
+ }
175
+
176
+ static int get_matrix_layout(VALUE val) {
177
+ const char option = NUM2CHR(val);
178
+
179
+ switch (option) {
180
+ case 'r':
181
+ case 'R':
182
+ break;
183
+ case 'c':
184
+ case 'C':
185
+ rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
186
+ break;
187
+ }
188
+
189
+ return LAPACK_ROW_MAJOR;
190
+ }
191
+ };
192
+
193
+ } // namespace TinyLinalg