numo-tiny_linalg 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: bbebba3b506ab283688f9d0935739e14c09106918a98571b9addb64b6689cc75
4
- data.tar.gz: 85eaa28da383e21a4503407667baacb95fb5382f20f54804f02e4c2d49c15cd1
3
+ metadata.gz: '023843b368ce62924768c3341587977e57edb6aca0e6d93b64279d5e7965bf0b'
4
+ data.tar.gz: fe00ae7b435a94a6ff2790320f89bfef078efafed110e253f8929553480b67d6
5
5
  SHA512:
6
- metadata.gz: 262e38a4bbbbf6141cca723830f93c80cc9d33ac8a34753e3af1b9b5b239f4576140f18ad43999e8d0568438380f8bb2f6cf3f78c1f5d35402d57fcf55e4e253
7
- data.tar.gz: ba6e767f8728022dff634e1a3824a166c5b3f4d17057b65f2e279a2ccda293c2432aea07783708fdd9f1d810e7a9bc9428ae2b9fe559b0bb519d8b1cb828b01e
6
+ metadata.gz: 36d907834c73b633ae3872ba5fa73f315dd01d79e99487e7fb4f6f2e3f37a7fbcbfa1b25d078b4d44cd66213da545a665c35817d6929d22a8a05fa023791bd2f
7
+ data.tar.gz: eaab6dbb2835d60ad3432525a830a1a5f19625c7b87d1010db81efeb480892c0807612d8d13cdc23c60abde1d7304594fc4562099102a6738b8cee1aeccb55a6
data/CHANGELOG.md CHANGED
@@ -1,5 +1,11 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.3...v0.0.4)] - 2023-08-06
4
+ - Add dsygv, ssygv, zhegv, and chegv module functions to TinyLinalg::Lapack.
5
+ - Add dsygvd, ssygvd, zhegvd, and chegvd module functions to TinyLinalg::Lapack.
6
+ - Add dsygvx, ssygvx, zhegvx, and chegvx module functions to TinyLinalg::Lapack.
7
+ - Add eigh module function to TinyLinalg.
8
+
3
9
  ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
4
10
  - Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
5
11
  - Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
