numo-tiny_linalg 0.1.0 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7a298eaec7ee7338e4856ac50b22a06bf774073456a01f5ffb2a405cab632f7a
4
- data.tar.gz: a44ae42f723ee7c9a1af6c2b1a01e6257f1fc417a5cd2a22cb6be51ac0589c1c
3
+ metadata.gz: 0aeb68c259f8d4c4dcd44c15caef6926cbe3f26f9dd99e9229721d02a939d7c1
4
+ data.tar.gz: 5465aebebc07612812b861932fcabf14b66ee0ec389603c9ffcf06694eb69ce4
5
5
  SHA512:
6
- metadata.gz: fce33aa331257bc37d7a972e09a0cf1d9328046459c9de6b9feaf8341d0d4545d202b90d6bb0e5e6ee097b4c9c58a4bc76fc0224fd87dc10818c8f566259ab7e
7
- data.tar.gz: 6da516469d0fdfb47884fea2d9fc2483b5ea066063ce1681afc8e9319d1cfcc6863f59e96a4ce1d7cb22ceb159a8b2f8f6e41478958ea245a5eaaf209753ea2c
6
+ metadata.gz: 3e1ddd9a4f43d89635d0fc4e07fc3eab0e302b108648e00506deaec106f147f10132b6c28818f6ef054d5fa0ef52bb2c33287dcdeb413fc12268fb35edfde859
7
+ data.tar.gz: 78903f772d37936f73a69624b3241d43703632f08542a2f57e38a619b6dc1ac012d8885c79056459ef4fd50102d416f996ea54269b64b919ac91a3db46330d8c
data/CHANGELOG.md CHANGED
@@ -1,5 +1,14 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.1.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.1...v0.1.2)] - 2023-08-09
4
+ - Add dsyev, ssyev, zheev, and cheev module functions to TinyLinalg::Lapack.
5
+ - Add dsyevd, ssyevd, zheevd, and cheevd module functions to TinyLinalg::Lapack.
6
+ - Add dsyevr, ssyevr, zheevr, and cheevr module functions to TinyLinalg::Lapack.
7
+ - Fix the confirmation processs whether the array b is a square matrix or not on TinyLinalg.eigh.
8
+
9
+ ## [[0.1.1](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.0...v0.1.1)] - 2023-08-07
10
+ - Fix method of getting start and end of eigenvalue range from vals_range arguement of TinyLinalg.eigh.
11
+
3
12
  ## [[0.1.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.4...v0.1.0)] - 2023-08-06
4
13
  - Refactor codes and update documentations.
5
14
 
@@ -25,5 +34,4 @@
25
34
  - Add inv module function to TinyLinalg.
26
35
 
27
36
  ## [0.0.1] - 2023-07-14
28
-
29
- - Initial release
37
+ - Initial release.
@@ -0,0 +1,87 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEv {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_zheev(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct CHeEv {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_cheev(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
16
+ class HeEv {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heev), -1);
20
+ }
21
+
22
+ private:
23
+ struct heev_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_heev(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ rtype* w = (rtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ heev_opt* opt = (heev_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_heev(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
43
+
44
+ VALUE a_vnary = Qnil;
45
+ VALUE kw_args = Qnil;
46
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
47
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
48
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
49
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
50
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
51
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
52
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+
61
+ narray_t* a_nary = nullptr;
62
+ GetNArray(a_vnary, a_nary);
63
+ if (NA_NDIM(a_nary) != 2) {
64
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
65
+ return Qnil;
66
+ }
67
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
68
+ rb_raise(rb_eArgError, "input array a must be square");
69
+ return Qnil;
70
+ }
71
+
72
+ const size_t n = NA_SHAPE(a_nary)[1];
73
+ size_t shape[1] = { n };
74
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
75
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
76
+ ndfunc_t ndf = { iter_heev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
77
+ heev_opt opt = { matrix_layout, jobz, uplo };
78
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
79
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
80
+
81
+ RB_GC_GUARD(a_vnary);
82
+
83
+ return ret;
84
+ }
85
+ };
86
+
87
+ } // namespace TinyLinalg
@@ -0,0 +1,87 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEvd {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_zheevd(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct CHeEvd {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_cheevd(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
16
+ class HeEvd {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevd), -1);
20
+ }
21
+
22
+ private:
23
+ struct heevd_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_heevd(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ rtype* w = (rtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ heevd_opt* opt = (heevd_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_heevd(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
43
+
44
+ VALUE a_vnary = Qnil;
45
+ VALUE kw_args = Qnil;
46
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
47
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
48
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
49
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
50
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
51
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
52
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+
61
+ narray_t* a_nary = nullptr;
62
+ GetNArray(a_vnary, a_nary);
63
+ if (NA_NDIM(a_nary) != 2) {
64
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
65
+ return Qnil;
66
+ }
67
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
68
+ rb_raise(rb_eArgError, "input array a must be square");
69
+ return Qnil;
70
+ }
71
+
72
+ const size_t n = NA_SHAPE(a_nary)[1];
73
+ size_t shape[1] = { n };
74
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
75
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
76
+ ndfunc_t ndf = { iter_heevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
77
+ heevd_opt opt = { matrix_layout, jobz, uplo };
78
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
79
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
80
+
81
+ RB_GC_GUARD(a_vnary);
82
+
83
+ return ret;
84
+ }
85
+ };
86
+
87
+ } // namespace TinyLinalg
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEvr {
4
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
5
+ lapack_int n, lapack_complex_double* a, lapack_int lda, double vl, double vu, lapack_int il,
6
+ lapack_int iu, double abstol, lapack_int* m,
7
+ double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* isuppz) {
8
+ return LAPACKE_zheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
9
+ }
10
+ };
11
+
12
+ struct CHeEvr {
13
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
14
+ lapack_int n, lapack_complex_float* a, lapack_int lda, float vl, float vu, lapack_int il,
15
+ lapack_int iu, float abstol, lapack_int* m,
16
+ float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* isuppz) {
17
+ return LAPACKE_cheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
22
+ class HeEvr {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevr), -1);
26
+ }
27
+
28
+ private:
29
+ struct heevr_opt {
30
+ int matrix_layout;
31
+ char jobz;
32
+ char range;
33
+ char uplo;
34
+ rtype vl;
35
+ rtype vu;
36
+ lapack_int il;
37
+ lapack_int iu;
38
+ };
39
+
40
+ static void iter_heevr(na_loop_t* const lp) {
41
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
42
+ int* m = (int*)NDL_PTR(lp, 1);
43
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
44
+ dtype* z = (dtype*)NDL_PTR(lp, 3);
45
+ int* isuppz = (int*)NDL_PTR(lp, 4);
46
+ int* info = (int*)NDL_PTR(lp, 5);
47
+ heevr_opt* opt = (heevr_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
51
+ const rtype abstol = 0.0;
52
+ const lapack_int i = LapackFn().call(
53
+ opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
54
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
55
+ *info = static_cast<int>(i);
56
+ }
57
+
58
+ static VALUE tiny_linalg_heevr(int argc, VALUE* argv, VALUE self) {
59
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
60
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
61
+
62
+ VALUE a_vnary = Qnil;
63
+ VALUE kw_args = Qnil;
64
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
65
+ ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
66
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
67
+ VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
68
+ rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
69
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
70
+ const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
71
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
72
+ const rtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
73
+ const rtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
74
+ const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
75
+ const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
76
+ const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
77
+
78
+ if (CLASS_OF(a_vnary) != nary_dtype) {
79
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
80
+ }
81
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
82
+ a_vnary = nary_dup(a_vnary);
83
+ }
84
+
85
+ narray_t* a_nary = nullptr;
86
+ GetNArray(a_vnary, a_nary);
87
+ if (NA_NDIM(a_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array a must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t m = range != 'I' ? n : iu - il + 1;
98
+ size_t w_shape[1] = { m };
99
+ size_t z_shape[2] = { n, m };
100
+ size_t isuppz_shape[1] = { 2 * m };
101
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
102
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
103
+ ndfunc_t ndf = { iter_heevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
104
+ heevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
105
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
106
+ VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
107
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
108
+
109
+ RB_GC_GUARD(a_vnary);
110
+
111
+ return ret;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -92,7 +92,7 @@ private:
92
92
  return Qnil;
93
93
  }
94
94
  narray_t* b_nary = nullptr;
95
- GetNArray(a_vnary, b_nary);
95
+ GetNArray(b_vnary, b_nary);
96
96
  if (NA_NDIM(b_nary) != 2) {
97
97
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
98
  return Qnil;
@@ -92,7 +92,7 @@ private:
92
92
  return Qnil;
93
93
  }
94
94
  narray_t* b_nary = nullptr;
95
- GetNArray(a_vnary, b_nary);
95
+ GetNArray(b_vnary, b_nary);
96
96
  if (NA_NDIM(b_nary) != 2) {
97
97
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
98
  return Qnil;
@@ -104,7 +104,7 @@ private:
104
104
  return Qnil;
105
105
  }
106
106
  narray_t* b_nary = nullptr;
107
- GetNArray(a_vnary, b_nary);
107
+ GetNArray(b_vnary, b_nary);
108
108
  if (NA_NDIM(b_nary) != 2) {
109
109
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
110
110
  return Qnil;
@@ -0,0 +1,86 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEv {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_dsyev(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct SSyEv {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_ssyev(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, typename dtype, class LapackFn>
16
+ class SyEv {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syev), -1);
20
+ }
21
+
22
+ private:
23
+ struct syev_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_syev(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* w = (dtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ syev_opt* opt = (syev_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_syev(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+
43
+ VALUE a_vnary = Qnil;
44
+ VALUE kw_args = Qnil;
45
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
46
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
47
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
48
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
49
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
50
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
51
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
52
+
53
+ if (CLASS_OF(a_vnary) != nary_dtype) {
54
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
55
+ }
56
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
57
+ a_vnary = nary_dup(a_vnary);
58
+ }
59
+
60
+ narray_t* a_nary = nullptr;
61
+ GetNArray(a_vnary, a_nary);
62
+ if (NA_NDIM(a_nary) != 2) {
63
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
64
+ return Qnil;
65
+ }
66
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
67
+ rb_raise(rb_eArgError, "input array a must be square");
68
+ return Qnil;
69
+ }
70
+
71
+ const size_t n = NA_SHAPE(a_nary)[1];
72
+ size_t shape[1] = { n };
73
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
74
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
75
+ ndfunc_t ndf = { iter_syev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
76
+ syev_opt opt = { matrix_layout, jobz, uplo };
77
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
78
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
79
+
80
+ RB_GC_GUARD(a_vnary);
81
+
82
+ return ret;
83
+ }
84
+ };
85
+
86
+ } // namespace TinyLinalg
@@ -0,0 +1,86 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEvd {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_dsyevd(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct SSyEvd {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_ssyevd(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, typename dtype, class LapackFn>
16
+ class SyEvd {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevd), -1);
20
+ }
21
+
22
+ private:
23
+ struct syevd_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_syevd(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* w = (dtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ syevd_opt* opt = (syevd_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_syevd(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+
43
+ VALUE a_vnary = Qnil;
44
+ VALUE kw_args = Qnil;
45
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
46
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
47
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
48
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
49
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
50
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
51
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
52
+
53
+ if (CLASS_OF(a_vnary) != nary_dtype) {
54
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
55
+ }
56
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
57
+ a_vnary = nary_dup(a_vnary);
58
+ }
59
+
60
+ narray_t* a_nary = nullptr;
61
+ GetNArray(a_vnary, a_nary);
62
+ if (NA_NDIM(a_nary) != 2) {
63
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
64
+ return Qnil;
65
+ }
66
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
67
+ rb_raise(rb_eArgError, "input array a must be square");
68
+ return Qnil;
69
+ }
70
+
71
+ const size_t n = NA_SHAPE(a_nary)[1];
72
+ size_t shape[1] = { n };
73
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
74
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
75
+ ndfunc_t ndf = { iter_syevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
76
+ syevd_opt opt = { matrix_layout, jobz, uplo };
77
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
78
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
79
+
80
+ RB_GC_GUARD(a_vnary);
81
+
82
+ return ret;
83
+ }
84
+ };
85
+
86
+ } // namespace TinyLinalg
@@ -0,0 +1,112 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEvr {
4
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
5
+ lapack_int n, double* a, lapack_int lda, double vl, double vu, lapack_int il, lapack_int iu,
6
+ double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* isuppz) {
7
+ return LAPACKE_dsyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
8
+ }
9
+ };
10
+
11
+ struct SSyEvr {
12
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
13
+ lapack_int n, float* a, lapack_int lda, float vl, float vu, lapack_int il, lapack_int iu,
14
+ float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* isuppz) {
15
+ return LAPACKE_ssyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
+ class SyEvr {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevr), -1);
24
+ }
25
+
26
+ private:
27
+ struct syevr_opt {
28
+ int matrix_layout;
29
+ char jobz;
30
+ char range;
31
+ char uplo;
32
+ dtype vl;
33
+ dtype vu;
34
+ lapack_int il;
35
+ lapack_int iu;
36
+ };
37
+
38
+ static void iter_syevr(na_loop_t* const lp) {
39
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
40
+ int* m = (int*)NDL_PTR(lp, 1);
41
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
42
+ dtype* z = (dtype*)NDL_PTR(lp, 3);
43
+ int* isuppz = (int*)NDL_PTR(lp, 4);
44
+ int* info = (int*)NDL_PTR(lp, 5);
45
+ syevr_opt* opt = (syevr_opt*)(lp->opt_ptr);
46
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
47
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
48
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
49
+ const dtype abstol = 0.0;
50
+ const lapack_int i = LapackFn().call(
51
+ opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
52
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
53
+ *info = static_cast<int>(i);
54
+ }
55
+
56
+ static VALUE tiny_linalg_syevr(int argc, VALUE* argv, VALUE self) {
57
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE kw_args = Qnil;
61
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
62
+ ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
63
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
64
+ VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
66
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
67
+ const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const dtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
70
+ const dtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
71
+ const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
72
+ const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
73
+ const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
74
+
75
+ if (CLASS_OF(a_vnary) != nary_dtype) {
76
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
77
+ }
78
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
79
+ a_vnary = nary_dup(a_vnary);
80
+ }
81
+
82
+ narray_t* a_nary = nullptr;
83
+ GetNArray(a_vnary, a_nary);
84
+ if (NA_NDIM(a_nary) != 2) {
85
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
86
+ return Qnil;
87
+ }
88
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
89
+ rb_raise(rb_eArgError, "input array a must be square");
90
+ return Qnil;
91
+ }
92
+
93
+ const size_t n = NA_SHAPE(a_nary)[1];
94
+ size_t m = range != 'I' ? n : iu - il + 1;
95
+ size_t w_shape[1] = { m };
96
+ size_t z_shape[2] = { n, m };
97
+ size_t isuppz_shape[1] = { 2 * m };
98
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_syevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
101
+ syevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
103
+ VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
104
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
105
+
106
+ RB_GC_GUARD(a_vnary);
107
+
108
+ return ret;
109
+ }
110
+ };
111
+
112
+ } // namespace TinyLinalg
@@ -83,7 +83,7 @@ private:
83
83
  return Qnil;
84
84
  }
85
85
  narray_t* b_nary = nullptr;
86
- GetNArray(a_vnary, b_nary);
86
+ GetNArray(b_vnary, b_nary);
87
87
  if (NA_NDIM(b_nary) != 2) {
88
88
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
89
  return Qnil;
@@ -83,7 +83,7 @@ private:
83
83
  return Qnil;
84
84
  }
85
85
  narray_t* b_nary = nullptr;
86
- GetNArray(a_vnary, b_nary);
86
+ GetNArray(b_vnary, b_nary);
87
87
  if (NA_NDIM(b_nary) != 2) {
88
88
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
89
  return Qnil;
@@ -103,7 +103,7 @@ private:
103
103
  return Qnil;
104
104
  }
105
105
  narray_t* b_nary = nullptr;
106
- GetNArray(a_vnary, b_nary);
106
+ GetNArray(b_vnary, b_nary);
107
107
  if (NA_NDIM(b_nary) != 2) {
108
108
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
109
109
  return Qnil;
@@ -44,10 +44,16 @@
44
44
  #include "lapack/gesvd.hpp"
45
45
  #include "lapack/getrf.hpp"
46
46
  #include "lapack/getri.hpp"
47
+ #include "lapack/heev.hpp"
48
+ #include "lapack/heevd.hpp"
49
+ #include "lapack/heevr.hpp"
47
50
  #include "lapack/hegv.hpp"
48
51
  #include "lapack/hegvd.hpp"
49
52
  #include "lapack/hegvx.hpp"
50
53
  #include "lapack/orgqr.hpp"
54
+ #include "lapack/syev.hpp"
55
+ #include "lapack/syevd.hpp"
56
+ #include "lapack/syevr.hpp"
51
57
  #include "lapack/sygv.hpp"
52
58
  #include "lapack/sygvd.hpp"
53
59
  #include "lapack/sygvx.hpp"
@@ -312,6 +318,18 @@ extern "C" void Init_tiny_linalg(void) {
312
318
  TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
313
319
  TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
314
320
  TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
321
+ TinyLinalg::SyEv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEv>::define_module_function(rb_mTinyLinalgLapack, "dsyev");
322
+ TinyLinalg::SyEv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEv>::define_module_function(rb_mTinyLinalgLapack, "ssyev");
323
+ TinyLinalg::HeEv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEv>::define_module_function(rb_mTinyLinalgLapack, "zheev");
324
+ TinyLinalg::HeEv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEv>::define_module_function(rb_mTinyLinalgLapack, "cheev");
325
+ TinyLinalg::SyEvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvd>::define_module_function(rb_mTinyLinalgLapack, "dsyevd");
326
+ TinyLinalg::SyEvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvd>::define_module_function(rb_mTinyLinalgLapack, "ssyevd");
327
+ TinyLinalg::HeEvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvd>::define_module_function(rb_mTinyLinalgLapack, "zheevd");
328
+ TinyLinalg::HeEvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvd>::define_module_function(rb_mTinyLinalgLapack, "cheevd");
329
+ TinyLinalg::SyEvr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvr>::define_module_function(rb_mTinyLinalgLapack, "dsyevr");
330
+ TinyLinalg::SyEvr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvr>::define_module_function(rb_mTinyLinalgLapack, "ssyevr");
331
+ TinyLinalg::HeEvr<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvr>::define_module_function(rb_mTinyLinalgLapack, "zheevr");
332
+ TinyLinalg::HeEvr<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvr>::define_module_function(rb_mTinyLinalgLapack, "cheevr");
315
333
  TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
316
334
  TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
317
335
  TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.1.0'
8
+ VERSION = '0.1.2'
9
9
  end
10
10
  end
@@ -39,15 +39,15 @@ module Numo
39
39
  # pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
40
40
  # # => 3.3306690738754696e-16
41
41
  #
42
- # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
43
- # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
42
+ # @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
43
+ # @param b [Numo::NArray] The n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
44
44
  # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
45
45
  # @param vals_range [Range/Array]
46
46
  # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
47
47
  # If nil, all eigenvalues and eigenvectors are computed.
48
48
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
49
49
  # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
50
- # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
50
+ # @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
51
51
  def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
52
52
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
53
53
  raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
@@ -70,8 +70,8 @@ module Numo
70
70
  vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
71
71
  else
72
72
  sy_he_gv << 'x'
73
- il = vals_range.first + 1
74
- iu = vals_range.last + 1
73
+ il = vals_range.first(1)[0] + 1
74
+ iu = vals_range.last(1)[0] + 1
75
75
  _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
76
76
  sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
77
77
  )
@@ -91,7 +91,7 @@ module Numo
91
91
  # pp (3.0 - Numo::Linalg.det(a)).abs
92
92
  # # => 1.3322676295501878e-15
93
93
  #
94
- # @param a [Numo::NArray] n-by-n square matrix.
94
+ # @param a [Numo::NArray] The n-by-n square matrix.
95
95
  # @return [Float/Complex] The determinant of `a`.
96
96
  def det(a)
97
97
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
@@ -132,7 +132,7 @@ module Numo
132
132
  # pp inv_a.dot(a).sum
133
133
  # # => 5.0
134
134
  #
135
- # @param a [Numo::NArray] n-by-n square matrix.
135
+ # @param a [Numo::NArray] The n-by-n square matrix.
136
136
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
137
137
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
138
138
  # @return [Numo::NArray] The inverse matrix of `a`.
@@ -156,7 +156,7 @@ module Numo
156
156
  end
157
157
  end
158
158
 
159
- # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
159
+ # Computes the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
160
160
  #
161
161
  # @example
162
162
  # require 'numo/tiny_linalg'
@@ -174,7 +174,7 @@ module Numo
174
174
  # # => 3.0
175
175
  #
176
176
  # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
177
- # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
177
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
178
178
  # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
179
179
  # @return [Numo::NArray] The pseudo-inverse of `a`.
180
180
  def pinv(a, driver: 'svd', rcond: nil)
@@ -186,7 +186,7 @@ module Numo
186
186
  u.dot(vh[0...rank, true]).conj.transpose
187
187
  end
188
188
 
189
- # Compute QR decomposition of a matrix.
189
+ # Computes the QR decomposition of a matrix.
190
190
  #
191
191
  # @example
192
192
  # require 'numo/tiny_linalg'
@@ -222,9 +222,8 @@ module Numo
222
222
  # - "r" -- returns only R,
223
223
  # - "economic" -- returns both Q [m, n] and R [n, n],
224
224
  # - "raw" -- returns QR and TAU (LAPACK geqrf results).
225
- # @return [Numo::NArray] if mode='r'
226
- # @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
227
- # @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
225
+ # @return [Numo::NArray] if mode='r'.
226
+ # @return [Array<Numo::NArray>] if mode='reduce' or 'economic' or 'raw'.
228
227
  def qr(a, mode: 'reduce')
229
228
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
230
229
  raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
@@ -295,7 +294,7 @@ module Numo
295
294
  Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
296
295
  end
297
296
 
298
- # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
297
+ # Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
299
298
  #
300
299
  # @example
301
300
  # require 'numo/tiny_linalg'
@@ -328,9 +327,9 @@ module Numo
328
327
  # # => 4.440892098500626e-16
329
328
  #
330
329
  # @param a [Numo::NArray] Matrix to be decomposed.
331
- # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
332
- # @param job [String] Job option ('A', 'S', or 'N').
333
- # @return [Array<Numo::NArray>] Singular values and singular vectors ([s, u, vt]).
330
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
331
+ # @param job [String] The job option ('A', 'S', or 'N').
332
+ # @return [Array<Numo::NArray>] The singular values and singular vectors ([s, u, vt]).
334
333
  def svd(a, driver: 'svd', job: 'A')
335
334
  raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
336
335
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-06 00:00:00.000000000 Z
11
+ date: 2023-08-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -52,10 +52,16 @@ files:
52
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
54
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/heev.hpp
56
+ - ext/numo/tiny_linalg/lapack/heevd.hpp
57
+ - ext/numo/tiny_linalg/lapack/heevr.hpp
55
58
  - ext/numo/tiny_linalg/lapack/hegv.hpp
56
59
  - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
60
  - ext/numo/tiny_linalg/lapack/hegvx.hpp
58
61
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
62
+ - ext/numo/tiny_linalg/lapack/syev.hpp
63
+ - ext/numo/tiny_linalg/lapack/syevd.hpp
64
+ - ext/numo/tiny_linalg/lapack/syevr.hpp
59
65
  - ext/numo/tiny_linalg/lapack/sygv.hpp
60
66
  - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
67
  - ext/numo/tiny_linalg/lapack/sygvx.hpp