numo-tiny_linalg 0.0.2 → 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '06379198010df5ee43b42c8d6776e551ac7870e0a81bb43ecfd776868464edf9'
4
- data.tar.gz: e0e38cd51c48332a496e8b9405e219cdcd9f7dcebd83e53d977daa6b82a5b1fc
3
+ metadata.gz: bbebba3b506ab283688f9d0935739e14c09106918a98571b9addb64b6689cc75
4
+ data.tar.gz: 85eaa28da383e21a4503407667baacb95fb5382f20f54804f02e4c2d49c15cd1
5
5
  SHA512:
6
- metadata.gz: 5288ec9be4365280fc177d4d57523e7187360e8d298bb5d60819a2776274f03979a2e0319aee1081532f0abfcc23e0c3a897dbca47d1cd1bde36df5ac87db011
7
- data.tar.gz: 4f03737044a25e81aab2fda5ce2bb555c8565a092bb6edc52776c54b758c1d76147ccdea845ea3e834f510b79a6254e8e9379f264e88421a1e446057f8bac058
6
+ metadata.gz: 262e38a4bbbbf6141cca723830f93c80cc9d33ac8a34753e3af1b9b5b239f4576140f18ad43999e8d0568438380f8bb2f6cf3f78c1f5d35402d57fcf55e4e253
7
+ data.tar.gz: ba6e767f8728022dff634e1a3824a166c5b3f4d17057b65f2e279a2ccda293c2432aea07783708fdd9f1d810e7a9bc9428ae2b9fe559b0bb519d8b1cb828b01e
data/CHANGELOG.md CHANGED
@@ -1,5 +1,12 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
4
+ - Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
5
+ - Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
6
+ - Add det module function to TinyLinalg.
7
+ - Add pinv module function to TinyLinalg.
8
+ - Add qr module function to TinyLinalg.
9
+
3
10
  ## [[0.0.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.1...v0.0.2)] - 2023-07-26
4
11
  - Add automatic build of OpenBLAS if it is not found.
5
12
  - Add dgesv, sgesv, zgesv, and cgesv module functions to TinyLinalg::Lapack.
