numo-tiny_linalg 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/numo/tiny_linalg/lapack/potrf.hpp +93 -0
- data/ext/numo/tiny_linalg/lapack/potrs.hpp +121 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +10 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +198 -3
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: dd5b3d90b9c0bc323420f2eed5363490ec7c2ec4a3b0950573f54f234d534eaf
|
4
|
+
data.tar.gz: cd0a71621d1e3e935faba0a40036b3950c0f6ac32a74cdaad46b0e313b5dfa2f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f904406ce9c883d59d93afc061bfc65c96073dc9660cb65e6eefa97e2be32cc8c79ac26970a8b181cd3acbdbd412cf97c08153017956c4cc4f2b29bf28bcda46
|
7
|
+
data.tar.gz: 8f3ee9b0e0e9c3789f61148db03d43c245dab66fc4cb428488cbe668dad00b0d7fb73ef825e5dc0c5772ba817d41bb90d5fce843ece1f1adf6fe81857b9917ff
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,11 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.3.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.2.0...v0.3.0)] - 2023-08-13
|
4
|
+
- Add cholesky and cho_solve module functions to TinyLinalg.
|
5
|
+
|
6
|
+
**Breaking change**
|
7
|
+
- Change to raise NotImplementedError when calling a method not yet implemented in Numo::TinyLinalg.
|
8
|
+
|
3
9
|
## [[0.2.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.2...v0.2.0)] - 2023-08-11
|
4
10
|
**Breaking change**
|
5
11
|
- Change LAPACK function to call when array b is not given to TinyLinalg.eigh method.
|
data/README.md
CHANGED
@@ -6,7 +6,7 @@
|
|
6
6
|
[](https://yoshoku.github.io/numo-tiny_linalg/doc/)
|
7
7
|
|
8
8
|
Numo::TinyLinalg is a subset library from [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) consisting only of methods used in Machine Learning algorithms.
|
9
|
-
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, and svd.
|
9
|
+
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, cholesky, cho_solve and svd.
|
10
10
|
|
11
11
|
Note that the version numbering rule of Numo::TinyLinalg is not compatible with that of Numo::Linalg.
|
12
12
|
|
@@ -0,0 +1,93 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DPoTrf {
|
4
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, double* a, lapack_int lda) {
|
5
|
+
return LAPACKE_dpotrf(matrix_layout, uplo, n, a, lda);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SPoTrf {
|
10
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, float* a, lapack_int lda) {
|
11
|
+
return LAPACKE_spotrf(matrix_layout, uplo, n, a, lda);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
struct ZPoTrf {
|
16
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda) {
|
17
|
+
return LAPACKE_zpotrf(matrix_layout, uplo, n, a, lda);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
struct CPoTrf {
|
22
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda) {
|
23
|
+
return LAPACKE_cpotrf(matrix_layout, uplo, n, a, lda);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
28
|
+
class PoTrf {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_potrf), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct potrf_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
char uplo;
|
38
|
+
};
|
39
|
+
|
40
|
+
static void iter_potrf(na_loop_t* const lp) {
|
41
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
42
|
+
int* info = (int*)NDL_PTR(lp, 1);
|
43
|
+
potrf_opt* opt = (potrf_opt*)(lp->opt_ptr);
|
44
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
45
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[1];
|
46
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, n, a, lda);
|
47
|
+
*info = static_cast<int>(i);
|
48
|
+
}
|
49
|
+
|
50
|
+
static VALUE tiny_linalg_potrf(int argc, VALUE* argv, VALUE self) {
|
51
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
52
|
+
|
53
|
+
VALUE a_vnary = Qnil;
|
54
|
+
VALUE kw_args = Qnil;
|
55
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
56
|
+
ID kw_table[2] = { rb_intern("order"), rb_intern("uplo") };
|
57
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
58
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
59
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
60
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
|
69
|
+
narray_t* a_nary = NULL;
|
70
|
+
GetNArray(a_vnary, a_nary);
|
71
|
+
if (NA_NDIM(a_nary) != 2) {
|
72
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
73
|
+
return Qnil;
|
74
|
+
}
|
75
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
76
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
77
|
+
return Qnil;
|
78
|
+
}
|
79
|
+
|
80
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
81
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
82
|
+
ndfunc_t ndf = { iter_potrf, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
83
|
+
potrf_opt opt = { matrix_layout, uplo };
|
84
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
85
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res);
|
86
|
+
|
87
|
+
RB_GC_GUARD(a_vnary);
|
88
|
+
|
89
|
+
return ret;
|
90
|
+
}
|
91
|
+
};
|
92
|
+
|
93
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,121 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DPoTrs {
|
4
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
5
|
+
const double* a, lapack_int lda, double* b, lapack_int ldb) {
|
6
|
+
return LAPACKE_dpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct SPoTrs {
|
11
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
12
|
+
const float* a, lapack_int lda, float* b, lapack_int ldb) {
|
13
|
+
return LAPACKE_spotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
struct ZPoTrs {
|
18
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
19
|
+
const lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb) {
|
20
|
+
return LAPACKE_zpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
struct CPoTrs {
|
25
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
26
|
+
const lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb) {
|
27
|
+
return LAPACKE_cpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
28
|
+
}
|
29
|
+
};
|
30
|
+
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
32
|
+
class PoTrs {
|
33
|
+
public:
|
34
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_potrs), -1);
|
36
|
+
}
|
37
|
+
|
38
|
+
private:
|
39
|
+
struct potrs_opt {
|
40
|
+
int matrix_layout;
|
41
|
+
char uplo;
|
42
|
+
};
|
43
|
+
|
44
|
+
static void iter_potrs(na_loop_t* const lp) {
|
45
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
46
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
47
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
48
|
+
potrs_opt* opt = (potrs_opt*)(lp->opt_ptr);
|
49
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
50
|
+
const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
|
51
|
+
const lapack_int lda = n;
|
52
|
+
const lapack_int ldb = nrhs;
|
53
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, n, nrhs, a, lda, b, ldb);
|
54
|
+
*info = static_cast<int>(i);
|
55
|
+
}
|
56
|
+
|
57
|
+
static VALUE tiny_linalg_potrs(int argc, VALUE* argv, VALUE self) {
|
58
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
59
|
+
|
60
|
+
VALUE a_vnary = Qnil;
|
61
|
+
VALUE b_vnary = Qnil;
|
62
|
+
VALUE kw_args = Qnil;
|
63
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
64
|
+
ID kw_table[2] = { rb_intern("order"), rb_intern("uplo") };
|
65
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
66
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
67
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
68
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
69
|
+
|
70
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
71
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
72
|
+
}
|
73
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
74
|
+
a_vnary = nary_dup(a_vnary);
|
75
|
+
}
|
76
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
77
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
78
|
+
}
|
79
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
80
|
+
b_vnary = nary_dup(b_vnary);
|
81
|
+
}
|
82
|
+
|
83
|
+
narray_t* a_nary = NULL;
|
84
|
+
GetNArray(a_vnary, a_nary);
|
85
|
+
if (NA_NDIM(a_nary) != 2) {
|
86
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
87
|
+
return Qnil;
|
88
|
+
}
|
89
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
90
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
91
|
+
return Qnil;
|
92
|
+
}
|
93
|
+
narray_t* b_nary = NULL;
|
94
|
+
GetNArray(b_vnary, b_nary);
|
95
|
+
const int b_n_dims = NA_NDIM(b_nary);
|
96
|
+
if (b_n_dims != 1 && b_n_dims != 2) {
|
97
|
+
rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
|
98
|
+
return Qnil;
|
99
|
+
}
|
100
|
+
|
101
|
+
lapack_int n = NA_SHAPE(a_nary)[0];
|
102
|
+
lapack_int nb = NA_SHAPE(b_nary)[0];
|
103
|
+
if (n != nb) {
|
104
|
+
rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb);
|
105
|
+
}
|
106
|
+
|
107
|
+
ndfunc_arg_in_t ain[2] = { { nary_dtype, 2 }, { OVERWRITE, b_n_dims } };
|
108
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
109
|
+
ndfunc_t ndf = { iter_potrs, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
110
|
+
potrs_opt opt = { matrix_layout, uplo };
|
111
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
112
|
+
VALUE ret = rb_ary_new3(2, b_vnary, res);
|
113
|
+
|
114
|
+
RB_GC_GUARD(a_vnary);
|
115
|
+
RB_GC_GUARD(b_vnary);
|
116
|
+
|
117
|
+
return ret;
|
118
|
+
}
|
119
|
+
};
|
120
|
+
|
121
|
+
} // namespace TinyLinalg
|
@@ -51,6 +51,8 @@
|
|
51
51
|
#include "lapack/hegvd.hpp"
|
52
52
|
#include "lapack/hegvx.hpp"
|
53
53
|
#include "lapack/orgqr.hpp"
|
54
|
+
#include "lapack/potrf.hpp"
|
55
|
+
#include "lapack/potrs.hpp"
|
54
56
|
#include "lapack/syev.hpp"
|
55
57
|
#include "lapack/syevd.hpp"
|
56
58
|
#include "lapack/syevr.hpp"
|
@@ -310,6 +312,14 @@ extern "C" void Init_tiny_linalg(void) {
|
|
310
312
|
TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
|
311
313
|
TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
|
312
314
|
TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
|
315
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrf>::define_module_function(rb_mTinyLinalgLapack, "dpotrf");
|
316
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrf>::define_module_function(rb_mTinyLinalgLapack, "spotrf");
|
317
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrf>::define_module_function(rb_mTinyLinalgLapack, "zpotrf");
|
318
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CPoTrf>::define_module_function(rb_mTinyLinalgLapack, "cpotrf");
|
319
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrs>::define_module_function(rb_mTinyLinalgLapack, "dpotrs");
|
320
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrs>::define_module_function(rb_mTinyLinalgLapack, "spotrs");
|
321
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrs>::define_module_function(rb_mTinyLinalgLapack, "zpotrs");
|
322
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CPoTrs>::define_module_function(rb_mTinyLinalgLapack, "cpotrs");
|
313
323
|
TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
|
314
324
|
TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
|
315
325
|
TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -95,6 +95,96 @@ module Numo
|
|
95
95
|
[vals, vecs]
|
96
96
|
end
|
97
97
|
|
98
|
+
# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
|
99
|
+
#
|
100
|
+
# @example
|
101
|
+
# require 'numo/tiny_linalg'
|
102
|
+
#
|
103
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
104
|
+
#
|
105
|
+
# s = Numo::DFloat.new(3, 3).rand - 0.5
|
106
|
+
# a = s.transpose.dot(s)
|
107
|
+
# u = Numo::Linalg.cholesky(a)
|
108
|
+
#
|
109
|
+
# pp u
|
110
|
+
# # =>
|
111
|
+
# # Numo::DFloat#shape=[3,3]
|
112
|
+
# # [[0.532006, 0.338183, -0.18036],
|
113
|
+
# # [0, 0.325153, 0.011721],
|
114
|
+
# # [0, 0, 0.436738]]
|
115
|
+
#
|
116
|
+
# pp (a - u.transpose.dot(u)).abs.max
|
117
|
+
# # => 1.3877787807814457e-17
|
118
|
+
#
|
119
|
+
# l = Numo::Linalg.cholesky(a, uplo: 'L')
|
120
|
+
#
|
121
|
+
# pp l
|
122
|
+
# # =>
|
123
|
+
# # Numo::DFloat#shape=[3,3]
|
124
|
+
# # [[0.532006, 0, 0],
|
125
|
+
# # [0.338183, 0.325153, 0],
|
126
|
+
# # [-0.18036, 0.011721, 0.436738]]
|
127
|
+
#
|
128
|
+
# pp (a - l.dot(l.transpose)).abs.max
|
129
|
+
# # => 1.3877787807814457e-17
|
130
|
+
#
|
131
|
+
# @param a [Numo::NArray] The n-by-n symmetric matrix.
|
132
|
+
# @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
|
133
|
+
# @return [Numo::NArray] The upper- or lower-triangular Cholesky factor of a.
|
134
|
+
def cholesky(a, uplo: 'U')
|
135
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
136
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
137
|
+
|
138
|
+
bchr = blas_char(a)
|
139
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
140
|
+
|
141
|
+
fnc = "#{bchr}potrf".to_sym
|
142
|
+
c, _info = Numo::TinyLinalg::Lapack.send(fnc, a.dup, uplo: uplo)
|
143
|
+
|
144
|
+
case uplo
|
145
|
+
when 'U'
|
146
|
+
c.triu
|
147
|
+
when 'L'
|
148
|
+
c.tril
|
149
|
+
else
|
150
|
+
raise ArgumentError, "invalid uplo: #{uplo}"
|
151
|
+
end
|
152
|
+
end
|
153
|
+
|
154
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` with the Cholesky factorization of `A`.
|
155
|
+
#
|
156
|
+
# @example
|
157
|
+
# require 'numo/tiny_linalg'
|
158
|
+
#
|
159
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
160
|
+
#
|
161
|
+
# s = Numo::DFloat.new(3, 3).rand - 0.5
|
162
|
+
# a = s.transpose.dot(s)
|
163
|
+
# u = Numo::Linalg.cholesky(a)
|
164
|
+
#
|
165
|
+
# b = Numo::DFloat.new(3).rand
|
166
|
+
# x = Numo::Linalg.cho_solve(u, b)
|
167
|
+
#
|
168
|
+
# puts (b - a.dot(x)).abs.max
|
169
|
+
# => 0.0
|
170
|
+
#
|
171
|
+
# @param a [Numo::NArray] The n-by-n cholesky factor.
|
172
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
173
|
+
# @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
|
174
|
+
# @return [Numo::NArray] The solution vector or matrix `X`.
|
175
|
+
def cho_solve(a, b, uplo: 'U')
|
176
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
177
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
178
|
+
raise ArgumentError, "incompatible dimensions: a.shape[0] = #{a.shape[0]} != b.shape[0] = #{b.shape[0]}" if a.shape[0] != b.shape[0]
|
179
|
+
|
180
|
+
bchr = blas_char(a, b)
|
181
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
182
|
+
|
183
|
+
fnc = "#{bchr}potrs".to_sym
|
184
|
+
x, _info = Numo::TinyLinalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
|
185
|
+
x
|
186
|
+
end
|
187
|
+
|
98
188
|
# Computes the determinant of matrix.
|
99
189
|
#
|
100
190
|
# @example
|
@@ -271,7 +361,7 @@ module Numo
|
|
271
361
|
[q, r]
|
272
362
|
end
|
273
363
|
|
274
|
-
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `
|
364
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `A`.
|
275
365
|
#
|
276
366
|
# @example
|
277
367
|
# require 'numo/tiny_linalg'
|
@@ -294,10 +384,10 @@ module Numo
|
|
294
384
|
# # => 2.1081041547796492e-16
|
295
385
|
#
|
296
386
|
# @param a [Numo::NArray] The n-by-n square matrix.
|
297
|
-
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix
|
387
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
298
388
|
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
299
389
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
300
|
-
# @return [Numo::NArray] The solusion vector / matrix `
|
390
|
+
# @return [Numo::NArray] The solusion vector / matrix `X`.
|
301
391
|
def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
302
392
|
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
303
393
|
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
@@ -368,5 +458,110 @@ module Numo
|
|
368
458
|
|
369
459
|
[s, u, vt]
|
370
460
|
end
|
461
|
+
|
462
|
+
# @!visibility private
|
463
|
+
def matmul(*args)
|
464
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
465
|
+
end
|
466
|
+
|
467
|
+
# @!visibility private
|
468
|
+
def matrix_power(*args)
|
469
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
470
|
+
end
|
471
|
+
|
472
|
+
# @!visibility private
|
473
|
+
def svdvals(*args)
|
474
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
475
|
+
end
|
476
|
+
|
477
|
+
# @!visibility private
|
478
|
+
def orth(*args)
|
479
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
480
|
+
end
|
481
|
+
|
482
|
+
# @!visibility private
|
483
|
+
def null_space(*args)
|
484
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
485
|
+
end
|
486
|
+
|
487
|
+
# @!visibility private
|
488
|
+
def lu(*args)
|
489
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
490
|
+
end
|
491
|
+
|
492
|
+
# @!visibility private
|
493
|
+
def lu_fact(*args)
|
494
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
495
|
+
end
|
496
|
+
|
497
|
+
# @!visibility private
|
498
|
+
def lu_inv(*args)
|
499
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
500
|
+
end
|
501
|
+
|
502
|
+
# @!visibility private
|
503
|
+
def lu_solve(*args)
|
504
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
505
|
+
end
|
506
|
+
|
507
|
+
# @!visibility private
|
508
|
+
def ldl(*args)
|
509
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
510
|
+
end
|
511
|
+
|
512
|
+
# @!visibility private
|
513
|
+
def cho_fact(*args)
|
514
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
515
|
+
end
|
516
|
+
|
517
|
+
# @!visibility private
|
518
|
+
def cho_inv(*args)
|
519
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
520
|
+
end
|
521
|
+
|
522
|
+
# @!visibility private
|
523
|
+
def eig(*args)
|
524
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
525
|
+
end
|
526
|
+
|
527
|
+
# @!visibility private
|
528
|
+
def eigvals(*args)
|
529
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
530
|
+
end
|
531
|
+
|
532
|
+
# @!visibility private
|
533
|
+
def eigvalsh(*args)
|
534
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
535
|
+
end
|
536
|
+
|
537
|
+
# @!visibility private
|
538
|
+
def norm(*args)
|
539
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
540
|
+
end
|
541
|
+
|
542
|
+
# @!visibility private
|
543
|
+
def cond(*args)
|
544
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
545
|
+
end
|
546
|
+
|
547
|
+
# @!visibility private
|
548
|
+
def slogdet(*args)
|
549
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
550
|
+
end
|
551
|
+
|
552
|
+
# @!visibility private
|
553
|
+
def matrix_rank(*args)
|
554
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
555
|
+
end
|
556
|
+
|
557
|
+
# @!visibility private
|
558
|
+
def lstsq(*args)
|
559
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
560
|
+
end
|
561
|
+
|
562
|
+
# @!visibility private
|
563
|
+
def expm(*args)
|
564
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
565
|
+
end
|
371
566
|
end
|
372
567
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-08-
|
11
|
+
date: 2023-08-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -59,6 +59,8 @@ files:
|
|
59
59
|
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
60
60
|
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
61
61
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
62
|
+
- ext/numo/tiny_linalg/lapack/potrf.hpp
|
63
|
+
- ext/numo/tiny_linalg/lapack/potrs.hpp
|
62
64
|
- ext/numo/tiny_linalg/lapack/syev.hpp
|
63
65
|
- ext/numo/tiny_linalg/lapack/syevd.hpp
|
64
66
|
- ext/numo/tiny_linalg/lapack/syevr.hpp
|