numo-tiny_linalg 0.1.0 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -2
- data/ext/numo/tiny_linalg/lapack/heev.hpp +87 -0
- data/ext/numo/tiny_linalg/lapack/heevd.hpp +87 -0
- data/ext/numo/tiny_linalg/lapack/heevr.hpp +115 -0
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/syev.hpp +86 -0
- data/ext/numo/tiny_linalg/lapack/syevd.hpp +86 -0
- data/ext/numo/tiny_linalg/lapack/syevr.hpp +112 -0
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +1 -1
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +18 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +16 -17
- metadata +8 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 0aeb68c259f8d4c4dcd44c15caef6926cbe3f26f9dd99e9229721d02a939d7c1
|
4
|
+
data.tar.gz: 5465aebebc07612812b861932fcabf14b66ee0ec389603c9ffcf06694eb69ce4
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3e1ddd9a4f43d89635d0fc4e07fc3eab0e302b108648e00506deaec106f147f10132b6c28818f6ef054d5fa0ef52bb2c33287dcdeb413fc12268fb35edfde859
|
7
|
+
data.tar.gz: 78903f772d37936f73a69624b3241d43703632f08542a2f57e38a619b6dc1ac012d8885c79056459ef4fd50102d416f996ea54269b64b919ac91a3db46330d8c
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,14 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.1.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.1...v0.1.2)] - 2023-08-09
|
4
|
+
- Add dsyev, ssyev, zheev, and cheev module functions to TinyLinalg::Lapack.
|
5
|
+
- Add dsyevd, ssyevd, zheevd, and cheevd module functions to TinyLinalg::Lapack.
|
6
|
+
- Add dsyevr, ssyevr, zheevr, and cheevr module functions to TinyLinalg::Lapack.
|
7
|
+
- Fix the confirmation processs whether the array b is a square matrix or not on TinyLinalg.eigh.
|
8
|
+
|
9
|
+
## [[0.1.1](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.0...v0.1.1)] - 2023-08-07
|
10
|
+
- Fix method of getting start and end of eigenvalue range from vals_range arguement of TinyLinalg.eigh.
|
11
|
+
|
3
12
|
## [[0.1.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.4...v0.1.0)] - 2023-08-06
|
4
13
|
- Refactor codes and update documentations.
|
5
14
|
|
@@ -25,5 +34,4 @@
|
|
25
34
|
- Add inv module function to TinyLinalg.
|
26
35
|
|
27
36
|
## [0.0.1] - 2023-07-14
|
28
|
-
|
29
|
-
- Initial release
|
37
|
+
- Initial release.
|
@@ -0,0 +1,87 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeEv {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
|
5
|
+
return LAPACKE_zheev(matrix_layout, jobz, uplo, n, a, lda, w);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct CHeEv {
|
10
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
|
11
|
+
return LAPACKE_cheev(matrix_layout, jobz, uplo, n, a, lda, w);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
16
|
+
class HeEv {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
19
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heev), -1);
|
20
|
+
}
|
21
|
+
|
22
|
+
private:
|
23
|
+
struct heev_opt {
|
24
|
+
int matrix_layout;
|
25
|
+
char jobz;
|
26
|
+
char uplo;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_heev(na_loop_t* const lp) {
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
rtype* w = (rtype*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
heev_opt* opt = (heev_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
35
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
36
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
|
37
|
+
*info = static_cast<int>(i);
|
38
|
+
}
|
39
|
+
|
40
|
+
static VALUE tiny_linalg_heev(int argc, VALUE* argv, VALUE self) {
|
41
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
42
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
43
|
+
|
44
|
+
VALUE a_vnary = Qnil;
|
45
|
+
VALUE kw_args = Qnil;
|
46
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
47
|
+
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
48
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
49
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
50
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
51
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
52
|
+
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
|
53
|
+
|
54
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
56
|
+
}
|
57
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
58
|
+
a_vnary = nary_dup(a_vnary);
|
59
|
+
}
|
60
|
+
|
61
|
+
narray_t* a_nary = nullptr;
|
62
|
+
GetNArray(a_vnary, a_nary);
|
63
|
+
if (NA_NDIM(a_nary) != 2) {
|
64
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
65
|
+
return Qnil;
|
66
|
+
}
|
67
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
68
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
69
|
+
return Qnil;
|
70
|
+
}
|
71
|
+
|
72
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
73
|
+
size_t shape[1] = { n };
|
74
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
75
|
+
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
|
76
|
+
ndfunc_t ndf = { iter_heev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
77
|
+
heev_opt opt = { matrix_layout, jobz, uplo };
|
78
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
79
|
+
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
80
|
+
|
81
|
+
RB_GC_GUARD(a_vnary);
|
82
|
+
|
83
|
+
return ret;
|
84
|
+
}
|
85
|
+
};
|
86
|
+
|
87
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,87 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeEvd {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
|
5
|
+
return LAPACKE_zheevd(matrix_layout, jobz, uplo, n, a, lda, w);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct CHeEvd {
|
10
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
|
11
|
+
return LAPACKE_cheevd(matrix_layout, jobz, uplo, n, a, lda, w);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
16
|
+
class HeEvd {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
19
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevd), -1);
|
20
|
+
}
|
21
|
+
|
22
|
+
private:
|
23
|
+
struct heevd_opt {
|
24
|
+
int matrix_layout;
|
25
|
+
char jobz;
|
26
|
+
char uplo;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_heevd(na_loop_t* const lp) {
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
rtype* w = (rtype*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
heevd_opt* opt = (heevd_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
35
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
36
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
|
37
|
+
*info = static_cast<int>(i);
|
38
|
+
}
|
39
|
+
|
40
|
+
static VALUE tiny_linalg_heevd(int argc, VALUE* argv, VALUE self) {
|
41
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
42
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
43
|
+
|
44
|
+
VALUE a_vnary = Qnil;
|
45
|
+
VALUE kw_args = Qnil;
|
46
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
47
|
+
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
48
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
49
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
50
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
51
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
52
|
+
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
|
53
|
+
|
54
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
56
|
+
}
|
57
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
58
|
+
a_vnary = nary_dup(a_vnary);
|
59
|
+
}
|
60
|
+
|
61
|
+
narray_t* a_nary = nullptr;
|
62
|
+
GetNArray(a_vnary, a_nary);
|
63
|
+
if (NA_NDIM(a_nary) != 2) {
|
64
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
65
|
+
return Qnil;
|
66
|
+
}
|
67
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
68
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
69
|
+
return Qnil;
|
70
|
+
}
|
71
|
+
|
72
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
73
|
+
size_t shape[1] = { n };
|
74
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
75
|
+
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
|
76
|
+
ndfunc_t ndf = { iter_heevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
77
|
+
heevd_opt opt = { matrix_layout, jobz, uplo };
|
78
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
79
|
+
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
80
|
+
|
81
|
+
RB_GC_GUARD(a_vnary);
|
82
|
+
|
83
|
+
return ret;
|
84
|
+
}
|
85
|
+
};
|
86
|
+
|
87
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,115 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeEvr {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, lapack_complex_double* a, lapack_int lda, double vl, double vu, lapack_int il,
|
6
|
+
lapack_int iu, double abstol, lapack_int* m,
|
7
|
+
double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* isuppz) {
|
8
|
+
return LAPACKE_zheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct CHeEvr {
|
13
|
+
lapack_int call(int matrix_layout, char jobz, char range, char uplo,
|
14
|
+
lapack_int n, lapack_complex_float* a, lapack_int lda, float vl, float vu, lapack_int il,
|
15
|
+
lapack_int iu, float abstol, lapack_int* m,
|
16
|
+
float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* isuppz) {
|
17
|
+
return LAPACKE_cheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
22
|
+
class HeEvr {
|
23
|
+
public:
|
24
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
25
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevr), -1);
|
26
|
+
}
|
27
|
+
|
28
|
+
private:
|
29
|
+
struct heevr_opt {
|
30
|
+
int matrix_layout;
|
31
|
+
char jobz;
|
32
|
+
char range;
|
33
|
+
char uplo;
|
34
|
+
rtype vl;
|
35
|
+
rtype vu;
|
36
|
+
lapack_int il;
|
37
|
+
lapack_int iu;
|
38
|
+
};
|
39
|
+
|
40
|
+
static void iter_heevr(na_loop_t* const lp) {
|
41
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
42
|
+
int* m = (int*)NDL_PTR(lp, 1);
|
43
|
+
rtype* w = (rtype*)NDL_PTR(lp, 2);
|
44
|
+
dtype* z = (dtype*)NDL_PTR(lp, 3);
|
45
|
+
int* isuppz = (int*)NDL_PTR(lp, 4);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 5);
|
47
|
+
heevr_opt* opt = (heevr_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
49
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
50
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
51
|
+
const rtype abstol = 0.0;
|
52
|
+
const lapack_int i = LapackFn().call(
|
53
|
+
opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
|
54
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
|
55
|
+
*info = static_cast<int>(i);
|
56
|
+
}
|
57
|
+
|
58
|
+
static VALUE tiny_linalg_heevr(int argc, VALUE* argv, VALUE self) {
|
59
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
60
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
61
|
+
|
62
|
+
VALUE a_vnary = Qnil;
|
63
|
+
VALUE kw_args = Qnil;
|
64
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
65
|
+
ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
66
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
67
|
+
VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
68
|
+
rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
|
69
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
70
|
+
const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
|
71
|
+
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
|
72
|
+
const rtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
|
73
|
+
const rtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
74
|
+
const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
|
75
|
+
const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
76
|
+
const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
|
77
|
+
|
78
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
79
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
80
|
+
}
|
81
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
82
|
+
a_vnary = nary_dup(a_vnary);
|
83
|
+
}
|
84
|
+
|
85
|
+
narray_t* a_nary = nullptr;
|
86
|
+
GetNArray(a_vnary, a_nary);
|
87
|
+
if (NA_NDIM(a_nary) != 2) {
|
88
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
89
|
+
return Qnil;
|
90
|
+
}
|
91
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
92
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
93
|
+
return Qnil;
|
94
|
+
}
|
95
|
+
|
96
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
97
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
98
|
+
size_t w_shape[1] = { m };
|
99
|
+
size_t z_shape[2] = { n, m };
|
100
|
+
size_t isuppz_shape[1] = { 2 * m };
|
101
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
102
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
|
103
|
+
ndfunc_t ndf = { iter_heevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
|
104
|
+
heevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
|
105
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
106
|
+
VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
107
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
108
|
+
|
109
|
+
RB_GC_GUARD(a_vnary);
|
110
|
+
|
111
|
+
return ret;
|
112
|
+
}
|
113
|
+
};
|
114
|
+
|
115
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,86 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyEv {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
|
5
|
+
return LAPACKE_dsyev(matrix_layout, jobz, uplo, n, a, lda, w);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SSyEv {
|
10
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
|
11
|
+
return LAPACKE_ssyev(matrix_layout, jobz, uplo, n, a, lda, w);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
16
|
+
class SyEv {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
19
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syev), -1);
|
20
|
+
}
|
21
|
+
|
22
|
+
private:
|
23
|
+
struct syev_opt {
|
24
|
+
int matrix_layout;
|
25
|
+
char jobz;
|
26
|
+
char uplo;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_syev(na_loop_t* const lp) {
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
dtype* w = (dtype*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
syev_opt* opt = (syev_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
35
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
36
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
|
37
|
+
*info = static_cast<int>(i);
|
38
|
+
}
|
39
|
+
|
40
|
+
static VALUE tiny_linalg_syev(int argc, VALUE* argv, VALUE self) {
|
41
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
42
|
+
|
43
|
+
VALUE a_vnary = Qnil;
|
44
|
+
VALUE kw_args = Qnil;
|
45
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
46
|
+
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
47
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
48
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
49
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
50
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
51
|
+
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
|
52
|
+
|
53
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
54
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
55
|
+
}
|
56
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
57
|
+
a_vnary = nary_dup(a_vnary);
|
58
|
+
}
|
59
|
+
|
60
|
+
narray_t* a_nary = nullptr;
|
61
|
+
GetNArray(a_vnary, a_nary);
|
62
|
+
if (NA_NDIM(a_nary) != 2) {
|
63
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
64
|
+
return Qnil;
|
65
|
+
}
|
66
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
67
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
68
|
+
return Qnil;
|
69
|
+
}
|
70
|
+
|
71
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
72
|
+
size_t shape[1] = { n };
|
73
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
74
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
75
|
+
ndfunc_t ndf = { iter_syev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
76
|
+
syev_opt opt = { matrix_layout, jobz, uplo };
|
77
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
78
|
+
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
79
|
+
|
80
|
+
RB_GC_GUARD(a_vnary);
|
81
|
+
|
82
|
+
return ret;
|
83
|
+
}
|
84
|
+
};
|
85
|
+
|
86
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,86 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyEvd {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
|
5
|
+
return LAPACKE_dsyevd(matrix_layout, jobz, uplo, n, a, lda, w);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SSyEvd {
|
10
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
|
11
|
+
return LAPACKE_ssyevd(matrix_layout, jobz, uplo, n, a, lda, w);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
16
|
+
class SyEvd {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
19
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevd), -1);
|
20
|
+
}
|
21
|
+
|
22
|
+
private:
|
23
|
+
struct syevd_opt {
|
24
|
+
int matrix_layout;
|
25
|
+
char jobz;
|
26
|
+
char uplo;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_syevd(na_loop_t* const lp) {
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
dtype* w = (dtype*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
syevd_opt* opt = (syevd_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
35
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
36
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
|
37
|
+
*info = static_cast<int>(i);
|
38
|
+
}
|
39
|
+
|
40
|
+
static VALUE tiny_linalg_syevd(int argc, VALUE* argv, VALUE self) {
|
41
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
42
|
+
|
43
|
+
VALUE a_vnary = Qnil;
|
44
|
+
VALUE kw_args = Qnil;
|
45
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
46
|
+
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
47
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
48
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
49
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
50
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
51
|
+
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
|
52
|
+
|
53
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
54
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
55
|
+
}
|
56
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
57
|
+
a_vnary = nary_dup(a_vnary);
|
58
|
+
}
|
59
|
+
|
60
|
+
narray_t* a_nary = nullptr;
|
61
|
+
GetNArray(a_vnary, a_nary);
|
62
|
+
if (NA_NDIM(a_nary) != 2) {
|
63
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
64
|
+
return Qnil;
|
65
|
+
}
|
66
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
67
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
68
|
+
return Qnil;
|
69
|
+
}
|
70
|
+
|
71
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
72
|
+
size_t shape[1] = { n };
|
73
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
74
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
75
|
+
ndfunc_t ndf = { iter_syevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
76
|
+
syevd_opt opt = { matrix_layout, jobz, uplo };
|
77
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
78
|
+
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
79
|
+
|
80
|
+
RB_GC_GUARD(a_vnary);
|
81
|
+
|
82
|
+
return ret;
|
83
|
+
}
|
84
|
+
};
|
85
|
+
|
86
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,112 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyEvr {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, double* a, lapack_int lda, double vl, double vu, lapack_int il, lapack_int iu,
|
6
|
+
double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* isuppz) {
|
7
|
+
return LAPACKE_dsyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
|
8
|
+
}
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SSyEvr {
|
12
|
+
lapack_int call(int matrix_layout, char jobz, char range, char uplo,
|
13
|
+
lapack_int n, float* a, lapack_int lda, float vl, float vu, lapack_int il, lapack_int iu,
|
14
|
+
float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* isuppz) {
|
15
|
+
return LAPACKE_ssyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
|
16
|
+
}
|
17
|
+
};
|
18
|
+
|
19
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
20
|
+
class SyEvr {
|
21
|
+
public:
|
22
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
23
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevr), -1);
|
24
|
+
}
|
25
|
+
|
26
|
+
private:
|
27
|
+
struct syevr_opt {
|
28
|
+
int matrix_layout;
|
29
|
+
char jobz;
|
30
|
+
char range;
|
31
|
+
char uplo;
|
32
|
+
dtype vl;
|
33
|
+
dtype vu;
|
34
|
+
lapack_int il;
|
35
|
+
lapack_int iu;
|
36
|
+
};
|
37
|
+
|
38
|
+
static void iter_syevr(na_loop_t* const lp) {
|
39
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
40
|
+
int* m = (int*)NDL_PTR(lp, 1);
|
41
|
+
dtype* w = (dtype*)NDL_PTR(lp, 2);
|
42
|
+
dtype* z = (dtype*)NDL_PTR(lp, 3);
|
43
|
+
int* isuppz = (int*)NDL_PTR(lp, 4);
|
44
|
+
int* info = (int*)NDL_PTR(lp, 5);
|
45
|
+
syevr_opt* opt = (syevr_opt*)(lp->opt_ptr);
|
46
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
47
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
48
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
49
|
+
const dtype abstol = 0.0;
|
50
|
+
const lapack_int i = LapackFn().call(
|
51
|
+
opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
|
52
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
|
53
|
+
*info = static_cast<int>(i);
|
54
|
+
}
|
55
|
+
|
56
|
+
static VALUE tiny_linalg_syevr(int argc, VALUE* argv, VALUE self) {
|
57
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
58
|
+
|
59
|
+
VALUE a_vnary = Qnil;
|
60
|
+
VALUE kw_args = Qnil;
|
61
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
62
|
+
ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
63
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
64
|
+
VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
65
|
+
rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
|
66
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
67
|
+
const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
|
68
|
+
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
|
69
|
+
const dtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
|
70
|
+
const dtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
71
|
+
const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
|
72
|
+
const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
73
|
+
const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
|
74
|
+
|
75
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
76
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
77
|
+
}
|
78
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
79
|
+
a_vnary = nary_dup(a_vnary);
|
80
|
+
}
|
81
|
+
|
82
|
+
narray_t* a_nary = nullptr;
|
83
|
+
GetNArray(a_vnary, a_nary);
|
84
|
+
if (NA_NDIM(a_nary) != 2) {
|
85
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
86
|
+
return Qnil;
|
87
|
+
}
|
88
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
89
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
90
|
+
return Qnil;
|
91
|
+
}
|
92
|
+
|
93
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
94
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
95
|
+
size_t w_shape[1] = { m };
|
96
|
+
size_t z_shape[2] = { n, m };
|
97
|
+
size_t isuppz_shape[1] = { 2 * m };
|
98
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
99
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
|
100
|
+
ndfunc_t ndf = { iter_syevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
|
101
|
+
syevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
|
102
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
103
|
+
VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
104
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
105
|
+
|
106
|
+
RB_GC_GUARD(a_vnary);
|
107
|
+
|
108
|
+
return ret;
|
109
|
+
}
|
110
|
+
};
|
111
|
+
|
112
|
+
} // namespace TinyLinalg
|
@@ -44,10 +44,16 @@
|
|
44
44
|
#include "lapack/gesvd.hpp"
|
45
45
|
#include "lapack/getrf.hpp"
|
46
46
|
#include "lapack/getri.hpp"
|
47
|
+
#include "lapack/heev.hpp"
|
48
|
+
#include "lapack/heevd.hpp"
|
49
|
+
#include "lapack/heevr.hpp"
|
47
50
|
#include "lapack/hegv.hpp"
|
48
51
|
#include "lapack/hegvd.hpp"
|
49
52
|
#include "lapack/hegvx.hpp"
|
50
53
|
#include "lapack/orgqr.hpp"
|
54
|
+
#include "lapack/syev.hpp"
|
55
|
+
#include "lapack/syevd.hpp"
|
56
|
+
#include "lapack/syevr.hpp"
|
51
57
|
#include "lapack/sygv.hpp"
|
52
58
|
#include "lapack/sygvd.hpp"
|
53
59
|
#include "lapack/sygvx.hpp"
|
@@ -312,6 +318,18 @@ extern "C" void Init_tiny_linalg(void) {
|
|
312
318
|
TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
|
313
319
|
TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
|
314
320
|
TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
|
321
|
+
TinyLinalg::SyEv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEv>::define_module_function(rb_mTinyLinalgLapack, "dsyev");
|
322
|
+
TinyLinalg::SyEv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEv>::define_module_function(rb_mTinyLinalgLapack, "ssyev");
|
323
|
+
TinyLinalg::HeEv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEv>::define_module_function(rb_mTinyLinalgLapack, "zheev");
|
324
|
+
TinyLinalg::HeEv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEv>::define_module_function(rb_mTinyLinalgLapack, "cheev");
|
325
|
+
TinyLinalg::SyEvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvd>::define_module_function(rb_mTinyLinalgLapack, "dsyevd");
|
326
|
+
TinyLinalg::SyEvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvd>::define_module_function(rb_mTinyLinalgLapack, "ssyevd");
|
327
|
+
TinyLinalg::HeEvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvd>::define_module_function(rb_mTinyLinalgLapack, "zheevd");
|
328
|
+
TinyLinalg::HeEvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvd>::define_module_function(rb_mTinyLinalgLapack, "cheevd");
|
329
|
+
TinyLinalg::SyEvr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvr>::define_module_function(rb_mTinyLinalgLapack, "dsyevr");
|
330
|
+
TinyLinalg::SyEvr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvr>::define_module_function(rb_mTinyLinalgLapack, "ssyevr");
|
331
|
+
TinyLinalg::HeEvr<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvr>::define_module_function(rb_mTinyLinalgLapack, "zheevr");
|
332
|
+
TinyLinalg::HeEvr<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvr>::define_module_function(rb_mTinyLinalgLapack, "cheevr");
|
315
333
|
TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
|
316
334
|
TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
|
317
335
|
TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -39,15 +39,15 @@ module Numo
|
|
39
39
|
# pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
|
40
40
|
# # => 3.3306690738754696e-16
|
41
41
|
#
|
42
|
-
# @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
|
43
|
-
# @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
|
42
|
+
# @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
|
43
|
+
# @param b [Numo::NArray] The n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
|
44
44
|
# @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
|
45
45
|
# @param vals_range [Range/Array]
|
46
46
|
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
47
47
|
# If nil, all eigenvalues and eigenvectors are computed.
|
48
48
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
49
49
|
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
50
|
-
# @return [Array<Numo::NArray
|
50
|
+
# @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
|
51
51
|
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
|
52
52
|
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
53
53
|
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
@@ -70,8 +70,8 @@ module Numo
|
|
70
70
|
vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
|
71
71
|
else
|
72
72
|
sy_he_gv << 'x'
|
73
|
-
il = vals_range.first + 1
|
74
|
-
iu = vals_range.last + 1
|
73
|
+
il = vals_range.first(1)[0] + 1
|
74
|
+
iu = vals_range.last(1)[0] + 1
|
75
75
|
_a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
|
76
76
|
sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
77
77
|
)
|
@@ -91,7 +91,7 @@ module Numo
|
|
91
91
|
# pp (3.0 - Numo::Linalg.det(a)).abs
|
92
92
|
# # => 1.3322676295501878e-15
|
93
93
|
#
|
94
|
-
# @param a [Numo::NArray] n-by-n square matrix.
|
94
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
95
95
|
# @return [Float/Complex] The determinant of `a`.
|
96
96
|
def det(a)
|
97
97
|
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
@@ -132,7 +132,7 @@ module Numo
|
|
132
132
|
# pp inv_a.dot(a).sum
|
133
133
|
# # => 5.0
|
134
134
|
#
|
135
|
-
# @param a [Numo::NArray] n-by-n square matrix.
|
135
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
136
136
|
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
137
137
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
138
138
|
# @return [Numo::NArray] The inverse matrix of `a`.
|
@@ -156,7 +156,7 @@ module Numo
|
|
156
156
|
end
|
157
157
|
end
|
158
158
|
|
159
|
-
#
|
159
|
+
# Computes the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
|
160
160
|
#
|
161
161
|
# @example
|
162
162
|
# require 'numo/tiny_linalg'
|
@@ -174,7 +174,7 @@ module Numo
|
|
174
174
|
# # => 3.0
|
175
175
|
#
|
176
176
|
# @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
|
177
|
-
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
177
|
+
# @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
|
178
178
|
# @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
|
179
179
|
# @return [Numo::NArray] The pseudo-inverse of `a`.
|
180
180
|
def pinv(a, driver: 'svd', rcond: nil)
|
@@ -186,7 +186,7 @@ module Numo
|
|
186
186
|
u.dot(vh[0...rank, true]).conj.transpose
|
187
187
|
end
|
188
188
|
|
189
|
-
#
|
189
|
+
# Computes the QR decomposition of a matrix.
|
190
190
|
#
|
191
191
|
# @example
|
192
192
|
# require 'numo/tiny_linalg'
|
@@ -222,9 +222,8 @@ module Numo
|
|
222
222
|
# - "r" -- returns only R,
|
223
223
|
# - "economic" -- returns both Q [m, n] and R [n, n],
|
224
224
|
# - "raw" -- returns QR and TAU (LAPACK geqrf results).
|
225
|
-
# @return [Numo::NArray] if mode='r'
|
226
|
-
# @return [Array<Numo::NArray
|
227
|
-
# @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
|
225
|
+
# @return [Numo::NArray] if mode='r'.
|
226
|
+
# @return [Array<Numo::NArray>] if mode='reduce' or 'economic' or 'raw'.
|
228
227
|
def qr(a, mode: 'reduce')
|
229
228
|
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
230
229
|
raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
|
@@ -295,7 +294,7 @@ module Numo
|
|
295
294
|
Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
|
296
295
|
end
|
297
296
|
|
298
|
-
#
|
297
|
+
# Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
|
299
298
|
#
|
300
299
|
# @example
|
301
300
|
# require 'numo/tiny_linalg'
|
@@ -328,9 +327,9 @@ module Numo
|
|
328
327
|
# # => 4.440892098500626e-16
|
329
328
|
#
|
330
329
|
# @param a [Numo::NArray] Matrix to be decomposed.
|
331
|
-
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
332
|
-
# @param job [String]
|
333
|
-
# @return [Array<Numo::NArray>]
|
330
|
+
# @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
|
331
|
+
# @param job [String] The job option ('A', 'S', or 'N').
|
332
|
+
# @return [Array<Numo::NArray>] The singular values and singular vectors ([s, u, vt]).
|
334
333
|
def svd(a, driver: 'svd', job: 'A')
|
335
334
|
raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
|
336
335
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-08-
|
11
|
+
date: 2023-08-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -52,10 +52,16 @@ files:
|
|
52
52
|
- ext/numo/tiny_linalg/lapack/gesvd.hpp
|
53
53
|
- ext/numo/tiny_linalg/lapack/getrf.hpp
|
54
54
|
- ext/numo/tiny_linalg/lapack/getri.hpp
|
55
|
+
- ext/numo/tiny_linalg/lapack/heev.hpp
|
56
|
+
- ext/numo/tiny_linalg/lapack/heevd.hpp
|
57
|
+
- ext/numo/tiny_linalg/lapack/heevr.hpp
|
55
58
|
- ext/numo/tiny_linalg/lapack/hegv.hpp
|
56
59
|
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
57
60
|
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
58
61
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
62
|
+
- ext/numo/tiny_linalg/lapack/syev.hpp
|
63
|
+
- ext/numo/tiny_linalg/lapack/syevd.hpp
|
64
|
+
- ext/numo/tiny_linalg/lapack/syevr.hpp
|
59
65
|
- ext/numo/tiny_linalg/lapack/sygv.hpp
|
60
66
|
- ext/numo/tiny_linalg/lapack/sygvd.hpp
|
61
67
|
- ext/numo/tiny_linalg/lapack/sygvx.hpp
|