numo-tiny_linalg 0.1.0 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7a298eaec7ee7338e4856ac50b22a06bf774073456a01f5ffb2a405cab632f7a
4
- data.tar.gz: a44ae42f723ee7c9a1af6c2b1a01e6257f1fc417a5cd2a22cb6be51ac0589c1c
3
+ metadata.gz: 0aeb68c259f8d4c4dcd44c15caef6926cbe3f26f9dd99e9229721d02a939d7c1
4
+ data.tar.gz: 5465aebebc07612812b861932fcabf14b66ee0ec389603c9ffcf06694eb69ce4
5
5
  SHA512:
6
- metadata.gz: fce33aa331257bc37d7a972e09a0cf1d9328046459c9de6b9feaf8341d0d4545d202b90d6bb0e5e6ee097b4c9c58a4bc76fc0224fd87dc10818c8f566259ab7e
7
- data.tar.gz: 6da516469d0fdfb47884fea2d9fc2483b5ea066063ce1681afc8e9319d1cfcc6863f59e96a4ce1d7cb22ceb159a8b2f8f6e41478958ea245a5eaaf209753ea2c
6
+ metadata.gz: 3e1ddd9a4f43d89635d0fc4e07fc3eab0e302b108648e00506deaec106f147f10132b6c28818f6ef054d5fa0ef52bb2c33287dcdeb413fc12268fb35edfde859
7
+ data.tar.gz: 78903f772d37936f73a69624b3241d43703632f08542a2f57e38a619b6dc1ac012d8885c79056459ef4fd50102d416f996ea54269b64b919ac91a3db46330d8c
data/CHANGELOG.md CHANGED
@@ -1,5 +1,14 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.1.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.1...v0.1.2)] - 2023-08-09
4
+ - Add dsyev, ssyev, zheev, and cheev module functions to TinyLinalg::Lapack.
5
+ - Add dsyevd, ssyevd, zheevd, and cheevd module functions to TinyLinalg::Lapack.
6
+ - Add dsyevr, ssyevr, zheevr, and cheevr module functions to TinyLinalg::Lapack.
7
+ - Fix the confirmation processs whether the array b is a square matrix or not on TinyLinalg.eigh.
8
+
9
+ ## [[0.1.1](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.0...v0.1.1)] - 2023-08-07
10
+ - Fix method of getting start and end of eigenvalue range from vals_range arguement of TinyLinalg.eigh.
11
+
3
12
  ## [[0.1.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.4...v0.1.0)] - 2023-08-06
4
13
  - Refactor codes and update documentations.
5
14
 
@@ -25,5 +34,4 @@
25
34
  - Add inv module function to TinyLinalg.
26
35
 
27
36
  ## [0.0.1] - 2023-07-14
