numo-tiny_linalg 0.0.2 → 0.0.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '06379198010df5ee43b42c8d6776e551ac7870e0a81bb43ecfd776868464edf9'
4
- data.tar.gz: e0e38cd51c48332a496e8b9405e219cdcd9f7dcebd83e53d977daa6b82a5b1fc
3
+ metadata.gz: '023843b368ce62924768c3341587977e57edb6aca0e6d93b64279d5e7965bf0b'
4
+ data.tar.gz: fe00ae7b435a94a6ff2790320f89bfef078efafed110e253f8929553480b67d6
5
5
  SHA512:
6
- metadata.gz: 5288ec9be4365280fc177d4d57523e7187360e8d298bb5d60819a2776274f03979a2e0319aee1081532f0abfcc23e0c3a897dbca47d1cd1bde36df5ac87db011
7
- data.tar.gz: 4f03737044a25e81aab2fda5ce2bb555c8565a092bb6edc52776c54b758c1d76147ccdea845ea3e834f510b79a6254e8e9379f264e88421a1e446057f8bac058
6
+ metadata.gz: 36d907834c73b633ae3872ba5fa73f315dd01d79e99487e7fb4f6f2e3f37a7fbcbfa1b25d078b4d44cd66213da545a665c35817d6929d22a8a05fa023791bd2f
7
+ data.tar.gz: eaab6dbb2835d60ad3432525a830a1a5f19625c7b87d1010db81efeb480892c0807612d8d13cdc23c60abde1d7304594fc4562099102a6738b8cee1aeccb55a6
data/CHANGELOG.md CHANGED
@@ -1,5 +1,18 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.3...v0.0.4)] - 2023-08-06
4
+ - Add dsygv, ssygv, zhegv, and chegv module functions to TinyLinalg::Lapack.
5
+ - Add dsygvd, ssygvd, zhegvd, and chegvd module functions to TinyLinalg::Lapack.
6
+ - Add dsygvx, ssygvx, zhegvx, and chegvx module functions to TinyLinalg::Lapack.
7
+ - Add eigh module function to TinyLinalg.
8
+
9
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
10
+ - Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
11
+ - Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
12
+ - Add det module function to TinyLinalg.
13
+ - Add pinv module function to TinyLinalg.
14
+ - Add qr module function to TinyLinalg.
15
+
3
16
  ## [[0.0.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.1...v0.0.2)] - 2023-07-26
4
17
  - Add automatic build of OpenBLAS if it is not found.
5
18
  - Add dgesv, sgesv, zgesv, and cgesv module functions to TinyLinalg::Lapack.