@@ -0,0 +1,167 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGv {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
28
+ class HeGv {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegv), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegv_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegv(na_loop_t* const lp) {
43
+ DType* a = (DType*)NDL_PTR(lp, 0);
44
+ DType* b = (DType*)NDL_PTR(lp, 1);
45
+ RType* w = (RType*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegv(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegv_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+
120
+ static lapack_int get_itype(VALUE val) {
121
+ const lapack_int itype = NUM2INT(val);
122
+
123
+ if (itype != 1 && itype != 2 && itype != 3) {
124
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
+ }
126
+
127
+ return itype;
128
+ }
129
+
130
+ static char get_jobz(VALUE val) {
131
+ const char jobz = NUM2CHR(val);
132
+
133
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
+ }
136
+
137
+ return jobz;
138
+ }
139
+
140
+ static char get_uplo(VALUE val) {
141
+ const char uplo = NUM2CHR(val);
142
+
143
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
+ }
146
+
147
+ return uplo;
148
+ }
149
+
150
+ static int get_matrix_layout(VALUE val) {
151
+ const char option = NUM2CHR(val);
152
+
153
+ switch (option) {
154
+ case 'r':
155
+ case 'R':
156
+ break;
157
+ case 'c':
158
+ case 'C':
159
+ rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
160
+ break;
161
+ }
162
+
163
+ return LAPACK_ROW_MAJOR;
164
+ }
165
+ };
166
+
167
+ } // namespace TinyLinalg
@@ -0,0 +1,167 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGvd {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
28
+ class HeGvd {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvd), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegvd_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegvd(na_loop_t* const lp) {
43
+ DType* a = (DType*)NDL_PTR(lp, 0);
44
+ DType* b = (DType*)NDL_PTR(lp, 1);
45
+ RType* w = (RType*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegvd(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegvd_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+
120
+ static lapack_int get_itype(VALUE val) {
121
+ const lapack_int itype = NUM2INT(val);
122
+
123
+ if (itype != 1 && itype != 2 && itype != 3) {
124
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
+ }
126
+
127
+ return itype;
128
+ }
129
+
130
+ static char get_jobz(VALUE val) {
131
+ const char jobz = NUM2CHR(val);
132
+
133
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
+ }
136
+
137
+ return jobz;
138
+ }
139
+
140
+ static char get_uplo(VALUE val) {
141
+ const char uplo = NUM2CHR(val);
142
+
143
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
+ }
146
+
147
+ return uplo;
148
+ }
149
+
150
+ static int get_matrix_layout(VALUE val) {
151
+ const char option = NUM2CHR(val);
152
+
153
+ switch (option) {
154
+ case 'r':
155
+ case 'R':
156
+ break;
157
+ case 'c':
158
+ case 'C':
159
+ rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
160
+ break;
161
+ }
162
+
163
+ return LAPACK_ROW_MAJOR;
164
+ }
165
+ };
166
+
167
+ } // namespace TinyLinalg
@@ -0,0 +1,193 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_zhegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct CHeGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_chegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
22
+ class HeGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct hegvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ RType vl;
36
+ RType vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_hegvx(na_loop_t* const lp) {
42
+ DType* a = (DType*)NDL_PTR(lp, 0);
43
+ DType* b = (DType*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ RType* w = (RType*)NDL_PTR(lp, 3);
46
+ DType* z = (DType*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const RType abstol = 0.0;
55
+ const lapack_int i = FncType().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_hegvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
64
+
65
+ VALUE a_vnary = Qnil;
66
+ VALUE b_vnary = Qnil;
67
+ VALUE kw_args = Qnil;
68
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
69
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
70
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
71
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
72
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
73
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
74
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
75
+ const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
76
+ const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
77
+ const RType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
+ const RType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
79
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
80
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
81
+ const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
82
+
83
+ if (CLASS_OF(a_vnary) != nary_dtype) {
84
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
85
+ }
86
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
87
+ a_vnary = nary_dup(a_vnary);
88
+ }
89
+ if (CLASS_OF(b_vnary) != nary_dtype) {
90
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
91
+ }
92
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
93
+ b_vnary = nary_dup(b_vnary);
94
+ }
95
+
96
+ narray_t* a_nary = nullptr;
97
+ GetNArray(a_vnary, a_nary);
98
+ if (NA_NDIM(a_nary) != 2) {
99
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
100
+ return Qnil;
101
+ }
102
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
103
+ rb_raise(rb_eArgError, "input array a must be square");
104
+ return Qnil;
105
+ }
106
+ narray_t* b_nary = nullptr;
107
+ GetNArray(a_vnary, b_nary);
108
+ if (NA_NDIM(b_nary) != 2) {
109
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
110
+ return Qnil;
111
+ }
112
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
113
+ rb_raise(rb_eArgError, "input array b must be square");
114
+ return Qnil;
115
+ }
116
+
117
+ const size_t n = NA_SHAPE(a_nary)[1];
118
+ size_t m = range != 'I' ? n : iu - il + 1;
119
+ size_t w_shape[1] = { m };
120
+ size_t z_shape[2] = { n, m };
121
+ size_t ifail_shape[1] = { n };
122
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
123
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
124
+ ndfunc_t ndf = { iter_hegvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
125
+ hegvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
126
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
127
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
128
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
129
+
130
+ RB_GC_GUARD(a_vnary);
131
+ RB_GC_GUARD(b_vnary);
132
+
133
+ return ret;
134
+ }
135
+
136
+ static lapack_int get_itype(VALUE val) {
137
+ const lapack_int itype = NUM2INT(val);
138
+
139
+ if (itype != 1 && itype != 2 && itype != 3) {
140
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
141
+ }
142
+
143
+ return itype;
144
+ }
145
+
146
+ static char get_jobz(VALUE val) {
147
+ const char jobz = NUM2CHR(val);
148
+
149
+ if (jobz != 'N' && jobz != 'V') {
150
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
151
+ }
152
+
153
+ return jobz;
154
+ }
155
+
156
+ static char get_range(VALUE val) {
157
+ const char range = NUM2CHR(val);
158
+
159
+ if (range != 'A' && range != 'V' && range != 'I') {
160
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
161
+ }
162
+
163
+ return range;
164
+ }
165
+
166
+ static char get_uplo(VALUE val) {
167
+ const char uplo = NUM2CHR(val);
168
+
169
+ if (uplo != 'U' && uplo != 'L') {
170
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
171
+ }
172
+
173
+ return uplo;
174
+ }
175
+
176
+ static int get_matrix_layout(VALUE val) {
177
+ const char option = NUM2CHR(val);
178
+
179
+ switch (option) {
180
+ case 'r':
181
+ case 'R':
182
+ break;
183
+ case 'c':
184
+ case 'C':
185
+ rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
186
+ break;
187
+ }
188
+
189
+ return LAPACK_ROW_MAJOR;
190
+ }
191
+ };
192
+
193
+ } // namespace TinyLinalg
@@ -0,0 +1,158 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGv {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename DType, typename FncType>
20
+ class SyGv {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygv), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygv_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygv(na_loop_t* const lp) {
35
+ DType* a = (DType*)NDL_PTR(lp, 0);
36
+ DType* b = (DType*)NDL_PTR(lp, 1);
37
+ DType* w = (DType*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygv(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygv_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+
111
+ static lapack_int get_itype(VALUE val) {
112
+ const lapack_int itype = NUM2INT(val);
113
+
114
+ if (itype != 1 && itype != 2 && itype != 3) {
115
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
+ }
117
+
118
+ return itype;
119
+ }
120
+
121
+ static char get_jobz(VALUE val) {
122
+ const char jobz = NUM2CHR(val);
123
+
124
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
+ }
127
+
128
+ return jobz;
129
+ }
130
+
131
+ static char get_uplo(VALUE val) {
132
+ const char uplo = NUM2CHR(val);
133
+
134
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
+ }
137
+
138
+ return uplo;
139
+ }
140
+
141
+ static int get_matrix_layout(VALUE val) {
142
+ const char option = NUM2CHR(val);
143
+
144
+ switch (option) {
145
+ case 'r':
146
+ case 'R':
147
+ break;
148
+ case 'c':
149
+ case 'C':
150
+ rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
151
+ break;
152
+ }
153
+
154
+ return LAPACK_ROW_MAJOR;
155
+ }
156
+ };
157
+
158
+ } // namespace TinyLinalg
@@ -0,0 +1,158 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGvd {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename DType, typename FncType>
20
+ class SyGvd {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvd), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygvd_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygvd(na_loop_t* const lp) {
35
+ DType* a = (DType*)NDL_PTR(lp, 0);
36
+ DType* b = (DType*)NDL_PTR(lp, 1);
37
+ DType* w = (DType*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygvd(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygvd_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+
111
+ static lapack_int get_itype(VALUE val) {
112
+ const lapack_int itype = NUM2INT(val);
113
+
114
+ if (itype != 1 && itype != 2 && itype != 3) {
115
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
+ }
117
+
118
+ return itype;
119
+ }
120
+
121
+ static char get_jobz(VALUE val) {
122
+ const char jobz = NUM2CHR(val);
123
+
124
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
+ }
127
+
128
+ return jobz;
129
+ }
130
+
131
+ static char get_uplo(VALUE val) {
132
+ const char uplo = NUM2CHR(val);
133
+
134
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
+ }
137
+
138
+ return uplo;
139
+ }
140
+
141
+ static int get_matrix_layout(VALUE val) {
142
+ const char option = NUM2CHR(val);
143
+
144
+ switch (option) {
145
+ case 'r':
146
+ case 'R':
147
+ break;
148
+ case 'c':
149
+ case 'C':
150
+ rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
151
+ break;
152
+ }
153
+
154
+ return LAPACK_ROW_MAJOR;
155
+ }
156
+ };
157
+
158
+ } // namespace TinyLinalg
@@ -0,0 +1,192 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, double* a, lapack_int lda, double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_dsygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct SSyGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, float* a, lapack_int lda, float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_ssygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, typename DType, typename FncType>
22
+ class SyGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct sygvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ DType vl;
36
+ DType vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_sygvx(na_loop_t* const lp) {
42
+ DType* a = (DType*)NDL_PTR(lp, 0);
43
+ DType* b = (DType*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ DType* w = (DType*)NDL_PTR(lp, 3);
46
+ DType* z = (DType*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ sygvx_opt* opt = (sygvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const DType abstol = 0.0;
55
+ const lapack_int i = FncType().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_sygvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+
64
+ VALUE a_vnary = Qnil;
65
+ VALUE b_vnary = Qnil;
66
+ VALUE kw_args = Qnil;
67
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
68
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
69
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
70
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
71
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
72
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
73
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
74
+ const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
75
+ const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
76
+ const DType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
77
+ const DType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
78
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
79
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
80
+ const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
81
+
82
+ if (CLASS_OF(a_vnary) != nary_dtype) {
83
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
84
+ }
85
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
86
+ a_vnary = nary_dup(a_vnary);
87
+ }
88
+ if (CLASS_OF(b_vnary) != nary_dtype) {
89
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
90
+ }
91
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
92
+ b_vnary = nary_dup(b_vnary);
93
+ }
94
+
95
+ narray_t* a_nary = nullptr;
96
+ GetNArray(a_vnary, a_nary);
97
+ if (NA_NDIM(a_nary) != 2) {
98
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
99
+ return Qnil;
100
+ }
101
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
102
+ rb_raise(rb_eArgError, "input array a must be square");
103
+ return Qnil;
104
+ }
105
+ narray_t* b_nary = nullptr;
106
+ GetNArray(a_vnary, b_nary);
107
+ if (NA_NDIM(b_nary) != 2) {
108
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
109
+ return Qnil;
110
+ }
111
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
112
+ rb_raise(rb_eArgError, "input array b must be square");
113
+ return Qnil;
114
+ }
115
+
116
+ const size_t n = NA_SHAPE(a_nary)[1];
117
+ size_t m = range != 'I' ? n : iu - il + 1;
118
+ size_t w_shape[1] = { m };
119
+ size_t z_shape[2] = { n, m };
120
+ size_t ifail_shape[1] = { n };
121
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
122
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
123
+ ndfunc_t ndf = { iter_sygvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
124
+ sygvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
125
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
126
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
127
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
128
+
129
+ RB_GC_GUARD(a_vnary);
130
+ RB_GC_GUARD(b_vnary);
131
+
132
+ return ret;
133
+ }
134
+
135
+ static lapack_int get_itype(VALUE val) {
136
+ const lapack_int itype = NUM2INT(val);
137
+
138
+ if (itype != 1 && itype != 2 && itype != 3) {
139
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
140
+ }
141
+
142
+ return itype;
143
+ }
144
+
145
+ static char get_jobz(VALUE val) {
146
+ const char jobz = NUM2CHR(val);
147
+
148
+ if (jobz != 'N' && jobz != 'V') {
149
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
150
+ }
151
+
152
+ return jobz;
153
+ }
154
+
155
+ static char get_range(VALUE val) {
156
+ const char range = NUM2CHR(val);
157
+
158
+ if (range != 'A' && range != 'V' && range != 'I') {
159
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
160
+ }
161
+
162
+ return range;
163
+ }
164
+
165
+ static char get_uplo(VALUE val) {
166
+ const char uplo = NUM2CHR(val);
167
+
168
+ if (uplo != 'U' && uplo != 'L') {
169
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
170
+ }
171
+
172
+ return uplo;
173
+ }
174
+
175
+ static int get_matrix_layout(VALUE val) {
176
+ const char option = NUM2CHR(val);
177
+
178
+ switch (option) {
179
+ case 'r':
180
+ case 'R':
181
+ break;
182
+ case 'c':
183
+ case 'C':
184
+ rb_warn("Numo::TinyLinalg::Lapack.sygvx does not support column major.");
185
+ break;
186
+ }
187
+
188
+ return LAPACK_ROW_MAJOR;
189
+ }
190
+ };
191
+
192
+ } // namespace TinyLinalg
@@ -11,7 +11,13 @@
11
11
  #include "lapack/gesvd.hpp"
