numo-tiny_linalg 0.0.1 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -0
- data/README.md +2 -0
- data/ext/numo/tiny_linalg/extconf.rb +64 -16
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +118 -0
- data/ext/numo/tiny_linalg/lapack/gesv.hpp +148 -0
- data/ext/numo/tiny_linalg/lapack/getrf.hpp +118 -0
- data/ext/numo/tiny_linalg/lapack/getri.hpp +127 -0
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +115 -0
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +115 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +33 -7
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +131 -1
- data/vendor/tmp/.gitkeep +0 -0
- metadata +17 -17
- data/.clang-format +0 -149
- data/.husky/commit-msg +0 -4
- data/.rubocop.yml +0 -47
- data/Gemfile +0 -15
- data/Rakefile +0 -30
- data/commitlint.config.js +0 -1
- data/numo-tiny_linalg.gemspec +0 -42
- data/package.json +0 -15
- /data/ext/numo/tiny_linalg/{dot.hpp → blas/dot.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{dot_sub.hpp → blas/dot_sub.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{gemm.hpp → blas/gemm.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{gemv.hpp → blas/gemv.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{nrm2.hpp → blas/nrm2.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{gesdd.hpp → lapack/gesdd.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{gesvd.hpp → lapack/gesvd.hpp} +0 -0
@@ -0,0 +1,115 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DOrgQr {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
|
5
|
+
double* a, lapack_int lda, const double* tau) {
|
6
|
+
return LAPACKE_dorgqr(matrix_layout, m, n, k, a, lda, tau);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct SOrgQr {
|
11
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
|
12
|
+
float* a, lapack_int lda, const float* tau) {
|
13
|
+
return LAPACKE_sorgqr(matrix_layout, m, n, k, a, lda, tau);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
18
|
+
class OrgQr {
|
19
|
+
public:
|
20
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
21
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_orgqr), -1);
|
22
|
+
}
|
23
|
+
|
24
|
+
private:
|
25
|
+
struct orgqr_opt {
|
26
|
+
int matrix_layout;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_orgqr(na_loop_t* const lp) {
|
30
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
31
|
+
DType* tau = (DType*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
35
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
36
|
+
const lapack_int k = NDL_SHAPE(lp, 1)[0];
|
37
|
+
const lapack_int lda = n;
|
38
|
+
const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
|
39
|
+
*info = static_cast<int>(i);
|
40
|
+
}
|
41
|
+
|
42
|
+
static VALUE tiny_linalg_orgqr(int argc, VALUE* argv, VALUE self) {
|
43
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
44
|
+
|
45
|
+
VALUE a_vnary = Qnil;
|
46
|
+
VALUE tau_vnary = Qnil;
|
47
|
+
VALUE kw_args = Qnil;
|
48
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
|
49
|
+
ID kw_table[1] = { rb_intern("order") };
|
50
|
+
VALUE kw_values[1] = { Qundef };
|
51
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
52
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
53
|
+
|
54
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
56
|
+
}
|
57
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
58
|
+
a_vnary = nary_dup(a_vnary);
|
59
|
+
}
|
60
|
+
if (CLASS_OF(tau_vnary) != nary_dtype) {
|
61
|
+
tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
|
62
|
+
}
|
63
|
+
if (!RTEST(nary_check_contiguous(tau_vnary))) {
|
64
|
+
tau_vnary = nary_dup(tau_vnary);
|
65
|
+
}
|
66
|
+
|
67
|
+
narray_t* a_nary = NULL;
|
68
|
+
GetNArray(a_vnary, a_nary);
|
69
|
+
if (NA_NDIM(a_nary) != 2) {
|
70
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
71
|
+
return Qnil;
|
72
|
+
}
|
73
|
+
narray_t* tau_nary = NULL;
|
74
|
+
GetNArray(tau_vnary, tau_nary);
|
75
|
+
if (NA_NDIM(tau_nary) != 1) {
|
76
|
+
rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
|
77
|
+
return Qnil;
|
78
|
+
}
|
79
|
+
|
80
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
|
81
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
82
|
+
ndfunc_t ndf = { iter_orgqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
83
|
+
orgqr_opt opt = { matrix_layout };
|
84
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
|
85
|
+
|
86
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res);
|
87
|
+
|
88
|
+
RB_GC_GUARD(a_vnary);
|
89
|
+
RB_GC_GUARD(tau_vnary);
|
90
|
+
|
91
|
+
return ret;
|
92
|
+
}
|
93
|
+
|
94
|
+
static int get_matrix_layout(VALUE val) {
|
95
|
+
const char* option_str = StringValueCStr(val);
|
96
|
+
|
97
|
+
if (std::strlen(option_str) > 0) {
|
98
|
+
switch (option_str[0]) {
|
99
|
+
case 'r':
|
100
|
+
case 'R':
|
101
|
+
break;
|
102
|
+
case 'c':
|
103
|
+
case 'C':
|
104
|
+
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
105
|
+
break;
|
106
|
+
}
|
107
|
+
}
|
108
|
+
|
109
|
+
RB_GC_GUARD(val);
|
110
|
+
|
111
|
+
return LAPACK_ROW_MAJOR;
|
112
|
+
}
|
113
|
+
};
|
114
|
+
|
115
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,115 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZUngQr {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
|
5
|
+
lapack_complex_double* a, lapack_int lda, const lapack_complex_double* tau) {
|
6
|
+
return LAPACKE_zungqr(matrix_layout, m, n, k, a, lda, tau);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct CUngQr {
|
11
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
|
12
|
+
lapack_complex_float* a, lapack_int lda, const lapack_complex_float* tau) {
|
13
|
+
return LAPACKE_cungqr(matrix_layout, m, n, k, a, lda, tau);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
18
|
+
class UngQr {
|
19
|
+
public:
|
20
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
21
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_ungqr), -1);
|
22
|
+
}
|
23
|
+
|
24
|
+
private:
|
25
|
+
struct ungqr_opt {
|
26
|
+
int matrix_layout;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_ungqr(na_loop_t* const lp) {
|
30
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
31
|
+
DType* tau = (DType*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
35
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
36
|
+
const lapack_int k = NDL_SHAPE(lp, 1)[0];
|
37
|
+
const lapack_int lda = n;
|
38
|
+
const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
|
39
|
+
*info = static_cast<int>(i);
|
40
|
+
}
|
41
|
+
|
42
|
+
static VALUE tiny_linalg_ungqr(int argc, VALUE* argv, VALUE self) {
|
43
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
44
|
+
|
45
|
+
VALUE a_vnary = Qnil;
|
46
|
+
VALUE tau_vnary = Qnil;
|
47
|
+
VALUE kw_args = Qnil;
|
48
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
|
49
|
+
ID kw_table[1] = { rb_intern("order") };
|
50
|
+
VALUE kw_values[1] = { Qundef };
|
51
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
52
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
53
|
+
|
54
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
56
|
+
}
|
57
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
58
|
+
a_vnary = nary_dup(a_vnary);
|
59
|
+
}
|
60
|
+
if (CLASS_OF(tau_vnary) != nary_dtype) {
|
61
|
+
tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
|
62
|
+
}
|
63
|
+
if (!RTEST(nary_check_contiguous(tau_vnary))) {
|
64
|
+
tau_vnary = nary_dup(tau_vnary);
|
65
|
+
}
|
66
|
+
|
67
|
+
narray_t* a_nary = NULL;
|
68
|
+
GetNArray(a_vnary, a_nary);
|
69
|
+
if (NA_NDIM(a_nary) != 2) {
|
70
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
71
|
+
return Qnil;
|
72
|
+
}
|
73
|
+
narray_t* tau_nary = NULL;
|
74
|
+
GetNArray(tau_vnary, tau_nary);
|
75
|
+
if (NA_NDIM(tau_nary) != 1) {
|
76
|
+
rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
|
77
|
+
return Qnil;
|
78
|
+
}
|
79
|
+
|
80
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
|
81
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
82
|
+
ndfunc_t ndf = { iter_ungqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
83
|
+
ungqr_opt opt = { matrix_layout };
|
84
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
|
85
|
+
|
86
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res);
|
87
|
+
|
88
|
+
RB_GC_GUARD(a_vnary);
|
89
|
+
RB_GC_GUARD(tau_vnary);
|
90
|
+
|
91
|
+
return ret;
|
92
|
+
}
|
93
|
+
|
94
|
+
static int get_matrix_layout(VALUE val) {
|
95
|
+
const char* option_str = StringValueCStr(val);
|
96
|
+
|
97
|
+
if (std::strlen(option_str) > 0) {
|
98
|
+
switch (option_str[0]) {
|
99
|
+
case 'r':
|
100
|
+
case 'R':
|
101
|
+
break;
|
102
|
+
case 'c':
|
103
|
+
case 'C':
|
104
|
+
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
105
|
+
break;
|
106
|
+
}
|
107
|
+
}
|
108
|
+
|
109
|
+
RB_GC_GUARD(val);
|
110
|
+
|
111
|
+
return LAPACK_ROW_MAJOR;
|
112
|
+
}
|
113
|
+
};
|
114
|
+
|
115
|
+
} // namespace TinyLinalg
|
@@ -1,12 +1,18 @@
|
|
1
1
|
#include "tiny_linalg.hpp"
|
2
|
+
#include "blas/dot.hpp"
|
3
|
+
#include "blas/dot_sub.hpp"
|
4
|
+
#include "blas/gemm.hpp"
|
5
|
+
#include "blas/gemv.hpp"
|
6
|
+
#include "blas/nrm2.hpp"
|
2
7
|
#include "converter.hpp"
|
3
|
-
#include "
|
4
|
-
#include "
|
5
|
-
#include "
|
6
|
-
#include "
|
7
|
-
#include "
|
8
|
-
#include "
|
9
|
-
#include "
|
8
|
+
#include "lapack/geqrf.hpp"
|
9
|
+
#include "lapack/gesdd.hpp"
|
10
|
+
#include "lapack/gesv.hpp"
|
11
|
+
#include "lapack/gesvd.hpp"
|
12
|
+
#include "lapack/getrf.hpp"
|
13
|
+
#include "lapack/getri.hpp"
|
14
|
+
#include "lapack/orgqr.hpp"
|
15
|
+
#include "lapack/ungqr.hpp"
|
10
16
|
|
11
17
|
VALUE rb_mTinyLinalg;
|
12
18
|
VALUE rb_mTinyLinalgBlas;
|
@@ -237,6 +243,10 @@ extern "C" void Init_tiny_linalg(void) {
|
|
237
243
|
TinyLinalg::Nrm2<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SNrm2>::define_module_function(rb_mTinyLinalgBlas, "snrm2");
|
238
244
|
TinyLinalg::Nrm2<TinyLinalg::numo_cDComplexId, double, TinyLinalg::DZNrm2>::define_module_function(rb_mTinyLinalgBlas, "dznrm2");
|
239
245
|
TinyLinalg::Nrm2<TinyLinalg::numo_cSComplexId, float, TinyLinalg::SCNrm2>::define_module_function(rb_mTinyLinalgBlas, "scnrm2");
|
246
|
+
TinyLinalg::GESV<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGESV>::define_module_function(rb_mTinyLinalgLapack, "dgesv");
|
247
|
+
TinyLinalg::GESV<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGESV>::define_module_function(rb_mTinyLinalgLapack, "sgesv");
|
248
|
+
TinyLinalg::GESV<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGESV>::define_module_function(rb_mTinyLinalgLapack, "zgesv");
|
249
|
+
TinyLinalg::GESV<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGESV>::define_module_function(rb_mTinyLinalgLapack, "cgesv");
|
240
250
|
TinyLinalg::GESVD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESVD>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
|
241
251
|
TinyLinalg::GESVD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESVD>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
|
242
252
|
TinyLinalg::GESVD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESVD>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
|
@@ -245,6 +255,22 @@ extern "C" void Init_tiny_linalg(void) {
|
|
245
255
|
TinyLinalg::GESDD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESDD>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
|
246
256
|
TinyLinalg::GESDD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESDD>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
|
247
257
|
TinyLinalg::GESDD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESDD>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
|
258
|
+
TinyLinalg::GETRF<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGETRF>::define_module_function(rb_mTinyLinalgLapack, "dgetrf");
|
259
|
+
TinyLinalg::GETRF<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRF>::define_module_function(rb_mTinyLinalgLapack, "sgetrf");
|
260
|
+
TinyLinalg::GETRF<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRF>::define_module_function(rb_mTinyLinalgLapack, "zgetrf");
|
261
|
+
TinyLinalg::GETRF<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRF>::define_module_function(rb_mTinyLinalgLapack, "cgetrf");
|
262
|
+
TinyLinalg::GETRI<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGETRI>::define_module_function(rb_mTinyLinalgLapack, "dgetri");
|
263
|
+
TinyLinalg::GETRI<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRI>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
|
264
|
+
TinyLinalg::GETRI<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRI>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
|
265
|
+
TinyLinalg::GETRI<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRI>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
|
266
|
+
TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
|
267
|
+
TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
|
268
|
+
TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
|
269
|
+
TinyLinalg::GeQrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeQrf>::define_module_function(rb_mTinyLinalgLapack, "cgeqrf");
|
270
|
+
TinyLinalg::OrgQr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DOrgQr>::define_module_function(rb_mTinyLinalgLapack, "dorgqr");
|
271
|
+
TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
|
272
|
+
TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
|
273
|
+
TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
|
248
274
|
|
249
275
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
|
250
276
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -7,9 +7,139 @@ require_relative 'tiny_linalg/tiny_linalg'
|
|
7
7
|
# Ruby/Numo (NUmerical MOdules)
|
8
8
|
module Numo
|
9
9
|
# Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
10
|
-
module TinyLinalg
|
10
|
+
module TinyLinalg # rubocop:disable Metrics/ModuleLength
|
11
11
|
module_function
|
12
12
|
|
13
|
+
# Computes the determinant of matrix.
|
14
|
+
#
|
15
|
+
# @param a [Numo::NArray] n-by-n square matrix.
|
16
|
+
# @return [Float/Complex] The determinant of `a`.
|
17
|
+
def det(a)
|
18
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
19
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
20
|
+
|
21
|
+
bchr = blas_char(a)
|
22
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
23
|
+
|
24
|
+
getrf = "#{bchr}getrf".to_sym
|
25
|
+
lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
|
26
|
+
|
27
|
+
if info.zero?
|
28
|
+
det_l = 1
|
29
|
+
det_u = lu.diagonal.prod
|
30
|
+
det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
|
31
|
+
det_l * det_u * det_p
|
32
|
+
elsif info.positive?
|
33
|
+
raise 'the factor U is singular, and the inverse matrix could not be computed.'
|
34
|
+
else
|
35
|
+
raise "the #{-info}-th argument of getrf had illegal value"
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
# Computes the inverse matrix of a square matrix.
|
40
|
+
#
|
41
|
+
# @param a [Numo::NArray] n-by-n square matrix.
|
42
|
+
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
43
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
44
|
+
# @return [Numo::NArray] The inverse matrix of `a`.
|
45
|
+
def inv(a, driver: 'getrf', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
46
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
47
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
48
|
+
|
49
|
+
bchr = blas_char(a)
|
50
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
51
|
+
|
52
|
+
getrf = "#{bchr}getrf".to_sym
|
53
|
+
getri = "#{bchr}getri".to_sym
|
54
|
+
|
55
|
+
lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
|
56
|
+
if info.zero?
|
57
|
+
Numo::TinyLinalg::Lapack.send(getri, lu, piv)[0]
|
58
|
+
elsif info.positive?
|
59
|
+
raise 'the factor U is singular, and the inverse matrix could not be computed.'
|
60
|
+
else
|
61
|
+
raise "the #{-info}-th argument of getrf had illegal value"
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
# Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
|
66
|
+
#
|
67
|
+
# @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
|
68
|
+
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
69
|
+
# @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
|
70
|
+
# @return [Numo::NArray] The pseudo-inverse of `a`.
|
71
|
+
def pinv(a, driver: 'svd', rcond: nil)
|
72
|
+
s, u, vh = svd(a, driver: driver, job: 'S')
|
73
|
+
rcond = a.shape.max * s.class::EPSILON if rcond.nil?
|
74
|
+
rank = s.gt(rcond * s[0]).count
|
75
|
+
|
76
|
+
u = u[true, 0...rank] / s[0...rank]
|
77
|
+
u.dot(vh[0...rank, true]).conj.transpose
|
78
|
+
end
|
79
|
+
|
80
|
+
# Compute QR decomposition of a matrix.
|
81
|
+
#
|
82
|
+
# @param a [Numo::NArray] The m-by-n matrix to be decomposed.
|
83
|
+
# @param mode [String] The mode of decomposition.
|
84
|
+
# - "reduce" -- returns both Q [m, m] and R [m, n],
|
85
|
+
# - "r" -- returns only R,
|
86
|
+
# - "economic" -- returns both Q [m, n] and R [n, n],
|
87
|
+
# - "raw" -- returns QR and TAU (LAPACK geqrf results).
|
88
|
+
# @return [Numo::NArray] if mode='r'
|
89
|
+
# @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
|
90
|
+
# @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
|
91
|
+
def qr(a, mode: 'reduce')
|
92
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
93
|
+
raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
|
94
|
+
|
95
|
+
bchr = blas_char(a)
|
96
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
97
|
+
|
98
|
+
geqrf = "#{bchr}geqrf".to_sym
|
99
|
+
qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
|
100
|
+
|
101
|
+
return [qr, tau] if mode == 'raw'
|
102
|
+
|
103
|
+
m, n = qr.shape
|
104
|
+
r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
|
105
|
+
|
106
|
+
return r if mode == 'r'
|
107
|
+
|
108
|
+
org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr".to_sym : "#{bchr}ungqr".to_sym
|
109
|
+
|
110
|
+
q = if m < n
|
111
|
+
Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
|
112
|
+
elsif mode == 'economic'
|
113
|
+
Numo::TinyLinalg::Lapack.send(org_ung_qr, qr, tau)[0]
|
114
|
+
else
|
115
|
+
qqr = a.class.zeros(m, m)
|
116
|
+
qqr[0...m, 0...n] = qr
|
117
|
+
Numo::TinyLinalg::Lapack.send(org_ung_qr, qqr, tau)[0]
|
118
|
+
end
|
119
|
+
|
120
|
+
[q, r]
|
121
|
+
end
|
122
|
+
|
123
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
|
124
|
+
#
|
125
|
+
# @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
|
126
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
|
127
|
+
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
128
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
129
|
+
# @return [Numo::NArray] The solusion vector / matrix `x`.
|
130
|
+
def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
131
|
+
case blas_char(a, b)
|
132
|
+
when 'd'
|
133
|
+
Lapack.dgesv(a.dup, b.dup)[1]
|
134
|
+
when 's'
|
135
|
+
Lapack.sgesv(a.dup, b.dup)[1]
|
136
|
+
when 'z'
|
137
|
+
Lapack.zgesv(a.dup, b.dup)[1]
|
138
|
+
when 'c'
|
139
|
+
Lapack.cgesv(a.dup, b.dup)[1]
|
140
|
+
end
|
141
|
+
end
|
142
|
+
|
13
143
|
# Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
|
14
144
|
#
|
15
145
|
# @param a [Numo::NArray] Matrix to be decomposed.
|
data/vendor/tmp/.gitkeep
ADDED
File without changes
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-08-02 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -35,31 +35,30 @@ extensions:
|
|
35
35
|
- ext/numo/tiny_linalg/extconf.rb
|
36
36
|
extra_rdoc_files: []
|
37
37
|
files:
|
38
|
-
- ".clang-format"
|
39
|
-
- ".husky/commit-msg"
|
40
|
-
- ".rubocop.yml"
|
41
38
|
- CHANGELOG.md
|
42
39
|
- CODE_OF_CONDUCT.md
|
43
|
-
- Gemfile
|
44
40
|
- LICENSE.txt
|
45
41
|
- README.md
|
46
|
-
-
|
47
|
-
-
|
42
|
+
- ext/numo/tiny_linalg/blas/dot.hpp
|
43
|
+
- ext/numo/tiny_linalg/blas/dot_sub.hpp
|
44
|
+
- ext/numo/tiny_linalg/blas/gemm.hpp
|
45
|
+
- ext/numo/tiny_linalg/blas/gemv.hpp
|
46
|
+
- ext/numo/tiny_linalg/blas/nrm2.hpp
|
48
47
|
- ext/numo/tiny_linalg/converter.hpp
|
49
|
-
- ext/numo/tiny_linalg/dot.hpp
|
50
|
-
- ext/numo/tiny_linalg/dot_sub.hpp
|
51
48
|
- ext/numo/tiny_linalg/extconf.rb
|
52
|
-
- ext/numo/tiny_linalg/
|
53
|
-
- ext/numo/tiny_linalg/
|
54
|
-
- ext/numo/tiny_linalg/
|
55
|
-
- ext/numo/tiny_linalg/gesvd.hpp
|
56
|
-
- ext/numo/tiny_linalg/
|
49
|
+
- ext/numo/tiny_linalg/lapack/geqrf.hpp
|
50
|
+
- ext/numo/tiny_linalg/lapack/gesdd.hpp
|
51
|
+
- ext/numo/tiny_linalg/lapack/gesv.hpp
|
52
|
+
- ext/numo/tiny_linalg/lapack/gesvd.hpp
|
53
|
+
- ext/numo/tiny_linalg/lapack/getrf.hpp
|
54
|
+
- ext/numo/tiny_linalg/lapack/getri.hpp
|
55
|
+
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
56
|
+
- ext/numo/tiny_linalg/lapack/ungqr.hpp
|
57
57
|
- ext/numo/tiny_linalg/tiny_linalg.cpp
|
58
58
|
- ext/numo/tiny_linalg/tiny_linalg.hpp
|
59
59
|
- lib/numo/tiny_linalg.rb
|
60
60
|
- lib/numo/tiny_linalg/version.rb
|
61
|
-
-
|
62
|
-
- package.json
|
61
|
+
- vendor/tmp/.gitkeep
|
63
62
|
homepage: https://github.com/yoshoku/numo-tiny_linalg
|
64
63
|
licenses:
|
65
64
|
- BSD-3-Clause
|
@@ -67,6 +66,7 @@ metadata:
|
|
67
66
|
homepage_uri: https://github.com/yoshoku/numo-tiny_linalg
|
68
67
|
source_code_uri: https://github.com/yoshoku/numo-tiny_linalg
|
69
68
|
changelog_uri: https://github.com/yoshoku/numo-tiny_linalg/blob/main/CHANGELOG.md
|
69
|
+
documentation_uri: https://yoshoku.github.io/numo-tiny_linalg/doc/
|
70
70
|
rubygems_mfa_required: 'true'
|
71
71
|
post_install_message:
|
72
72
|
rdoc_options: []
|
data/.clang-format
DELETED
@@ -1,149 +0,0 @@
|
|
1
|
-
---
|
2
|
-
Language: Cpp
|
3
|
-
# BasedOnStyle: LLVM
|
4
|
-
AccessModifierOffset: -2
|
5
|
-
AlignAfterOpenBracket: Align
|
6
|
-
AlignConsecutiveMacros: false
|
7
|
-
AlignConsecutiveAssignments: false
|
8
|
-
AlignConsecutiveBitFields: false
|
9
|
-
AlignConsecutiveDeclarations: false
|
10
|
-
AlignEscapedNewlines: Right
|
11
|
-
AlignOperands: Align
|
12
|
-
AlignTrailingComments: true
|
13
|
-
AllowAllArgumentsOnNextLine: true
|
14
|
-
AllowAllConstructorInitializersOnNextLine: true
|
15
|
-
AllowAllParametersOfDeclarationOnNextLine: true
|
16
|
-
AllowShortEnumsOnASingleLine: true
|
17
|
-
AllowShortBlocksOnASingleLine: Never
|
18
|
-
AllowShortCaseLabelsOnASingleLine: false
|
19
|
-
AllowShortFunctionsOnASingleLine: All
|
20
|
-
AllowShortLambdasOnASingleLine: All
|
21
|
-
AllowShortIfStatementsOnASingleLine: true
|
22
|
-
AllowShortLoopsOnASingleLine: true
|
23
|
-
AlwaysBreakAfterDefinitionReturnType: None
|
24
|
-
AlwaysBreakAfterReturnType: None
|
25
|
-
AlwaysBreakBeforeMultilineStrings: false
|
26
|
-
AlwaysBreakTemplateDeclarations: MultiLine
|
27
|
-
BinPackArguments: true
|
28
|
-
BinPackParameters: true
|
29
|
-
BraceWrapping:
|
30
|
-
AfterCaseLabel: false
|
31
|
-
AfterClass: false
|
32
|
-
AfterControlStatement: Never
|
33
|
-
AfterEnum: false
|
34
|
-
AfterFunction: false
|
35
|
-
AfterNamespace: false
|
36
|
-
AfterObjCDeclaration: false
|
37
|
-
AfterStruct: false
|
38
|
-
AfterUnion: false
|
39
|
-
AfterExternBlock: false
|
40
|
-
BeforeCatch: false
|
41
|
-
BeforeElse: false
|
42
|
-
BeforeLambdaBody: false
|
43
|
-
BeforeWhile: false
|
44
|
-
IndentBraces: false
|
45
|
-
SplitEmptyFunction: true
|
46
|
-
SplitEmptyRecord: true
|
47
|
-
SplitEmptyNamespace: true
|
48
|
-
BreakBeforeBinaryOperators: None
|
49
|
-
BreakBeforeBraces: Attach
|
50
|
-
BreakBeforeInheritanceComma: false
|
51
|
-
BreakInheritanceList: BeforeColon
|
52
|
-
BreakBeforeTernaryOperators: true
|
53
|
-
BreakConstructorInitializersBeforeComma: false
|
54
|
-
BreakConstructorInitializers: BeforeColon
|
55
|
-
BreakAfterJavaFieldAnnotations: false
|
56
|
-
BreakStringLiterals: true
|
57
|
-
ColumnLimit: 0
|
58
|
-
CommentPragmas: '^ IWYU pragma:'
|
59
|
-
CompactNamespaces: false
|
60
|
-
ConstructorInitializerAllOnOneLineOrOnePerLine: false
|
61
|
-
ConstructorInitializerIndentWidth: 4
|
62
|
-
ContinuationIndentWidth: 2
|
63
|
-
Cpp11BracedListStyle: false
|
64
|
-
DeriveLineEnding: true
|
65
|
-
DerivePointerAlignment: false
|
66
|
-
DisableFormat: false
|
67
|
-
ExperimentalAutoDetectBinPacking: false
|
68
|
-
FixNamespaceComments: true
|
69
|
-
ForEachMacros:
|
70
|
-
- foreach
|
71
|
-
- Q_FOREACH
|
72
|
-
- BOOST_FOREACH
|
73
|
-
IncludeBlocks: Preserve
|
74
|
-
IncludeCategories:
|
75
|
-
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
|
76
|
-
Priority: 2
|
77
|
-
SortPriority: 0
|
78
|
-
- Regex: '^(<|"(gtest|gmock|isl|json)/)'
|
79
|
-
Priority: 3
|
80
|
-
SortPriority: 0
|
81
|
-
- Regex: '.*'
|
82
|
-
Priority: 1
|
83
|
-
SortPriority: 0
|
84
|
-
IncludeIsMainRegex: '(Test)?$'
|
85
|
-
IncludeIsMainSourceRegex: ''
|
86
|
-
IndentCaseLabels: false
|
87
|
-
IndentCaseBlocks: false
|
88
|
-
IndentGotoLabels: true
|
89
|
-
IndentPPDirectives: None
|
90
|
-
IndentExternBlock: AfterExternBlock
|
91
|
-
IndentWidth: 2
|
92
|
-
IndentWrappedFunctionNames: false
|
93
|
-
InsertTrailingCommas: None
|
94
|
-
JavaScriptQuotes: Leave
|
95
|
-
JavaScriptWrapImports: true
|
96
|
-
KeepEmptyLinesAtTheStartOfBlocks: true
|
97
|
-
MacroBlockBegin: ''
|
98
|
-
MacroBlockEnd: ''
|
99
|
-
MaxEmptyLinesToKeep: 1
|
100
|
-
NamespaceIndentation: None
|
101
|
-
ObjCBinPackProtocolList: Auto
|
102
|
-
ObjCBlockIndentWidth: 2
|
103
|
-
ObjCBreakBeforeNestedBlockParam: true
|
104
|
-
ObjCSpaceAfterProperty: false
|
105
|
-
ObjCSpaceBeforeProtocolList: true
|
106
|
-
PenaltyBreakAssignment: 2
|
107
|
-
PenaltyBreakBeforeFirstCallParameter: 19
|
108
|
-
PenaltyBreakComment: 300
|
109
|
-
PenaltyBreakFirstLessLess: 120
|
110
|
-
PenaltyBreakString: 1000
|
111
|
-
PenaltyBreakTemplateDeclaration: 10
|
112
|
-
PenaltyExcessCharacter: 1000000
|
113
|
-
PenaltyReturnTypeOnItsOwnLine: 60
|
114
|
-
PointerAlignment: Left
|
115
|
-
ReflowComments: true
|
116
|
-
SortIncludes: true
|
117
|
-
SortUsingDeclarations: true
|
118
|
-
SpaceAfterCStyleCast: false
|
119
|
-
SpaceAfterLogicalNot: false
|
120
|
-
SpaceAfterTemplateKeyword: true
|
121
|
-
SpaceBeforeAssignmentOperators: true
|
122
|
-
SpaceBeforeCpp11BracedList: false
|
123
|
-
SpaceBeforeCtorInitializerColon: true
|
124
|
-
SpaceBeforeInheritanceColon: true
|
125
|
-
SpaceBeforeParens: ControlStatements
|
126
|
-
SpaceBeforeRangeBasedForLoopColon: true
|
127
|
-
SpaceInEmptyBlock: false
|
128
|
-
SpaceInEmptyParentheses: false
|
129
|
-
SpacesBeforeTrailingComments: 1
|
130
|
-
SpacesInAngles: false
|
131
|
-
SpacesInConditionalStatement: false
|
132
|
-
SpacesInContainerLiterals: true
|
133
|
-
SpacesInCStyleCastParentheses: false
|
134
|
-
SpacesInParentheses: false
|
135
|
-
SpacesInSquareBrackets: false
|
136
|
-
SpaceBeforeSquareBrackets: false
|
137
|
-
Standard: Latest
|
138
|
-
StatementMacros:
|
139
|
-
- Q_UNUSED
|
140
|
-
- QT_REQUIRE_VERSION
|
141
|
-
TabWidth: 8
|
142
|
-
UseCRLF: false
|
143
|
-
UseTab: Never
|
144
|
-
WhitespaceSensitiveMacros:
|
145
|
-
- STRINGIZE
|
146
|
-
- PP_STRINGIZE
|
147
|
-
- BOOST_PP_STRINGIZE
|
148
|
-
...
|
149
|
-
|
data/.husky/commit-msg
DELETED