@@ -0,0 +1,118 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGeQrf {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* tau) {
6
+ return LAPACKE_dgeqrf(matrix_layout, m, n, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SGeQrf {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, float* tau) {
13
+ return LAPACKE_sgeqrf(matrix_layout, m, n, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ struct ZGeQrf {
18
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, lapack_complex_double* tau) {
20
+ return LAPACKE_zgeqrf(matrix_layout, m, n, a, lda, tau);
21
+ }
22
+ };
23
+
24
+ struct CGeQrf {
25
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, lapack_complex_float* tau) {
27
+ return LAPACKE_cgeqrf(matrix_layout, m, n, a, lda, tau);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename DType, typename FncType>
32
+ class GeQrf {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_geqrf), -1);
36
+ }
37
+
38
+ private:
39
+ struct geqrf_opt {
40
+ int matrix_layout;
41
+ };
42
+
43
+ static void iter_geqrf(na_loop_t* const lp) {
44
+ DType* a = (DType*)NDL_PTR(lp, 0);
45
+ DType* tau = (DType*)NDL_PTR(lp, 1);
46
+ int* info = (int*)NDL_PTR(lp, 2);
47
+ geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
48
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
49
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
50
+ const lapack_int lda = n;
51
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, tau);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_geqrf(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+
58
+ VALUE a_vnary = Qnil;
59
+ VALUE kw_args = Qnil;
60
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
61
+ ID kw_table[1] = { rb_intern("order") };
62
+ VALUE kw_values[1] = { Qundef };
63
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
64
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
65
+
66
+ if (CLASS_OF(a_vnary) != nary_dtype) {
67
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
68
+ }
69
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
70
+ a_vnary = nary_dup(a_vnary);
71
+ }
72
+
73
+ narray_t* a_nary = NULL;
74
+ GetNArray(a_vnary, a_nary);
75
+ const int n_dims = NA_NDIM(a_nary);
76
+ if (n_dims != 2) {
77
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
78
+ return Qnil;
79
+ }
80
+
81
+ size_t m = NA_SHAPE(a_nary)[0];
82
+ size_t n = NA_SHAPE(a_nary)[1];
83
+ size_t shape[1] = { m < n ? m : n };
84
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
85
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
86
+ ndfunc_t ndf = { iter_geqrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
87
+ geqrf_opt opt = { matrix_layout };
88
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
89
+
90
+ VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
91
+
92
+ RB_GC_GUARD(a_vnary);
93
+
94
+ return ret;
95
+ }
96
+
97
+ static int get_matrix_layout(VALUE val) {
98
+ const char* option_str = StringValueCStr(val);
99
+
100
+ if (std::strlen(option_str) > 0) {
101
+ switch (option_str[0]) {
102
+ case 'r':
103
+ case 'R':
104
+ break;
105
+ case 'c':
106
+ case 'C':
107
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
108
+ break;
109
+ }
110
+ }
111
+
112
+ RB_GC_GUARD(val);
113
+
114
+ return LAPACK_ROW_MAJOR;
115
+ }
116
+ };
117
+
118
+ } // namespace TinyLinalg
@@ -0,0 +1,167 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGv {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
28
+ class HeGv {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegv), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegv_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegv(na_loop_t* const lp) {
43
+ DType* a = (DType*)NDL_PTR(lp, 0);
44
+ DType* b = (DType*)NDL_PTR(lp, 1);
45
+ RType* w = (RType*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegv(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegv_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+
120
+ static lapack_int get_itype(VALUE val) {
121
+ const lapack_int itype = NUM2INT(val);
122
+
123
+ if (itype != 1 && itype != 2 && itype != 3) {
124
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
+ }
126
+
127
+ return itype;
128
+ }
129
+
130
+ static char get_jobz(VALUE val) {
131
+ const char jobz = NUM2CHR(val);
132
+
133
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
+ }
136
+
137
+ return jobz;
138
+ }
139
+
140
+ static char get_uplo(VALUE val) {
141
+ const char uplo = NUM2CHR(val);
142
+
143
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
+ }
146
+
147
+ return uplo;
148
+ }
149
+
150
+ static int get_matrix_layout(VALUE val) {
151
+ const char option = NUM2CHR(val);
152
+
153
+ switch (option) {
154
+ case 'r':
155
+ case 'R':
156
+ break;
157
+ case 'c':
158
+ case 'C':
159
+ rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
160
+ break;
161
+ }
162
+
163
+ return LAPACK_ROW_MAJOR;
164
+ }
165
+ };
166
+
167
+ } // namespace TinyLinalg
@@ -0,0 +1,167 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGvd {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
28
+ class HeGvd {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvd), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegvd_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegvd(na_loop_t* const lp) {
43
+ DType* a = (DType*)NDL_PTR(lp, 0);
44
+ DType* b = (DType*)NDL_PTR(lp, 1);
45
+ RType* w = (RType*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegvd(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegvd_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+
120
+ static lapack_int get_itype(VALUE val) {
121
+ const lapack_int itype = NUM2INT(val);
122
+
123
+ if (itype != 1 && itype != 2 && itype != 3) {
124
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
+ }
126
+
127
+ return itype;
128
+ }
129
+
130
+ static char get_jobz(VALUE val) {
131
+ const char jobz = NUM2CHR(val);
132
+
133
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
+ }
136
+
137
+ return jobz;
138
+ }
139
+
140
+ static char get_uplo(VALUE val) {
141
+ const char uplo = NUM2CHR(val);
142
+
143
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
+ }
146
+
147
+ return uplo;
148
+ }
149
+
150
+ static int get_matrix_layout(VALUE val) {
151
+ const char option = NUM2CHR(val);
152
+
153
+ switch (option) {
154
+ case 'r':
155
+ case 'R':
156
+ break;
157
+ case 'c':
158
+ case 'C':
159
+ rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
160
+ break;
161
+ }
162
+
163
+ return LAPACK_ROW_MAJOR;
164
+ }
165
+ };
166
+
167
+ } // namespace TinyLinalg
@@ -0,0 +1,193 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_zhegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct CHeGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_chegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
22
+ class HeGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct hegvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ RType vl;
36
+ RType vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_hegvx(na_loop_t* const lp) {
42
+ DType* a = (DType*)NDL_PTR(lp, 0);
43
+ DType* b = (DType*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ RType* w = (RType*)NDL_PTR(lp, 3);
46
+ DType* z = (DType*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const RType abstol = 0.0;
55
+ const lapack_int i = FncType().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_hegvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
64
+
65
+ VALUE a_vnary = Qnil;
66
+ VALUE b_vnary = Qnil;
67
+ VALUE kw_args = Qnil;
68
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
69
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
70
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
71
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
72
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
73
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
74
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
75
+ const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
76
+ const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
77
+ const RType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
+ const RType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
79
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
80
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
81
+ const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
82
+
83
+ if (CLASS_OF(a_vnary) != nary_dtype) {
84
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
85
+ }
86
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
87
+ a_vnary = nary_dup(a_vnary);
88
+ }
89
+ if (CLASS_OF(b_vnary) != nary_dtype) {
90
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
91
+ }
92
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
93
+ b_vnary = nary_dup(b_vnary);
94
+ }
95
+
96
+ narray_t* a_nary = nullptr;
97
+ GetNArray(a_vnary, a_nary);
98
+ if (NA_NDIM(a_nary) != 2) {
99
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
100
+ return Qnil;
101
+ }
102
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
103
+ rb_raise(rb_eArgError, "input array a must be square");
104
+ return Qnil;
105
+ }
106
+ narray_t* b_nary = nullptr;
107
+ GetNArray(a_vnary, b_nary);
108
+ if (NA_NDIM(b_nary) != 2) {
109
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
110
+ return Qnil;
111
+ }
112
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
113
+ rb_raise(rb_eArgError, "input array b must be square");
114
+ return Qnil;
115
+ }
116
+
117
+ const size_t n = NA_SHAPE(a_nary)[1];
118
+ size_t m = range != 'I' ? n : iu - il + 1;
119
+ size_t w_shape[1] = { m };
120
+ size_t z_shape[2] = { n, m };
121
+ size_t ifail_shape[1] = { n };
122
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
123
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
124
+ ndfunc_t ndf = { iter_hegvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
125
+ hegvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
126
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
127
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
128
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
129
+
130
+ RB_GC_GUARD(a_vnary);
131
+ RB_GC_GUARD(b_vnary);
132
+
133
+ return ret;
134
+ }
135
+
136
+ static lapack_int get_itype(VALUE val) {
137
+ const lapack_int itype = NUM2INT(val);
138
+
139
+ if (itype != 1 && itype != 2 && itype != 3) {
140
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
141
+ }
142
+
143
+ return itype;
144
+ }
145
+
146
+ static char get_jobz(VALUE val) {
147
+ const char jobz = NUM2CHR(val);
148
+
149
+ if (jobz != 'N' && jobz != 'V') {
150
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
151
+ }
152
+
153
+ return jobz;
154
+ }
155
+
156
+ static char get_range(VALUE val) {
157
+ const char range = NUM2CHR(val);
158
+
159
+ if (range != 'A' && range != 'V' && range != 'I') {
160
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
161
+ }
162
+
163
+ return range;
164
+ }
165
+
166
+ static char get_uplo(VALUE val) {
167
+ const char uplo = NUM2CHR(val);
168
+
169
+ if (uplo != 'U' && uplo != 'L') {
170
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
171
+ }
172
+
173
+ return uplo;
174
+ }
175
+
176
+ static int get_matrix_layout(VALUE val) {
177
+ const char option = NUM2CHR(val);
178
+
179
+ switch (option) {
180
+ case 'r':
181
+ case 'R':
182
+ break;
183
+ case 'c':
184
+ case 'C':
185
+ rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
186
+ break;
187
+ }
188
+
189
+ return LAPACK_ROW_MAJOR;
190
+ }
191
+ };
192
+
193
+ } // namespace TinyLinalg