numo-tiny_linalg 0.1.1 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 81e69278b32b0f565e3f5b5e1eae1738a57556773e9306bb3b0c8450c9fb75ea
4
- data.tar.gz: de6daa349cd0d16b14f9c133ad451fac8bacbcd42cca682be68509bd5d8fc416
3
+ metadata.gz: 911223ba40879d779bcc9bc77b840fb7b5e41ed68d6ccf7d26816913c42aed73
4
+ data.tar.gz: 9dce9c353da5466e1104c604c2749369131cb6640e14ddd33c697cee5d973563
5
5
  SHA512:
6
- metadata.gz: 72894605f074e09f47a37834ba1c2b3a3e504e149742df3290fd4c4e8042b1ac0c70d11bd50e4ab83e19b07181d8a656ab57afcaa47de53794fc28a298e1e768
7
- data.tar.gz: c896b561fcbeb0ff19196548a3a0f2f6a06098d706fbfe95b785ffd850b139e87ea5d62c1144cde4c456a7ef5bb2b79422b5c2e0ab909df76f5ea9fa4a4bd446
6
+ metadata.gz: 9d666bdd04f48132005a9aee455047c80d83184815fac5c688ec593412ff8e670f1f31ba3376b1ee2985b8e829680590c1dad6a0ec3bf5366e584a82b523d8e4
7
+ data.tar.gz: 22027c54b56bad7214060ba7d5939b986be660a75afd65d74de7fdae095a6699ceea35422eee56831b621f2132e5bce921cb6b70bc9709e77e3efdf6cfc8a098
data/CHANGELOG.md CHANGED
@@ -1,5 +1,15 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.2.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.2...v0.2.0)] - 2023-08-11
4
+ **Breaking change**
5
+ - Change LAPACK function to call when array b is not given to TinyLinalg.eigh method.
6
+
7
+ ## [[0.1.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.1...v0.1.2)] - 2023-08-09
8
+ - Add dsyev, ssyev, zheev, and cheev module functions to TinyLinalg::Lapack.
9
+ - Add dsyevd, ssyevd, zheevd, and cheevd module functions to TinyLinalg::Lapack.
10
+ - Add dsyevr, ssyevr, zheevr, and cheevr module functions to TinyLinalg::Lapack.
11
+ - Fix the confirmation processs whether the array b is a square matrix or not on TinyLinalg.eigh.
12
+
3
13
  ## [[0.1.1](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.0...v0.1.1)] - 2023-08-07
4
14
  - Fix method of getting start and end of eigenvalue range from vals_range arguement of TinyLinalg.eigh.
5
15
 
@@ -28,5 +38,4 @@
28
38
  - Add inv module function to TinyLinalg.
29
39
 
30
40
  ## [0.0.1] - 2023-07-14
31
-
32
- - Initial release
41
+ - Initial release.
data/README.md CHANGED
@@ -8,6 +8,8 @@
8
8
  Numo::TinyLinalg is a subset library from [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) consisting only of methods used in Machine Learning algorithms.
9
9
  The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, and svd.
10
10
 
11
+ Note that the version numbering rule of Numo::TinyLinalg is not compatible with that of Numo::Linalg.
12
+
11
13
  ## Installation
12
14
  Unlike Numo::Linalg, Numo::TinyLinalg only supports OpenBLAS as a backend library for BLAS and LAPACK.
13
15
 
