numo-tiny_linalg 0.1.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -2
- data/README.md +2 -0
- data/ext/numo/tiny_linalg/lapack/heev.hpp +87 -0
- data/ext/numo/tiny_linalg/lapack/heevd.hpp +87 -0
- data/ext/numo/tiny_linalg/lapack/heevr.hpp +133 -0
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +19 -1
- data/ext/numo/tiny_linalg/lapack/syev.hpp +86 -0
- data/ext/numo/tiny_linalg/lapack/syevd.hpp +86 -0
- data/ext/numo/tiny_linalg/lapack/syevr.hpp +130 -0
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +1 -1
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +19 -1
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +18 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +33 -18
- metadata +8 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 911223ba40879d779bcc9bc77b840fb7b5e41ed68d6ccf7d26816913c42aed73
|
4
|
+
data.tar.gz: 9dce9c353da5466e1104c604c2749369131cb6640e14ddd33c697cee5d973563
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9d666bdd04f48132005a9aee455047c80d83184815fac5c688ec593412ff8e670f1f31ba3376b1ee2985b8e829680590c1dad6a0ec3bf5366e584a82b523d8e4
|
7
|
+
data.tar.gz: 22027c54b56bad7214060ba7d5939b986be660a75afd65d74de7fdae095a6699ceea35422eee56831b621f2132e5bce921cb6b70bc9709e77e3efdf6cfc8a098
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,15 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.2.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.2...v0.2.0)] - 2023-08-11
|
4
|
+
**Breaking change**
|
5
|
+
- Change LAPACK function to call when array b is not given to TinyLinalg.eigh method.
|
6
|
+
|
7
|
+
## [[0.1.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.1...v0.1.2)] - 2023-08-09
|
8
|
+
- Add dsyev, ssyev, zheev, and cheev module functions to TinyLinalg::Lapack.
|
9
|
+
- Add dsyevd, ssyevd, zheevd, and cheevd module functions to TinyLinalg::Lapack.
|
10
|
+
- Add dsyevr, ssyevr, zheevr, and cheevr module functions to TinyLinalg::Lapack.
|
11
|
+
- Fix the confirmation processs whether the array b is a square matrix or not on TinyLinalg.eigh.
|
12
|
+
|
3
13
|
## [[0.1.1](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.0...v0.1.1)] - 2023-08-07
|
4
14
|
- Fix method of getting start and end of eigenvalue range from vals_range arguement of TinyLinalg.eigh.
|
5
15
|
|
@@ -28,5 +38,4 @@
|
|
28
38
|
- Add inv module function to TinyLinalg.
|
29
39
|
|
30
40
|
## [0.0.1] - 2023-07-14
|
31
|
-
|
32
|
-
- Initial release
|
41
|
+
- Initial release.
|
data/README.md
CHANGED
@@ -8,6 +8,8 @@
|
|
8
8
|
Numo::TinyLinalg is a subset library from [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) consisting only of methods used in Machine Learning algorithms.
|
9
9
|
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, and svd.
|
10
10
|
|
11
|
+
Note that the version numbering rule of Numo::TinyLinalg is not compatible with that of Numo::Linalg.
|
12
|
+
|
11
13
|
## Installation
|
12
14
|
Unlike Numo::Linalg, Numo::TinyLinalg only supports OpenBLAS as a backend library for BLAS and LAPACK.
|
13
15
|
|
@@ -0,0 +1,87 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeEv {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
|
5
|
+
return LAPACKE_zheev(matrix_layout, jobz, uplo, n, a, lda, w);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct CHeEv {
|
10
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
|
11
|
+
return LAPACKE_cheev(matrix_layout, jobz, uplo, n, a, lda, w);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
16
|
+
class HeEv {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
19
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heev), -1);
|
20
|
+
}
|
21
|
+
|
22
|
+
private:
|
23
|
+
struct heev_opt {
|
24
|
+
int matrix_layout;
|
25
|
+
char jobz;
|
26
|
+
char uplo;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_heev(na_loop_t* const lp) {
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
rtype* w = (rtype*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
heev_opt* opt = (heev_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
35
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
36
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
|
37
|
+
*info = static_cast<int>(i);
|
38
|
+
}
|
39
|
+
|
40
|
+
static VALUE tiny_linalg_heev(int argc, VALUE* argv, VALUE self) {
|
41
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
42
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
43
|
+
|
44
|
+
VALUE a_vnary = Qnil;
|
45
|
+
VALUE kw_args = Qnil;
|
46
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
47
|
+
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
48
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
49
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
50
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
51
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
52
|
+
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
|
53
|
+
|
54
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
56
|
+
}
|
57
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
58
|
+
a_vnary = nary_dup(a_vnary);
|
59
|
+
}
|
60
|
+
|
61
|
+
narray_t* a_nary = nullptr;
|
62
|
+
GetNArray(a_vnary, a_nary);
|
63
|
+
if (NA_NDIM(a_nary) != 2) {
|
64
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
65
|
+
return Qnil;
|
66
|
+
}
|
67
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
68
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
69
|
+
return Qnil;
|
70
|
+
}
|
71
|
+
|
72
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
73
|
+
size_t shape[1] = { n };
|
74
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
75
|
+
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
|
76
|
+
ndfunc_t ndf = { iter_heev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
77
|
+
heev_opt opt = { matrix_layout, jobz, uplo };
|
78
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
79
|
+
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
80
|
+
|
81
|
+
RB_GC_GUARD(a_vnary);
|
82
|
+
|
83
|
+
return ret;
|
84
|
+
}
|
85
|
+
};
|
86
|
+
|
87
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,87 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeEvd {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
|
5
|
+
return LAPACKE_zheevd(matrix_layout, jobz, uplo, n, a, lda, w);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct CHeEvd {
|
10
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
|
11
|
+
return LAPACKE_cheevd(matrix_layout, jobz, uplo, n, a, lda, w);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
16
|
+
class HeEvd {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
19
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevd), -1);
|
20
|
+
}
|
21
|
+
|
22
|
+
private:
|
23
|
+
struct heevd_opt {
|
24
|
+
int matrix_layout;
|
25
|
+
char jobz;
|
26
|
+
char uplo;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_heevd(na_loop_t* const lp) {
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
rtype* w = (rtype*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
heevd_opt* opt = (heevd_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
35
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
36
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
|
37
|
+
*info = static_cast<int>(i);
|
38
|
+
}
|
39
|
+
|
40
|
+
static VALUE tiny_linalg_heevd(int argc, VALUE* argv, VALUE self) {
|
41
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
42
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
43
|
+
|
44
|
+
VALUE a_vnary = Qnil;
|
45
|
+
VALUE kw_args = Qnil;
|
46
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
47
|
+
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
48
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
49
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
50
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
51
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
52
|
+
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
|
53
|
+
|
54
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
56
|
+
}
|
57
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
58
|
+
a_vnary = nary_dup(a_vnary);
|
59
|
+
}
|
60
|
+
|
61
|
+
narray_t* a_nary = nullptr;
|
62
|
+
GetNArray(a_vnary, a_nary);
|
63
|
+
if (NA_NDIM(a_nary) != 2) {
|
64
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
65
|
+
return Qnil;
|
66
|
+
}
|
67
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
68
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
69
|
+
return Qnil;
|
70
|
+
}
|
71
|
+
|
72
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
73
|
+
size_t shape[1] = { n };
|
74
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
75
|
+
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
|
76
|
+
ndfunc_t ndf = { iter_heevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
77
|
+
heevd_opt opt = { matrix_layout, jobz, uplo };
|
78
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
79
|
+
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
80
|
+
|
81
|
+
RB_GC_GUARD(a_vnary);
|
82
|
+
|
83
|
+
return ret;
|
84
|
+
}
|
85
|
+
};
|
86
|
+
|
87
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,133 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZHeEvr {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, lapack_complex_double* a, lapack_int lda, double vl, double vu, lapack_int il,
|
6
|
+
lapack_int iu, double abstol, lapack_int* m,
|
7
|
+
double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* isuppz) {
|
8
|
+
return LAPACKE_zheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct CHeEvr {
|
13
|
+
lapack_int call(int matrix_layout, char jobz, char range, char uplo,
|
14
|
+
lapack_int n, lapack_complex_float* a, lapack_int lda, float vl, float vu, lapack_int il,
|
15
|
+
lapack_int iu, float abstol, lapack_int* m,
|
16
|
+
float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* isuppz) {
|
17
|
+
return LAPACKE_cheevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
22
|
+
class HeEvr {
|
23
|
+
public:
|
24
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
25
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevr), -1);
|
26
|
+
}
|
27
|
+
|
28
|
+
private:
|
29
|
+
struct heevr_opt {
|
30
|
+
int matrix_layout;
|
31
|
+
char jobz;
|
32
|
+
char range;
|
33
|
+
char uplo;
|
34
|
+
rtype vl;
|
35
|
+
rtype vu;
|
36
|
+
lapack_int il;
|
37
|
+
lapack_int iu;
|
38
|
+
};
|
39
|
+
|
40
|
+
static void iter_heevr(na_loop_t* const lp) {
|
41
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
42
|
+
int* m = (int*)NDL_PTR(lp, 1);
|
43
|
+
rtype* w = (rtype*)NDL_PTR(lp, 2);
|
44
|
+
dtype* z = (dtype*)NDL_PTR(lp, 3);
|
45
|
+
int* isuppz = (int*)NDL_PTR(lp, 4);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 5);
|
47
|
+
heevr_opt* opt = (heevr_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
49
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
50
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
51
|
+
const rtype abstol = 0.0;
|
52
|
+
const lapack_int i = LapackFn().call(
|
53
|
+
opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
|
54
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
|
55
|
+
*info = static_cast<int>(i);
|
56
|
+
}
|
57
|
+
|
58
|
+
static VALUE tiny_linalg_heevr(int argc, VALUE* argv, VALUE self) {
|
59
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
60
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
61
|
+
|
62
|
+
VALUE a_vnary = Qnil;
|
63
|
+
VALUE kw_args = Qnil;
|
64
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
65
|
+
ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
66
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
67
|
+
VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
68
|
+
rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
|
69
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
70
|
+
const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
|
71
|
+
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
|
72
|
+
const rtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
|
73
|
+
const rtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
74
|
+
const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
|
75
|
+
const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
76
|
+
const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
|
77
|
+
|
78
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
79
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
80
|
+
}
|
81
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
82
|
+
a_vnary = nary_dup(a_vnary);
|
83
|
+
}
|
84
|
+
|
85
|
+
narray_t* a_nary = nullptr;
|
86
|
+
GetNArray(a_vnary, a_nary);
|
87
|
+
if (NA_NDIM(a_nary) != 2) {
|
88
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
89
|
+
return Qnil;
|
90
|
+
}
|
91
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
92
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
93
|
+
return Qnil;
|
94
|
+
}
|
95
|
+
|
96
|
+
if (range == 'V' && vu <= vl) {
|
97
|
+
rb_raise(rb_eArgError, "vu must be greater than vl");
|
98
|
+
return Qnil;
|
99
|
+
}
|
100
|
+
|
101
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
102
|
+
if (range == 'I' && (il < 1 || il > n)) {
|
103
|
+
rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
|
104
|
+
return Qnil;
|
105
|
+
}
|
106
|
+
if (range == 'I' && (iu < 1 || iu > n)) {
|
107
|
+
rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
|
108
|
+
return Qnil;
|
109
|
+
}
|
110
|
+
if (range == 'I' && iu < il) {
|
111
|
+
rb_raise(rb_eArgError, "iu must be greater than or equal to il");
|
112
|
+
return Qnil;
|
113
|
+
}
|
114
|
+
|
115
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
116
|
+
size_t w_shape[1] = { m };
|
117
|
+
size_t z_shape[2] = { n, m };
|
118
|
+
size_t isuppz_shape[1] = { 2 * m };
|
119
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
120
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
|
121
|
+
ndfunc_t ndf = { iter_heevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
|
122
|
+
heevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
|
123
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
124
|
+
VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
125
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
126
|
+
|
127
|
+
RB_GC_GUARD(a_vnary);
|
128
|
+
|
129
|
+
return ret;
|
130
|
+
}
|
131
|
+
};
|
132
|
+
|
133
|
+
} // namespace TinyLinalg
|
@@ -104,7 +104,7 @@ private:
|
|
104
104
|
return Qnil;
|
105
105
|
}
|
106
106
|
narray_t* b_nary = nullptr;
|
107
|
-
GetNArray(
|
107
|
+
GetNArray(b_vnary, b_nary);
|
108
108
|
if (NA_NDIM(b_nary) != 2) {
|
109
109
|
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
110
110
|
return Qnil;
|
@@ -114,7 +114,25 @@ private:
|
|
114
114
|
return Qnil;
|
115
115
|
}
|
116
116
|
|
117
|
+
if (range == 'V' && vu <= vl) {
|
118
|
+
rb_raise(rb_eArgError, "vu must be greater than vl");
|
119
|
+
return Qnil;
|
120
|
+
}
|
121
|
+
|
117
122
|
const size_t n = NA_SHAPE(a_nary)[1];
|
123
|
+
if (range == 'I' && (il < 1 || il > n)) {
|
124
|
+
rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
|
125
|
+
return Qnil;
|
126
|
+
}
|
127
|
+
if (range == 'I' && (iu < 1 || iu > n)) {
|
128
|
+
rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
|
129
|
+
return Qnil;
|
130
|
+
}
|
131
|
+
if (range == 'I' && iu < il) {
|
132
|
+
rb_raise(rb_eArgError, "il must be less than or equal to iu");
|
133
|
+
return Qnil;
|
134
|
+
}
|
135
|
+
|
118
136
|
size_t m = range != 'I' ? n : iu - il + 1;
|
119
137
|
size_t w_shape[1] = { m };
|
120
138
|
size_t z_shape[2] = { n, m };
|
@@ -0,0 +1,86 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyEv {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
|
5
|
+
return LAPACKE_dsyev(matrix_layout, jobz, uplo, n, a, lda, w);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SSyEv {
|
10
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
|
11
|
+
return LAPACKE_ssyev(matrix_layout, jobz, uplo, n, a, lda, w);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
16
|
+
class SyEv {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
19
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syev), -1);
|
20
|
+
}
|
21
|
+
|
22
|
+
private:
|
23
|
+
struct syev_opt {
|
24
|
+
int matrix_layout;
|
25
|
+
char jobz;
|
26
|
+
char uplo;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_syev(na_loop_t* const lp) {
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
dtype* w = (dtype*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
syev_opt* opt = (syev_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
35
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
36
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
|
37
|
+
*info = static_cast<int>(i);
|
38
|
+
}
|
39
|
+
|
40
|
+
static VALUE tiny_linalg_syev(int argc, VALUE* argv, VALUE self) {
|
41
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
42
|
+
|
43
|
+
VALUE a_vnary = Qnil;
|
44
|
+
VALUE kw_args = Qnil;
|
45
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
46
|
+
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
47
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
48
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
49
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
50
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
51
|
+
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
|
52
|
+
|
53
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
54
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
55
|
+
}
|
56
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
57
|
+
a_vnary = nary_dup(a_vnary);
|
58
|
+
}
|
59
|
+
|
60
|
+
narray_t* a_nary = nullptr;
|
61
|
+
GetNArray(a_vnary, a_nary);
|
62
|
+
if (NA_NDIM(a_nary) != 2) {
|
63
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
64
|
+
return Qnil;
|
65
|
+
}
|
66
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
67
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
68
|
+
return Qnil;
|
69
|
+
}
|
70
|
+
|
71
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
72
|
+
size_t shape[1] = { n };
|
73
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
74
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
75
|
+
ndfunc_t ndf = { iter_syev, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
76
|
+
syev_opt opt = { matrix_layout, jobz, uplo };
|
77
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
78
|
+
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
79
|
+
|
80
|
+
RB_GC_GUARD(a_vnary);
|
81
|
+
|
82
|
+
return ret;
|
83
|
+
}
|
84
|
+
};
|
85
|
+
|
86
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,86 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyEvd {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
|
5
|
+
return LAPACKE_dsyevd(matrix_layout, jobz, uplo, n, a, lda, w);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SSyEvd {
|
10
|
+
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
|
11
|
+
return LAPACKE_ssyevd(matrix_layout, jobz, uplo, n, a, lda, w);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
16
|
+
class SyEvd {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
19
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevd), -1);
|
20
|
+
}
|
21
|
+
|
22
|
+
private:
|
23
|
+
struct syevd_opt {
|
24
|
+
int matrix_layout;
|
25
|
+
char jobz;
|
26
|
+
char uplo;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_syevd(na_loop_t* const lp) {
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
dtype* w = (dtype*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
syevd_opt* opt = (syevd_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
35
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
36
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
|
37
|
+
*info = static_cast<int>(i);
|
38
|
+
}
|
39
|
+
|
40
|
+
static VALUE tiny_linalg_syevd(int argc, VALUE* argv, VALUE self) {
|
41
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
42
|
+
|
43
|
+
VALUE a_vnary = Qnil;
|
44
|
+
VALUE kw_args = Qnil;
|
45
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
46
|
+
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
|
47
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
48
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
49
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
50
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
51
|
+
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;
|
52
|
+
|
53
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
54
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
55
|
+
}
|
56
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
57
|
+
a_vnary = nary_dup(a_vnary);
|
58
|
+
}
|
59
|
+
|
60
|
+
narray_t* a_nary = nullptr;
|
61
|
+
GetNArray(a_vnary, a_nary);
|
62
|
+
if (NA_NDIM(a_nary) != 2) {
|
63
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
64
|
+
return Qnil;
|
65
|
+
}
|
66
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
67
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
68
|
+
return Qnil;
|
69
|
+
}
|
70
|
+
|
71
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
72
|
+
size_t shape[1] = { n };
|
73
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
74
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
75
|
+
ndfunc_t ndf = { iter_syevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
76
|
+
syevd_opt opt = { matrix_layout, jobz, uplo };
|
77
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
78
|
+
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
|
79
|
+
|
80
|
+
RB_GC_GUARD(a_vnary);
|
81
|
+
|
82
|
+
return ret;
|
83
|
+
}
|
84
|
+
};
|
85
|
+
|
86
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,130 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyEvr {
|
4
|
+
lapack_int call(int matrix_layout, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, double* a, lapack_int lda, double vl, double vu, lapack_int il, lapack_int iu,
|
6
|
+
double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* isuppz) {
|
7
|
+
return LAPACKE_dsyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
|
8
|
+
}
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SSyEvr {
|
12
|
+
lapack_int call(int matrix_layout, char jobz, char range, char uplo,
|
13
|
+
lapack_int n, float* a, lapack_int lda, float vl, float vu, lapack_int il, lapack_int iu,
|
14
|
+
float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* isuppz) {
|
15
|
+
return LAPACKE_ssyevr(matrix_layout, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz);
|
16
|
+
}
|
17
|
+
};
|
18
|
+
|
19
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
20
|
+
class SyEvr {
|
21
|
+
public:
|
22
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
23
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevr), -1);
|
24
|
+
}
|
25
|
+
|
26
|
+
private:
|
27
|
+
struct syevr_opt {
|
28
|
+
int matrix_layout;
|
29
|
+
char jobz;
|
30
|
+
char range;
|
31
|
+
char uplo;
|
32
|
+
dtype vl;
|
33
|
+
dtype vu;
|
34
|
+
lapack_int il;
|
35
|
+
lapack_int iu;
|
36
|
+
};
|
37
|
+
|
38
|
+
static void iter_syevr(na_loop_t* const lp) {
|
39
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
40
|
+
int* m = (int*)NDL_PTR(lp, 1);
|
41
|
+
dtype* w = (dtype*)NDL_PTR(lp, 2);
|
42
|
+
dtype* z = (dtype*)NDL_PTR(lp, 3);
|
43
|
+
int* isuppz = (int*)NDL_PTR(lp, 4);
|
44
|
+
int* info = (int*)NDL_PTR(lp, 5);
|
45
|
+
syevr_opt* opt = (syevr_opt*)(lp->opt_ptr);
|
46
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
47
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
48
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
49
|
+
const dtype abstol = 0.0;
|
50
|
+
const lapack_int i = LapackFn().call(
|
51
|
+
opt->matrix_layout, opt->jobz, opt->range, opt->uplo, n, a, lda,
|
52
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, isuppz);
|
53
|
+
*info = static_cast<int>(i);
|
54
|
+
}
|
55
|
+
|
56
|
+
static VALUE tiny_linalg_syevr(int argc, VALUE* argv, VALUE self) {
|
57
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
58
|
+
|
59
|
+
VALUE a_vnary = Qnil;
|
60
|
+
VALUE kw_args = Qnil;
|
61
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
62
|
+
ID kw_table[8] = { rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
63
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
64
|
+
VALUE kw_values[8] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
65
|
+
rb_get_kwargs(kw_args, kw_table, 0, 8, kw_values);
|
66
|
+
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
|
67
|
+
const char range = kw_values[1] != Qundef ? Util().get_range(kw_values[1]) : 'A';
|
68
|
+
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
|
69
|
+
const dtype vl = kw_values[3] != Qundef ? NUM2DBL(kw_values[3]) : 0.0;
|
70
|
+
const dtype vu = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
71
|
+
const lapack_int il = kw_values[5] != Qundef ? NUM2INT(kw_values[5]) : 0;
|
72
|
+
const lapack_int iu = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
73
|
+
const int matrix_layout = kw_values[7] != Qundef ? Util().get_matrix_layout(kw_values[7]) : LAPACK_ROW_MAJOR;
|
74
|
+
|
75
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
76
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
77
|
+
}
|
78
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
79
|
+
a_vnary = nary_dup(a_vnary);
|
80
|
+
}
|
81
|
+
|
82
|
+
narray_t* a_nary = nullptr;
|
83
|
+
GetNArray(a_vnary, a_nary);
|
84
|
+
if (NA_NDIM(a_nary) != 2) {
|
85
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
86
|
+
return Qnil;
|
87
|
+
}
|
88
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
89
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
90
|
+
return Qnil;
|
91
|
+
}
|
92
|
+
|
93
|
+
if (range == 'V' && vu <= vl) {
|
94
|
+
rb_raise(rb_eArgError, "vu must be greater than vl");
|
95
|
+
return Qnil;
|
96
|
+
}
|
97
|
+
|
98
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
99
|
+
if (range == 'I' && (il < 1 || il > n)) {
|
100
|
+
rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
|
101
|
+
return Qnil;
|
102
|
+
}
|
103
|
+
if (range == 'I' && (iu < 1 || iu > n)) {
|
104
|
+
rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
|
105
|
+
return Qnil;
|
106
|
+
}
|
107
|
+
if (range == 'I' && iu < il) {
|
108
|
+
rb_raise(rb_eArgError, "iu must be greater than or equal to il");
|
109
|
+
return Qnil;
|
110
|
+
}
|
111
|
+
|
112
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
113
|
+
size_t w_shape[1] = { m };
|
114
|
+
size_t z_shape[2] = { n, m };
|
115
|
+
size_t isuppz_shape[1] = { 2 * m };
|
116
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
117
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, isuppz_shape }, { numo_cInt32, 0 } };
|
118
|
+
ndfunc_t ndf = { iter_syevr, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout };
|
119
|
+
syevr_opt opt = { matrix_layout, jobz, range, uplo, vl, vu, il, iu };
|
120
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
121
|
+
VALUE ret = rb_ary_new3(6, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
122
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
123
|
+
|
124
|
+
RB_GC_GUARD(a_vnary);
|
125
|
+
|
126
|
+
return ret;
|
127
|
+
}
|
128
|
+
};
|
129
|
+
|
130
|
+
} // namespace TinyLinalg
|
@@ -103,7 +103,7 @@ private:
|
|
103
103
|
return Qnil;
|
104
104
|
}
|
105
105
|
narray_t* b_nary = nullptr;
|
106
|
-
GetNArray(
|
106
|
+
GetNArray(b_vnary, b_nary);
|
107
107
|
if (NA_NDIM(b_nary) != 2) {
|
108
108
|
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
109
109
|
return Qnil;
|
@@ -113,7 +113,25 @@ private:
|
|
113
113
|
return Qnil;
|
114
114
|
}
|
115
115
|
|
116
|
+
if (range == 'V' && vu <= vl) {
|
117
|
+
rb_raise(rb_eArgError, "vu must be greater than vl");
|
118
|
+
return Qnil;
|
119
|
+
}
|
120
|
+
|
116
121
|
const size_t n = NA_SHAPE(a_nary)[1];
|
122
|
+
if (range == 'I' && (il < 1 || il > n)) {
|
123
|
+
rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
|
124
|
+
return Qnil;
|
125
|
+
}
|
126
|
+
if (range == 'I' && (iu < 1 || iu > n)) {
|
127
|
+
rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
|
128
|
+
return Qnil;
|
129
|
+
}
|
130
|
+
if (range == 'I' && iu < il) {
|
131
|
+
rb_raise(rb_eArgError, "iu must be greater than or equal to il");
|
132
|
+
return Qnil;
|
133
|
+
}
|
134
|
+
|
117
135
|
size_t m = range != 'I' ? n : iu - il + 1;
|
118
136
|
size_t w_shape[1] = { m };
|
119
137
|
size_t z_shape[2] = { n, m };
|
@@ -44,10 +44,16 @@
|
|
44
44
|
#include "lapack/gesvd.hpp"
|
45
45
|
#include "lapack/getrf.hpp"
|
46
46
|
#include "lapack/getri.hpp"
|
47
|
+
#include "lapack/heev.hpp"
|
48
|
+
#include "lapack/heevd.hpp"
|
49
|
+
#include "lapack/heevr.hpp"
|
47
50
|
#include "lapack/hegv.hpp"
|
48
51
|
#include "lapack/hegvd.hpp"
|
49
52
|
#include "lapack/hegvx.hpp"
|
50
53
|
#include "lapack/orgqr.hpp"
|
54
|
+
#include "lapack/syev.hpp"
|
55
|
+
#include "lapack/syevd.hpp"
|
56
|
+
#include "lapack/syevr.hpp"
|
51
57
|
#include "lapack/sygv.hpp"
|
52
58
|
#include "lapack/sygvd.hpp"
|
53
59
|
#include "lapack/sygvx.hpp"
|
@@ -312,6 +318,18 @@ extern "C" void Init_tiny_linalg(void) {
|
|
312
318
|
TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
|
313
319
|
TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
|
314
320
|
TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
|
321
|
+
TinyLinalg::SyEv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEv>::define_module_function(rb_mTinyLinalgLapack, "dsyev");
|
322
|
+
TinyLinalg::SyEv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEv>::define_module_function(rb_mTinyLinalgLapack, "ssyev");
|
323
|
+
TinyLinalg::HeEv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEv>::define_module_function(rb_mTinyLinalgLapack, "zheev");
|
324
|
+
TinyLinalg::HeEv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEv>::define_module_function(rb_mTinyLinalgLapack, "cheev");
|
325
|
+
TinyLinalg::SyEvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvd>::define_module_function(rb_mTinyLinalgLapack, "dsyevd");
|
326
|
+
TinyLinalg::SyEvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvd>::define_module_function(rb_mTinyLinalgLapack, "ssyevd");
|
327
|
+
TinyLinalg::HeEvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvd>::define_module_function(rb_mTinyLinalgLapack, "zheevd");
|
328
|
+
TinyLinalg::HeEvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvd>::define_module_function(rb_mTinyLinalgLapack, "cheevd");
|
329
|
+
TinyLinalg::SyEvr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvr>::define_module_function(rb_mTinyLinalgLapack, "dsyevr");
|
330
|
+
TinyLinalg::SyEvr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvr>::define_module_function(rb_mTinyLinalgLapack, "ssyevr");
|
331
|
+
TinyLinalg::HeEvr<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvr>::define_module_function(rb_mTinyLinalgLapack, "zheevr");
|
332
|
+
TinyLinalg::HeEvr<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvr>::define_module_function(rb_mTinyLinalgLapack, "cheevr");
|
315
333
|
TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
|
316
334
|
TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
|
317
335
|
TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -48,35 +48,50 @@ module Numo
|
|
48
48
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
49
49
|
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
50
50
|
# @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
|
51
|
-
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
|
51
|
+
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/ParameterLists, Metrics/PerceivedComplexity, Lint/UnusedMethodArgument
|
52
52
|
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
53
53
|
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
54
54
|
|
55
|
+
b_given = !b.nil?
|
56
|
+
raise ArgumentError, 'input array b must be 2-dimensional' if b_given && b.ndim != 2
|
57
|
+
raise ArgumentError, 'input array b must be square' if b_given && b.shape[0] != b.shape[1]
|
58
|
+
raise ArgumentError, "invalid array type: #{b.class}" if b_given && blas_char(b) == 'n'
|
59
|
+
|
55
60
|
bchr = blas_char(a)
|
56
61
|
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
57
62
|
|
58
|
-
unless b.nil?
|
59
|
-
raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
|
60
|
-
raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
|
61
|
-
raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
|
62
|
-
end
|
63
|
-
|
64
63
|
jobz = vals_only ? 'N' : 'V'
|
65
|
-
b = a.class.eye(a.shape[0]) if b.nil?
|
66
|
-
sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
67
64
|
|
68
|
-
if
|
69
|
-
|
70
|
-
|
65
|
+
if b_given
|
66
|
+
fnc = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
67
|
+
if vals_range.nil?
|
68
|
+
fnc << 'd' if turbo
|
69
|
+
vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(fnc.to_sym, a.dup, b.dup, jobz: jobz)
|
70
|
+
else
|
71
|
+
fnc << 'x'
|
72
|
+
il = vals_range.first(1)[0] + 1
|
73
|
+
iu = vals_range.last(1)[0] + 1
|
74
|
+
_a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
|
75
|
+
fnc.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
76
|
+
)
|
77
|
+
end
|
71
78
|
else
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
79
|
+
fnc = %w[d s].include?(bchr) ? "#{bchr}syev" : "#{bchr}heev"
|
80
|
+
if vals_range.nil?
|
81
|
+
fnc << 'd' if turbo
|
82
|
+
vecs, vals, _info = Numo::TinyLinalg::Lapack.send(fnc.to_sym, a.dup, jobz: jobz)
|
83
|
+
else
|
84
|
+
fnc << 'r'
|
85
|
+
il = vals_range.first(1)[0] + 1
|
86
|
+
iu = vals_range.last(1)[0] + 1
|
87
|
+
_a, _m, vals, vecs, _isuppz, _info = Numo::TinyLinalg::Lapack.send(
|
88
|
+
fnc.to_sym, a.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
89
|
+
)
|
90
|
+
end
|
78
91
|
end
|
92
|
+
|
79
93
|
vecs = nil if vals_only
|
94
|
+
|
80
95
|
[vals, vecs]
|
81
96
|
end
|
82
97
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-08-
|
11
|
+
date: 2023-08-10 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -52,10 +52,16 @@ files:
|
|
52
52
|
- ext/numo/tiny_linalg/lapack/gesvd.hpp
|
53
53
|
- ext/numo/tiny_linalg/lapack/getrf.hpp
|
54
54
|
- ext/numo/tiny_linalg/lapack/getri.hpp
|
55
|
+
- ext/numo/tiny_linalg/lapack/heev.hpp
|
56
|
+
- ext/numo/tiny_linalg/lapack/heevd.hpp
|
57
|
+
- ext/numo/tiny_linalg/lapack/heevr.hpp
|
55
58
|
- ext/numo/tiny_linalg/lapack/hegv.hpp
|
56
59
|
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
57
60
|
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
58
61
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
62
|
+
- ext/numo/tiny_linalg/lapack/syev.hpp
|
63
|
+
- ext/numo/tiny_linalg/lapack/syevd.hpp
|
64
|
+
- ext/numo/tiny_linalg/lapack/syevr.hpp
|
59
65
|
- ext/numo/tiny_linalg/lapack/sygv.hpp
|
60
66
|
- ext/numo/tiny_linalg/lapack/sygvd.hpp
|
61
67
|
- ext/numo/tiny_linalg/lapack/sygvx.hpp
|