28
-
29
- - Initial release
37
+ - Initial release.
@@ -0,0 +1,87 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEv {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_zheev(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct CHeEv {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_cheev(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
16
+ class HeEv {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heev), -1);
20
+ }
21
+
22
+ private:
23
+ struct heev_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_heev(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ rtype* w = (rtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ heev_opt* opt = (heev_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_heev(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
43
+
44
+ VALUE a_vnary = Qnil;
45
+ VALUE kw_args = Qnil;
46
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
47
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
48
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
49
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
50
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
51
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
52
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+
61
+ narray_t* a_nary = nullptr;
62
+ GetNArray(a_vnary, a_nary);
63
+ if (NA_NDIM(a_nary) != 2) {
64
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
65
+ return Qnil;
66
+ }
67
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
68
+ rb_raise(rb_eArgError, "input array a must be square");
69
+ return Qnil;
70
+ }
71
+
72
+ const size_t n = NA_SHAPE(a_nary)[1];
73
+ size_t shape[1] = { n };
74
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
75
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
76
+ ndfunc_t ndf = { iter_heev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
77
+ heev_opt opt = { matrix_layout, jobz, uplo };
78
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
79
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
80
+
81
+ RB_GC_GUARD(a_vnary);
82
+
83
+ return ret;
84
+ }
85
+ };
86
+
87
+ } // namespace TinyLinalg
@@ -0,0 +1,87 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEvd {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_zheevd(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct CHeEvd {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_cheevd(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
16
+ class HeEvd {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevd), -1);
20
+ }
21
+
22
+ private:
23
+ struct heevd_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_heevd(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ rtype* w = (rtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ heevd_opt* opt = (heevd_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_heevd(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
43
+
44
+ VALUE a_vnary = Qnil;
45
+ VALUE kw_args = Qnil;
46
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
47
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
48
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
49
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
50
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
51
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
52
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+
61
+ narray_t* a_nary = nullptr;
62
+ GetNArray(a_vnary, a_nary);
63
+ if (NA_NDIM(a_nary) != 2) {
64
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
65
+ return Qnil;
66
+ }
67
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
68
+ rb_raise(rb_eArgError, "input array a must be square");
69
+ return Qnil;
70
+ }
71
+
72
+ const size_t n = NA_SHAPE(a_nary)[1];
73
+ size_t shape[1] = { n };
74
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
75
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
76
+ ndfunc_t ndf = { iter_heevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
77
+ heevd_opt opt = { matrix_layout, jobz, uplo };
78
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
79
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
80
+
81
+ RB_GC_GUARD(a_vnary);
82
+
83
+ return ret;
84
+ }
85
+ };
86
+
87
+ } // namespace TinyLinalg
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeEvr {
4
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
5
+ lapack_int n, lapack_complex_double* a, lapack_int lda, double vl, double vu, lapack_int il,
6
+ lapack_int iu, double abstol, lapack_int* m,
7
+ double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* isuppz) {
8
+ return LAPACKE_zheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
9
+ }
10
+ };
11
+
12
+ struct CHeEvr {
13
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
14
+ lapack_int n, lapack_complex_float* a, lapack_int lda, float vl, float vu, lapack_int il,
15
+ lapack_int iu, float abstol, lapack_int* m,
16
+ float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* isuppz) {
17
+ return LAPACKE_cheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
22
+ class HeEvr {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevr), -1);
26
+ }
27
+
28
+ private:
29
+ struct heevr_opt {
30
+ int matrix_layout;
31
+ char jobz;
32
+ char range;
33
+ char uplo;
34
+ rtype vl;
35
+ rtype vu;
36
+ lapack_int il;
37
+ lapack_int iu;
38
+ };
39
+
40
+ static void iter_heevr(na_loop_t* const lp) {
41
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
42
+ int* m = (int*)NDL_PTR(lp, 1);
43
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
44
+ dtype* z = (dtype*)NDL_PTR(lp, 3);
45
+ int* isuppz = (int*)NDL_PTR(lp, 4);
46
+ int* info = (int*)NDL_PTR(lp, 5);
47
+ heevr_opt* opt = (heevr_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
51
+ const rtype abstol = 0.0;
52
+ const lapack_int i = LapackFn().call(
53
+ opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
54
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
55
+ *info = static_cast<int>(i);
56
+ }
57
+
58
+ static VALUE tiny_linalg_heevr(int argc, VALUE* argv, VALUE self) {
59
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
60
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
61
+
62
+ VALUE a_vnary = Qnil;
63
+ VALUE kw_args = Qnil;
64
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
65
+ ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
66
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
67
+ VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
68
+ rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
69
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
70
+ const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
71
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
72
+ const rtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
73
+ const rtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
74
+ const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
75
+ const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
76
+ const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
77
+
78
+ if (CLASS_OF(a_vnary) != nary_dtype) {
79
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
80
+ }
81
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
82
+ a_vnary = nary_dup(a_vnary);
83
+ }
84
+
85
+ narray_t* a_nary = nullptr;
86
+ GetNArray(a_vnary, a_nary);
87
+ if (NA_NDIM(a_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array a must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t m = range != 'I' ? n : iu - il + 1;
98
+ size_t w_shape[1] = { m };
99
+ size_t z_shape[2] = { n, m };
100
+ size_t isuppz_shape[1] = { 2 * m };
101
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
102
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
103
+ ndfunc_t ndf = { iter_heevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
104
+ heevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
105
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
106
+ VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
107
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
108
+
109
+ RB_GC_GUARD(a_vnary);
110
+
111
+ return ret;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -92,7 +92,7 @@ private:
92
92
  return Qnil;
93
93
  }
94
94
  narray_t* b_nary = nullptr;
95
- GetNArray(a_vnary, b_nary);
95
+ GetNArray(b_vnary, b_nary);
96
96
  if (NA_NDIM(b_nary) != 2) {
97
97
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
98
  return Qnil;
@@ -92,7 +92,7 @@ private:
92
92
  return Qnil;
93
93
  }
94
94
  narray_t* b_nary = nullptr;
95
- GetNArray(a_vnary, b_nary);
95
+ GetNArray(b_vnary, b_nary);
96
96
  if (NA_NDIM(b_nary) != 2) {
97
97
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
98
  return Qnil;
@@ -104,7 +104,7 @@ private:
104
104
  return Qnil;
105
105
  }
106
106
  narray_t* b_nary = nullptr;
107
- GetNArray(a_vnary, b_nary);
107
+ GetNArray(b_vnary, b_nary);
108
108
  if (NA_NDIM(b_nary) != 2) {
109
109
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
110
110
  return Qnil;
@@ -0,0 +1,86 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEv {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_dsyev(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct SSyEv {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_ssyev(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, typename dtype, class LapackFn>
16
+ class SyEv {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syev), -1);
20
+ }
21
+
22
+ private:
23
+ struct syev_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_syev(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* w = (dtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ syev_opt* opt = (syev_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_syev(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+
43
+ VALUE a_vnary = Qnil;
44
+ VALUE kw_args = Qnil;
45
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
46
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
47
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
48
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
49
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
50
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
51
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
52
+
53
+ if (CLASS_OF(a_vnary) != nary_dtype) {
54
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
55
+ }
56
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
57
+ a_vnary = nary_dup(a_vnary);
58
+ }
59
+
60
+ narray_t* a_nary = nullptr;
61
+ GetNArray(a_vnary, a_nary);
62
+ if (NA_NDIM(a_nary) != 2) {
63
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
64
+ return Qnil;
65
+ }
66
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
67
+ rb_raise(rb_eArgError, "input array a must be square");
68
+ return Qnil;
69
+ }
70
+
71
+ const size_t n = NA_SHAPE(a_nary)[1];
72
+ size_t shape[1] = { n };
73
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
74
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
75
+ ndfunc_t ndf = { iter_syev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
76
+ syev_opt opt = { matrix_layout, jobz, uplo };
77
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
78
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
79
+
80
+ RB_GC_GUARD(a_vnary);
81
+
82
+ return ret;
83
+ }
84
+ };
85
+
86
+ } // namespace TinyLinalg
@@ -0,0 +1,86 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEvd {
4
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
5
+ return LAPACKE_dsyevd(matrix_layout, jobz, uplo, n, a, lda, w);
6
+ }
7
+ };
8
+
9
+ struct SSyEvd {
10
+ lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
11
+ return LAPACKE_ssyevd(matrix_layout, jobz, uplo, n, a, lda, w);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, typename dtype, class LapackFn>
16
+ class SyEvd {
17
+ public:
18
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
19
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevd), -1);
20
+ }
21
+
22
+ private:
23
+ struct syevd_opt {
24
+ int matrix_layout;
25
+ char jobz;
26
+ char uplo;
27
+ };
28
+
29
+ static void iter_syevd(na_loop_t* const lp) {
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* w = (dtype*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ syevd_opt* opt = (syevd_opt*)(lp->opt_ptr);
34
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
35
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
36
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
37
+ *info = static_cast<int>(i);
38
+ }
39
+
40
+ static VALUE tiny_linalg_syevd(int argc, VALUE* argv, VALUE self) {
41
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
42
+
43
+ VALUE a_vnary = Qnil;
44
+ VALUE kw_args = Qnil;
45
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
46
+ ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
47
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
48
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
49
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
50
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
51
+ const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
52
+
53
+ if (CLASS_OF(a_vnary) != nary_dtype) {
54
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
55
+ }
56
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
57
+ a_vnary = nary_dup(a_vnary);
58
+ }
59
+
60
+ narray_t* a_nary = nullptr;
61
+ GetNArray(a_vnary, a_nary);
62
+ if (NA_NDIM(a_nary) != 2) {
63
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
64
+ return Qnil;
65
+ }
66
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
67
+ rb_raise(rb_eArgError, "input array a must be square");
68
+ return Qnil;
69
+ }
70
+
71
+ const size_t n = NA_SHAPE(a_nary)[1];
72
+ size_t shape[1] = { n };
73
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
74
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
75
+ ndfunc_t ndf = { iter_syevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
76
+ syevd_opt opt = { matrix_layout, jobz, uplo };
77
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
78
+ VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
79
+
80
+ RB_GC_GUARD(a_vnary);
81
+
82
+ return ret;
83
+ }
84
+ };
85
+
86
+ } // namespace TinyLinalg
@@ -0,0 +1,112 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyEvr {
4
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
5
+ lapack_int n, double* a, lapack_int lda, double vl, double vu, lapack_int il, lapack_int iu,
6
+ double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* isuppz) {
7
+ return LAPACKE_dsyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
8
+ }
9
+ };
10
+
11
+ struct SSyEvr {
12
+ lapack_int call(int matrix_layout, char jobz, char range, char uplo,
13
+ lapack_int n, float* a, lapack_int lda, float vl, float vu, lapack_int il, lapack_int iu,
14
+ float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* isuppz) {
15
+ return LAPACKE_ssyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
+ class SyEvr {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevr), -1);
24
+ }
25
+
26
+ private:
27
+ struct syevr_opt {
28
+ int matrix_layout;
29
+ char jobz;
30
+ char range;
31
+ char uplo;
32
+ dtype vl;
33
+ dtype vu;
34
+ lapack_int il;
35
+ lapack_int iu;
36
+ };
37
+
38
+ static void iter_syevr(na_loop_t* const lp) {
39
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
40
+ int* m = (int*)NDL_PTR(lp, 1);
41
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
42
+ dtype* z = (dtype*)NDL_PTR(lp, 3);
43
+ int* isuppz = (int*)NDL_PTR(lp, 4);
44
+ int* info = (int*)NDL_PTR(lp, 5);
45
+ syevr_opt* opt = (syevr_opt*)(lp->opt_ptr);
46
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
47
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
48
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
49
+ const dtype abstol = 0.0;
50
+ const lapack_int i = LapackFn().call(
51
+ opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
52
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
53
+ *info = static_cast<int>(i);
54
+ }
55
+
56
+ static VALUE tiny_linalg_syevr(int argc, VALUE* argv, VALUE self) {
57
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE kw_args = Qnil;
61
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
62
+ ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
63
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
64
+ VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
66
+ const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
67
+ const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const dtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
70
+ const dtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
71
+ const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
72
+ const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
73
+ const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
74
+
75
+ if (CLASS_OF(a_vnary) != nary_dtype) {
76
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
77
+ }
78
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
79
+ a_vnary = nary_dup(a_vnary);
80
+ }
81
+
82
+ narray_t* a_nary = nullptr;
83
+ GetNArray(a_vnary, a_nary);
84
+ if (NA_NDIM(a_nary) != 2) {
85
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
86
+ return Qnil;
87
+ }
88
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
89
+ rb_raise(rb_eArgError, "input array a must be square");
90
+ return Qnil;
91
+ }
92
+
93
+ const size_t n = NA_SHAPE(a_nary)[1];
94
+ size_t m = range != 'I' ? n : iu - il + 1;
95
+ size_t w_shape[1] = { m };
96
+ size_t z_shape[2] = { n, m };
97
+ size_t isuppz_shape[1] = { 2 * m };
98
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_syevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
101
+ syevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
103
+ VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
104
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
105
+
106
+ RB_GC_GUARD(a_vnary);
107
+
108
+ return ret;
109
+ }
110
+ };
111
+
112
+ } // namespace TinyLinalg
@@ -83,7 +83,7 @@ private:
83
83
  return Qnil;
84
84
  }
85
85
  narray_t* b_nary = nullptr;
86
- GetNArray(a_vnary, b_nary);
86
+ GetNArray(b_vnary, b_nary);
87
87
  if (NA_NDIM(b_nary) != 2) {
88
88
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
89
  return Qnil;
@@ -83,7 +83,7 @@ private:
83
83
  return Qnil;
84
84
  }
85
85
  narray_t* b_nary = nullptr;
86
- GetNArray(a_vnary, b_nary);
86
+ GetNArray(b_vnary, b_nary);
87
87
  if (NA_NDIM(b_nary) != 2) {
88
88
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
89
  return Qnil;
@@ -103,7 +103,7 @@ private:
103
103
  return Qnil;
104
104
  }
105
105
  narray_t* b_nary = nullptr;
106
- GetNArray(a_vnary, b_nary);
106
+ GetNArray(b_vnary, b_nary);
107
107
  if (NA_NDIM(b_nary) != 2) {
108
108
  rb_raise(rb_eArgError, "input array b must be 2-dimensional");
109
109
  return Qnil;
@@ -44,10 +44,16 @@
44
44
  #include "lapack/gesvd.hpp"
45
45
  #include "lapack/getrf.hpp"
46
46
  #include "lapack/getri.hpp"
47
+ #include "lapack/heev.hpp"
48
+ #include "lapack/heevd.hpp"
49
+ #include "lapack/heevr.hpp"
47
50
  #include "lapack/hegv.hpp"
48
51
  #include "lapack/hegvd.hpp"
49
52
  #include "lapack/hegvx.hpp"
50
53
  #include "lapack/orgqr.hpp"
54
+ #include "lapack/syev.hpp"
55
+ #include "lapack/syevd.hpp"
56
+ #include "lapack/syevr.hpp"
51
57
  #include "lapack/sygv.hpp"
52
58
  #include "lapack/sygvd.hpp"
53
59
  #include "lapack/sygvx.hpp"
@@ -312,6 +318,18 @@ extern "C" void Init_tiny_linalg(void) {
312
318
  TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
313
319
  TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
314
320
  TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
321
+ TinyLinalg::SyEv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEv>::define_module_function(rb_mTinyLinalgLapack, "dsyev");
322
+ TinyLinalg::SyEv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEv>::define_module_function(rb_mTinyLinalgLapack, "ssyev");
323
+ TinyLinalg::HeEv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEv>::define_module_function(rb_mTinyLinalgLapack, "zheev");
324
+ TinyLinalg::HeEv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEv>::define_module_function(rb_mTinyLinalgLapack, "cheev");
325
+ TinyLinalg::SyEvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvd>::define_module_function(rb_mTinyLinalgLapack, "dsyevd");
326
+ TinyLinalg::SyEvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvd>::define_module_function(rb_mTinyLinalgLapack, "ssyevd");
327
+ TinyLinalg::HeEvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvd>::define_module_function(rb_mTinyLinalgLapack, "zheevd");
328
+ TinyLinalg::HeEvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvd>::define_module_function(rb_mTinyLinalgLapack, "cheevd");
329
+ TinyLinalg::SyEvr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvr>::define_module_function(rb_mTinyLinalgLapack, "dsyevr");
330
+ TinyLinalg::SyEvr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvr>::define_module_function(rb_mTinyLinalgLapack, "ssyevr");
331
+ TinyLinalg::HeEvr<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvr>::define_module_function(rb_mTinyLinalgLapack, "zheevr");
332
+ TinyLinalg::HeEvr<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvr>::define_module_function(rb_mTinyLinalgLapack, "cheevr");
315
333
  TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
316
334
  TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
317
335
  TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.1.0'
8
+ VERSION = '0.1.2'
9
9
  end
10
10
  end
@@ -39,15 +39,15 @@ module Numo
39
39
  # pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
40
40
  # # => 3.3306690738754696e-16
41
41
  #
42
- # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
43
- # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
42
+ # @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
43
+ # @param b [Numo::NArray] The n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
44
44
  # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
45
45
  # @param vals_range [Range/Array]
46
46
  # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
47
47
  # If nil, all eigenvalues and eigenvectors are computed.
48
48
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
49
49
  # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
50
- # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
50
+ # @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
51
51
  def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
52
52
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
53
53
  raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
@@ -70,8 +70,8 @@ module Numo
70
70
  vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
71
71
  else
72
72
  sy_he_gv << 'x'
73
- il = vals_range.first + 1
74
- iu = vals_range.last + 1
73
+ il = vals_range.first(1)[0] + 1
74
+ iu = vals_range.last(1)[0] + 1
75
75
  _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
76
76
  sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
77
77
  )
@@ -91,7 +91,7 @@ module Numo
91
91
  # pp (3.0 - Numo::Linalg.det(a)).abs
92
92
  # # => 1.3322676295501878e-15
93
93
  #
94
- # @param a [Numo::NArray] n-by-n square matrix.
94
+ # @param a [Numo::NArray] The n-by-n square matrix.
95
95
  # @return [Float/Complex] The determinant of `a`.
96
96
  def det(a)
97
97
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
@@ -132,7 +132,7 @@ module Numo
132
132
  # pp inv_a.dot(a).sum
133
133
  # # => 5.0
134
134
  #
135
- # @param a [Numo::NArray] n-by-n square matrix.
135
+ # @param a [Numo::NArray] The n-by-n square matrix.
136
136
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
137
137
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
138
138
  # @return [Numo::NArray] The inverse matrix of `a`.
@@ -156,7 +156,7 @@ module Numo
156
156
  end
157
157
  end
158
158
 
159
- # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
159
+ # Computes the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
160
160
  #
161
161
  # @example
162
162
  # require 'numo/tiny_linalg'
@@ -174,7 +174,7 @@ module Numo
174
174
  # # => 3.0
175
175
  #
176
176
  # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
177
- # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
177
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
178
178
  # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
179
179
  # @return [Numo::NArray] The pseudo-inverse of `a`.
180
180
  def pinv(a, driver: 'svd', rcond: nil)
@@ -186,7 +186,7 @@ module Numo
186
186
  u.dot(vh[0...rank, true]).conj.transpose
187
187
  end
188
188
 
189
- # Compute QR decomposition of a matrix.
189
+ # Computes the QR decomposition of a matrix.
190
190
  #
191
191
  # @example
192
192
  # require 'numo/tiny_linalg'
@@ -222,9 +222,8 @@ module Numo
222
222
  # - "r" -- returns only R,
223
223
  # - "economic" -- returns both Q [m, n] and R [n, n],
224
224
  # - "raw" -- returns QR and TAU (LAPACK geqrf results).
225
- # @return [Numo::NArray] if mode='r'
226
- # @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
227
- # @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
225
+ # @return [Numo::NArray] if mode='r'.
226
+ # @return [Array<Numo::NArray>] if mode='reduce' or 'economic' or 'raw'.
228
227
  def qr(a, mode: 'reduce')
229
228
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
230
229
  raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
@@ -295,7 +294,7 @@ module Numo
295
294
  Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
296
295
  end
297
296
 
298
- # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
297
+ # Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
299
298
  #
300
299
  # @example
301
300
  # require 'numo/tiny_linalg'
@@ -328,9 +327,9 @@ module Numo
328
327
  # # => 4.440892098500626e-16
329
328
  #
330
329
  # @param a [Numo::NArray] Matrix to be decomposed.
331
- # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
332
- # @param job [String] Job option ('A', 'S', or 'N').
333
- # @return [Array<Numo::NArray>] Singular values and singular vectors ([s, u, vt]).
330
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
331
+ # @param job [String] The job option ('A', 'S', or 'N').
332
+ # @return [Array<Numo::NArray>] The singular values and singular vectors ([s, u, vt]).
334
333
  def svd(a, driver: 'svd', job: 'A')
335
334
  raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
336
335
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-06 00:00:00.000000000 Z
11
+ date: 2023-08-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -52,10 +52,16 @@ files:
52
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
54
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/heev.hpp
56
+ - ext/numo/tiny_linalg/lapack/heevd.hpp
57
+ - ext/numo/tiny_linalg/lapack/heevr.hpp
55
58
  - ext/numo/tiny_linalg/lapack/hegv.hpp
56
59
  - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
60
  - ext/numo/tiny_linalg/lapack/hegvx.hpp
58
61
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
62
+ - ext/numo/tiny_linalg/lapack/syev.hpp
63
+ - ext/numo/tiny_linalg/lapack/syevd.hpp
64
+ - ext/numo/tiny_linalg/lapack/syevr.hpp
59
65
  - ext/numo/tiny_linalg/lapack/sygv.hpp
60
66
  - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
67
  - ext/numo/tiny_linalg/lapack/sygvx.hpp