numo-tiny_linalg 0.2.0 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/numo/tiny_linalg/lapack/potrf.hpp +93 -0
- data/ext/numo/tiny_linalg/lapack/potrs.hpp +121 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +10 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +198 -3
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: dd5b3d90b9c0bc323420f2eed5363490ec7c2ec4a3b0950573f54f234d534eaf
|
4
|
+
data.tar.gz: cd0a71621d1e3e935faba0a40036b3950c0f6ac32a74cdaad46b0e313b5dfa2f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f904406ce9c883d59d93afc061bfc65c96073dc9660cb65e6eefa97e2be32cc8c79ac26970a8b181cd3acbdbd412cf97c08153017956c4cc4f2b29bf28bcda46
|
7
|
+
data.tar.gz: 8f3ee9b0e0e9c3789f61148db03d43c245dab66fc4cb428488cbe668dad00b0d7fb73ef825e5dc0c5772ba817d41bb90d5fce843ece1f1adf6fe81857b9917ff
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,11 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.3.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.2.0...v0.3.0)] - 2023-08-13
|
4
|
+
- Add cholesky and cho_solve module functions to TinyLinalg.
|
5
|
+
|
6
|
+
**Breaking change**
|
7
|
+
- Change to raise NotImplementedError when calling a method not yet implemented in Numo::TinyLinalg.
|
8
|
+
|
3
9
|
## [[0.2.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.2...v0.2.0)] - 2023-08-11
|
4
10
|
**Breaking change**
|
5
11
|
- Change LAPACK function to call when array b is not given to TinyLinalg.eigh method.
|
data/README.md
CHANGED
@@ -6,7 +6,7 @@
|
|
6
6
|
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/numo-tiny_linalg/doc/)
|
7
7
|
|
8
8
|
Numo::TinyLinalg is a subset library from [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) consisting only of methods used in Machine Learning algorithms.
|
9
|
-
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, and svd.
|
9
|
+
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, cholesky, cho_solve and svd.
|
10
10
|
|
11
11
|
Note that the version numbering rule of Numo::TinyLinalg is not compatible with that of Numo::Linalg.
|
12
12
|
|
@@ -0,0 +1,93 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DPoTrf {
|
4
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, double* a, lapack_int lda) {
|
5
|
+
return LAPACKE_dpotrf(matrix_layout, uplo, n, a, lda);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SPoTrf {
|
10
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, float* a, lapack_int lda) {
|
11
|
+
return LAPACKE_spotrf(matrix_layout, uplo, n, a, lda);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
struct ZPoTrf {
|
16
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda) {
|
17
|
+
return LAPACKE_zpotrf(matrix_layout, uplo, n, a, lda);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
struct CPoTrf {
|
22
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda) {
|
23
|
+
return LAPACKE_cpotrf(matrix_layout, uplo, n, a, lda);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
28
|
+
class PoTrf {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_potrf), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct potrf_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
char uplo;
|
38
|
+
};
|
39
|
+
|
40
|
+
static void iter_potrf(na_loop_t* const lp) {
|
41
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
42
|
+
int* info = (int*)NDL_PTR(lp, 1);
|
43
|
+
potrf_opt* opt = (potrf_opt*)(lp->opt_ptr);
|
44
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
45
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[1];
|
46
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, n, a, lda);
|
47
|
+
*info = static_cast<int>(i);
|
48
|
+
}
|
49
|
+
|
50
|
+
static VALUE tiny_linalg_potrf(int argc, VALUE* argv, VALUE self) {
|
51
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
52
|
+
|
53
|
+
VALUE a_vnary = Qnil;
|
54
|
+
VALUE kw_args = Qnil;
|
55
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
56
|
+
ID kw_table[2] = { rb_intern("order"), rb_intern("uplo") };
|
57
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
58
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
59
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
60
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
|
69
|
+
narray_t* a_nary = NULL;
|
70
|
+
GetNArray(a_vnary, a_nary);
|
71
|
+
if (NA_NDIM(a_nary) != 2) {
|
72
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
73
|
+
return Qnil;
|
74
|
+
}
|
75
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
76
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
77
|
+
return Qnil;
|
78
|
+
}
|
79
|
+
|
80
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
81
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
82
|
+
ndfunc_t ndf = { iter_potrf, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
83
|
+
potrf_opt opt = { matrix_layout, uplo };
|
84
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
85
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res);
|
86
|
+
|
87
|
+
RB_GC_GUARD(a_vnary);
|
88
|
+
|
89
|
+
return ret;
|
90
|
+
}
|
91
|
+
};
|
92
|
+
|
93
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,121 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DPoTrs {
|
4
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
5
|
+
const double* a, lapack_int lda, double* b, lapack_int ldb) {
|
6
|
+
return LAPACKE_dpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct SPoTrs {
|
11
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
12
|
+
const float* a, lapack_int lda, float* b, lapack_int ldb) {
|
13
|
+
return LAPACKE_spotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
struct ZPoTrs {
|
18
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
19
|
+
const lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb) {
|
20
|
+
return LAPACKE_zpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
struct CPoTrs {
|
25
|
+
lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
|
26
|
+
const lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb) {
|
27
|
+
return LAPACKE_cpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
|
28
|
+
}
|
29
|
+
};
|
30
|
+
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
32
|
+
class PoTrs {
|
33
|
+
public:
|
34
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_potrs), -1);
|
36
|
+
}
|
37
|
+
|
38
|
+
private:
|
39
|
+
struct potrs_opt {
|
40
|
+
int matrix_layout;
|
41
|
+
char uplo;
|
42
|
+
};
|
43
|
+
|
44
|
+
static void iter_potrs(na_loop_t* const lp) {
|
45
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
46
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
47
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
48
|
+
potrs_opt* opt = (potrs_opt*)(lp->opt_ptr);
|
49
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
50
|
+
const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
|
51
|
+
const lapack_int lda = n;
|
52
|
+
const lapack_int ldb = nrhs;
|
53
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, n, nrhs, a, lda, b, ldb);
|
54
|
+
*info = static_cast<int>(i);
|
55
|
+
}
|
56
|
+
|
57
|
+
static VALUE tiny_linalg_potrs(int argc, VALUE* argv, VALUE self) {
|
58
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
59
|
+
|
60
|
+
VALUE a_vnary = Qnil;
|
61
|
+
VALUE b_vnary = Qnil;
|
62
|
+
VALUE kw_args = Qnil;
|
63
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
64
|
+
ID kw_table[2] = { rb_intern("order"), rb_intern("uplo") };
|
65
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
66
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
67
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
68
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
69
|
+
|
70
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
71
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
72
|
+
}
|
73
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
74
|
+
a_vnary = nary_dup(a_vnary);
|
75
|
+
}
|
76
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
77
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
78
|
+
}
|
79
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
80
|
+
b_vnary = nary_dup(b_vnary);
|
81
|
+
}
|
82
|
+
|
83
|
+
narray_t* a_nary = NULL;
|
84
|
+
GetNArray(a_vnary, a_nary);
|
85
|
+
if (NA_NDIM(a_nary) != 2) {
|
86
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
87
|
+
return Qnil;
|
88
|
+
}
|
89
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
90
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
91
|
+
return Qnil;
|
92
|
+
}
|
93
|
+
narray_t* b_nary = NULL;
|
94
|
+
GetNArray(b_vnary, b_nary);
|
95
|
+
const int b_n_dims = NA_NDIM(b_nary);
|
96
|
+
if (b_n_dims != 1 && b_n_dims != 2) {
|
97
|
+
rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
|
98
|
+
return Qnil;
|
99
|
+
}
|
100
|
+
|
101
|
+
lapack_int n = NA_SHAPE(a_nary)[0];
|
102
|
+
lapack_int nb = NA_SHAPE(b_nary)[0];
|
103
|
+
if (n != nb) {
|
104
|
+
rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb);
|
105
|
+
}
|
106
|
+
|
107
|
+
ndfunc_arg_in_t ain[2] = { { nary_dtype, 2 }, { OVERWRITE, b_n_dims } };
|
108
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
109
|
+
ndfunc_t ndf = { iter_potrs, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
110
|
+
potrs_opt opt = { matrix_layout, uplo };
|
111
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
112
|
+
VALUE ret = rb_ary_new3(2, b_vnary, res);
|
113
|
+
|
114
|
+
RB_GC_GUARD(a_vnary);
|
115
|
+
RB_GC_GUARD(b_vnary);
|
116
|
+
|
117
|
+
return ret;
|
118
|
+
}
|
119
|
+
};
|
120
|
+
|
121
|
+
} // namespace TinyLinalg
|
@@ -51,6 +51,8 @@
|
|
51
51
|
#include "lapack/hegvd.hpp"
|
52
52
|
#include "lapack/hegvx.hpp"
|
53
53
|
#include "lapack/orgqr.hpp"
|
54
|
+
#include "lapack/potrf.hpp"
|
55
|
+
#include "lapack/potrs.hpp"
|
54
56
|
#include "lapack/syev.hpp"
|
55
57
|
#include "lapack/syevd.hpp"
|
56
58
|
#include "lapack/syevr.hpp"
|
@@ -310,6 +312,14 @@ extern "C" void Init_tiny_linalg(void) {
|
|
310
312
|
TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
|
311
313
|
TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
|
312
314
|
TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
|
315
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrf>::define_module_function(rb_mTinyLinalgLapack, "dpotrf");
|
316
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrf>::define_module_function(rb_mTinyLinalgLapack, "spotrf");
|
317
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrf>::define_module_function(rb_mTinyLinalgLapack, "zpotrf");
|
318
|
+
TinyLinalg::PoTrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CPoTrf>::define_module_function(rb_mTinyLinalgLapack, "cpotrf");
|
319
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrs>::define_module_function(rb_mTinyLinalgLapack, "dpotrs");
|
320
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrs>::define_module_function(rb_mTinyLinalgLapack, "spotrs");
|
321
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrs>::define_module_function(rb_mTinyLinalgLapack, "zpotrs");
|
322
|
+
TinyLinalg::PoTrs<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CPoTrs>::define_module_function(rb_mTinyLinalgLapack, "cpotrs");
|
313
323
|
TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
|
314
324
|
TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
|
315
325
|
TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -95,6 +95,96 @@ module Numo
|
|
95
95
|
[vals, vecs]
|
96
96
|
end
|
97
97
|
|
98
|
+
# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
|
99
|
+
#
|
100
|
+
# @example
|
101
|
+
# require 'numo/tiny_linalg'
|
102
|
+
#
|
103
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
104
|
+
#
|
105
|
+
# s = Numo::DFloat.new(3, 3).rand - 0.5
|
106
|
+
# a = s.transpose.dot(s)
|
107
|
+
# u = Numo::Linalg.cholesky(a)
|
108
|
+
#
|
109
|
+
# pp u
|
110
|
+
# # =>
|
111
|
+
# # Numo::DFloat#shape=[3,3]
|
112
|
+
# # [[0.532006, 0.338183, -0.18036],
|
113
|
+
# # [0, 0.325153, 0.011721],
|
114
|
+
# # [0, 0, 0.436738]]
|
115
|
+
#
|
116
|
+
# pp (a - u.transpose.dot(u)).abs.max
|
117
|
+
# # => 1.3877787807814457e-17
|
118
|
+
#
|
119
|
+
# l = Numo::Linalg.cholesky(a, uplo: 'L')
|
120
|
+
#
|
121
|
+
# pp l
|
122
|
+
# # =>
|
123
|
+
# # Numo::DFloat#shape=[3,3]
|
124
|
+
# # [[0.532006, 0, 0],
|
125
|
+
# # [0.338183, 0.325153, 0],
|
126
|
+
# # [-0.18036, 0.011721, 0.436738]]
|
127
|
+
#
|
128
|
+
# pp (a - l.dot(l.transpose)).abs.max
|
129
|
+
# # => 1.3877787807814457e-17
|
130
|
+
#
|
131
|
+
# @param a [Numo::NArray] The n-by-n symmetric matrix.
|
132
|
+
# @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
|
133
|
+
# @return [Numo::NArray] The upper- or lower-triangular Cholesky factor of a.
|
134
|
+
def cholesky(a, uplo: 'U')
|
135
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
136
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
137
|
+
|
138
|
+
bchr = blas_char(a)
|
139
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
140
|
+
|
141
|
+
fnc = "#{bchr}potrf".to_sym
|
142
|
+
c, _info = Numo::TinyLinalg::Lapack.send(fnc, a.dup, uplo: uplo)
|
143
|
+
|
144
|
+
case uplo
|
145
|
+
when 'U'
|
146
|
+
c.triu
|
147
|
+
when 'L'
|
148
|
+
c.tril
|
149
|
+
else
|
150
|
+
raise ArgumentError, "invalid uplo: #{uplo}"
|
151
|
+
end
|
152
|
+
end
|
153
|
+
|
154
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` with the Cholesky factorization of `A`.
|
155
|
+
#
|
156
|
+
# @example
|
157
|
+
# require 'numo/tiny_linalg'
|
158
|
+
#
|
159
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
160
|
+
#
|
161
|
+
# s = Numo::DFloat.new(3, 3).rand - 0.5
|
162
|
+
# a = s.transpose.dot(s)
|
163
|
+
# u = Numo::Linalg.cholesky(a)
|
164
|
+
#
|
165
|
+
# b = Numo::DFloat.new(3).rand
|
166
|
+
# x = Numo::Linalg.cho_solve(u, b)
|
167
|
+
#
|
168
|
+
# puts (b - a.dot(x)).abs.max
|
169
|
+
# => 0.0
|
170
|
+
#
|
171
|
+
# @param a [Numo::NArray] The n-by-n cholesky factor.
|
172
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
173
|
+
# @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
|
174
|
+
# @return [Numo::NArray] The solution vector or matrix `X`.
|
175
|
+
def cho_solve(a, b, uplo: 'U')
|
176
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
177
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
178
|
+
raise ArgumentError, "incompatible dimensions: a.shape[0] = #{a.shape[0]} != b.shape[0] = #{b.shape[0]}" if a.shape[0] != b.shape[0]
|
179
|
+
|
180
|
+
bchr = blas_char(a, b)
|
181
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
182
|
+
|
183
|
+
fnc = "#{bchr}potrs".to_sym
|
184
|
+
x, _info = Numo::TinyLinalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
|
185
|
+
x
|
186
|
+
end
|
187
|
+
|
98
188
|
# Computes the determinant of matrix.
|
99
189
|
#
|
100
190
|
# @example
|
@@ -271,7 +361,7 @@ module Numo
|
|
271
361
|
[q, r]
|
272
362
|
end
|
273
363
|
|
274
|
-
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `
|
364
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `A`.
|
275
365
|
#
|
276
366
|
# @example
|
277
367
|
# require 'numo/tiny_linalg'
|
@@ -294,10 +384,10 @@ module Numo
|
|
294
384
|
# # => 2.1081041547796492e-16
|
295
385
|
#
|
296
386
|
# @param a [Numo::NArray] The n-by-n square matrix.
|
297
|
-
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix
|
387
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
298
388
|
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
299
389
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
300
|
-
# @return [Numo::NArray] The solusion vector / matrix `
|
390
|
+
# @return [Numo::NArray] The solusion vector / matrix `X`.
|
301
391
|
def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
302
392
|
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
303
393
|
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
@@ -368,5 +458,110 @@ module Numo
|
|
368
458
|
|
369
459
|
[s, u, vt]
|
370
460
|
end
|
461
|
+
|
462
|
+
# @!visibility private
|
463
|
+
def matmul(*args)
|
464
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
465
|
+
end
|
466
|
+
|
467
|
+
# @!visibility private
|
468
|
+
def matrix_power(*args)
|
469
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
470
|
+
end
|
471
|
+
|
472
|
+
# @!visibility private
|
473
|
+
def svdvals(*args)
|
474
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
475
|
+
end
|
476
|
+
|
477
|
+
# @!visibility private
|
478
|
+
def orth(*args)
|
479
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
480
|
+
end
|
481
|
+
|
482
|
+
# @!visibility private
|
483
|
+
def null_space(*args)
|
484
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
485
|
+
end
|
486
|
+
|
487
|
+
# @!visibility private
|
488
|
+
def lu(*args)
|
489
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
490
|
+
end
|
491
|
+
|
492
|
+
# @!visibility private
|
493
|
+
def lu_fact(*args)
|
494
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
495
|
+
end
|
496
|
+
|
497
|
+
# @!visibility private
|
498
|
+
def lu_inv(*args)
|
499
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
500
|
+
end
|
501
|
+
|
502
|
+
# @!visibility private
|
503
|
+
def lu_solve(*args)
|
504
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
505
|
+
end
|
506
|
+
|
507
|
+
# @!visibility private
|
508
|
+
def ldl(*args)
|
509
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
510
|
+
end
|
511
|
+
|
512
|
+
# @!visibility private
|
513
|
+
def cho_fact(*args)
|
514
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
515
|
+
end
|
516
|
+
|
517
|
+
# @!visibility private
|
518
|
+
def cho_inv(*args)
|
519
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
520
|
+
end
|
521
|
+
|
522
|
+
# @!visibility private
|
523
|
+
def eig(*args)
|
524
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
525
|
+
end
|
526
|
+
|
527
|
+
# @!visibility private
|
528
|
+
def eigvals(*args)
|
529
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
530
|
+
end
|
531
|
+
|
532
|
+
# @!visibility private
|
533
|
+
def eigvalsh(*args)
|
534
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
535
|
+
end
|
536
|
+
|
537
|
+
# @!visibility private
|
538
|
+
def norm(*args)
|
539
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
540
|
+
end
|
541
|
+
|
542
|
+
# @!visibility private
|
543
|
+
def cond(*args)
|
544
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
545
|
+
end
|
546
|
+
|
547
|
+
# @!visibility private
|
548
|
+
def slogdet(*args)
|
549
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
550
|
+
end
|
551
|
+
|
552
|
+
# @!visibility private
|
553
|
+
def matrix_rank(*args)
|
554
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
555
|
+
end
|
556
|
+
|
557
|
+
# @!visibility private
|
558
|
+
def lstsq(*args)
|
559
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
560
|
+
end
|
561
|
+
|
562
|
+
# @!visibility private
|
563
|
+
def expm(*args)
|
564
|
+
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
565
|
+
end
|
371
566
|
end
|
372
567
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-08-
|
11
|
+
date: 2023-08-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -59,6 +59,8 @@ files:
|
|
59
59
|
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
60
60
|
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
61
61
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
62
|
+
- ext/numo/tiny_linalg/lapack/potrf.hpp
|
63
|
+
- ext/numo/tiny_linalg/lapack/potrs.hpp
|
62
64
|
- ext/numo/tiny_linalg/lapack/syev.hpp
|
63
65
|
- ext/numo/tiny_linalg/lapack/syevd.hpp
|
64
66
|
- ext/numo/tiny_linalg/lapack/syevr.hpp
|