12
12
  #include "lapack/getrf.hpp"
13
13
  #include "lapack/getri.hpp"
14
+ #include "lapack/hegv.hpp"
15
+ #include "lapack/hegvd.hpp"
16
+ #include "lapack/hegvx.hpp"
14
17
  #include "lapack/orgqr.hpp"
18
+ #include "lapack/sygv.hpp"
19
+ #include "lapack/sygvd.hpp"
20
+ #include "lapack/sygvx.hpp"
15
21
  #include "lapack/ungqr.hpp"
16
22
 
17
23
  VALUE rb_mTinyLinalg;
@@ -271,6 +277,18 @@ extern "C" void Init_tiny_linalg(void) {
271
277
  TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
272
278
  TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
273
279
  TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
280
+ TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
281
+ TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
282
+ TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
283
+ TinyLinalg::HeGv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGv>::define_module_function(rb_mTinyLinalgLapack, "chegv");
284
+ TinyLinalg::SyGvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvd>::define_module_function(rb_mTinyLinalgLapack, "dsygvd");
285
+ TinyLinalg::SyGvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvd>::define_module_function(rb_mTinyLinalgLapack, "ssygvd");
286
+ TinyLinalg::HeGvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvd>::define_module_function(rb_mTinyLinalgLapack, "zhegvd");
287
+ TinyLinalg::HeGvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvd>::define_module_function(rb_mTinyLinalgLapack, "chegvd");
288
+ TinyLinalg::SyGvx<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvx>::define_module_function(rb_mTinyLinalgLapack, "dsygvx");
289
+ TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
290
+ TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
291
+ TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
274
292
 
275
293
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
276
294
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.0.3'
8
+ VERSION = '0.0.4'
9
9
  end
10
10
  end
@@ -10,6 +10,50 @@ module Numo
10
10
  module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
+ # by solving an ordinary or generalized eigenvalue problem.
15
+ #
16
+ # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
17
+ # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
18
+ # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
19
+ # @param vals_range [Range/Array]
20
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
21
+ # If nil, all eigenvalues and eigenvectors are computed.
22
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
23
+ # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
24
+ # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
25
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
26
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
27
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
28
+
29
+ bchr = blas_char(a)
30
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
31
+
32
+ unless b.nil?
33
+ raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
34
+ raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
35
+ raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
36
+ end
37
+
38
+ jobz = vals_only ? 'N' : 'V'
39
+ b = a.class.eye(a.shape[0]) if b.nil?
40
+ sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
41
+
42
+ if vals_range.nil?
43
+ sy_he_gv << 'd' if turbo
44
+ vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
45
+ else
46
+ sy_he_gv << 'x'
47
+ il = vals_range.first + 1
48
+ iu = vals_range.last + 1
49
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
50
+ sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
51
+ )
52
+ end
53
+ vecs = nil if vals_only
54
+ [vals, vecs]
55
+ end
56
+
13
57
  # Computes the determinant of matrix.
14
58
  #
15
59
  # @param a [Numo::NArray] n-by-n square matrix.
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.3
4
+ version: 0.0.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-02 00:00:00.000000000 Z
11
+ date: 2023-08-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -52,7 +52,13 @@ files:
52
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
54
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/hegv.hpp
56
+ - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
+ - ext/numo/tiny_linalg/lapack/hegvx.hpp
55
58
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
59
+ - ext/numo/tiny_linalg/lapack/sygv.hpp
60
+ - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
+ - ext/numo/tiny_linalg/lapack/sygvx.hpp
56
62
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
57
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
58
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp