numo-tiny_linalg 0.0.3 → 0.0.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: bbebba3b506ab283688f9d0935739e14c09106918a98571b9addb64b6689cc75
4
- data.tar.gz: 85eaa28da383e21a4503407667baacb95fb5382f20f54804f02e4c2d49c15cd1
3
+ metadata.gz: '023843b368ce62924768c3341587977e57edb6aca0e6d93b64279d5e7965bf0b'
4
+ data.tar.gz: fe00ae7b435a94a6ff2790320f89bfef078efafed110e253f8929553480b67d6
5
5
  SHA512:
6
- metadata.gz: 262e38a4bbbbf6141cca723830f93c80cc9d33ac8a34753e3af1b9b5b239f4576140f18ad43999e8d0568438380f8bb2f6cf3f78c1f5d35402d57fcf55e4e253
7
- data.tar.gz: ba6e767f8728022dff634e1a3824a166c5b3f4d17057b65f2e279a2ccda293c2432aea07783708fdd9f1d810e7a9bc9428ae2b9fe559b0bb519d8b1cb828b01e
6
+ metadata.gz: 36d907834c73b633ae3872ba5fa73f315dd01d79e99487e7fb4f6f2e3f37a7fbcbfa1b25d078b4d44cd66213da545a665c35817d6929d22a8a05fa023791bd2f
7
+ data.tar.gz: eaab6dbb2835d60ad3432525a830a1a5f19625c7b87d1010db81efeb480892c0807612d8d13cdc23c60abde1d7304594fc4562099102a6738b8cee1aeccb55a6
data/CHANGELOG.md CHANGED
@@ -1,5 +1,11 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.3...v0.0.4)] - 2023-08-06
4
+ - Add dsygv, ssygv, zhegv, and chegv module functions to TinyLinalg::Lapack.
5
+ - Add dsygvd, ssygvd, zhegvd, and chegvd module functions to TinyLinalg::Lapack.
6
+ - Add dsygvx, ssygvx, zhegvx, and chegvx module functions to TinyLinalg::Lapack.
7
+ - Add eigh module function to TinyLinalg.
8
+
3
9
  ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
4
10
  - Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
5
11
  - Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
