numo-tiny_linalg 0.3.5 → 0.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/numo/tiny_linalg/extconf.rb +1 -2
- data/ext/numo/tiny_linalg/lapack/lange.hpp +89 -0
- data/ext/numo/tiny_linalg/lapack/trtrs.hpp +126 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +10 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +207 -5
- metadata +5 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c508a2c2f31a965c26824084f1eeb1a2caf53ea82a3d23e523980d6cf78e6259
|
4
|
+
data.tar.gz: 4bffa14a1a73bfb60230c65fa35c98eeb957e176b8af2003f9c3e5b8aaf89b19
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4ad7a9af39fdc0c5ffb9ba7fc51dcf037fe8ffe7d81c33c70a0c669b384127929efee3a9ceb8440b7d1ec10c79ee66b362aed7098befbe02a892d4490961f169
|
7
|
+
data.tar.gz: 91c9935714ef5b929718affc9cfae518f6e08f85063f673f26ffa7c137bb7fe62d41c20386704657dccb89feebf57bdaed5b182750c6f0fe3ad5b20e55cc4221
|
data/CHANGELOG.md
CHANGED
@@ -1,4 +1,12 @@
|
|
1
1
|
## [Unreleased]
|
2
|
+
## [[0.3.6](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.5...v0.3.6)] - 2024-01-28
|
3
|
+
|
4
|
+
- Add solve_triangular module function to TinyLinalg.
|
5
|
+
- The solve_triangular is not implemented in Numo::Linalg, but I have implemented it because it uses some machine learning algorithms.
|
6
|
+
- Add dtrtrs, strtrs, ztrtrs, and ctrtrs module functions to TinyLinalg::Lapack.
|
7
|
+
- Add norm module function to TinyLinalg.
|
8
|
+
- Add dlange, slange, zlange, and clange module functions to TinyLinalg::Lapack.
|
9
|
+
|
2
10
|
## [[0.3.5](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.4...v0.3.5)] - 2024-01-03
|
3
11
|
- Bump OpenBLAS to be downloaded from 0.3.25 to 0.3.26.
|
4
12
|
- Minor changes using RuboCop.
|
@@ -83,8 +83,7 @@ if build_openblas
|
|
83
83
|
end
|
84
84
|
end
|
85
85
|
|
86
|
-
|
87
|
-
abort('libopenblas is not found.') unless find_library('openblas', nil, libopenblas_dir)
|
86
|
+
abort('libopenblas is not found.') unless find_library('openblas', nil, "#{VENDOR_DIR}/lib")
|
88
87
|
abort('openblas_config.h is not found.') unless find_header('openblas_config.h', nil, "#{VENDOR_DIR}/include")
|
89
88
|
abort('cblas.h is not found.') unless find_header('cblas.h', nil, "#{VENDOR_DIR}/include")
|
90
89
|
abort('lapacke.h is not found.') unless find_header('lapacke.h', nil, "#{VENDOR_DIR}/include")
|
@@ -0,0 +1,89 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DLanGe {
|
4
|
+
double call(int matrix_layout, char norm, lapack_int m, lapack_int n, const double* a, lapack_int lda) {
|
5
|
+
return LAPACKE_dlange(matrix_layout, norm, m, n, a, lda);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SLanGe {
|
10
|
+
float call(int matrix_layout, char norm, lapack_int m, lapack_int n, const float* a, lapack_int lda) {
|
11
|
+
return LAPACKE_slange(matrix_layout, norm, m, n, a, lda);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
struct ZLanGe {
|
16
|
+
double call(int matrix_layout, char norm, lapack_int m, lapack_int n, const lapack_complex_double* a, lapack_int lda) {
|
17
|
+
return LAPACKE_zlange(matrix_layout, norm, m, n, a, lda);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
struct CLanGe {
|
22
|
+
float call(int matrix_layout, char norm, lapack_int m, lapack_int n, const lapack_complex_float* a, lapack_int lda) {
|
23
|
+
return LAPACKE_clange(matrix_layout, norm, m, n, a, lda);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
28
|
+
class LanGe {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_lange), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct lange_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
char norm;
|
38
|
+
};
|
39
|
+
|
40
|
+
static void iter_lange(na_loop_t* const lp) {
|
41
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
42
|
+
dtype* d = (dtype*)NDL_PTR(lp, 1);
|
43
|
+
lange_opt* opt = (lange_opt*)(lp->opt_ptr);
|
44
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
45
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
46
|
+
const lapack_int lda = n;
|
47
|
+
*d = LapackFn().call(opt->matrix_layout, opt->norm, m, n, a, lda);
|
48
|
+
}
|
49
|
+
|
50
|
+
static VALUE tiny_linalg_lange(int argc, VALUE* argv, VALUE self) {
|
51
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
52
|
+
|
53
|
+
VALUE a_vnary = Qnil;
|
54
|
+
VALUE kw_args = Qnil;
|
55
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
56
|
+
ID kw_table[2] = { rb_intern("order"), rb_intern("norm") };
|
57
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
58
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
59
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
60
|
+
const char norm = kw_values[1] != Qundef ? NUM2CHR(kw_values[1]) : 'F';
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
|
69
|
+
narray_t* a_nary = NULL;
|
70
|
+
GetNArray(a_vnary, a_nary);
|
71
|
+
if (NA_NDIM(a_nary) != 2) {
|
72
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
73
|
+
return Qnil;
|
74
|
+
}
|
75
|
+
|
76
|
+
ndfunc_arg_in_t ain[1] = { { nary_dtype, 2 } };
|
77
|
+
size_t shape_out[1] = { 1 };
|
78
|
+
ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
|
79
|
+
ndfunc_t ndf = { iter_lange, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
80
|
+
lange_opt opt = { matrix_layout, norm };
|
81
|
+
VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
82
|
+
|
83
|
+
RB_GC_GUARD(a_vnary);
|
84
|
+
|
85
|
+
return ret;
|
86
|
+
}
|
87
|
+
};
|
88
|
+
|
89
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,126 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DTrTrs {
|
4
|
+
lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
|
5
|
+
const double* a, lapack_int lda, double* b, lapack_int ldb) {
|
6
|
+
return LAPACKE_dtrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct STrTrs {
|
11
|
+
lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
|
12
|
+
const float* a, lapack_int lda, float* b, lapack_int ldb) {
|
13
|
+
return LAPACKE_strtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
struct ZTrTrs {
|
18
|
+
lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
|
19
|
+
const lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb) {
|
20
|
+
return LAPACKE_ztrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
struct CTrTrs {
|
25
|
+
lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
|
26
|
+
const lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb) {
|
27
|
+
return LAPACKE_ctrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
|
28
|
+
}
|
29
|
+
};
|
30
|
+
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
32
|
+
class TrTrs {
|
33
|
+
public:
|
34
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_trtrs), -1);
|
36
|
+
}
|
37
|
+
|
38
|
+
private:
|
39
|
+
struct trtrs_opt {
|
40
|
+
int matrix_layout;
|
41
|
+
char uplo;
|
42
|
+
char trans;
|
43
|
+
char diag;
|
44
|
+
};
|
45
|
+
|
46
|
+
static void iter_trtrs(na_loop_t* const lp) {
|
47
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
48
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
49
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
50
|
+
trtrs_opt* opt = (trtrs_opt*)(lp->opt_ptr);
|
51
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
52
|
+
const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
|
53
|
+
const lapack_int lda = n;
|
54
|
+
const lapack_int ldb = nrhs;
|
55
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, opt->trans, opt->diag, n, nrhs, a, lda, b, ldb);
|
56
|
+
*info = static_cast<int>(i);
|
57
|
+
}
|
58
|
+
|
59
|
+
static VALUE tiny_linalg_trtrs(int argc, VALUE* argv, VALUE self) {
|
60
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
61
|
+
|
62
|
+
VALUE a_vnary = Qnil;
|
63
|
+
VALUE b_vnary = Qnil;
|
64
|
+
VALUE kw_args = Qnil;
|
65
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
66
|
+
ID kw_table[4] = { rb_intern("order"), rb_intern("uplo"), rb_intern("trans"), rb_intern("diag") };
|
67
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
68
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
69
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
70
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
71
|
+
const char trans = kw_values[2] != Qundef ? NUM2CHR(kw_values[2]) : 'N';
|
72
|
+
const char diag = kw_values[3] != Qundef ? NUM2CHR(kw_values[3]) : 'N';
|
73
|
+
|
74
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
75
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
76
|
+
}
|
77
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
78
|
+
a_vnary = nary_dup(a_vnary);
|
79
|
+
}
|
80
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
81
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
82
|
+
}
|
83
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
84
|
+
b_vnary = nary_dup(b_vnary);
|
85
|
+
}
|
86
|
+
|
87
|
+
narray_t* a_nary = NULL;
|
88
|
+
GetNArray(a_vnary, a_nary);
|
89
|
+
if (NA_NDIM(a_nary) != 2) {
|
90
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
91
|
+
return Qnil;
|
92
|
+
}
|
93
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
94
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
95
|
+
return Qnil;
|
96
|
+
}
|
97
|
+
|
98
|
+
narray_t* b_nary = NULL;
|
99
|
+
GetNArray(b_vnary, b_nary);
|
100
|
+
const int b_n_dims = NA_NDIM(b_nary);
|
101
|
+
if (b_n_dims != 1 && b_n_dims != 2) {
|
102
|
+
rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
|
103
|
+
return Qnil;
|
104
|
+
}
|
105
|
+
|
106
|
+
lapack_int n = NA_SHAPE(a_nary)[0];
|
107
|
+
lapack_int nb = NA_SHAPE(b_nary)[0];
|
108
|
+
if (n != nb) {
|
109
|
+
rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb);
|
110
|
+
}
|
111
|
+
|
112
|
+
ndfunc_arg_in_t ain[2] = { { nary_dtype, 2 }, { OVERWRITE, b_n_dims } };
|
113
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
114
|
+
ndfunc_t ndf = { iter_trtrs, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
115
|
+
trtrs_opt opt = { matrix_layout, uplo, trans, diag };
|
116
|
+
VALUE info = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
117
|
+
VALUE ret = rb_ary_new3(2, b_vnary, info);
|
118
|
+
|
119
|
+
RB_GC_GUARD(a_vnary);
|
120
|
+
RB_GC_GUARD(b_vnary);
|
121
|
+
|
122
|
+
return ret;
|
123
|
+
}
|
124
|
+
};
|
125
|
+
|
126
|
+
} // namespace TinyLinalg
|
@@ -50,6 +50,7 @@
|
|
50
50
|
#include "lapack/hegv.hpp"
|
51
51
|
#include "lapack/hegvd.hpp"
|
52
52
|
#include "lapack/hegvx.hpp"
|
53
|
+
#include "lapack/lange.hpp"
|
53
54
|
#include "lapack/orgqr.hpp"
|
54
55
|
#include "lapack/potrf.hpp"
|
55
56
|
#include "lapack/potrs.hpp"
|
@@ -59,6 +60,7 @@
|
|
59
60
|
#include "lapack/sygv.hpp"
|
60
61
|
#include "lapack/sygvd.hpp"
|
61
62
|
#include "lapack/sygvx.hpp"
|
63
|
+
#include "lapack/trtrs.hpp"
|
62
64
|
#include "lapack/ungqr.hpp"
|
63
65
|
|
64
66
|
VALUE rb_mTinyLinalg;
|
@@ -314,6 +316,10 @@ extern "C" void Init_tiny_linalg(void) {
|
|
314
316
|
TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
|
315
317
|
TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
|
316
318
|
TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
|
319
|
+
TinyLinalg::TrTrs<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DTrTrs>::define_module_function(rb_mTinyLinalgLapack, "dtrtrs");
|
320
|
+
TinyLinalg::TrTrs<TinyLinalg::numo_cSFloatId, float, TinyLinalg::STrTrs>::define_module_function(rb_mTinyLinalgLapack, "strtrs");
|
321
|
+
TinyLinalg::TrTrs<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZTrTrs>::define_module_function(rb_mTinyLinalgLapack, "ztrtrs");
|
322
|
+
TinyLinalg::TrTrs<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CTrTrs>::define_module_function(rb_mTinyLinalgLapack, "ctrtrs");
|
317
323
|
TinyLinalg::PoTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrf>::define_module_function(rb_mTinyLinalgLapack, "dpotrf");
|
318
324
|
TinyLinalg::PoTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrf>::define_module_function(rb_mTinyLinalgLapack, "spotrf");
|
319
325
|
TinyLinalg::PoTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrf>::define_module_function(rb_mTinyLinalgLapack, "zpotrf");
|
@@ -354,6 +360,10 @@ extern "C" void Init_tiny_linalg(void) {
|
|
354
360
|
TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
|
355
361
|
TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
|
356
362
|
TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
|
363
|
+
TinyLinalg::LanGe<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DLanGe>::define_module_function(rb_mTinyLinalgLapack, "dlange");
|
364
|
+
TinyLinalg::LanGe<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SLanGe>::define_module_function(rb_mTinyLinalgLapack, "slange");
|
365
|
+
TinyLinalg::LanGe<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZLanGe>::define_module_function(rb_mTinyLinalgLapack, "zlange");
|
366
|
+
TinyLinalg::LanGe<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CLanGe>::define_module_function(rb_mTinyLinalgLapack, "clange");
|
357
367
|
|
358
368
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
|
359
369
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -95,6 +95,172 @@ module Numo
|
|
95
95
|
[vals, vecs]
|
96
96
|
end
|
97
97
|
|
98
|
+
# Computes the matrix or vector norm.
|
99
|
+
#
|
100
|
+
# | ord | matrix norm | vector norm |
|
101
|
+
# | ----- | ---------------------- | --------------------------- |
|
102
|
+
# | nil | Frobenius norm | 2-norm |
|
103
|
+
# | 'fro' | Frobenius norm | - |
|
104
|
+
# | 'nuc' | nuclear norm | - |
|
105
|
+
# | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
|
106
|
+
# | 0 | - | (x.ne 0).sum |
|
107
|
+
# | 1 | x.abs.sum(axis:-2).max | same as below |
|
108
|
+
# | 2 | 2-norm (max sing_vals) | same as below |
|
109
|
+
# | other | - | (x.abs**ord).sum**(1.0/ord) |
|
110
|
+
#
|
111
|
+
# @example
|
112
|
+
# require 'numo/tiny_linalg'
|
113
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
114
|
+
#
|
115
|
+
# # matrix norm
|
116
|
+
# x = Numo::DFloat[[1, 2, -3, 1], [-4, 1, 8, 2]]
|
117
|
+
# pp Numo::Linalg.norm(x)
|
118
|
+
# # => 10
|
119
|
+
#
|
120
|
+
# # vector norm
|
121
|
+
# x = Numo::DFloat[3, -4]
|
122
|
+
# pp Numo::Linalg.norm(x)
|
123
|
+
# # => 5
|
124
|
+
#
|
125
|
+
# @param a [Numo::NArray] The matrix or vector (>= 1-dimensinal NArray)
|
126
|
+
# @param ord [String/Numeric] The order of the norm.
|
127
|
+
# @param axis [Integer/Array] The applied axes.
|
128
|
+
# @param keepdims [Bool] The flag indicating whether to leave the normed axes in the result as dimensions with size one.
|
129
|
+
# @return [Numo::NArray/Numeric] The norm of the matrix or vectors.
|
130
|
+
def norm(a, ord = nil, axis: nil, keepdims: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
131
|
+
a = Numo::NArray.asarray(a) unless a.is_a?(Numo::NArray)
|
132
|
+
|
133
|
+
return 0.0 if a.empty?
|
134
|
+
|
135
|
+
# for compatibility with Numo::Linalg.norm
|
136
|
+
if ord.is_a?(String)
|
137
|
+
if ord == 'inf'
|
138
|
+
ord = Float::INFINITY
|
139
|
+
elsif ord == '-inf'
|
140
|
+
ord = -Float::INFINITY
|
141
|
+
end
|
142
|
+
end
|
143
|
+
|
144
|
+
if axis.nil?
|
145
|
+
norm = case a.ndim
|
146
|
+
when 1
|
147
|
+
Numo::TinyLinalg::Blas.send(:"#{blas_char(a)}nrm2", a) if ord.nil? || ord == 2
|
148
|
+
when 2
|
149
|
+
if ord.nil? || ord == 'fro'
|
150
|
+
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'F')
|
151
|
+
elsif ord.is_a?(Numeric)
|
152
|
+
if ord == 1
|
153
|
+
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: '1')
|
154
|
+
elsif !ord.infinite?.nil? && ord.infinite?.positive?
|
155
|
+
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'I')
|
156
|
+
end
|
157
|
+
end
|
158
|
+
else
|
159
|
+
if ord.nil?
|
160
|
+
b = a.flatten.dup
|
161
|
+
Numo::TinyLinalg::Blas.send(:"#{blas_char(b)}nrm2", b)
|
162
|
+
end
|
163
|
+
end
|
164
|
+
unless norm.nil?
|
165
|
+
norm = Numo::NArray.asarray(norm).reshape(*([1] * a.ndim)) if keepdims
|
166
|
+
return norm
|
167
|
+
end
|
168
|
+
end
|
169
|
+
|
170
|
+
if axis.nil?
|
171
|
+
axis = Array.new(a.ndim) { |d| d }
|
172
|
+
else
|
173
|
+
case axis
|
174
|
+
when Integer
|
175
|
+
axis = [axis]
|
176
|
+
when Array, Numo::NArray
|
177
|
+
axis = axis.flatten.to_a
|
178
|
+
else
|
179
|
+
raise ArgumentError, "invalid axis: #{axis}"
|
180
|
+
end
|
181
|
+
end
|
182
|
+
|
183
|
+
raise ArgumentError, "the number of dimensions of axis is inappropriate for the norm: #{axis.size}" unless axis.size == 1 || axis.size == 2
|
184
|
+
raise ArgumentError, "axis is out of range: #{axis}" unless axis.all? { |ax| (-a.ndim...a.ndim).cover?(ax) }
|
185
|
+
|
186
|
+
if axis.size == 1
|
187
|
+
ord ||= 2
|
188
|
+
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(Numeric)
|
189
|
+
|
190
|
+
ord_inf = ord.infinite?
|
191
|
+
if ord_inf.nil?
|
192
|
+
case ord
|
193
|
+
when 0
|
194
|
+
a.class.cast(a.ne(0)).sum(axis: axis, keepdims: keepdims)
|
195
|
+
when 1
|
196
|
+
a.abs.sum(axis: axis, keepdims: keepdims)
|
197
|
+
else
|
198
|
+
(a.abs**ord).sum(axis: axis, keepdims: keepdims)**1.fdiv(ord)
|
199
|
+
end
|
200
|
+
elsif ord_inf.positive?
|
201
|
+
a.abs.max(axis: axis, keepdims: keepdims)
|
202
|
+
else
|
203
|
+
a.abs.min(axis: axis, keepdims: keepdims)
|
204
|
+
end
|
205
|
+
else
|
206
|
+
ord ||= 'fro'
|
207
|
+
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(String) || ord.is_a?(Numeric)
|
208
|
+
raise ArgumentError, "invalid axis: #{axis}" if axis.uniq.size == 1
|
209
|
+
|
210
|
+
r_axis, c_axis = axis.map { |ax| ax.negative? ? ax + a.ndim : ax }
|
211
|
+
|
212
|
+
norm = if ord.is_a?(String)
|
213
|
+
raise ArgumentError, "invalid ord: #{ord}" unless %w[fro nuc].include?(ord)
|
214
|
+
|
215
|
+
if ord == 'fro'
|
216
|
+
Numo::NMath.sqrt((a.abs**2).sum(axis: axis))
|
217
|
+
else
|
218
|
+
b = a.transpose(c_axis, r_axis).dup
|
219
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
220
|
+
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
221
|
+
s.sum(axis: -1)
|
222
|
+
end
|
223
|
+
else
|
224
|
+
ord_inf = ord.infinite?
|
225
|
+
if ord_inf.nil?
|
226
|
+
case ord
|
227
|
+
when -2
|
228
|
+
b = a.transpose(c_axis, r_axis).dup
|
229
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
230
|
+
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
231
|
+
s.min(axis: -1)
|
232
|
+
when -1
|
233
|
+
c_axis -= 1 if c_axis > r_axis
|
234
|
+
a.abs.sum(axis: r_axis).min(axis: c_axis)
|
235
|
+
when 1
|
236
|
+
c_axis -= 1 if c_axis > r_axis
|
237
|
+
a.abs.sum(axis: r_axis).max(axis: c_axis)
|
238
|
+
when 2
|
239
|
+
b = a.transpose(c_axis, r_axis).dup
|
240
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
241
|
+
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
242
|
+
s.max(axis: -1)
|
243
|
+
else
|
244
|
+
raise ArgumentError, "invalid ord: #{ord}"
|
245
|
+
end
|
246
|
+
else
|
247
|
+
r_axis -= 1 if r_axis > c_axis
|
248
|
+
if ord_inf.positive?
|
249
|
+
a.abs.sum(axis: c_axis).max(axis: r_axis)
|
250
|
+
else
|
251
|
+
a.abs.sum(axis: c_axis).min(axis: r_axis)
|
252
|
+
end
|
253
|
+
end
|
254
|
+
end
|
255
|
+
if keepdims
|
256
|
+
norm = Numo::NArray.asarray(norm) unless norm.is_a?(Numo::NArray)
|
257
|
+
norm = norm.reshape(*([1] * a.ndim))
|
258
|
+
end
|
259
|
+
|
260
|
+
norm
|
261
|
+
end
|
262
|
+
end
|
263
|
+
|
98
264
|
# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
|
99
265
|
#
|
100
266
|
# @example
|
@@ -399,6 +565,47 @@ module Numo
|
|
399
565
|
Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
|
400
566
|
end
|
401
567
|
|
568
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix.
|
569
|
+
#
|
570
|
+
# @example
|
571
|
+
# require 'numo/tiny_linalg'
|
572
|
+
#
|
573
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
574
|
+
#
|
575
|
+
# a = Numo::DFloat.new(3, 3).rand.triu
|
576
|
+
# b = Numo::DFloat.eye(3)
|
577
|
+
#
|
578
|
+
# x = Numo::Linalg.solve(a, b)
|
579
|
+
#
|
580
|
+
# pp x
|
581
|
+
# # =>
|
582
|
+
# # Numo::DFloat#shape=[3,3]
|
583
|
+
# # [[16.1932, -52.0604, 30.5283],
|
584
|
+
# # [0, 8.61765, -17.9585],
|
585
|
+
# # [0, 0, 6.05735]]
|
586
|
+
#
|
587
|
+
# pp (b - a.dot(x)).abs.max
|
588
|
+
# # => 4.071100642430302e-16
|
589
|
+
#
|
590
|
+
# @param a [Numo::NArray] The n-by-n triangular matrix.
|
591
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
592
|
+
# @param lower [Boolean] The flag indicating whether to use the lower-triangular part of `a`.
|
593
|
+
# @return [Numo::NArray] The solusion vector / matrix `X`.
|
594
|
+
def solve_triangular(a, b, lower: false)
|
595
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
596
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
597
|
+
|
598
|
+
bchr = blas_char(a, b)
|
599
|
+
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
|
600
|
+
|
601
|
+
trtrs = :"#{bchr}trtrs"
|
602
|
+
uplo = lower ? 'L' : 'U'
|
603
|
+
x, info = Numo::TinyLinalg::Lapack.send(trtrs, a, b.dup, uplo: uplo)
|
604
|
+
raise "wrong value is given to the #{info}-th argument of #{trtrs} used internally" if info.negative?
|
605
|
+
|
606
|
+
x
|
607
|
+
end
|
608
|
+
|
402
609
|
# Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
|
403
610
|
#
|
404
611
|
# @example
|
@@ -534,11 +741,6 @@ module Numo
|
|
534
741
|
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
535
742
|
end
|
536
743
|
|
537
|
-
# @!visibility private
|
538
|
-
def norm(*args)
|
539
|
-
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
540
|
-
end
|
541
|
-
|
542
744
|
# @!visibility private
|
543
745
|
def cond(*args)
|
544
746
|
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-01-
|
11
|
+
date: 2024-01-28 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -57,6 +57,7 @@ files:
|
|
57
57
|
- ext/numo/tiny_linalg/lapack/hegv.hpp
|
58
58
|
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
59
59
|
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
60
|
+
- ext/numo/tiny_linalg/lapack/lange.hpp
|
60
61
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
61
62
|
- ext/numo/tiny_linalg/lapack/potrf.hpp
|
62
63
|
- ext/numo/tiny_linalg/lapack/potrs.hpp
|
@@ -66,6 +67,7 @@ files:
|
|
66
67
|
- ext/numo/tiny_linalg/lapack/sygv.hpp
|
67
68
|
- ext/numo/tiny_linalg/lapack/sygvd.hpp
|
68
69
|
- ext/numo/tiny_linalg/lapack/sygvx.hpp
|
70
|
+
- ext/numo/tiny_linalg/lapack/trtrs.hpp
|
69
71
|
- ext/numo/tiny_linalg/lapack/ungqr.hpp
|
70
72
|
- ext/numo/tiny_linalg/tiny_linalg.cpp
|
71
73
|
- ext/numo/tiny_linalg/tiny_linalg.hpp
|
@@ -97,7 +99,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
97
99
|
- !ruby/object:Gem::Version
|
98
100
|
version: '0'
|
99
101
|
requirements: []
|
100
|
-
rubygems_version: 3.4.
|
102
|
+
rubygems_version: 3.4.19
|
101
103
|
signing_key:
|
102
104
|
specification_version: 4
|
103
105
|
summary: Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of
|