numo-tiny_linalg 0.0.2 → 0.0.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +118 -0
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +115 -0
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +115 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +11 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +85 -1
- metadata +5 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: bbebba3b506ab283688f9d0935739e14c09106918a98571b9addb64b6689cc75
|
4
|
+
data.tar.gz: 85eaa28da383e21a4503407667baacb95fb5382f20f54804f02e4c2d49c15cd1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 262e38a4bbbbf6141cca723830f93c80cc9d33ac8a34753e3af1b9b5b239f4576140f18ad43999e8d0568438380f8bb2f6cf3f78c1f5d35402d57fcf55e4e253
|
7
|
+
data.tar.gz: ba6e767f8728022dff634e1a3824a166c5b3f4d17057b65f2e279a2ccda293c2432aea07783708fdd9f1d810e7a9bc9428ae2b9fe559b0bb519d8b1cb828b01e
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,12 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
|
4
|
+
- Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
|
5
|
+
- Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
|
6
|
+
- Add det module function to TinyLinalg.
|
7
|
+
- Add pinv module function to TinyLinalg.
|
8
|
+
- Add qr module function to TinyLinalg.
|
9
|
+
|
3
10
|
## [[0.0.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.1...v0.0.2)] - 2023-07-26
|
4
11
|
- Add automatic build of OpenBLAS if it is not found.
|
5
12
|
- Add dgesv, sgesv, zgesv, and cgesv module functions to TinyLinalg::Lapack.
|
@@ -0,0 +1,118 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGeQrf {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
5
|
+
double* a, lapack_int lda, double* tau) {
|
6
|
+
return LAPACKE_dgeqrf(matrix_layout, m, n, a, lda, tau);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct SGeQrf {
|
11
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
12
|
+
float* a, lapack_int lda, float* tau) {
|
13
|
+
return LAPACKE_sgeqrf(matrix_layout, m, n, a, lda, tau);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
struct ZGeQrf {
|
18
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
19
|
+
lapack_complex_double* a, lapack_int lda, lapack_complex_double* tau) {
|
20
|
+
return LAPACKE_zgeqrf(matrix_layout, m, n, a, lda, tau);
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
struct CGeQrf {
|
25
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
26
|
+
lapack_complex_float* a, lapack_int lda, lapack_complex_float* tau) {
|
27
|
+
return LAPACKE_cgeqrf(matrix_layout, m, n, a, lda, tau);
|
28
|
+
}
|
29
|
+
};
|
30
|
+
|
31
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
32
|
+
class GeQrf {
|
33
|
+
public:
|
34
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_geqrf), -1);
|
36
|
+
}
|
37
|
+
|
38
|
+
private:
|
39
|
+
struct geqrf_opt {
|
40
|
+
int matrix_layout;
|
41
|
+
};
|
42
|
+
|
43
|
+
static void iter_geqrf(na_loop_t* const lp) {
|
44
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
45
|
+
DType* tau = (DType*)NDL_PTR(lp, 1);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
47
|
+
geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
49
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
50
|
+
const lapack_int lda = n;
|
51
|
+
const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, tau);
|
52
|
+
*info = static_cast<int>(i);
|
53
|
+
}
|
54
|
+
|
55
|
+
static VALUE tiny_linalg_geqrf(int argc, VALUE* argv, VALUE self) {
|
56
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
57
|
+
|
58
|
+
VALUE a_vnary = Qnil;
|
59
|
+
VALUE kw_args = Qnil;
|
60
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
61
|
+
ID kw_table[1] = { rb_intern("order") };
|
62
|
+
VALUE kw_values[1] = { Qundef };
|
63
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
64
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
65
|
+
|
66
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
67
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
68
|
+
}
|
69
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
70
|
+
a_vnary = nary_dup(a_vnary);
|
71
|
+
}
|
72
|
+
|
73
|
+
narray_t* a_nary = NULL;
|
74
|
+
GetNArray(a_vnary, a_nary);
|
75
|
+
const int n_dims = NA_NDIM(a_nary);
|
76
|
+
if (n_dims != 2) {
|
77
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
78
|
+
return Qnil;
|
79
|
+
}
|
80
|
+
|
81
|
+
size_t m = NA_SHAPE(a_nary)[0];
|
82
|
+
size_t n = NA_SHAPE(a_nary)[1];
|
83
|
+
size_t shape[1] = { m < n ? m : n };
|
84
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
85
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
86
|
+
ndfunc_t ndf = { iter_geqrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
87
|
+
geqrf_opt opt = { matrix_layout };
|
88
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
89
|
+
|
90
|
+
VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
|
91
|
+
|
92
|
+
RB_GC_GUARD(a_vnary);
|
93
|
+
|
94
|
+
return ret;
|
95
|
+
}
|
96
|
+
|
97
|
+
static int get_matrix_layout(VALUE val) {
|
98
|
+
const char* option_str = StringValueCStr(val);
|
99
|
+
|
100
|
+
if (std::strlen(option_str) > 0) {
|
101
|
+
switch (option_str[0]) {
|
102
|
+
case 'r':
|
103
|
+
case 'R':
|
104
|
+
break;
|
105
|
+
case 'c':
|
106
|
+
case 'C':
|
107
|
+
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
108
|
+
break;
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
RB_GC_GUARD(val);
|
113
|
+
|
114
|
+
return LAPACK_ROW_MAJOR;
|
115
|
+
}
|
116
|
+
};
|
117
|
+
|
118
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,115 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DOrgQr {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
|
5
|
+
double* a, lapack_int lda, const double* tau) {
|
6
|
+
return LAPACKE_dorgqr(matrix_layout, m, n, k, a, lda, tau);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct SOrgQr {
|
11
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
|
12
|
+
float* a, lapack_int lda, const float* tau) {
|
13
|
+
return LAPACKE_sorgqr(matrix_layout, m, n, k, a, lda, tau);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
18
|
+
class OrgQr {
|
19
|
+
public:
|
20
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
21
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_orgqr), -1);
|
22
|
+
}
|
23
|
+
|
24
|
+
private:
|
25
|
+
struct orgqr_opt {
|
26
|
+
int matrix_layout;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_orgqr(na_loop_t* const lp) {
|
30
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
31
|
+
DType* tau = (DType*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
35
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
36
|
+
const lapack_int k = NDL_SHAPE(lp, 1)[0];
|
37
|
+
const lapack_int lda = n;
|
38
|
+
const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
|
39
|
+
*info = static_cast<int>(i);
|
40
|
+
}
|
41
|
+
|
42
|
+
static VALUE tiny_linalg_orgqr(int argc, VALUE* argv, VALUE self) {
|
43
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
44
|
+
|
45
|
+
VALUE a_vnary = Qnil;
|
46
|
+
VALUE tau_vnary = Qnil;
|
47
|
+
VALUE kw_args = Qnil;
|
48
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
|
49
|
+
ID kw_table[1] = { rb_intern("order") };
|
50
|
+
VALUE kw_values[1] = { Qundef };
|
51
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
52
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
53
|
+
|
54
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
56
|
+
}
|
57
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
58
|
+
a_vnary = nary_dup(a_vnary);
|
59
|
+
}
|
60
|
+
if (CLASS_OF(tau_vnary) != nary_dtype) {
|
61
|
+
tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
|
62
|
+
}
|
63
|
+
if (!RTEST(nary_check_contiguous(tau_vnary))) {
|
64
|
+
tau_vnary = nary_dup(tau_vnary);
|
65
|
+
}
|
66
|
+
|
67
|
+
narray_t* a_nary = NULL;
|
68
|
+
GetNArray(a_vnary, a_nary);
|
69
|
+
if (NA_NDIM(a_nary) != 2) {
|
70
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
71
|
+
return Qnil;
|
72
|
+
}
|
73
|
+
narray_t* tau_nary = NULL;
|
74
|
+
GetNArray(tau_vnary, tau_nary);
|
75
|
+
if (NA_NDIM(tau_nary) != 1) {
|
76
|
+
rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
|
77
|
+
return Qnil;
|
78
|
+
}
|
79
|
+
|
80
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
|
81
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
82
|
+
ndfunc_t ndf = { iter_orgqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
83
|
+
orgqr_opt opt = { matrix_layout };
|
84
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
|
85
|
+
|
86
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res);
|
87
|
+
|
88
|
+
RB_GC_GUARD(a_vnary);
|
89
|
+
RB_GC_GUARD(tau_vnary);
|
90
|
+
|
91
|
+
return ret;
|
92
|
+
}
|
93
|
+
|
94
|
+
static int get_matrix_layout(VALUE val) {
|
95
|
+
const char* option_str = StringValueCStr(val);
|
96
|
+
|
97
|
+
if (std::strlen(option_str) > 0) {
|
98
|
+
switch (option_str[0]) {
|
99
|
+
case 'r':
|
100
|
+
case 'R':
|
101
|
+
break;
|
102
|
+
case 'c':
|
103
|
+
case 'C':
|
104
|
+
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
105
|
+
break;
|
106
|
+
}
|
107
|
+
}
|
108
|
+
|
109
|
+
RB_GC_GUARD(val);
|
110
|
+
|
111
|
+
return LAPACK_ROW_MAJOR;
|
112
|
+
}
|
113
|
+
};
|
114
|
+
|
115
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,115 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZUngQr {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
|
5
|
+
lapack_complex_double* a, lapack_int lda, const lapack_complex_double* tau) {
|
6
|
+
return LAPACKE_zungqr(matrix_layout, m, n, k, a, lda, tau);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct CUngQr {
|
11
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
|
12
|
+
lapack_complex_float* a, lapack_int lda, const lapack_complex_float* tau) {
|
13
|
+
return LAPACKE_cungqr(matrix_layout, m, n, k, a, lda, tau);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
18
|
+
class UngQr {
|
19
|
+
public:
|
20
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
21
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_ungqr), -1);
|
22
|
+
}
|
23
|
+
|
24
|
+
private:
|
25
|
+
struct ungqr_opt {
|
26
|
+
int matrix_layout;
|
27
|
+
};
|
28
|
+
|
29
|
+
static void iter_ungqr(na_loop_t* const lp) {
|
30
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
31
|
+
DType* tau = (DType*)NDL_PTR(lp, 1);
|
32
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
33
|
+
ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
|
34
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
35
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
36
|
+
const lapack_int k = NDL_SHAPE(lp, 1)[0];
|
37
|
+
const lapack_int lda = n;
|
38
|
+
const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
|
39
|
+
*info = static_cast<int>(i);
|
40
|
+
}
|
41
|
+
|
42
|
+
static VALUE tiny_linalg_ungqr(int argc, VALUE* argv, VALUE self) {
|
43
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
44
|
+
|
45
|
+
VALUE a_vnary = Qnil;
|
46
|
+
VALUE tau_vnary = Qnil;
|
47
|
+
VALUE kw_args = Qnil;
|
48
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
|
49
|
+
ID kw_table[1] = { rb_intern("order") };
|
50
|
+
VALUE kw_values[1] = { Qundef };
|
51
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
52
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
53
|
+
|
54
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
56
|
+
}
|
57
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
58
|
+
a_vnary = nary_dup(a_vnary);
|
59
|
+
}
|
60
|
+
if (CLASS_OF(tau_vnary) != nary_dtype) {
|
61
|
+
tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
|
62
|
+
}
|
63
|
+
if (!RTEST(nary_check_contiguous(tau_vnary))) {
|
64
|
+
tau_vnary = nary_dup(tau_vnary);
|
65
|
+
}
|
66
|
+
|
67
|
+
narray_t* a_nary = NULL;
|
68
|
+
GetNArray(a_vnary, a_nary);
|
69
|
+
if (NA_NDIM(a_nary) != 2) {
|
70
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
71
|
+
return Qnil;
|
72
|
+
}
|
73
|
+
narray_t* tau_nary = NULL;
|
74
|
+
GetNArray(tau_vnary, tau_nary);
|
75
|
+
if (NA_NDIM(tau_nary) != 1) {
|
76
|
+
rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
|
77
|
+
return Qnil;
|
78
|
+
}
|
79
|
+
|
80
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
|
81
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
82
|
+
ndfunc_t ndf = { iter_ungqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
83
|
+
ungqr_opt opt = { matrix_layout };
|
84
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
|
85
|
+
|
86
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res);
|
87
|
+
|
88
|
+
RB_GC_GUARD(a_vnary);
|
89
|
+
RB_GC_GUARD(tau_vnary);
|
90
|
+
|
91
|
+
return ret;
|
92
|
+
}
|
93
|
+
|
94
|
+
static int get_matrix_layout(VALUE val) {
|
95
|
+
const char* option_str = StringValueCStr(val);
|
96
|
+
|
97
|
+
if (std::strlen(option_str) > 0) {
|
98
|
+
switch (option_str[0]) {
|
99
|
+
case 'r':
|
100
|
+
case 'R':
|
101
|
+
break;
|
102
|
+
case 'c':
|
103
|
+
case 'C':
|
104
|
+
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
105
|
+
break;
|
106
|
+
}
|
107
|
+
}
|
108
|
+
|
109
|
+
RB_GC_GUARD(val);
|
110
|
+
|
111
|
+
return LAPACK_ROW_MAJOR;
|
112
|
+
}
|
113
|
+
};
|
114
|
+
|
115
|
+
} // namespace TinyLinalg
|
@@ -5,11 +5,14 @@
|
|
5
5
|
#include "blas/gemv.hpp"
|
6
6
|
#include "blas/nrm2.hpp"
|
7
7
|
#include "converter.hpp"
|
8
|
+
#include "lapack/geqrf.hpp"
|
8
9
|
#include "lapack/gesdd.hpp"
|
9
10
|
#include "lapack/gesv.hpp"
|
10
11
|
#include "lapack/gesvd.hpp"
|
11
12
|
#include "lapack/getrf.hpp"
|
12
13
|
#include "lapack/getri.hpp"
|
14
|
+
#include "lapack/orgqr.hpp"
|
15
|
+
#include "lapack/ungqr.hpp"
|
13
16
|
|
14
17
|
VALUE rb_mTinyLinalg;
|
15
18
|
VALUE rb_mTinyLinalgBlas;
|
@@ -260,6 +263,14 @@ extern "C" void Init_tiny_linalg(void) {
|
|
260
263
|
TinyLinalg::GETRI<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRI>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
|
261
264
|
TinyLinalg::GETRI<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRI>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
|
262
265
|
TinyLinalg::GETRI<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRI>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
|
266
|
+
TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
|
267
|
+
TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
|
268
|
+
TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
|
269
|
+
TinyLinalg::GeQrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeQrf>::define_module_function(rb_mTinyLinalgLapack, "cgeqrf");
|
270
|
+
TinyLinalg::OrgQr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DOrgQr>::define_module_function(rb_mTinyLinalgLapack, "dorgqr");
|
271
|
+
TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
|
272
|
+
TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
|
273
|
+
TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
|
263
274
|
|
264
275
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
|
265
276
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
|
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -7,9 +7,35 @@ require_relative 'tiny_linalg/tiny_linalg'
|
|
7
7
|
# Ruby/Numo (NUmerical MOdules)
|
8
8
|
module Numo
|
9
9
|
# Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
10
|
-
module TinyLinalg
|
10
|
+
module TinyLinalg # rubocop:disable Metrics/ModuleLength
|
11
11
|
module_function
|
12
12
|
|
13
|
+
# Computes the determinant of matrix.
|
14
|
+
#
|
15
|
+
# @param a [Numo::NArray] n-by-n square matrix.
|
16
|
+
# @return [Float/Complex] The determinant of `a`.
|
17
|
+
def det(a)
|
18
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
19
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
20
|
+
|
21
|
+
bchr = blas_char(a)
|
22
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
23
|
+
|
24
|
+
getrf = "#{bchr}getrf".to_sym
|
25
|
+
lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
|
26
|
+
|
27
|
+
if info.zero?
|
28
|
+
det_l = 1
|
29
|
+
det_u = lu.diagonal.prod
|
30
|
+
det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
|
31
|
+
det_l * det_u * det_p
|
32
|
+
elsif info.positive?
|
33
|
+
raise 'the factor U is singular, and the inverse matrix could not be computed.'
|
34
|
+
else
|
35
|
+
raise "the #{-info}-th argument of getrf had illegal value"
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
13
39
|
# Computes the inverse matrix of a square matrix.
|
14
40
|
#
|
15
41
|
# @param a [Numo::NArray] n-by-n square matrix.
|
@@ -36,6 +62,64 @@ module Numo
|
|
36
62
|
end
|
37
63
|
end
|
38
64
|
|
65
|
+
# Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
|
66
|
+
#
|
67
|
+
# @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
|
68
|
+
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
69
|
+
# @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
|
70
|
+
# @return [Numo::NArray] The pseudo-inverse of `a`.
|
71
|
+
def pinv(a, driver: 'svd', rcond: nil)
|
72
|
+
s, u, vh = svd(a, driver: driver, job: 'S')
|
73
|
+
rcond = a.shape.max * s.class::EPSILON if rcond.nil?
|
74
|
+
rank = s.gt(rcond * s[0]).count
|
75
|
+
|
76
|
+
u = u[true, 0...rank] / s[0...rank]
|
77
|
+
u.dot(vh[0...rank, true]).conj.transpose
|
78
|
+
end
|
79
|
+
|
80
|
+
# Compute QR decomposition of a matrix.
|
81
|
+
#
|
82
|
+
# @param a [Numo::NArray] The m-by-n matrix to be decomposed.
|
83
|
+
# @param mode [String] The mode of decomposition.
|
84
|
+
# - "reduce" -- returns both Q [m, m] and R [m, n],
|
85
|
+
# - "r" -- returns only R,
|
86
|
+
# - "economic" -- returns both Q [m, n] and R [n, n],
|
87
|
+
# - "raw" -- returns QR and TAU (LAPACK geqrf results).
|
88
|
+
# @return [Numo::NArray] if mode='r'
|
89
|
+
# @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
|
90
|
+
# @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
|
91
|
+
def qr(a, mode: 'reduce')
|
92
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
93
|
+
raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
|
94
|
+
|
95
|
+
bchr = blas_char(a)
|
96
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
97
|
+
|
98
|
+
geqrf = "#{bchr}geqrf".to_sym
|
99
|
+
qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
|
100
|
+
|
101
|
+
return [qr, tau] if mode == 'raw'
|
102
|
+
|
103
|
+
m, n = qr.shape
|
104
|
+
r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
|
105
|
+
|
106
|
+
return r if mode == 'r'
|
107
|
+
|
108
|
+
org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr".to_sym : "#{bchr}ungqr".to_sym
|
109
|
+
|
110
|
+
q = if m < n
|
111
|
+
Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
|
112
|
+
elsif mode == 'economic'
|
113
|
+
Numo::TinyLinalg::Lapack.send(org_ung_qr, qr, tau)[0]
|
114
|
+
else
|
115
|
+
qqr = a.class.zeros(m, m)
|
116
|
+
qqr[0...m, 0...n] = qr
|
117
|
+
Numo::TinyLinalg::Lapack.send(org_ung_qr, qqr, tau)[0]
|
118
|
+
end
|
119
|
+
|
120
|
+
[q, r]
|
121
|
+
end
|
122
|
+
|
39
123
|
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
|
40
124
|
#
|
41
125
|
# @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-08-02 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -46,11 +46,14 @@ files:
|
|
46
46
|
- ext/numo/tiny_linalg/blas/nrm2.hpp
|
47
47
|
- ext/numo/tiny_linalg/converter.hpp
|
48
48
|
- ext/numo/tiny_linalg/extconf.rb
|
49
|
+
- ext/numo/tiny_linalg/lapack/geqrf.hpp
|
49
50
|
- ext/numo/tiny_linalg/lapack/gesdd.hpp
|
50
51
|
- ext/numo/tiny_linalg/lapack/gesv.hpp
|
51
52
|
- ext/numo/tiny_linalg/lapack/gesvd.hpp
|
52
53
|
- ext/numo/tiny_linalg/lapack/getrf.hpp
|
53
54
|
- ext/numo/tiny_linalg/lapack/getri.hpp
|
55
|
+
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
56
|
+
- ext/numo/tiny_linalg/lapack/ungqr.hpp
|
54
57
|
- ext/numo/tiny_linalg/tiny_linalg.cpp
|
55
58
|
- ext/numo/tiny_linalg/tiny_linalg.hpp
|
56
59
|
- lib/numo/tiny_linalg.rb
|