numo-tiny_linalg 0.1.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +3 -1
- data/ext/numo/tiny_linalg/lapack/heevr.hpp +18 -0
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +18 -0
- data/ext/numo/tiny_linalg/lapack/potrf.hpp +93 -0
- data/ext/numo/tiny_linalg/lapack/potrs.hpp +121 -0
- data/ext/numo/tiny_linalg/lapack/syevr.hpp +18 -0
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +18 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +10 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +231 -21
- metadata +4 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: dd5b3d90b9c0bc323420f2eed5363490ec7c2ec4a3b0950573f54f234d534eaf
|
|
4
|
+
data.tar.gz: cd0a71621d1e3e935faba0a40036b3950c0f6ac32a74cdaad46b0e313b5dfa2f
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: f904406ce9c883d59d93afc061bfc65c96073dc9660cb65e6eefa97e2be32cc8c79ac26970a8b181cd3acbdbd412cf97c08153017956c4cc4f2b29bf28bcda46
|
|
7
|
+
data.tar.gz: 8f3ee9b0e0e9c3789f61148db03d43c245dab66fc4cb428488cbe668dad00b0d7fb73ef825e5dc0c5772ba817d41bb90d5fce843ece1f1adf6fe81857b9917ff
|
data/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,15 @@
|
|
|
1
1
|
## [Unreleased]
|
|
2
2
|
|
|
3
|
+
## [[0.3.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.2.0...v0.3.0)] - 2023-08-13
|
|
4
|
+
- Add cholesky and cho_solve module functions to TinyLinalg.
|
|
5
|
+
|
|
6
|
+
**Breaking change**
|
|
7
|
+
- Change to raise NotImplementedError when calling a method not yet implemented in Numo::TinyLinalg.
|
|
8
|
+
|
|
9
|
+
## [[0.2.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.2...v0.2.0)] - 2023-08-11
|
|
10
|
+
**Breaking change**
|
|
11
|
+
- Change LAPACK function to call when array b is not given to TinyLinalg.eigh method.
|
|
12
|
+
|
|
3
13
|
## [[0.1.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.1...v0.1.2)] - 2023-08-09
|
|
4
14
|
- Add dsyev, ssyev, zheev, and cheev module functions to TinyLinalg::Lapack.
|
|
5
15
|
- Add dsyevd, ssyevd, zheevd, and cheevd module functions to TinyLinalg::Lapack.
|
data/README.md
CHANGED
|
@@ -6,7 +6,9 @@
|
|
|
6
6
|
[](https://yoshoku.github.io/numo-tiny_linalg/doc/)
|
|
7
7
|
|
|
8
8
|
Numo::TinyLinalg is a subset library from [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) consisting only of methods used in Machine Learning algorithms.
|
|
9
|
-
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, and svd.
|
|
9
|
+
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, cholesky, cho_solve and svd.
|
|
10
|
+
|
|
11
|
+
Note that the version numbering rule of Numo::TinyLinalg is not compatible with that of Numo::Linalg.
|
|
10
12
|
|
|
11
13
|
## Installation
|
|
12
14
|
Unlike Numo::Linalg, Numo::TinyLinalg only supports OpenBLAS as a backend library for BLAS and LAPACK.
|
|
@@ -93,7 +93,25 @@ private:
|
|
|
93
93
|
return Qnil;
|
|
94
94
|
}
|
|
95
95
|
|
|
96
|
+
if (range == 'V' && vu <= vl) {
|
|
97
|
+
rb_raise(rb_eArgError, "vu must be greater than vl");
|
|
98
|
+
return Qnil;
|
|
99
|
+
}
|
|
100
|
+
|
|
96
101
|
const size_t n = NA_SHAPE(a_nary)[1];
|
|
102
|
+
if (range == 'I' && (il < 1 || il > n)) {
|
|
103
|
+
rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
|
|
104
|
+
return Qnil;
|
|
105
|
+
}
|
|
106
|
+
if (range == 'I' && (iu < 1 || iu > n)) {
|
|
107
|
+
rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
|
|
108
|
+
return Qnil;
|
|
109
|
+
}
|
|
110
|
+
if (range == 'I' && iu < il) {
|
|
111
|
+
rb_raise(rb_eArgError, "iu must be greater than or equal to il");
|
|
112
|
+
return Qnil;
|
|
113
|
+
}
|
|
114
|
+
|
|
97
115
|
size_t m = range != 'I' ? n : iu - il + 1;
|
|
98
116
|
size_t w_shape[1] = { m };
|
|
99
117
|
size_t z_shape[2] = { n, m };
|
|
@@ -114,7 +114,25 @@ private:
|
|
|
114
114
|
return Qnil;
|
|
115
115
|
}
|
|
116
116
|
|
|
117
|
+
if (range == 'V' && vu <= vl) {
|
|
118
|
+
rb_raise(rb_eArgError, "vu must be greater than vl");
|
|
119
|
+
return Qnil;
|
|
120
|
+
}
|
|
121
|
+
|
|
117
122
|
const size_t n = NA_SHAPE(a_nary)[1];
|
|
123
|
+
if (range == 'I' && (il < 1 || il > n)) {
|
|
124
|
+
rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
|
|
125
|
+
return Qnil;
|
|
126
|
+
}
|
|
127
|
+
if (range == 'I' && (iu < 1 || iu > n)) {
|
|
128
|
+
rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
|
|
129
|
+
return Qnil;
|
|
130
|
+
}
|
|
131
|
+
if (range == 'I' && iu < il) {
|
|
132
|
+
rb_raise(rb_eArgError, "il must be less than or equal to iu");
|
|
133
|
+
return Qnil;
|
|
134
|
+
}
|
|
135
|
+
|
|
118
136
|
size_t m = range != 'I' ? n : iu - il + 1;
|
|
119
137
|
size_t w_shape[1] = { m };
|
|
120
138
|
size_t z_shape[2] = { n, m };
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
namespace TinyLinalg {
|
|
2
|
+
|
|
3
|
+
struct DPoTrf {
|
|
4
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, double* a, lapack_int lda) {
|
|
5
|
+
return LAPACKE_dpotrf(matrix_layout, uplo, n, a, lda);
|
|
6
|
+
}
|
|
7
|
+
};
|
|
8
|
+
|
|
9
|
+
struct SPoTrf {
|
|
10
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, float* a, lapack_int lda) {
|
|
11
|
+
return LAPACKE_spotrf(matrix_layout, uplo, n, a, lda);
|
|
12
|
+
}
|
|
13
|
+
};
|
|
14
|
+
|
|
15
|
+
struct ZPoTrf {
|
|
16
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda) {
|
|
17
|
+
return LAPACKE_zpotrf(matrix_layout, uplo, n, a, lda);
|
|
18
|
+
}
|
|
19
|
+
};
|
|
20
|
+
|
|
21
|
+
struct CPoTrf {
|
|
22
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda) {
|
|
23
|
+
return LAPACKE_cpotrf(matrix_layout, uplo, n, a, lda);
|
|
24
|
+
}
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
|
28
|
+
class PoTrf {
|
|
29
|
+
public:
|
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_potrf), -1);
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
private:
|
|
35
|
+
struct potrf_opt {
|
|
36
|
+
int matrix_layout;
|
|
37
|
+
char uplo;
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
static void iter_potrf(na_loop_t* const lp) {
|
|
41
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
|
42
|
+
int* info = (int*)NDL_PTR(lp, 1);
|
|
43
|
+
potrf_opt* opt = (potrf_opt*)(lp->opt_ptr);
|
|
44
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
|
45
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[1];
|
|
46
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, n, a, lda);
|
|
47
|
+
*info = static_cast<int>(i);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
static VALUE tiny_linalg_potrf(int argc, VALUE* argv, VALUE self) {
|
|
51
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
|
52
|
+
|
|
53
|
+
VALUE a_vnary = Qnil;
|
|
54
|
+
VALUE kw_args = Qnil;
|
|
55
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
|
56
|
+
ID kw_table[2] = { rb_intern("order"), rb_intern("uplo") };
|
|
57
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
|
58
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
|
59
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
|
60
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
|
61
|
+
|
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
|
64
|
+
}
|
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
narray_t* a_nary = NULL;
|
|
70
|
+
GetNArray(a_vnary, a_nary);
|
|
71
|
+
if (NA_NDIM(a_nary) != 2) {
|
|
72
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
|
73
|
+
return Qnil;
|
|
74
|
+
}
|
|
75
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
|
76
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
|
77
|
+
return Qnil;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
|
81
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
|
82
|
+
ndfunc_t ndf = { iter_potrf, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
|
83
|
+
potrf_opt opt = { matrix_layout, uplo };
|
|
84
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
|
85
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res);
|
|
86
|
+
|
|
87
|
+
RB_GC_GUARD(a_vnary);
|
|
88
|
+
|
|
89
|
+
return ret;
|
|
90
|
+
}
|
|
91
|
+
};
|
|
92
|
+
|
|
93
|
+
} // namespace TinyLinalg
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
namespace TinyLinalg {
|
|
2
|
+
|
|
3
|
+
struct DPoTrs {
|
|
4
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
|
5
|
+
const double* a, lapack_int lda, double* b, lapack_int ldb) {
|
|
6
|
+
return LAPACKE_dpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
|
7
|
+
}
|
|
8
|
+
};
|
|
9
|
+
|
|
10
|
+
struct SPoTrs {
|
|
11
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
|
12
|
+
const float* a, lapack_int lda, float* b, lapack_int ldb) {
|
|
13
|
+
return LAPACKE_spotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
|
14
|
+
}
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
struct ZPoTrs {
|
|
18
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
|
19
|
+
const lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb) {
|
|
20
|
+
return LAPACKE_zpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
|
21
|
+
}
|
|
22
|
+
};
|
|
23
|
+
|
|
24
|
+
struct CPoTrs {
|
|
25
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
|
26
|
+
const lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb) {
|
|
27
|
+
return LAPACKE_cpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
|
28
|
+
}
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
|
32
|
+
class PoTrs {
|
|
33
|
+
public:
|
|
34
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
|
35
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_potrs), -1);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
private:
|
|
39
|
+
struct potrs_opt {
|
|
40
|
+
int matrix_layout;
|
|
41
|
+
char uplo;
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
static void iter_potrs(na_loop_t* const lp) {
|
|
45
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
|
46
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
|
47
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
|
48
|
+
potrs_opt* opt = (potrs_opt*)(lp->opt_ptr);
|
|
49
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
|
50
|
+
const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
|
|
51
|
+
const lapack_int lda = n;
|
|
52
|
+
const lapack_int ldb = nrhs;
|
|
53
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, n, nrhs, a, lda, b, ldb);
|
|
54
|
+
*info = static_cast<int>(i);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
static VALUE tiny_linalg_potrs(int argc, VALUE* argv, VALUE self) {
|
|
58
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
|
59
|
+
|
|
60
|
+
VALUE a_vnary = Qnil;
|
|
61
|
+
VALUE b_vnary = Qnil;
|
|
62
|
+
VALUE kw_args = Qnil;
|
|
63
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
|
64
|
+
ID kw_table[2] = { rb_intern("order"), rb_intern("uplo") };
|
|
65
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
|
66
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
|
67
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
|
68
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
|
69
|
+
|
|
70
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
|
71
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
|
72
|
+
}
|
|
73
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
|
74
|
+
a_vnary = nary_dup(a_vnary);
|
|
75
|
+
}
|
|
76
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
|
77
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
|
78
|
+
}
|
|
79
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
|
80
|
+
b_vnary = nary_dup(b_vnary);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
narray_t* a_nary = NULL;
|
|
84
|
+
GetNArray(a_vnary, a_nary);
|
|
85
|
+
if (NA_NDIM(a_nary) != 2) {
|
|
86
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
|
87
|
+
return Qnil;
|
|
88
|
+
}
|
|
89
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
|
90
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
|
91
|
+
return Qnil;
|
|
92
|
+
}
|
|
93
|
+
narray_t* b_nary = NULL;
|
|
94
|
+
GetNArray(b_vnary, b_nary);
|
|
95
|
+
const int b_n_dims = NA_NDIM(b_nary);
|
|
96
|
+
if (b_n_dims != 1 && b_n_dims != 2) {
|
|
97
|
+
rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
|
|
98
|
+
return Qnil;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
lapack_int n = NA_SHAPE(a_nary)[0];
|
|
102
|
+
lapack_int nb = NA_SHAPE(b_nary)[0];
|
|
103
|
+
if (n != nb) {
|
|
104
|
+
rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
ndfunc_arg_in_t ain[2] = { { nary_dtype, 2 }, { OVERWRITE, b_n_dims } };
|
|
108
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
|
109
|
+
ndfunc_t ndf = { iter_potrs, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
|
110
|
+
potrs_opt opt = { matrix_layout, uplo };
|
|
111
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
|
112
|
+
VALUE ret = rb_ary_new3(2, b_vnary, res);
|
|
113
|
+
|
|
114
|
+
RB_GC_GUARD(a_vnary);
|
|
115
|
+
RB_GC_GUARD(b_vnary);
|
|
116
|
+
|
|
117
|
+
return ret;
|
|
118
|
+
}
|
|
119
|
+
};
|
|
120
|
+
|
|
121
|
+
} // namespace TinyLinalg
|
|
@@ -90,7 +90,25 @@ private:
|
|
|
90
90
|
return Qnil;
|
|
91
91
|
}
|
|
92
92
|
|
|
93
|
+
if (range == 'V' && vu <= vl) {
|
|
94
|
+
rb_raise(rb_eArgError, "vu must be greater than vl");
|
|
95
|
+
return Qnil;
|
|
96
|
+
}
|
|
97
|
+
|
|
93
98
|
const size_t n = NA_SHAPE(a_nary)[1];
|
|
99
|
+
if (range == 'I' && (il < 1 || il > n)) {
|
|
100
|
+
rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
|
|
101
|
+
return Qnil;
|
|
102
|
+
}
|
|
103
|
+
if (range == 'I' && (iu < 1 || iu > n)) {
|
|
104
|
+
rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
|
|
105
|
+
return Qnil;
|
|
106
|
+
}
|
|
107
|
+
if (range == 'I' && iu < il) {
|
|
108
|
+
rb_raise(rb_eArgError, "iu must be greater than or equal to il");
|
|
109
|
+
return Qnil;
|
|
110
|
+
}
|
|
111
|
+
|
|
94
112
|
size_t m = range != 'I' ? n : iu - il + 1;
|
|
95
113
|
size_t w_shape[1] = { m };
|
|
96
114
|
size_t z_shape[2] = { n, m };
|
|
@@ -113,7 +113,25 @@ private:
|
|
|
113
113
|
return Qnil;
|
|
114
114
|
}
|
|
115
115
|
|
|
116
|
+
if (range == 'V' && vu <= vl) {
|
|
117
|
+
rb_raise(rb_eArgError, "vu must be greater than vl");
|
|
118
|
+
return Qnil;
|
|
119
|
+
}
|
|
120
|
+
|
|
116
121
|
const size_t n = NA_SHAPE(a_nary)[1];
|
|
122
|
+
if (range == 'I' && (il < 1 || il > n)) {
|
|
123
|
+
rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
|
|
124
|
+
return Qnil;
|
|
125
|
+
}
|
|
126
|
+
if (range == 'I' && (iu < 1 || iu > n)) {
|
|
127
|
+
rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
|
|
128
|
+
return Qnil;
|
|
129
|
+
}
|
|
130
|
+
if (range == 'I' && iu < il) {
|
|
131
|
+
rb_raise(rb_eArgError, "iu must be greater than or equal to il");
|
|
132
|
+
return Qnil;
|
|
133
|
+
}
|
|
134
|
+
|
|
117
135
|
size_t m = range != 'I' ? n : iu - il + 1;
|
|
118
136
|
size_t w_shape[1] = { m };
|
|
119
137
|
size_t z_shape[2] = { n, m };
|
|
@@ -51,6 +51,8 @@
|
|
|
51
51
|
#include "lapack/hegvd.hpp"
|
|
52
52
|
#include "lapack/hegvx.hpp"
|
|
53
53
|
#include "lapack/orgqr.hpp"
|
|
54
|
+
#include "lapack/potrf.hpp"
|
|
55
|
+
#include "lapack/potrs.hpp"
|
|
54
56
|
#include "lapack/syev.hpp"
|
|
55
57
|
#include "lapack/syevd.hpp"
|
|
56
58
|
#include "lapack/syevr.hpp"
|
|
@@ -310,6 +312,14 @@ extern "C" void Init_tiny_linalg(void) {
|
|
|
310
312
|
TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
|
|
311
313
|
TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
|
|
312
314
|
TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
|
|
315
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrf>::define_module_function(rb_mTinyLinalgLapack, "dpotrf");
|
|
316
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrf>::define_module_function(rb_mTinyLinalgLapack, "spotrf");
|
|
317
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrf>::define_module_function(rb_mTinyLinalgLapack, "zpotrf");
|
|
318
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CPoTrf>::define_module_function(rb_mTinyLinalgLapack, "cpotrf");
|
|
319
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrs>::define_module_function(rb_mTinyLinalgLapack, "dpotrs");
|
|
320
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrs>::define_module_function(rb_mTinyLinalgLapack, "spotrs");
|
|
321
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrs>::define_module_function(rb_mTinyLinalgLapack, "zpotrs");
|
|
322
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CPoTrs>::define_module_function(rb_mTinyLinalgLapack, "cpotrs");
|
|
313
323
|
TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
|
|
314
324
|
TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
|
|
315
325
|
TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
|
data/lib/numo/tiny_linalg.rb
CHANGED
|
@@ -48,38 +48,143 @@ module Numo
|
|
|
48
48
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
|
49
49
|
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
|
50
50
|
# @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
|
|
51
|
-
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
|
|
51
|
+
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/ParameterLists, Metrics/PerceivedComplexity, Lint/UnusedMethodArgument
|
|
52
52
|
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
|
53
53
|
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
|
54
54
|
|
|
55
|
+
b_given = !b.nil?
|
|
56
|
+
raise ArgumentError, 'input array b must be 2-dimensional' if b_given && b.ndim != 2
|
|
57
|
+
raise ArgumentError, 'input array b must be square' if b_given && b.shape[0] != b.shape[1]
|
|
58
|
+
raise ArgumentError, "invalid array type: #{b.class}" if b_given && blas_char(b) == 'n'
|
|
59
|
+
|
|
55
60
|
bchr = blas_char(a)
|
|
56
61
|
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
|
57
62
|
|
|
58
|
-
unless b.nil?
|
|
59
|
-
raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
|
|
60
|
-
raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
|
|
61
|
-
raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
|
|
62
|
-
end
|
|
63
|
-
|
|
64
63
|
jobz = vals_only ? 'N' : 'V'
|
|
65
|
-
b = a.class.eye(a.shape[0]) if b.nil?
|
|
66
|
-
sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
|
67
64
|
|
|
68
|
-
if
|
|
69
|
-
|
|
70
|
-
|
|
65
|
+
if b_given
|
|
66
|
+
fnc = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
|
67
|
+
if vals_range.nil?
|
|
68
|
+
fnc << 'd' if turbo
|
|
69
|
+
vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(fnc.to_sym, a.dup, b.dup, jobz: jobz)
|
|
70
|
+
else
|
|
71
|
+
fnc << 'x'
|
|
72
|
+
il = vals_range.first(1)[0] + 1
|
|
73
|
+
iu = vals_range.last(1)[0] + 1
|
|
74
|
+
_a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
|
|
75
|
+
fnc.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
|
76
|
+
)
|
|
77
|
+
end
|
|
71
78
|
else
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
79
|
+
fnc = %w[d s].include?(bchr) ? "#{bchr}syev" : "#{bchr}heev"
|
|
80
|
+
if vals_range.nil?
|
|
81
|
+
fnc << 'd' if turbo
|
|
82
|
+
vecs, vals, _info = Numo::TinyLinalg::Lapack.send(fnc.to_sym, a.dup, jobz: jobz)
|
|
83
|
+
else
|
|
84
|
+
fnc << 'r'
|
|
85
|
+
il = vals_range.first(1)[0] + 1
|
|
86
|
+
iu = vals_range.last(1)[0] + 1
|
|
87
|
+
_a, _m, vals, vecs, _isuppz, _info = Numo::TinyLinalg::Lapack.send(
|
|
88
|
+
fnc.to_sym, a.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
|
89
|
+
)
|
|
90
|
+
end
|
|
78
91
|
end
|
|
92
|
+
|
|
79
93
|
vecs = nil if vals_only
|
|
94
|
+
|
|
80
95
|
[vals, vecs]
|
|
81
96
|
end
|
|
82
97
|
|
|
98
|
+
# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
|
|
99
|
+
#
|
|
100
|
+
# @example
|
|
101
|
+
# require 'numo/tiny_linalg'
|
|
102
|
+
#
|
|
103
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
|
104
|
+
#
|
|
105
|
+
# s = Numo::DFloat.new(3, 3).rand - 0.5
|
|
106
|
+
# a = s.transpose.dot(s)
|
|
107
|
+
# u = Numo::Linalg.cholesky(a)
|
|
108
|
+
#
|
|
109
|
+
# pp u
|
|
110
|
+
# # =>
|
|
111
|
+
# # Numo::DFloat#shape=[3,3]
|
|
112
|
+
# # [[0.532006, 0.338183, -0.18036],
|
|
113
|
+
# # [0, 0.325153, 0.011721],
|
|
114
|
+
# # [0, 0, 0.436738]]
|
|
115
|
+
#
|
|
116
|
+
# pp (a - u.transpose.dot(u)).abs.max
|
|
117
|
+
# # => 1.3877787807814457e-17
|
|
118
|
+
#
|
|
119
|
+
# l = Numo::Linalg.cholesky(a, uplo: 'L')
|
|
120
|
+
#
|
|
121
|
+
# pp l
|
|
122
|
+
# # =>
|
|
123
|
+
# # Numo::DFloat#shape=[3,3]
|
|
124
|
+
# # [[0.532006, 0, 0],
|
|
125
|
+
# # [0.338183, 0.325153, 0],
|
|
126
|
+
# # [-0.18036, 0.011721, 0.436738]]
|
|
127
|
+
#
|
|
128
|
+
# pp (a - l.dot(l.transpose)).abs.max
|
|
129
|
+
# # => 1.3877787807814457e-17
|
|
130
|
+
#
|
|
131
|
+
# @param a [Numo::NArray] The n-by-n symmetric matrix.
|
|
132
|
+
# @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
|
|
133
|
+
# @return [Numo::NArray] The upper- or lower-triangular Cholesky factor of a.
|
|
134
|
+
def cholesky(a, uplo: 'U')
|
|
135
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
|
136
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
|
137
|
+
|
|
138
|
+
bchr = blas_char(a)
|
|
139
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
|
140
|
+
|
|
141
|
+
fnc = "#{bchr}potrf".to_sym
|
|
142
|
+
c, _info = Numo::TinyLinalg::Lapack.send(fnc, a.dup, uplo: uplo)
|
|
143
|
+
|
|
144
|
+
case uplo
|
|
145
|
+
when 'U'
|
|
146
|
+
c.triu
|
|
147
|
+
when 'L'
|
|
148
|
+
c.tril
|
|
149
|
+
else
|
|
150
|
+
raise ArgumentError, "invalid uplo: #{uplo}"
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` with the Cholesky factorization of `A`.
|
|
155
|
+
#
|
|
156
|
+
# @example
|
|
157
|
+
# require 'numo/tiny_linalg'
|
|
158
|
+
#
|
|
159
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
|
160
|
+
#
|
|
161
|
+
# s = Numo::DFloat.new(3, 3).rand - 0.5
|
|
162
|
+
# a = s.transpose.dot(s)
|
|
163
|
+
# u = Numo::Linalg.cholesky(a)
|
|
164
|
+
#
|
|
165
|
+
# b = Numo::DFloat.new(3).rand
|
|
166
|
+
# x = Numo::Linalg.cho_solve(u, b)
|
|
167
|
+
#
|
|
168
|
+
# puts (b - a.dot(x)).abs.max
|
|
169
|
+
# => 0.0
|
|
170
|
+
#
|
|
171
|
+
# @param a [Numo::NArray] The n-by-n cholesky factor.
|
|
172
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
|
173
|
+
# @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
|
|
174
|
+
# @return [Numo::NArray] The solution vector or matrix `X`.
|
|
175
|
+
def cho_solve(a, b, uplo: 'U')
|
|
176
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
|
177
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
|
178
|
+
raise ArgumentError, "incompatible dimensions: a.shape[0] = #{a.shape[0]} != b.shape[0] = #{b.shape[0]}" if a.shape[0] != b.shape[0]
|
|
179
|
+
|
|
180
|
+
bchr = blas_char(a, b)
|
|
181
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
|
182
|
+
|
|
183
|
+
fnc = "#{bchr}potrs".to_sym
|
|
184
|
+
x, _info = Numo::TinyLinalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
|
|
185
|
+
x
|
|
186
|
+
end
|
|
187
|
+
|
|
83
188
|
# Computes the determinant of matrix.
|
|
84
189
|
#
|
|
85
190
|
# @example
|
|
@@ -256,7 +361,7 @@ module Numo
|
|
|
256
361
|
[q, r]
|
|
257
362
|
end
|
|
258
363
|
|
|
259
|
-
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `
|
|
364
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `A`.
|
|
260
365
|
#
|
|
261
366
|
# @example
|
|
262
367
|
# require 'numo/tiny_linalg'
|
|
@@ -279,10 +384,10 @@ module Numo
|
|
|
279
384
|
# # => 2.1081041547796492e-16
|
|
280
385
|
#
|
|
281
386
|
# @param a [Numo::NArray] The n-by-n square matrix.
|
|
282
|
-
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix
|
|
387
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
|
283
388
|
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
|
284
389
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
|
285
|
-
# @return [Numo::NArray] The solusion vector / matrix `
|
|
390
|
+
# @return [Numo::NArray] The solusion vector / matrix `X`.
|
|
286
391
|
def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
|
287
392
|
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
|
288
393
|
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
|
@@ -353,5 +458,110 @@ module Numo
|
|
|
353
458
|
|
|
354
459
|
[s, u, vt]
|
|
355
460
|
end
|
|
461
|
+
|
|
462
|
+
# @!visibility private
|
|
463
|
+
def matmul(*args)
|
|
464
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
465
|
+
end
|
|
466
|
+
|
|
467
|
+
# @!visibility private
|
|
468
|
+
def matrix_power(*args)
|
|
469
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
470
|
+
end
|
|
471
|
+
|
|
472
|
+
# @!visibility private
|
|
473
|
+
def svdvals(*args)
|
|
474
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
475
|
+
end
|
|
476
|
+
|
|
477
|
+
# @!visibility private
|
|
478
|
+
def orth(*args)
|
|
479
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
480
|
+
end
|
|
481
|
+
|
|
482
|
+
# @!visibility private
|
|
483
|
+
def null_space(*args)
|
|
484
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
485
|
+
end
|
|
486
|
+
|
|
487
|
+
# @!visibility private
|
|
488
|
+
def lu(*args)
|
|
489
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
490
|
+
end
|
|
491
|
+
|
|
492
|
+
# @!visibility private
|
|
493
|
+
def lu_fact(*args)
|
|
494
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
495
|
+
end
|
|
496
|
+
|
|
497
|
+
# @!visibility private
|
|
498
|
+
def lu_inv(*args)
|
|
499
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
500
|
+
end
|
|
501
|
+
|
|
502
|
+
# @!visibility private
|
|
503
|
+
def lu_solve(*args)
|
|
504
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
505
|
+
end
|
|
506
|
+
|
|
507
|
+
# @!visibility private
|
|
508
|
+
def ldl(*args)
|
|
509
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
510
|
+
end
|
|
511
|
+
|
|
512
|
+
# @!visibility private
|
|
513
|
+
def cho_fact(*args)
|
|
514
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
515
|
+
end
|
|
516
|
+
|
|
517
|
+
# @!visibility private
|
|
518
|
+
def cho_inv(*args)
|
|
519
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
520
|
+
end
|
|
521
|
+
|
|
522
|
+
# @!visibility private
|
|
523
|
+
def eig(*args)
|
|
524
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
525
|
+
end
|
|
526
|
+
|
|
527
|
+
# @!visibility private
|
|
528
|
+
def eigvals(*args)
|
|
529
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
530
|
+
end
|
|
531
|
+
|
|
532
|
+
# @!visibility private
|
|
533
|
+
def eigvalsh(*args)
|
|
534
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
535
|
+
end
|
|
536
|
+
|
|
537
|
+
# @!visibility private
|
|
538
|
+
def norm(*args)
|
|
539
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
540
|
+
end
|
|
541
|
+
|
|
542
|
+
# @!visibility private
|
|
543
|
+
def cond(*args)
|
|
544
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
545
|
+
end
|
|
546
|
+
|
|
547
|
+
# @!visibility private
|
|
548
|
+
def slogdet(*args)
|
|
549
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
550
|
+
end
|
|
551
|
+
|
|
552
|
+
# @!visibility private
|
|
553
|
+
def matrix_rank(*args)
|
|
554
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
555
|
+
end
|
|
556
|
+
|
|
557
|
+
# @!visibility private
|
|
558
|
+
def lstsq(*args)
|
|
559
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
560
|
+
end
|
|
561
|
+
|
|
562
|
+
# @!visibility private
|
|
563
|
+
def expm(*args)
|
|
564
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
|
565
|
+
end
|
|
356
566
|
end
|
|
357
567
|
end
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: numo-tiny_linalg
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.3.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- yoshoku
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2023-08-
|
|
11
|
+
date: 2023-08-13 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: numo-narray
|
|
@@ -59,6 +59,8 @@ files:
|
|
|
59
59
|
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
|
60
60
|
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
|
61
61
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
|
62
|
+
- ext/numo/tiny_linalg/lapack/potrf.hpp
|
|
63
|
+
- ext/numo/tiny_linalg/lapack/potrs.hpp
|
|
62
64
|
- ext/numo/tiny_linalg/lapack/syev.hpp
|
|
63
65
|
- ext/numo/tiny_linalg/lapack/syevd.hpp
|
|
64
66
|
- ext/numo/tiny_linalg/lapack/syevr.hpp
|