@@ -0,0 +1,87 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEv {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_zheev(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct CHeEv {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_cheev(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
16
+ class HeEv {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heev), -1);
20
+ }
21
+
22
+ private:
23
+ struct heev_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_heev(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ rtype* w = (rtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ heev_opt* opt = (heev_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_heev(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
43
+
44
+ VALUE a_vnary = Qnil;
45
+ VALUE kw_args = Qnil;
46
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
47
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
48
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
49
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
50
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
51
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
52
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+
61
+ narray_t* a_nary = nullptr;
62
+ GetNArray(a_vnary, a_nary);
63
+ if (NA_NDIM(a_nary) != 2) {
64
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
65
+ return Qnil;
66
+ }
67
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
68
+ rb_raise(rb_eArgError, "input array a must be square");
69
+ return Qnil;
70
+ }
71
+
72
+ const size_t n = NA_SHAPE(a_nary)[1];
73
+ size_t shape[1] = { n };
74
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
75
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
76
+ ndfunc_t ndf = { iter_heev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
77
+ heev_opt opt = { matrix_layout, jobz, uplo };
78
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
79
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
80
+
81
+ RB_GC_GUARD(a_vnary);
82
+
83
+ return ret;
84
+ }
85
+ };
86
+
87
+ } // namespace TinyLinalg
@@ -0,0 +1,87 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEvd {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_zheevd(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct CHeEvd {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_cheevd(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
16
+ class HeEvd {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevd), -1);
20
+ }
21
+
22
+ private:
23
+ struct heevd_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_heevd(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ rtype* w = (rtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ heevd_opt* opt = (heevd_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_heevd(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
43
+
44
+ VALUE a_vnary = Qnil;
45
+ VALUE kw_args = Qnil;
46
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
47
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
48
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
49
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
50
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
51
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
52
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+
61
+ narray_t* a_nary = nullptr;
62
+ GetNArray(a_vnary, a_nary);
63
+ if (NA_NDIM(a_nary) != 2) {
64
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
65
+ return Qnil;
66
+ }
67
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
68
+ rb_raise(rb_eArgError, "input array a must be square");
69
+ return Qnil;
70
+ }
71
+
72
+ const size_t n = NA_SHAPE(a_nary)[1];
73
+ size_t shape[1] = { n };
74
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
75
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
76
+ ndfunc_t ndf = { iter_heevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
77
+ heevd_opt opt = { matrix_layout, jobz, uplo };
78
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
79
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
80
+
81
+ RB_GC_GUARD(a_vnary);
82
+
83
+ return ret;
84
+ }
85
+ };
86
+
87
+ } // namespace TinyLinalg
@@ -0,0 +1,133 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEvr {
4
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
5
+ lapack_int n, lapack_complex_double* a, lapack_int lda, double vl, double vu, lapack_int il,
6
+ lapack_int iu, double abstol, lapack_int* m,
7
+ double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* isuppz) {
8
+ return LAPACKE_zheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
9
+ }
10
+ };
11
+
12
+ struct CHeEvr {
13
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
14
+ lapack_int n, lapack_complex_float* a, lapack_int lda, float vl, float vu, lapack_int il,
15
+ lapack_int iu, float abstol, lapack_int* m,
16
+ float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* isuppz) {
17
+ return LAPACKE_cheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
22
+ class HeEvr {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevr), -1);
26
+ }
27
+
28
+ private:
29
+ struct heevr_opt {
30
+ int matrix_layout;
31
+ char jobz;
32
+ char range;
33
+ char uplo;
34
+ rtype vl;
35
+ rtype vu;
36
+ lapack_int il;
37
+ lapack_int iu;
38
+ };
39
+
40
+ static void iter_heevr(na_loop_t* const lp) {
41
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
42
+ int* m = (int*)NDL_PTR(lp, 1);
43
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
44
+ dtype* z = (dtype*)NDL_PTR(lp, 3);
45
+ int* isuppz = (int*)NDL_PTR(lp, 4);
46
+ int* info = (int*)NDL_PTR(lp, 5);
47
+ heevr_opt* opt = (heevr_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
51
+ const rtype abstol = 0.0;
52
+ const lapack_int i = LapackFn().call(
53
+ opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
54
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
55
+ *info = static_cast<int>(i);
56
+ }
57
+
58
+ static VALUE tiny_linalg_heevr(int argc, VALUE* argv, VALUE self) {
59
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
60
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
61
+
62
+ VALUE a_vnary = Qnil;
63
+ VALUE kw_args = Qnil;
64
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
65
+ ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
66
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
67
+ VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
68
+ rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
69
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
70
+ const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
71
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
72
+ const rtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
73
+ const rtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
74
+ const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
75
+ const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
76
+ const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
77
+
78
+ if (CLASS_OF(a_vnary) != nary_dtype) {
79
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
80
+ }
81
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
82
+ a_vnary = nary_dup(a_vnary);
83
+ }
84
+
85
+ narray_t* a_nary = nullptr;
86
+ GetNArray(a_vnary, a_nary);
87
+ if (NA_NDIM(a_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array a must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ if (range == 'V' && vu <= vl) {
97
+ rb_raise(rb_eArgError, "vu must be greater than vl");
98
+ return Qnil;
99
+ }
100
+
101
+ const size_t n = NA_SHAPE(a_nary)[1];
102
+ if (range == 'I' && (il < 1 || il > n)) {
103
+ rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
104
+ return Qnil;
105
+ }
106
+ if (range == 'I' && (iu < 1 || iu > n)) {
107
+ rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
108
+ return Qnil;
109
+ }
110
+ if (range == 'I' && iu < il) {
111
+ rb_raise(rb_eArgError, "iu must be greater than or equal to il");
112
+ return Qnil;
113
+ }
114
+
115
+ size_t m = range != 'I' ? n : iu - il + 1;
116
+ size_t w_shape[1] = { m };
117
+ size_t z_shape[2] = { n, m };
118
+ size_t isuppz_shape[1] = { 2 * m };
119
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
120
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
121
+ ndfunc_t ndf = { iter_heevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
122
+ heevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
123
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
124
+ VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
125
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
126
+
127
+ RB_GC_GUARD(a_vnary);
128
+
129
+ return ret;
130
+ }
131
+ };
132
+
133
+ } // namespace TinyLinalg
@@ -92,7 +92,7 @@ private:
92
92
  return Qnil;
93
93
  }
94
94
  narray_t* b_nary = nullptr;
95
- GetNArray(a_vnary, b_nary);
95
+ GetNArray(b_vnary, b_nary);
96
96
  if (NA_NDIM(b_nary) != 2) {
97
97
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
98
  return Qnil;
@@ -92,7 +92,7 @@ private:
92
92
  return Qnil;
93
93
  }
94
94
  narray_t* b_nary = nullptr;
95
- GetNArray(a_vnary, b_nary);
95
+ GetNArray(b_vnary, b_nary);
96
96
  if (NA_NDIM(b_nary) != 2) {
97
97
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
98
  return Qnil;
@@ -104,7 +104,7 @@ private:
104
104
  return Qnil;
105
105
  }
106
106
  narray_t* b_nary = nullptr;
107
- GetNArray(a_vnary, b_nary);
107
+ GetNArray(b_vnary, b_nary);
108
108
  if (NA_NDIM(b_nary) != 2) {
109
109
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
110
110
  return Qnil;
@@ -114,7 +114,25 @@ private:
114
114
  return Qnil;
115
115
  }
116
116
 
117
+ if (range == 'V' && vu <= vl) {
118
+ rb_raise(rb_eArgError, "vu must be greater than vl");
119
+ return Qnil;
120
+ }
121
+
117
122
  const size_t n = NA_SHAPE(a_nary)[1];
123
+ if (range == 'I' && (il < 1 || il > n)) {
124
+ rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
125
+ return Qnil;
126
+ }
127
+ if (range == 'I' && (iu < 1 || iu > n)) {
128
+ rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
129
+ return Qnil;
130
+ }
131
+ if (range == 'I' && iu < il) {
132
+ rb_raise(rb_eArgError, "il must be less than or equal to iu");
133
+ return Qnil;
134
+ }
135
+
118
136
  size_t m = range != 'I' ? n : iu - il + 1;
119
137
  size_t w_shape[1] = { m };
120
138
  size_t z_shape[2] = { n, m };
@@ -0,0 +1,86 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEv {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_dsyev(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct SSyEv {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_ssyev(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, typename dtype, class LapackFn>
16
+ class SyEv {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syev), -1);
20
+ }
21
+
22
+ private:
23
+ struct syev_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_syev(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* w = (dtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ syev_opt* opt = (syev_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_syev(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+
43
+ VALUE a_vnary = Qnil;
44
+ VALUE kw_args = Qnil;
45
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
46
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
47
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
48
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
49
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
50
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
51
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
52
+
53
+ if (CLASS_OF(a_vnary) != nary_dtype) {
54
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
55
+ }
56
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
57
+ a_vnary = nary_dup(a_vnary);
58
+ }
59
+
60
+ narray_t* a_nary = nullptr;
61
+ GetNArray(a_vnary, a_nary);
62
+ if (NA_NDIM(a_nary) != 2) {
63
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
64
+ return Qnil;
65
+ }
66
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
67
+ rb_raise(rb_eArgError, "input array a must be square");
68
+ return Qnil;
69
+ }
70
+
71
+ const size_t n = NA_SHAPE(a_nary)[1];
72
+ size_t shape[1] = { n };
73
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
74
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
75
+ ndfunc_t ndf = { iter_syev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
76
+ syev_opt opt = { matrix_layout, jobz, uplo };
77
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
78
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
79
+
80
+ RB_GC_GUARD(a_vnary);
81
+
82
+ return ret;
83
+ }
84
+ };
85
+
86
+ } // namespace TinyLinalg
@@ -0,0 +1,86 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEvd {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_dsyevd(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct SSyEvd {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_ssyevd(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, typename dtype, class LapackFn>
16
+ class SyEvd {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevd), -1);
20
+ }
21
+
22
+ private:
23
+ struct syevd_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_syevd(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* w = (dtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ syevd_opt* opt = (syevd_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_syevd(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+
43
+ VALUE a_vnary = Qnil;
44
+ VALUE kw_args = Qnil;
45
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
46
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
47
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
48
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
49
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
50
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
51
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
52
+
53
+ if (CLASS_OF(a_vnary) != nary_dtype) {
54
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
55
+ }
56
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
57
+ a_vnary = nary_dup(a_vnary);
58
+ }
59
+
60
+ narray_t* a_nary = nullptr;
61
+ GetNArray(a_vnary, a_nary);
62
+ if (NA_NDIM(a_nary) != 2) {
63
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
64
+ return Qnil;
65
+ }
66
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
67
+ rb_raise(rb_eArgError, "input array a must be square");
68
+ return Qnil;
69
+ }
70
+
71
+ const size_t n = NA_SHAPE(a_nary)[1];
72
+ size_t shape[1] = { n };
73
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
74
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
75
+ ndfunc_t ndf = { iter_syevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
76
+ syevd_opt opt = { matrix_layout, jobz, uplo };
77
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
78
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
79
+
80
+ RB_GC_GUARD(a_vnary);
81
+
82
+ return ret;
83
+ }
84
+ };
85
+
86
+ } // namespace TinyLinalg
@@ -0,0 +1,130 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEvr {
4
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
5
+ lapack_int n, double* a, lapack_int lda, double vl, double vu, lapack_int il, lapack_int iu,
6
+ double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* isuppz) {
7
+ return LAPACKE_dsyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
8
+ }
9
+ };
10
+
11
+ struct SSyEvr {
12
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
13
+ lapack_int n, float* a, lapack_int lda, float vl, float vu, lapack_int il, lapack_int iu,
14
+ float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* isuppz) {
15
+ return LAPACKE_ssyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
+ class SyEvr {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevr), -1);
24
+ }
25
+
26
+ private:
27
+ struct syevr_opt {
28
+ int matrix_layout;
29
+ char jobz;
30
+ char range;
31
+ char uplo;
32
+ dtype vl;
33
+ dtype vu;
34
+ lapack_int il;
35
+ lapack_int iu;
36
+ };
37
+
38
+ static void iter_syevr(na_loop_t* const lp) {
39
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
40
+ int* m = (int*)NDL_PTR(lp, 1);
41
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
42
+ dtype* z = (dtype*)NDL_PTR(lp, 3);
43
+ int* isuppz = (int*)NDL_PTR(lp, 4);
44
+ int* info = (int*)NDL_PTR(lp, 5);
45
+ syevr_opt* opt = (syevr_opt*)(lp->opt_ptr);
46
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
47
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
48
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
49
+ const dtype abstol = 0.0;
50
+ const lapack_int i = LapackFn().call(
51
+ opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
52
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
53
+ *info = static_cast<int>(i);
54
+ }
55
+
56
+ static VALUE tiny_linalg_syevr(int argc, VALUE* argv, VALUE self) {
57
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE kw_args = Qnil;
61
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
62
+ ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
63
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
64
+ VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
66
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
67
+ const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const dtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
70
+ const dtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
71
+ const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
72
+ const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
73
+ const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
74
+
75
+ if (CLASS_OF(a_vnary) != nary_dtype) {
76
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
77
+ }
78
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
79
+ a_vnary = nary_dup(a_vnary);
80
+ }
81
+
82
+ narray_t* a_nary = nullptr;
83
+ GetNArray(a_vnary, a_nary);
84
+ if (NA_NDIM(a_nary) != 2) {
85
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
86
+ return Qnil;
87
+ }
88
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
89
+ rb_raise(rb_eArgError, "input array a must be square");
90
+ return Qnil;
91
+ }
92
+
93
+ if (range == 'V' && vu <= vl) {
94
+ rb_raise(rb_eArgError, "vu must be greater than vl");
95
+ return Qnil;
96
+ }
97
+
98
+ const size_t n = NA_SHAPE(a_nary)[1];
99
+ if (range == 'I' && (il < 1 || il > n)) {
100
+ rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
101
+ return Qnil;
102
+ }
103
+ if (range == 'I' && (iu < 1 || iu > n)) {
104
+ rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
105
+ return Qnil;
106
+ }
107
+ if (range == 'I' && iu < il) {
108
+ rb_raise(rb_eArgError, "iu must be greater than or equal to il");
109
+ return Qnil;
110
+ }
111
+
112
+ size_t m = range != 'I' ? n : iu - il + 1;
113
+ size_t w_shape[1] = { m };
114
+ size_t z_shape[2] = { n, m };
115
+ size_t isuppz_shape[1] = { 2 * m };
116
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
117
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
118
+ ndfunc_t ndf = { iter_syevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
119
+ syevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
120
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
121
+ VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
122
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
123
+
124
+ RB_GC_GUARD(a_vnary);
125
+
126
+ return ret;
127
+ }
128
+ };
129
+
130
+ } // namespace TinyLinalg
@@ -83,7 +83,7 @@ private:
83
83
  return Qnil;
84
84
  }
85
85
  narray_t* b_nary = nullptr;
86
- GetNArray(a_vnary, b_nary);
86
+ GetNArray(b_vnary, b_nary);
87
87
  if (NA_NDIM(b_nary) != 2) {
88
88
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
89
  return Qnil;
@@ -83,7 +83,7 @@ private:
83
83
  return Qnil;
84
84
  }
85
85
  narray_t* b_nary = nullptr;
86
- GetNArray(a_vnary, b_nary);
86
+ GetNArray(b_vnary, b_nary);
87
87
  if (NA_NDIM(b_nary) != 2) {
88
88
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
89
  return Qnil;
@@ -103,7 +103,7 @@ private:
103
103
  return Qnil;
104
104
  }
105
105
  narray_t* b_nary = nullptr;
106
- GetNArray(a_vnary, b_nary);
106
+ GetNArray(b_vnary, b_nary);
107
107
  if (NA_NDIM(b_nary) != 2) {
108
108
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
109
109
  return Qnil;
@@ -113,7 +113,25 @@ private:
113
113
  return Qnil;
114
114
  }
115
115
 
116
+ if (range == 'V' && vu <= vl) {
117
+ rb_raise(rb_eArgError, "vu must be greater than vl");
118
+ return Qnil;
119
+ }
120
+
116
121
  const size_t n = NA_SHAPE(a_nary)[1];
122
+ if (range == 'I' && (il < 1 || il > n)) {
123
+ rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
124
+ return Qnil;
125
+ }
126
+ if (range == 'I' && (iu < 1 || iu > n)) {
127
+ rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
128
+ return Qnil;
129
+ }
130
+ if (range == 'I' && iu < il) {
131
+ rb_raise(rb_eArgError, "iu must be greater than or equal to il");
132
+ return Qnil;
133
+ }
134
+
117
135
  size_t m = range != 'I' ? n : iu - il + 1;
118
136
  size_t w_shape[1] = { m };
119
137
  size_t z_shape[2] = { n, m };
@@ -44,10 +44,16 @@
44
44
  #include "lapack/gesvd.hpp"
45
45
  #include "lapack/getrf.hpp"
46
46
  #include "lapack/getri.hpp"
47
+ #include "lapack/heev.hpp"
48
+ #include "lapack/heevd.hpp"
49
+ #include "lapack/heevr.hpp"
47
50
  #include "lapack/hegv.hpp"
48
51
  #include "lapack/hegvd.hpp"
49
52
  #include "lapack/hegvx.hpp"
50
53
  #include "lapack/orgqr.hpp"
54
+ #include "lapack/syev.hpp"
55
+ #include "lapack/syevd.hpp"
56
+ #include "lapack/syevr.hpp"
51
57
  #include "lapack/sygv.hpp"
52
58
  #include "lapack/sygvd.hpp"
53
59
  #include "lapack/sygvx.hpp"
@@ -312,6 +318,18 @@ extern "C" void Init_tiny_linalg(void) {
312
318
  TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
313
319
  TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
314
320
  TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
321
+ TinyLinalg::SyEv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEv>::define_module_function(rb_mTinyLinalgLapack, "dsyev");
322
+ TinyLinalg::SyEv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEv>::define_module_function(rb_mTinyLinalgLapack, "ssyev");
323
+ TinyLinalg::HeEv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEv>::define_module_function(rb_mTinyLinalgLapack, "zheev");
324
+ TinyLinalg::HeEv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEv>::define_module_function(rb_mTinyLinalgLapack, "cheev");
325
+ TinyLinalg::SyEvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvd>::define_module_function(rb_mTinyLinalgLapack, "dsyevd");
326
+ TinyLinalg::SyEvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvd>::define_module_function(rb_mTinyLinalgLapack, "ssyevd");
327
+ TinyLinalg::HeEvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvd>::define_module_function(rb_mTinyLinalgLapack, "zheevd");
328
+ TinyLinalg::HeEvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvd>::define_module_function(rb_mTinyLinalgLapack, "cheevd");
329
+ TinyLinalg::SyEvr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvr>::define_module_function(rb_mTinyLinalgLapack, "dsyevr");
330
+ TinyLinalg::SyEvr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvr>::define_module_function(rb_mTinyLinalgLapack, "ssyevr");
331
+ TinyLinalg::HeEvr<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvr>::define_module_function(rb_mTinyLinalgLapack, "zheevr");
332
+ TinyLinalg::HeEvr<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvr>::define_module_function(rb_mTinyLinalgLapack, "cheevr");
315
333
  TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
316
334
  TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
317
335
  TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.1.1'
8
+ VERSION = '0.2.0'
9
9
  end
10
10
  end
@@ -48,35 +48,50 @@ module Numo
48
48
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
49
49
  # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
50
50
  # @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
51
- def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
51
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/ParameterLists, Metrics/PerceivedComplexity, Lint/UnusedMethodArgument
52
52
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
53
53
  raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
54
54
 
55
+ b_given = !b.nil?
56
+ raise ArgumentError, 'input array b must be 2-dimensional' if b_given && b.ndim != 2
57
+ raise ArgumentError, 'input array b must be square' if b_given && b.shape[0] != b.shape[1]
58
+ raise ArgumentError, "invalid array type: #{b.class}" if b_given && blas_char(b) == 'n'
59
+
55
60
  bchr = blas_char(a)
56
61
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
57
62
 
58
- unless b.nil?
59
- raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
60
- raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
61
- raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
62
- end
63
-
64
63
  jobz = vals_only ? 'N' : 'V'
65
- b = a.class.eye(a.shape[0]) if b.nil?
66
- sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
67
64
 
68
- if vals_range.nil?
69
- sy_he_gv << 'd' if turbo
70
- vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
65
+ if b_given
66
+ fnc = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
67
+ if vals_range.nil?
68
+ fnc << 'd' if turbo
69
+ vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(fnc.to_sym, a.dup, b.dup, jobz: jobz)
70
+ else
71
+ fnc << 'x'
72
+ il = vals_range.first(1)[0] + 1
73
+ iu = vals_range.last(1)[0] + 1
74
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
75
+ fnc.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
76
+ )
77
+ end
71
78
  else
72
- sy_he_gv << 'x'
73
- il = vals_range.first(1)[0] + 1
74
- iu = vals_range.last(1)[0] + 1
75
- _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
76
- sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
77
- )
79
+ fnc = %w[d s].include?(bchr) ? "#{bchr}syev" : "#{bchr}heev"
80
+ if vals_range.nil?
81
+ fnc << 'd' if turbo
82
+ vecs, vals, _info = Numo::TinyLinalg::Lapack.send(fnc.to_sym, a.dup, jobz: jobz)
83
+ else
84
+ fnc << 'r'
85
+ il = vals_range.first(1)[0] + 1
86
+ iu = vals_range.last(1)[0] + 1
87
+ _a, _m, vals, vecs, _isuppz, _info = Numo::TinyLinalg::Lapack.send(
88
+ fnc.to_sym, a.dup, jobz: jobz, range: 'I', il: il, iu: iu
89
+ )
90
+ end
78
91
  end
92
+
79
93
  vecs = nil if vals_only
94
+
80
95
  [vals, vecs]
81
96
  end
82
97
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-07 00:00:00.000000000 Z
11
+ date: 2023-08-10 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -52,10 +52,16 @@ files:
52
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
54
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/heev.hpp
56
+ - ext/numo/tiny_linalg/lapack/heevd.hpp
57
+ - ext/numo/tiny_linalg/lapack/heevr.hpp
55
58
  - ext/numo/tiny_linalg/lapack/hegv.hpp
56
59
  - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
60
  - ext/numo/tiny_linalg/lapack/hegvx.hpp
58
61
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
62
+ - ext/numo/tiny_linalg/lapack/syev.hpp
63
+ - ext/numo/tiny_linalg/lapack/syevd.hpp
64
+ - ext/numo/tiny_linalg/lapack/syevr.hpp
59
65
  - ext/numo/tiny_linalg/lapack/sygv.hpp
60
66
  - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
67
  - ext/numo/tiny_linalg/lapack/sygvx.hpp