@@ -0,0 +1,167 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGv {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
28
+ class HeGv {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegv), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegv_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegv(na_loop_t* const lp) {
43
+ DType* a = (DType*)NDL_PTR(lp, 0);
44
+ DType* b = (DType*)NDL_PTR(lp, 1);
45
+ RType* w = (RType*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegv(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegv_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+
120
+ static lapack_int get_itype(VALUE val) {
121
+ const lapack_int itype = NUM2INT(val);
122
+
123
+ if (itype != 1 && itype != 2 && itype != 3) {
124
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
+ }
126
+
127
+ return itype;
128
+ }
129
+
130
+ static char get_jobz(VALUE val) {
131
+ const char jobz = NUM2CHR(val);
132
+
133
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
+ }
136
+
137
+ return jobz;
138
+ }
139
+
140
+ static char get_uplo(VALUE val) {
141
+ const char uplo = NUM2CHR(val);
142
+
143
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
+ }
146
+
147
+ return uplo;
148
+ }
149
+
150
+ static int get_matrix_layout(VALUE val) {
151
+ const char option = NUM2CHR(val);
152
+
153
+ switch (option) {
154
+ case 'r':
155
+ case 'R':
156
+ break;
157
+ case 'c':
158
+ case 'C':
159
+ rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
160
+ break;
161
+ }
162
+
163
+ return LAPACK_ROW_MAJOR;
164
+ }
165
+ };
166
+
167
+ } // namespace TinyLinalg
@@ -0,0 +1,167 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGvd {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
28
+ class HeGvd {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvd), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegvd_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegvd(na_loop_t* const lp) {
43
+ DType* a = (DType*)NDL_PTR(lp, 0);
44
+ DType* b = (DType*)NDL_PTR(lp, 1);
45
+ RType* w = (RType*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegvd(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegvd_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+
120
+ static lapack_int get_itype(VALUE val) {
121
+ const lapack_int itype = NUM2INT(val);
122
+
123
+ if (itype != 1 && itype != 2 && itype != 3) {
124
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
+ }
126
+
127
+ return itype;
128
+ }
129
+
130
+ static char get_jobz(VALUE val) {
131
+ const char jobz = NUM2CHR(val);
132
+
133
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
+ }
136
+
137
+ return jobz;
138
+ }
139
+
140
+ static char get_uplo(VALUE val) {
141
+ const char uplo = NUM2CHR(val);
142
+
143
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
+ }
146
+
147
+ return uplo;
148
+ }
149
+
150
+ static int get_matrix_layout(VALUE val) {
151
+ const char option = NUM2CHR(val);
152
+
153
+ switch (option) {
154
+ case 'r':
155
+ case 'R':
156
+ break;
157
+ case 'c':
158
+ case 'C':
159
+ rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
160
+ break;
161
+ }
162
+
163
+ return LAPACK_ROW_MAJOR;
164
+ }
165
+ };
166
+
167
+ } // namespace TinyLinalg
@@ -0,0 +1,193 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_zhegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct CHeGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_chegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
22
+ class HeGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct hegvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ RType vl;
36
+ RType vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_hegvx(na_loop_t* const lp) {
42
+ DType* a = (DType*)NDL_PTR(lp, 0);
43
+ DType* b = (DType*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ RType* w = (RType*)NDL_PTR(lp, 3);
46
+ DType* z = (DType*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const RType abstol = 0.0;
55
+ const lapack_int i = FncType().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_hegvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
64
+
65
+ VALUE a_vnary = Qnil;
66
+ VALUE b_vnary = Qnil;
67
+ VALUE kw_args = Qnil;
68
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
69
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
70
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
71
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
72
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
73
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
74
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
75
+ const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
76
+ const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
77
+ const RType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
+ const RType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
79
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
80
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
81
+ const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
82
+
83
+ if (CLASS_OF(a_vnary) != nary_dtype) {
84
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
85
+ }
86
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
87
+ a_vnary = nary_dup(a_vnary);
88
+ }
89
+ if (CLASS_OF(b_vnary) != nary_dtype) {
90
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
91
+ }
92
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
93
+ b_vnary = nary_dup(b_vnary);
94
+ }
95
+
96
+ narray_t* a_nary = nullptr;
97
+ GetNArray(a_vnary, a_nary);
98
+ if (NA_NDIM(a_nary) != 2) {
99
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
100
+ return Qnil;
101
+ }
102
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
103
+ rb_raise(rb_eArgError, "input array a must be square");
104
+ return Qnil;
105
+ }
106
+ narray_t* b_nary = nullptr;
107
+ GetNArray(a_vnary, b_nary);
108
+ if (NA_NDIM(b_nary) != 2) {
109
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
110
+ return Qnil;
111
+ }
112
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
113
+ rb_raise(rb_eArgError, "input array b must be square");
114
+ return Qnil;
115
+ }
116
+
117
+ const size_t n = NA_SHAPE(a_nary)[1];
118
+ size_t m = range != 'I' ? n : iu - il + 1;
119
+ size_t w_shape[1] = { m };
120
+ size_t z_shape[2] = { n, m };
121
+ size_t ifail_shape[1] = { n };
122
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
123
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
124
+ ndfunc_t ndf = { iter_hegvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
125
+ hegvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
126
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
127
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
128
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
129
+
130
+ RB_GC_GUARD(a_vnary);
131
+ RB_GC_GUARD(b_vnary);
132
+
133
+ return ret;
134
+ }
135
+
136
+ static lapack_int get_itype(VALUE val) {
137
+ const lapack_int itype = NUM2INT(val);
138
+
139
+ if (itype != 1 && itype != 2 && itype != 3) {
140
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
141
+ }
142
+
143
+ return itype;
144
+ }
145
+
146
+ static char get_jobz(VALUE val) {
147
+ const char jobz = NUM2CHR(val);
148
+
149
+ if (jobz != 'N' && jobz != 'V') {
150
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
151
+ }
152
+
153
+ return jobz;
154
+ }
155
+
156
+ static char get_range(VALUE val) {
157
+ const char range = NUM2CHR(val);
158
+
159
+ if (range != 'A' && range != 'V' && range != 'I') {
160
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
161
+ }
162
+
163
+ return range;
164
+ }
165
+
166
+ static char get_uplo(VALUE val) {
167
+ const char uplo = NUM2CHR(val);
168
+
169
+ if (uplo != 'U' && uplo != 'L') {
170
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
171
+ }
172
+
173
+ return uplo;
174
+ }
175
+
176
+ static int get_matrix_layout(VALUE val) {
177
+ const char option = NUM2CHR(val);
178
+
179
+ switch (option) {
180
+ case 'r':
181
+ case 'R':
182
+ break;
183
+ case 'c':
184
+ case 'C':
185
+ rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
186
+ break;
187
+ }
188
+
189
+ return LAPACK_ROW_MAJOR;
190
+ }
191
+ };
192
+
193
+ } // namespace TinyLinalg
@@ -0,0 +1,158 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGv {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename DType, typename FncType>
20
+ class SyGv {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygv), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygv_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygv(na_loop_t* const lp) {
35
+ DType* a = (DType*)NDL_PTR(lp, 0);
36
+ DType* b = (DType*)NDL_PTR(lp, 1);
37
+ DType* w = (DType*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygv(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygv_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+
111
+ static lapack_int get_itype(VALUE val) {
112
+ const lapack_int itype = NUM2INT(val);
113
+
114
+ if (itype != 1 && itype != 2 && itype != 3) {
115
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
+ }
117
+
118
+ return itype;
119
+ }
120
+
121
+ static char get_jobz(VALUE val) {
122
+ const char jobz = NUM2CHR(val);
123
+
124
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
+ }
127
+
128
+ return jobz;
129
+ }
130
+
131
+ static char get_uplo(VALUE val) {
132
+ const char uplo = NUM2CHR(val);
133
+
134
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
+ }
137
+
138
+ return uplo;
139
+ }
140
+
141
+ static int get_matrix_layout(VALUE val) {
142
+ const char option = NUM2CHR(val);
143
+
144
+ switch (option) {
145
+ case 'r':
146
+ case 'R':
147
+ break;
148
+ case 'c':
149
+ case 'C':
150
+ rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
151
+ break;
152
+ }
153
+
154
+ return LAPACK_ROW_MAJOR;
155
+ }
156
+ };
157
+
158
+ } // namespace TinyLinalg
@@ -0,0 +1,158 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGvd {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename DType, typename FncType>
20
+ class SyGvd {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvd), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygvd_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygvd(na_loop_t* const lp) {
35
+ DType* a = (DType*)NDL_PTR(lp, 0);
36
+ DType* b = (DType*)NDL_PTR(lp, 1);
37
+ DType* w = (DType*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygvd(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygvd_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+
111
+ static lapack_int get_itype(VALUE val) {
112
+ const lapack_int itype = NUM2INT(val);
113
+
114
+ if (itype != 1 && itype != 2 && itype != 3) {
115
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
+ }
117
+
118
+ return itype;
119
+ }
120
+
121
+ static char get_jobz(VALUE val) {
122
+ const char jobz = NUM2CHR(val);
123
+
124
+ if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
+ }
127
+
128
+ return jobz;
129
+ }
130
+
131
+ static char get_uplo(VALUE val) {
132
+ const char uplo = NUM2CHR(val);
133
+
134
+ if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
+ }
137
+
138
+ return uplo;
139
+ }
140
+
141
+ static int get_matrix_layout(VALUE val) {
142
+ const char option = NUM2CHR(val);
143
+
144
+ switch (option) {
145
+ case 'r':
146
+ case 'R':
147
+ break;
148
+ case 'c':
149
+ case 'C':
150
+ rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
151
+ break;
152
+ }
153
+
154
+ return LAPACK_ROW_MAJOR;
155
+ }
156
+ };
157
+
158
+ } // namespace TinyLinalg
@@ -0,0 +1,192 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, double* a, lapack_int lda, double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_dsygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct SSyGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, float* a, lapack_int lda, float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_ssygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, typename DType, typename FncType>
22
+ class SyGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct sygvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ DType vl;
36
+ DType vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_sygvx(na_loop_t* const lp) {
42
+ DType* a = (DType*)NDL_PTR(lp, 0);
43
+ DType* b = (DType*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ DType* w = (DType*)NDL_PTR(lp, 3);
46
+ DType* z = (DType*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ sygvx_opt* opt = (sygvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const DType abstol = 0.0;
55
+ const lapack_int i = FncType().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_sygvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+
64
+ VALUE a_vnary = Qnil;
65
+ VALUE b_vnary = Qnil;
66
+ VALUE kw_args = Qnil;
67
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
68
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
69
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
70
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
71
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
72
+ const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
73
+ const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
74
+ const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
75
+ const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
76
+ const DType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
77
+ const DType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
78
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
79
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
80
+ const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
81
+
82
+ if (CLASS_OF(a_vnary) != nary_dtype) {
83
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
84
+ }
85
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
86
+ a_vnary = nary_dup(a_vnary);
87
+ }
88
+ if (CLASS_OF(b_vnary) != nary_dtype) {
89
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
90
+ }
91
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
92
+ b_vnary = nary_dup(b_vnary);
93
+ }
94
+
95
+ narray_t* a_nary = nullptr;
96
+ GetNArray(a_vnary, a_nary);
97
+ if (NA_NDIM(a_nary) != 2) {
98
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
99
+ return Qnil;
100
+ }
101
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
102
+ rb_raise(rb_eArgError, "input array a must be square");
103
+ return Qnil;
104
+ }
105
+ narray_t* b_nary = nullptr;
106
+ GetNArray(a_vnary, b_nary);
107
+ if (NA_NDIM(b_nary) != 2) {
108
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
109
+ return Qnil;
110
+ }
111
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
112
+ rb_raise(rb_eArgError, "input array b must be square");
113
+ return Qnil;
114
+ }
115
+
116
+ const size_t n = NA_SHAPE(a_nary)[1];
117
+ size_t m = range != 'I' ? n : iu - il + 1;
118
+ size_t w_shape[1] = { m };
119
+ size_t z_shape[2] = { n, m };
120
+ size_t ifail_shape[1] = { n };
121
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
122
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
123
+ ndfunc_t ndf = { iter_sygvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
124
+ sygvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
125
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
126
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
127
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
128
+
129
+ RB_GC_GUARD(a_vnary);
130
+ RB_GC_GUARD(b_vnary);
131
+
132
+ return ret;
133
+ }
134
+
135
+ static lapack_int get_itype(VALUE val) {
136
+ const lapack_int itype = NUM2INT(val);
137
+
138
+ if (itype != 1 && itype != 2 && itype != 3) {
139
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
140
+ }
141
+
142
+ return itype;
143
+ }
144
+
145
+ static char get_jobz(VALUE val) {
146
+ const char jobz = NUM2CHR(val);
147
+
148
+ if (jobz != 'N' && jobz != 'V') {
149
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
150
+ }
151
+
152
+ return jobz;
153
+ }
154
+
155
+ static char get_range(VALUE val) {
156
+ const char range = NUM2CHR(val);
157
+
158
+ if (range != 'A' && range != 'V' && range != 'I') {
159
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
160
+ }
161
+
162
+ return range;
163
+ }
164
+
165
+ static char get_uplo(VALUE val) {
166
+ const char uplo = NUM2CHR(val);
167
+
168
+ if (uplo != 'U' && uplo != 'L') {
169
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
170
+ }
171
+
172
+ return uplo;
173
+ }
174
+
175
+ static int get_matrix_layout(VALUE val) {
176
+ const char option = NUM2CHR(val);
177
+
178
+ switch (option) {
179
+ case 'r':
180
+ case 'R':
181
+ break;
182
+ case 'c':
183
+ case 'C':
184
+ rb_warn("Numo::TinyLinalg::Lapack.sygvx does not support column major.");
185
+ break;
186
+ }
187
+
188
+ return LAPACK_ROW_MAJOR;
189
+ }
190
+ };
191
+
192
+ } // namespace TinyLinalg
@@ -11,7 +11,13 @@
11
11
  #include "lapack/gesvd.hpp"
