numo-tiny_linalg 0.3.4 → 0.3.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/LICENSE.txt +1 -1
- data/ext/numo/tiny_linalg/extconf.rb +3 -4
- data/ext/numo/tiny_linalg/lapack/lange.hpp +89 -0
- data/ext/numo/tiny_linalg/lapack/trtrs.hpp +126 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +11 -1
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +1 -1
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +217 -15
- metadata +5 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c508a2c2f31a965c26824084f1eeb1a2caf53ea82a3d23e523980d6cf78e6259
|
4
|
+
data.tar.gz: 4bffa14a1a73bfb60230c65fa35c98eeb957e176b8af2003f9c3e5b8aaf89b19
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4ad7a9af39fdc0c5ffb9ba7fc51dcf037fe8ffe7d81c33c70a0c669b384127929efee3a9ceb8440b7d1ec10c79ee66b362aed7098befbe02a892d4490961f169
|
7
|
+
data.tar.gz: 91c9935714ef5b929718affc9cfae518f6e08f85063f673f26ffa7c137bb7fe62d41c20386704657dccb89feebf57bdaed5b182750c6f0fe3ad5b20e55cc4221
|
data/CHANGELOG.md
CHANGED
@@ -1,4 +1,15 @@
|
|
1
1
|
## [Unreleased]
|
2
|
+
## [[0.3.6](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.5...v0.3.6)] - 2024-01-28
|
3
|
+
|
4
|
+
- Add solve_triangular module function to TinyLinalg.
|
5
|
+
- The solve_triangular is not implemented in Numo::Linalg, but I have implemented it because it uses some machine learning algorithms.
|
6
|
+
- Add dtrtrs, strtrs, ztrtrs, and ctrtrs module functions to TinyLinalg::Lapack.
|
7
|
+
- Add norm module function to TinyLinalg.
|
8
|
+
- Add dlange, slange, zlange, and clange module functions to TinyLinalg::Lapack.
|
9
|
+
|
10
|
+
## [[0.3.5](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.4...v0.3.5)] - 2024-01-03
|
11
|
+
- Bump OpenBLAS to be downloaded from 0.3.25 to 0.3.26.
|
12
|
+
- Minor changes using RuboCop.
|
2
13
|
|
3
14
|
## [[0.3.4](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.3...v0.3.4)] - 2023-11-19
|
4
15
|
- Bump OpenBLAS to be downloaded from 0.3.24 to 0.3.25.
|
data/LICENSE.txt
CHANGED
@@ -46,8 +46,8 @@ if build_openblas
|
|
46
46
|
|
47
47
|
VENDOR_DIR = File.expand_path("#{__dir__}/../../../vendor")
|
48
48
|
TINYLINALG_DIR = File.expand_path("#{__dir__}/../../../lib/numo/tiny_linalg")
|
49
|
-
OPENBLAS_VER = '0.3.
|
50
|
-
OPENBLAS_KEY = '
|
49
|
+
OPENBLAS_VER = '0.3.26'
|
50
|
+
OPENBLAS_KEY = 'bd496a1c81769ed19a161c1f8f904ccd'
|
51
51
|
OPENBLAS_URI = "https://github.com/OpenMathLib/OpenBLAS/archive/v#{OPENBLAS_VER}.tar.gz"
|
52
52
|
OPENBLAS_TGZ = "#{VENDOR_DIR}/tmp/openblas.tgz"
|
53
53
|
|
@@ -83,8 +83,7 @@ if build_openblas
|
|
83
83
|
end
|
84
84
|
end
|
85
85
|
|
86
|
-
|
87
|
-
abort('libopenblas is not found.') unless find_library('openblas', nil, libopenblas_dir)
|
86
|
+
abort('libopenblas is not found.') unless find_library('openblas', nil, "#{VENDOR_DIR}/lib")
|
88
87
|
abort('openblas_config.h is not found.') unless find_header('openblas_config.h', nil, "#{VENDOR_DIR}/include")
|
89
88
|
abort('cblas.h is not found.') unless find_header('cblas.h', nil, "#{VENDOR_DIR}/include")
|
90
89
|
abort('lapacke.h is not found.') unless find_header('lapacke.h', nil, "#{VENDOR_DIR}/include")
|
@@ -0,0 +1,89 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DLanGe {
|
4
|
+
double call(int matrix_layout, char norm, lapack_int m, lapack_int n, const double* a, lapack_int lda) {
|
5
|
+
return LAPACKE_dlange(matrix_layout, norm, m, n, a, lda);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SLanGe {
|
10
|
+
float call(int matrix_layout, char norm, lapack_int m, lapack_int n, const float* a, lapack_int lda) {
|
11
|
+
return LAPACKE_slange(matrix_layout, norm, m, n, a, lda);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
struct ZLanGe {
|
16
|
+
double call(int matrix_layout, char norm, lapack_int m, lapack_int n, const lapack_complex_double* a, lapack_int lda) {
|
17
|
+
return LAPACKE_zlange(matrix_layout, norm, m, n, a, lda);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
struct CLanGe {
|
22
|
+
float call(int matrix_layout, char norm, lapack_int m, lapack_int n, const lapack_complex_float* a, lapack_int lda) {
|
23
|
+
return LAPACKE_clange(matrix_layout, norm, m, n, a, lda);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
28
|
+
class LanGe {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_lange), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct lange_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
char norm;
|
38
|
+
};
|
39
|
+
|
40
|
+
static void iter_lange(na_loop_t* const lp) {
|
41
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
42
|
+
dtype* d = (dtype*)NDL_PTR(lp, 1);
|
43
|
+
lange_opt* opt = (lange_opt*)(lp->opt_ptr);
|
44
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
45
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
46
|
+
const lapack_int lda = n;
|
47
|
+
*d = LapackFn().call(opt->matrix_layout, opt->norm, m, n, a, lda);
|
48
|
+
}
|
49
|
+
|
50
|
+
static VALUE tiny_linalg_lange(int argc, VALUE* argv, VALUE self) {
|
51
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
52
|
+
|
53
|
+
VALUE a_vnary = Qnil;
|
54
|
+
VALUE kw_args = Qnil;
|
55
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
56
|
+
ID kw_table[2] = { rb_intern("order"), rb_intern("norm") };
|
57
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
58
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
59
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
60
|
+
const char norm = kw_values[1] != Qundef ? NUM2CHR(kw_values[1]) : 'F';
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
|
69
|
+
narray_t* a_nary = NULL;
|
70
|
+
GetNArray(a_vnary, a_nary);
|
71
|
+
if (NA_NDIM(a_nary) != 2) {
|
72
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
73
|
+
return Qnil;
|
74
|
+
}
|
75
|
+
|
76
|
+
ndfunc_arg_in_t ain[1] = { { nary_dtype, 2 } };
|
77
|
+
size_t shape_out[1] = { 1 };
|
78
|
+
ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
|
79
|
+
ndfunc_t ndf = { iter_lange, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
80
|
+
lange_opt opt = { matrix_layout, norm };
|
81
|
+
VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
82
|
+
|
83
|
+
RB_GC_GUARD(a_vnary);
|
84
|
+
|
85
|
+
return ret;
|
86
|
+
}
|
87
|
+
};
|
88
|
+
|
89
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,126 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DTrTrs {
|
4
|
+
lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
|
5
|
+
const double* a, lapack_int lda, double* b, lapack_int ldb) {
|
6
|
+
return LAPACKE_dtrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct STrTrs {
|
11
|
+
lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
|
12
|
+
const float* a, lapack_int lda, float* b, lapack_int ldb) {
|
13
|
+
return LAPACKE_strtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
struct ZTrTrs {
|
18
|
+
lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
|
19
|
+
const lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb) {
|
20
|
+
return LAPACKE_ztrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
struct CTrTrs {
|
25
|
+
lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
|
26
|
+
const lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb) {
|
27
|
+
return LAPACKE_ctrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
|
28
|
+
}
|
29
|
+
};
|
30
|
+
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
32
|
+
class TrTrs {
|
33
|
+
public:
|
34
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_trtrs), -1);
|
36
|
+
}
|
37
|
+
|
38
|
+
private:
|
39
|
+
struct trtrs_opt {
|
40
|
+
int matrix_layout;
|
41
|
+
char uplo;
|
42
|
+
char trans;
|
43
|
+
char diag;
|
44
|
+
};
|
45
|
+
|
46
|
+
static void iter_trtrs(na_loop_t* const lp) {
|
47
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
48
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
49
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
50
|
+
trtrs_opt* opt = (trtrs_opt*)(lp->opt_ptr);
|
51
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
52
|
+
const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
|
53
|
+
const lapack_int lda = n;
|
54
|
+
const lapack_int ldb = nrhs;
|
55
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, opt->trans, opt->diag, n, nrhs, a, lda, b, ldb);
|
56
|
+
*info = static_cast<int>(i);
|
57
|
+
}
|
58
|
+
|
59
|
+
static VALUE tiny_linalg_trtrs(int argc, VALUE* argv, VALUE self) {
|
60
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
61
|
+
|
62
|
+
VALUE a_vnary = Qnil;
|
63
|
+
VALUE b_vnary = Qnil;
|
64
|
+
VALUE kw_args = Qnil;
|
65
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
66
|
+
ID kw_table[4] = { rb_intern("order"), rb_intern("uplo"), rb_intern("trans"), rb_intern("diag") };
|
67
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
68
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
69
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
70
|
+
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
|
71
|
+
const char trans = kw_values[2] != Qundef ? NUM2CHR(kw_values[2]) : 'N';
|
72
|
+
const char diag = kw_values[3] != Qundef ? NUM2CHR(kw_values[3]) : 'N';
|
73
|
+
|
74
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
75
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
76
|
+
}
|
77
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
78
|
+
a_vnary = nary_dup(a_vnary);
|
79
|
+
}
|
80
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
81
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
82
|
+
}
|
83
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
84
|
+
b_vnary = nary_dup(b_vnary);
|
85
|
+
}
|
86
|
+
|
87
|
+
narray_t* a_nary = NULL;
|
88
|
+
GetNArray(a_vnary, a_nary);
|
89
|
+
if (NA_NDIM(a_nary) != 2) {
|
90
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
91
|
+
return Qnil;
|
92
|
+
}
|
93
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
94
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
95
|
+
return Qnil;
|
96
|
+
}
|
97
|
+
|
98
|
+
narray_t* b_nary = NULL;
|
99
|
+
GetNArray(b_vnary, b_nary);
|
100
|
+
const int b_n_dims = NA_NDIM(b_nary);
|
101
|
+
if (b_n_dims != 1 && b_n_dims != 2) {
|
102
|
+
rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
|
103
|
+
return Qnil;
|
104
|
+
}
|
105
|
+
|
106
|
+
lapack_int n = NA_SHAPE(a_nary)[0];
|
107
|
+
lapack_int nb = NA_SHAPE(b_nary)[0];
|
108
|
+
if (n != nb) {
|
109
|
+
rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb);
|
110
|
+
}
|
111
|
+
|
112
|
+
ndfunc_arg_in_t ain[2] = { { nary_dtype, 2 }, { OVERWRITE, b_n_dims } };
|
113
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
114
|
+
ndfunc_t ndf = { iter_trtrs, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
115
|
+
trtrs_opt opt = { matrix_layout, uplo, trans, diag };
|
116
|
+
VALUE info = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
117
|
+
VALUE ret = rb_ary_new3(2, b_vnary, info);
|
118
|
+
|
119
|
+
RB_GC_GUARD(a_vnary);
|
120
|
+
RB_GC_GUARD(b_vnary);
|
121
|
+
|
122
|
+
return ret;
|
123
|
+
}
|
124
|
+
};
|
125
|
+
|
126
|
+
} // namespace TinyLinalg
|
@@ -1,5 +1,5 @@
|
|
1
1
|
/**
|
2
|
-
* Copyright (c) 2023 Atsushi Tatsuma
|
2
|
+
* Copyright (c) 2023-2024 Atsushi Tatsuma
|
3
3
|
* All rights reserved.
|
4
4
|
*
|
5
5
|
* Redistribution and use in source and binary forms, with or without
|
@@ -50,6 +50,7 @@
|
|
50
50
|
#include "lapack/hegv.hpp"
|
51
51
|
#include "lapack/hegvd.hpp"
|
52
52
|
#include "lapack/hegvx.hpp"
|
53
|
+
#include "lapack/lange.hpp"
|
53
54
|
#include "lapack/orgqr.hpp"
|
54
55
|
#include "lapack/potrf.hpp"
|
55
56
|
#include "lapack/potrs.hpp"
|
@@ -59,6 +60,7 @@
|
|
59
60
|
#include "lapack/sygv.hpp"
|
60
61
|
#include "lapack/sygvd.hpp"
|
61
62
|
#include "lapack/sygvx.hpp"
|
63
|
+
#include "lapack/trtrs.hpp"
|
62
64
|
#include "lapack/ungqr.hpp"
|
63
65
|
|
64
66
|
VALUE rb_mTinyLinalg;
|
@@ -314,6 +316,10 @@ extern "C" void Init_tiny_linalg(void) {
|
|
314
316
|
TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
|
315
317
|
TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
|
316
318
|
TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
|
319
|
+
TinyLinalg::TrTrs<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DTrTrs>::define_module_function(rb_mTinyLinalgLapack, "dtrtrs");
|
320
|
+
TinyLinalg::TrTrs<TinyLinalg::numo_cSFloatId, float, TinyLinalg::STrTrs>::define_module_function(rb_mTinyLinalgLapack, "strtrs");
|
321
|
+
TinyLinalg::TrTrs<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZTrTrs>::define_module_function(rb_mTinyLinalgLapack, "ztrtrs");
|
322
|
+
TinyLinalg::TrTrs<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CTrTrs>::define_module_function(rb_mTinyLinalgLapack, "ctrtrs");
|
317
323
|
TinyLinalg::PoTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrf>::define_module_function(rb_mTinyLinalgLapack, "dpotrf");
|
318
324
|
TinyLinalg::PoTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrf>::define_module_function(rb_mTinyLinalgLapack, "spotrf");
|
319
325
|
TinyLinalg::PoTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrf>::define_module_function(rb_mTinyLinalgLapack, "zpotrf");
|
@@ -354,6 +360,10 @@ extern "C" void Init_tiny_linalg(void) {
|
|
354
360
|
TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
|
355
361
|
TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
|
356
362
|
TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
|
363
|
+
TinyLinalg::LanGe<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DLanGe>::define_module_function(rb_mTinyLinalgLapack, "dlange");
|
364
|
+
TinyLinalg::LanGe<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SLanGe>::define_module_function(rb_mTinyLinalgLapack, "slange");
|
365
|
+
TinyLinalg::LanGe<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZLanGe>::define_module_function(rb_mTinyLinalgLapack, "zlange");
|
366
|
+
TinyLinalg::LanGe<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CLanGe>::define_module_function(rb_mTinyLinalgLapack, "clange");
|
357
367
|
|
358
368
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
|
359
369
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -95,6 +95,172 @@ module Numo
|
|
95
95
|
[vals, vecs]
|
96
96
|
end
|
97
97
|
|
98
|
+
# Computes the matrix or vector norm.
|
99
|
+
#
|
100
|
+
# | ord | matrix norm | vector norm |
|
101
|
+
# | ----- | ---------------------- | --------------------------- |
|
102
|
+
# | nil | Frobenius norm | 2-norm |
|
103
|
+
# | 'fro' | Frobenius norm | - |
|
104
|
+
# | 'nuc' | nuclear norm | - |
|
105
|
+
# | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
|
106
|
+
# | 0 | - | (x.ne 0).sum |
|
107
|
+
# | 1 | x.abs.sum(axis:-2).max | same as below |
|
108
|
+
# | 2 | 2-norm (max sing_vals) | same as below |
|
109
|
+
# | other | - | (x.abs**ord).sum**(1.0/ord) |
|
110
|
+
#
|
111
|
+
# @example
|
112
|
+
# require 'numo/tiny_linalg'
|
113
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
114
|
+
#
|
115
|
+
# # matrix norm
|
116
|
+
# x = Numo::DFloat[[1, 2, -3, 1], [-4, 1, 8, 2]]
|
117
|
+
# pp Numo::Linalg.norm(x)
|
118
|
+
# # => 10
|
119
|
+
#
|
120
|
+
# # vector norm
|
121
|
+
# x = Numo::DFloat[3, -4]
|
122
|
+
# pp Numo::Linalg.norm(x)
|
123
|
+
# # => 5
|
124
|
+
#
|
125
|
+
# @param a [Numo::NArray] The matrix or vector (>= 1-dimensinal NArray)
|
126
|
+
# @param ord [String/Numeric] The order of the norm.
|
127
|
+
# @param axis [Integer/Array] The applied axes.
|
128
|
+
# @param keepdims [Bool] The flag indicating whether to leave the normed axes in the result as dimensions with size one.
|
129
|
+
# @return [Numo::NArray/Numeric] The norm of the matrix or vectors.
|
130
|
+
def norm(a, ord = nil, axis: nil, keepdims: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
131
|
+
a = Numo::NArray.asarray(a) unless a.is_a?(Numo::NArray)
|
132
|
+
|
133
|
+
return 0.0 if a.empty?
|
134
|
+
|
135
|
+
# for compatibility with Numo::Linalg.norm
|
136
|
+
if ord.is_a?(String)
|
137
|
+
if ord == 'inf'
|
138
|
+
ord = Float::INFINITY
|
139
|
+
elsif ord == '-inf'
|
140
|
+
ord = -Float::INFINITY
|
141
|
+
end
|
142
|
+
end
|
143
|
+
|
144
|
+
if axis.nil?
|
145
|
+
norm = case a.ndim
|
146
|
+
when 1
|
147
|
+
Numo::TinyLinalg::Blas.send(:"#{blas_char(a)}nrm2", a) if ord.nil? || ord == 2
|
148
|
+
when 2
|
149
|
+
if ord.nil? || ord == 'fro'
|
150
|
+
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'F')
|
151
|
+
elsif ord.is_a?(Numeric)
|
152
|
+
if ord == 1
|
153
|
+
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: '1')
|
154
|
+
elsif !ord.infinite?.nil? && ord.infinite?.positive?
|
155
|
+
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'I')
|
156
|
+
end
|
157
|
+
end
|
158
|
+
else
|
159
|
+
if ord.nil?
|
160
|
+
b = a.flatten.dup
|
161
|
+
Numo::TinyLinalg::Blas.send(:"#{blas_char(b)}nrm2", b)
|
162
|
+
end
|
163
|
+
end
|
164
|
+
unless norm.nil?
|
165
|
+
norm = Numo::NArray.asarray(norm).reshape(*([1] * a.ndim)) if keepdims
|
166
|
+
return norm
|
167
|
+
end
|
168
|
+
end
|
169
|
+
|
170
|
+
if axis.nil?
|
171
|
+
axis = Array.new(a.ndim) { |d| d }
|
172
|
+
else
|
173
|
+
case axis
|
174
|
+
when Integer
|
175
|
+
axis = [axis]
|
176
|
+
when Array, Numo::NArray
|
177
|
+
axis = axis.flatten.to_a
|
178
|
+
else
|
179
|
+
raise ArgumentError, "invalid axis: #{axis}"
|
180
|
+
end
|
181
|
+
end
|
182
|
+
|
183
|
+
raise ArgumentError, "the number of dimensions of axis is inappropriate for the norm: #{axis.size}" unless axis.size == 1 || axis.size == 2
|
184
|
+
raise ArgumentError, "axis is out of range: #{axis}" unless axis.all? { |ax| (-a.ndim...a.ndim).cover?(ax) }
|
185
|
+
|
186
|
+
if axis.size == 1
|
187
|
+
ord ||= 2
|
188
|
+
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(Numeric)
|
189
|
+
|
190
|
+
ord_inf = ord.infinite?
|
191
|
+
if ord_inf.nil?
|
192
|
+
case ord
|
193
|
+
when 0
|
194
|
+
a.class.cast(a.ne(0)).sum(axis: axis, keepdims: keepdims)
|
195
|
+
when 1
|
196
|
+
a.abs.sum(axis: axis, keepdims: keepdims)
|
197
|
+
else
|
198
|
+
(a.abs**ord).sum(axis: axis, keepdims: keepdims)**1.fdiv(ord)
|
199
|
+
end
|
200
|
+
elsif ord_inf.positive?
|
201
|
+
a.abs.max(axis: axis, keepdims: keepdims)
|
202
|
+
else
|
203
|
+
a.abs.min(axis: axis, keepdims: keepdims)
|
204
|
+
end
|
205
|
+
else
|
206
|
+
ord ||= 'fro'
|
207
|
+
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(String) || ord.is_a?(Numeric)
|
208
|
+
raise ArgumentError, "invalid axis: #{axis}" if axis.uniq.size == 1
|
209
|
+
|
210
|
+
r_axis, c_axis = axis.map { |ax| ax.negative? ? ax + a.ndim : ax }
|
211
|
+
|
212
|
+
norm = if ord.is_a?(String)
|
213
|
+
raise ArgumentError, "invalid ord: #{ord}" unless %w[fro nuc].include?(ord)
|
214
|
+
|
215
|
+
if ord == 'fro'
|
216
|
+
Numo::NMath.sqrt((a.abs**2).sum(axis: axis))
|
217
|
+
else
|
218
|
+
b = a.transpose(c_axis, r_axis).dup
|
219
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
220
|
+
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
221
|
+
s.sum(axis: -1)
|
222
|
+
end
|
223
|
+
else
|
224
|
+
ord_inf = ord.infinite?
|
225
|
+
if ord_inf.nil?
|
226
|
+
case ord
|
227
|
+
when -2
|
228
|
+
b = a.transpose(c_axis, r_axis).dup
|
229
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
230
|
+
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
231
|
+
s.min(axis: -1)
|
232
|
+
when -1
|
233
|
+
c_axis -= 1 if c_axis > r_axis
|
234
|
+
a.abs.sum(axis: r_axis).min(axis: c_axis)
|
235
|
+
when 1
|
236
|
+
c_axis -= 1 if c_axis > r_axis
|
237
|
+
a.abs.sum(axis: r_axis).max(axis: c_axis)
|
238
|
+
when 2
|
239
|
+
b = a.transpose(c_axis, r_axis).dup
|
240
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
241
|
+
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
242
|
+
s.max(axis: -1)
|
243
|
+
else
|
244
|
+
raise ArgumentError, "invalid ord: #{ord}"
|
245
|
+
end
|
246
|
+
else
|
247
|
+
r_axis -= 1 if r_axis > c_axis
|
248
|
+
if ord_inf.positive?
|
249
|
+
a.abs.sum(axis: c_axis).max(axis: r_axis)
|
250
|
+
else
|
251
|
+
a.abs.sum(axis: c_axis).min(axis: r_axis)
|
252
|
+
end
|
253
|
+
end
|
254
|
+
end
|
255
|
+
if keepdims
|
256
|
+
norm = Numo::NArray.asarray(norm) unless norm.is_a?(Numo::NArray)
|
257
|
+
norm = norm.reshape(*([1] * a.ndim))
|
258
|
+
end
|
259
|
+
|
260
|
+
norm
|
261
|
+
end
|
262
|
+
end
|
263
|
+
|
98
264
|
# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
|
99
265
|
#
|
100
266
|
# @example
|
@@ -138,7 +304,7 @@ module Numo
|
|
138
304
|
bchr = blas_char(a)
|
139
305
|
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
140
306
|
|
141
|
-
fnc = "#{bchr}potrf"
|
307
|
+
fnc = :"#{bchr}potrf"
|
142
308
|
c, _info = Numo::TinyLinalg::Lapack.send(fnc, a.dup, uplo: uplo)
|
143
309
|
|
144
310
|
case uplo
|
@@ -180,7 +346,7 @@ module Numo
|
|
180
346
|
bchr = blas_char(a, b)
|
181
347
|
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
182
348
|
|
183
|
-
fnc = "#{bchr}potrs"
|
349
|
+
fnc = :"#{bchr}potrs"
|
184
350
|
x, _info = Numo::TinyLinalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
|
185
351
|
x
|
186
352
|
end
|
@@ -205,7 +371,7 @@ module Numo
|
|
205
371
|
bchr = blas_char(a)
|
206
372
|
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
207
373
|
|
208
|
-
getrf = "#{bchr}getrf"
|
374
|
+
getrf = :"#{bchr}getrf"
|
209
375
|
lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
|
210
376
|
|
211
377
|
if info.zero?
|
@@ -248,8 +414,8 @@ module Numo
|
|
248
414
|
bchr = blas_char(a)
|
249
415
|
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
250
416
|
|
251
|
-
getrf = "#{bchr}getrf"
|
252
|
-
getri = "#{bchr}getri"
|
417
|
+
getrf = :"#{bchr}getrf"
|
418
|
+
getri = :"#{bchr}getri"
|
253
419
|
|
254
420
|
lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
|
255
421
|
if info.zero?
|
@@ -336,7 +502,7 @@ module Numo
|
|
336
502
|
bchr = blas_char(a)
|
337
503
|
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
338
504
|
|
339
|
-
geqrf = "#{bchr}geqrf"
|
505
|
+
geqrf = :"#{bchr}geqrf"
|
340
506
|
qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
|
341
507
|
|
342
508
|
return [qr, tau] if mode == 'raw'
|
@@ -346,7 +512,7 @@ module Numo
|
|
346
512
|
|
347
513
|
return r if mode == 'r'
|
348
514
|
|
349
|
-
org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr"
|
515
|
+
org_ung_qr = %w[d s].include?(bchr) ? :"#{bchr}orgqr" : :"#{bchr}ungqr"
|
350
516
|
|
351
517
|
q = if m < n
|
352
518
|
Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
|
@@ -395,10 +561,51 @@ module Numo
|
|
395
561
|
bchr = blas_char(a, b)
|
396
562
|
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
|
397
563
|
|
398
|
-
gesv = "#{bchr}gesv"
|
564
|
+
gesv = :"#{bchr}gesv"
|
399
565
|
Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
|
400
566
|
end
|
401
567
|
|
568
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix.
|
569
|
+
#
|
570
|
+
# @example
|
571
|
+
# require 'numo/tiny_linalg'
|
572
|
+
#
|
573
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
574
|
+
#
|
575
|
+
# a = Numo::DFloat.new(3, 3).rand.triu
|
576
|
+
# b = Numo::DFloat.eye(3)
|
577
|
+
#
|
578
|
+
# x = Numo::Linalg.solve(a, b)
|
579
|
+
#
|
580
|
+
# pp x
|
581
|
+
# # =>
|
582
|
+
# # Numo::DFloat#shape=[3,3]
|
583
|
+
# # [[16.1932, -52.0604, 30.5283],
|
584
|
+
# # [0, 8.61765, -17.9585],
|
585
|
+
# # [0, 0, 6.05735]]
|
586
|
+
#
|
587
|
+
# pp (b - a.dot(x)).abs.max
|
588
|
+
# # => 4.071100642430302e-16
|
589
|
+
#
|
590
|
+
# @param a [Numo::NArray] The n-by-n triangular matrix.
|
591
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
592
|
+
# @param lower [Boolean] The flag indicating whether to use the lower-triangular part of `a`.
|
593
|
+
# @return [Numo::NArray] The solusion vector / matrix `X`.
|
594
|
+
def solve_triangular(a, b, lower: false)
|
595
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
596
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
597
|
+
|
598
|
+
bchr = blas_char(a, b)
|
599
|
+
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
|
600
|
+
|
601
|
+
trtrs = :"#{bchr}trtrs"
|
602
|
+
uplo = lower ? 'L' : 'U'
|
603
|
+
x, info = Numo::TinyLinalg::Lapack.send(trtrs, a, b.dup, uplo: uplo)
|
604
|
+
raise "wrong value is given to the #{info}-th argument of #{trtrs} used internally" if info.negative?
|
605
|
+
|
606
|
+
x
|
607
|
+
end
|
608
|
+
|
402
609
|
# Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
|
403
610
|
#
|
404
611
|
# @example
|
@@ -443,10 +650,10 @@ module Numo
|
|
443
650
|
|
444
651
|
case driver.to_s
|
445
652
|
when 'sdd'
|
446
|
-
gesdd = "#{bchr}gesdd"
|
653
|
+
gesdd = :"#{bchr}gesdd"
|
447
654
|
s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
|
448
655
|
when 'svd'
|
449
|
-
gesvd = "#{bchr}gesvd"
|
656
|
+
gesvd = :"#{bchr}gesvd"
|
450
657
|
s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
|
451
658
|
else
|
452
659
|
raise ArgumentError, "invalid driver: #{driver}"
|
@@ -534,11 +741,6 @@ module Numo
|
|
534
741
|
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
535
742
|
end
|
536
743
|
|
537
|
-
# @!visibility private
|
538
|
-
def norm(*args)
|
539
|
-
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
540
|
-
end
|
541
|
-
|
542
744
|
# @!visibility private
|
543
745
|
def cond(*args)
|
544
746
|
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-01-28 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -57,6 +57,7 @@ files:
|
|
57
57
|
- ext/numo/tiny_linalg/lapack/hegv.hpp
|
58
58
|
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
59
59
|
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
60
|
+
- ext/numo/tiny_linalg/lapack/lange.hpp
|
60
61
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
61
62
|
- ext/numo/tiny_linalg/lapack/potrf.hpp
|
62
63
|
- ext/numo/tiny_linalg/lapack/potrs.hpp
|
@@ -66,6 +67,7 @@ files:
|
|
66
67
|
- ext/numo/tiny_linalg/lapack/sygv.hpp
|
67
68
|
- ext/numo/tiny_linalg/lapack/sygvd.hpp
|
68
69
|
- ext/numo/tiny_linalg/lapack/sygvx.hpp
|
70
|
+
- ext/numo/tiny_linalg/lapack/trtrs.hpp
|
69
71
|
- ext/numo/tiny_linalg/lapack/ungqr.hpp
|
70
72
|
- ext/numo/tiny_linalg/tiny_linalg.cpp
|
71
73
|
- ext/numo/tiny_linalg/tiny_linalg.hpp
|
@@ -97,7 +99,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
97
99
|
- !ruby/object:Gem::Version
|
98
100
|
version: '0'
|
99
101
|
requirements: []
|
100
|
-
rubygems_version: 3.4.
|
102
|
+
rubygems_version: 3.4.19
|
101
103
|
signing_key:
|
102
104
|
specification_version: 4
|
103
105
|
summary: Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of
|