@@ -0,0 +1,118 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGeQrf {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* tau) {
6
+ return LAPACKE_dgeqrf(matrix_layout, m, n, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SGeQrf {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, float* tau) {
13
+ return LAPACKE_sgeqrf(matrix_layout, m, n, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ struct ZGeQrf {
18
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, lapack_complex_double* tau) {
20
+ return LAPACKE_zgeqrf(matrix_layout, m, n, a, lda, tau);
21
+ }
22
+ };
23
+
24
+ struct CGeQrf {
25
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, lapack_complex_float* tau) {
27
+ return LAPACKE_cgeqrf(matrix_layout, m, n, a, lda, tau);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename DType, typename FncType>
32
+ class GeQrf {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_geqrf), -1);
36
+ }
37
+
38
+ private:
39
+ struct geqrf_opt {
40
+ int matrix_layout;
41
+ };
42
+
43
+ static void iter_geqrf(na_loop_t* const lp) {
44
+ DType* a = (DType*)NDL_PTR(lp, 0);
45
+ DType* tau = (DType*)NDL_PTR(lp, 1);
46
+ int* info = (int*)NDL_PTR(lp, 2);
47
+ geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
48
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
49
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
50
+ const lapack_int lda = n;
51
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, tau);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_geqrf(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+
58
+ VALUE a_vnary = Qnil;
59
+ VALUE kw_args = Qnil;
60
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
61
+ ID kw_table[1] = { rb_intern("order") };
62
+ VALUE kw_values[1] = { Qundef };
63
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
64
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
65
+
66
+ if (CLASS_OF(a_vnary) != nary_dtype) {
67
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
68
+ }
69
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
70
+ a_vnary = nary_dup(a_vnary);
71
+ }
72
+
73
+ narray_t* a_nary = NULL;
74
+ GetNArray(a_vnary, a_nary);
75
+ const int n_dims = NA_NDIM(a_nary);
76
+ if (n_dims != 2) {
77
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
78
+ return Qnil;
79
+ }
80
+
81
+ size_t m = NA_SHAPE(a_nary)[0];
82
+ size_t n = NA_SHAPE(a_nary)[1];
83
+ size_t shape[1] = { m < n ? m : n };
84
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
85
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
86
+ ndfunc_t ndf = { iter_geqrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
87
+ geqrf_opt opt = { matrix_layout };
88
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
89
+
90
+ VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
91
+
92
+ RB_GC_GUARD(a_vnary);
93
+
94
+ return ret;
95
+ }
96
+
97
+ static int get_matrix_layout(VALUE val) {
98
+ const char* option_str = StringValueCStr(val);
99
+
100
+ if (std::strlen(option_str) > 0) {
101
+ switch (option_str[0]) {
102
+ case 'r':
103
+ case 'R':
104
+ break;
105
+ case 'c':
106
+ case 'C':
107
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
108
+ break;
109
+ }
110
+ }
111
+
112
+ RB_GC_GUARD(val);
113
+
114
+ return LAPACK_ROW_MAJOR;
115
+ }
116
+ };
117
+
118
+ } // namespace TinyLinalg
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DOrgQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ double* a, lapack_int lda, const double* tau) {
6
+ return LAPACKE_dorgqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SOrgQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ float* a, lapack_int lda, const float* tau) {
13
+ return LAPACKE_sorgqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class OrgQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_orgqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct orgqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_orgqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_orgqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_orgqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ orgqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZUngQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ lapack_complex_double* a, lapack_int lda, const lapack_complex_double* tau) {
6
+ return LAPACKE_zungqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct CUngQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ lapack_complex_float* a, lapack_int lda, const lapack_complex_float* tau) {
13
+ return LAPACKE_cungqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class UngQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_ungqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct ungqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_ungqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_ungqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_ungqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ ungqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -5,11 +5,14 @@
5
5
  #include "blas/gemv.hpp"
6
6
  #include "blas/nrm2.hpp"
7
7
  #include "converter.hpp"
8
+ #include "lapack/geqrf.hpp"
8
9
  #include "lapack/gesdd.hpp"
9
10
  #include "lapack/gesv.hpp"
10
11
  #include "lapack/gesvd.hpp"
11
12
  #include "lapack/getrf.hpp"
12
13
  #include "lapack/getri.hpp"
14
+ #include "lapack/orgqr.hpp"
15
+ #include "lapack/ungqr.hpp"
13
16
 
14
17
  VALUE rb_mTinyLinalg;
15
18
  VALUE rb_mTinyLinalgBlas;
@@ -260,6 +263,14 @@ extern "C" void Init_tiny_linalg(void) {
260
263
  TinyLinalg::GETRI<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRI>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
261
264
  TinyLinalg::GETRI<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRI>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
262
265
  TinyLinalg::GETRI<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRI>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
266
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
267
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
268
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
269
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeQrf>::define_module_function(rb_mTinyLinalgLapack, "cgeqrf");
270
+ TinyLinalg::OrgQr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DOrgQr>::define_module_function(rb_mTinyLinalgLapack, "dorgqr");
271
+ TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
272
+ TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
273
+ TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
263
274
 
264
275
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
265
276
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.0.2'
8
+ VERSION = '0.0.3'
9
9
  end
10
10
  end
@@ -7,9 +7,35 @@ require_relative 'tiny_linalg/tiny_linalg'
7
7
  # Ruby/Numo (NUmerical MOdules)
8
8
  module Numo
9
9
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
10
- module TinyLinalg
10
+ module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the determinant of matrix.
14
+ #
15
+ # @param a [Numo::NArray] n-by-n square matrix.
16
+ # @return [Float/Complex] The determinant of `a`.
17
+ def det(a)
18
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
19
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
20
+
21
+ bchr = blas_char(a)
22
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
23
+
24
+ getrf = "#{bchr}getrf".to_sym
25
+ lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
26
+
27
+ if info.zero?
28
+ det_l = 1
29
+ det_u = lu.diagonal.prod
30
+ det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
31
+ det_l * det_u * det_p
32
+ elsif info.positive?
33
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
34
+ else
35
+ raise "the #{-info}-th argument of getrf had illegal value"
36
+ end
37
+ end
38
+
13
39
  # Computes the inverse matrix of a square matrix.
14
40
  #
15
41
  # @param a [Numo::NArray] n-by-n square matrix.
@@ -36,6 +62,64 @@ module Numo
36
62
  end
37
63
  end
38
64
 
65
+ # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
66
+ #
67
+ # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
68
+ # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
69
+ # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
70
+ # @return [Numo::NArray] The pseudo-inverse of `a`.
71
+ def pinv(a, driver: 'svd', rcond: nil)
72
+ s, u, vh = svd(a, driver: driver, job: 'S')
73
+ rcond = a.shape.max * s.class::EPSILON if rcond.nil?
74
+ rank = s.gt(rcond * s[0]).count
75
+
76
+ u = u[true, 0...rank] / s[0...rank]
77
+ u.dot(vh[0...rank, true]).conj.transpose
78
+ end
79
+
80
+ # Compute QR decomposition of a matrix.
81
+ #
82
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
83
+ # @param mode [String] The mode of decomposition.
84
+ # - "reduce" -- returns both Q [m, m] and R [m, n],
85
+ # - "r" -- returns only R,
86
+ # - "economic" -- returns both Q [m, n] and R [n, n],
87
+ # - "raw" -- returns QR and TAU (LAPACK geqrf results).
88
+ # @return [Numo::NArray] if mode='r'
89
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
90
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
91
+ def qr(a, mode: 'reduce')
92
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
93
+ raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
94
+
95
+ bchr = blas_char(a)
96
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
97
+
98
+ geqrf = "#{bchr}geqrf".to_sym
99
+ qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
100
+
101
+ return [qr, tau] if mode == 'raw'
102
+
103
+ m, n = qr.shape
104
+ r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
105
+
106
+ return r if mode == 'r'
107
+
108
+ org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr".to_sym : "#{bchr}ungqr".to_sym
109
+
110
+ q = if m < n
111
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
112
+ elsif mode == 'economic'
113
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr, tau)[0]
114
+ else
115
+ qqr = a.class.zeros(m, m)
116
+ qqr[0...m, 0...n] = qr
117
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qqr, tau)[0]
118
+ end
119
+
120
+ [q, r]
121
+ end
122
+
39
123
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
40
124
  #
41
125
  # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.2
4
+ version: 0.0.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-07-26 00:00:00.000000000 Z
11
+ date: 2023-08-02 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -46,11 +46,14 @@ files:
46
46
  - ext/numo/tiny_linalg/blas/nrm2.hpp
47
47
  - ext/numo/tiny_linalg/converter.hpp
48
48
  - ext/numo/tiny_linalg/extconf.rb
49
+ - ext/numo/tiny_linalg/lapack/geqrf.hpp
49
50
  - ext/numo/tiny_linalg/lapack/gesdd.hpp
50
51
  - ext/numo/tiny_linalg/lapack/gesv.hpp
51
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
52
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
53
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/orgqr.hpp
56
+ - ext/numo/tiny_linalg/lapack/ungqr.hpp
54
57
  - ext/numo/tiny_linalg/tiny_linalg.cpp
55
58
  - ext/numo/tiny_linalg/tiny_linalg.hpp
56
59
  - lib/numo/tiny_linalg.rb