12
12
  #include "lapack/getrf.hpp"
13
13
  #include "lapack/getri.hpp"
14
+ #include "lapack/hegv.hpp"
15
+ #include "lapack/hegvd.hpp"
16
+ #include "lapack/hegvx.hpp"
14
17
  #include "lapack/orgqr.hpp"
18
+ #include "lapack/sygv.hpp"
19
+ #include "lapack/sygvd.hpp"
20
+ #include "lapack/sygvx.hpp"
15
21
  #include "lapack/ungqr.hpp"
16
22
 
17
23
  VALUE rb_mTinyLinalg;
@@ -271,6 +277,18 @@ extern "C" void Init_tiny_linalg(void) {
271
277
  TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
272
278
  TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
273
279
  TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
280
+ TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
281
+ TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
282
+ TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
283
+ TinyLinalg::HeGv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGv>::define_module_function(rb_mTinyLinalgLapack, "chegv");
284
+ TinyLinalg::SyGvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvd>::define_module_function(rb_mTinyLinalgLapack, "dsygvd");
285
+ TinyLinalg::SyGvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvd>::define_module_function(rb_mTinyLinalgLapack, "ssygvd");
286
+ TinyLinalg::HeGvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvd>::define_module_function(rb_mTinyLinalgLapack, "zhegvd");
287
+ TinyLinalg::HeGvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvd>::define_module_function(rb_mTinyLinalgLapack, "chegvd");
288
+ TinyLinalg::SyGvx<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvx>::define_module_function(rb_mTinyLinalgLapack, "dsygvx");
289
+ TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
290
+ TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
291
+ TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
274
292
 
275
293
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
276
294
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.0.3'
8
+ VERSION = '0.0.4'
9
9
  end
10
10
  end
@@ -10,6 +10,50 @@ module Numo
10
10
  module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
+ # by solving an ordinary or generalized eigenvalue problem.
15
+ #
16
+ # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
17
+ # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
18
+ # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
19
+ # @param vals_range [Range/Array]
20
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
21
+ # If nil, all eigenvalues and eigenvectors are computed.
22
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
23
+ # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
24
+ # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
25
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
26
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
27
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
28
+
29
+ bchr = blas_char(a)
30
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
31
+
32
+ unless b.nil?
33
+ raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
34
+ raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
35
+ raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
36
+ end
37
+
38
+ jobz = vals_only ? 'N' : 'V'
39
+ b = a.class.eye(a.shape[0]) if b.nil?
40
+ sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
41
+
42
+ if vals_range.nil?
43
+ sy_he_gv << 'd' if turbo
44
+ vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
45
+ else
46
+ sy_he_gv << 'x'
47
+ il = vals_range.first + 1
48
+ iu = vals_range.last + 1
49
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
50
+ sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
51
+ )
52
+ end
53
+ vecs = nil if vals_only
54
+ [vals, vecs]
55
+ end
56
+
13
57
  # Computes the determinant of matrix.
14
58
  #
15
59
  # @param a [Numo::NArray] n-by-n square matrix.
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.3
4
+ version: 0.0.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-02 00:00:00.000000000 Z
11
+ date: 2023-08-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -52,7 +52,13 @@ files:
52
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
54
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/hegv.hpp
56
+ - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
+ - ext/numo/tiny_linalg/lapack/hegvx.hpp
55
58
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
59
+ - ext/numo/tiny_linalg/lapack/sygv.hpp
60
+ - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
+ - ext/numo/tiny_linalg/lapack/sygvx.hpp
56
62
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
57
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
58
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp