numo-tiny_linalg 0.0.1 → 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DOrgQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ double* a, lapack_int lda, const double* tau) {
6
+ return LAPACKE_dorgqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SOrgQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ float* a, lapack_int lda, const float* tau) {
13
+ return LAPACKE_sorgqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class OrgQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_orgqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct orgqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_orgqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_orgqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_orgqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ orgqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZUngQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ lapack_complex_double* a, lapack_int lda, const lapack_complex_double* tau) {
6
+ return LAPACKE_zungqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct CUngQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ lapack_complex_float* a, lapack_int lda, const lapack_complex_float* tau) {
13
+ return LAPACKE_cungqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class UngQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_ungqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct ungqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_ungqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_ungqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_ungqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ ungqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -1,12 +1,18 @@
1
1
  #include "tiny_linalg.hpp"
2
+ #include "blas/dot.hpp"
3
+ #include "blas/dot_sub.hpp"
4
+ #include "blas/gemm.hpp"
5
+ #include "blas/gemv.hpp"
6
+ #include "blas/nrm2.hpp"
2
7
  #include "converter.hpp"
3
- #include "dot.hpp"
4
- #include "dot_sub.hpp"
5
- #include "gemm.hpp"
6
- #include "gemv.hpp"
7
- #include "gesdd.hpp"
8
- #include "gesvd.hpp"
9
- #include "nrm2.hpp"
8
+ #include "lapack/geqrf.hpp"
9
+ #include "lapack/gesdd.hpp"
10
+ #include "lapack/gesv.hpp"
11
+ #include "lapack/gesvd.hpp"
12
+ #include "lapack/getrf.hpp"
13
+ #include "lapack/getri.hpp"
14
+ #include "lapack/orgqr.hpp"
15
+ #include "lapack/ungqr.hpp"
10
16
 
11
17
  VALUE rb_mTinyLinalg;
12
18
  VALUE rb_mTinyLinalgBlas;
@@ -237,6 +243,10 @@ extern "C" void Init_tiny_linalg(void) {
237
243
  TinyLinalg::Nrm2<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SNrm2>::define_module_function(rb_mTinyLinalgBlas, "snrm2");
238
244
  TinyLinalg::Nrm2<TinyLinalg::numo_cDComplexId, double, TinyLinalg::DZNrm2>::define_module_function(rb_mTinyLinalgBlas, "dznrm2");
239
245
  TinyLinalg::Nrm2<TinyLinalg::numo_cSComplexId, float, TinyLinalg::SCNrm2>::define_module_function(rb_mTinyLinalgBlas, "scnrm2");
246
+ TinyLinalg::GESV<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGESV>::define_module_function(rb_mTinyLinalgLapack, "dgesv");
247
+ TinyLinalg::GESV<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGESV>::define_module_function(rb_mTinyLinalgLapack, "sgesv");
248
+ TinyLinalg::GESV<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGESV>::define_module_function(rb_mTinyLinalgLapack, "zgesv");
249
+ TinyLinalg::GESV<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGESV>::define_module_function(rb_mTinyLinalgLapack, "cgesv");
240
250
  TinyLinalg::GESVD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESVD>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
241
251
  TinyLinalg::GESVD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESVD>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
242
252
  TinyLinalg::GESVD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESVD>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
@@ -245,6 +255,22 @@ extern "C" void Init_tiny_linalg(void) {
245
255
  TinyLinalg::GESDD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESDD>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
246
256
  TinyLinalg::GESDD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESDD>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
247
257
  TinyLinalg::GESDD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESDD>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
258
+ TinyLinalg::GETRF<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGETRF>::define_module_function(rb_mTinyLinalgLapack, "dgetrf");
259
+ TinyLinalg::GETRF<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRF>::define_module_function(rb_mTinyLinalgLapack, "sgetrf");
260
+ TinyLinalg::GETRF<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRF>::define_module_function(rb_mTinyLinalgLapack, "zgetrf");
261
+ TinyLinalg::GETRF<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRF>::define_module_function(rb_mTinyLinalgLapack, "cgetrf");
262
+ TinyLinalg::GETRI<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGETRI>::define_module_function(rb_mTinyLinalgLapack, "dgetri");
263
+ TinyLinalg::GETRI<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRI>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
264
+ TinyLinalg::GETRI<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRI>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
265
+ TinyLinalg::GETRI<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRI>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
266
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
267
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
268
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
269
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeQrf>::define_module_function(rb_mTinyLinalgLapack, "cgeqrf");
270
+ TinyLinalg::OrgQr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DOrgQr>::define_module_function(rb_mTinyLinalgLapack, "dorgqr");
271
+ TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
272
+ TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
273
+ TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
248
274
 
249
275
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
250
276
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.0.1'
8
+ VERSION = '0.0.3'
9
9
  end
10
10
  end
@@ -7,9 +7,139 @@ require_relative 'tiny_linalg/tiny_linalg'
7
7
  # Ruby/Numo (NUmerical MOdules)
8
8
  module Numo
9
9
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
10
- module TinyLinalg
10
+ module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the determinant of matrix.
14
+ #
15
+ # @param a [Numo::NArray] n-by-n square matrix.
16
+ # @return [Float/Complex] The determinant of `a`.
17
+ def det(a)
18
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
19
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
20
+
21
+ bchr = blas_char(a)
22
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
23
+
24
+ getrf = "#{bchr}getrf".to_sym
25
+ lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
26
+
27
+ if info.zero?
28
+ det_l = 1
29
+ det_u = lu.diagonal.prod
30
+ det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
31
+ det_l * det_u * det_p
32
+ elsif info.positive?
33
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
34
+ else
35
+ raise "the #{-info}-th argument of getrf had illegal value"
36
+ end
37
+ end
38
+
39
+ # Computes the inverse matrix of a square matrix.
40
+ #
41
+ # @param a [Numo::NArray] n-by-n square matrix.
42
+ # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
43
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
44
+ # @return [Numo::NArray] The inverse matrix of `a`.
45
+ def inv(a, driver: 'getrf', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
46
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
47
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
48
+
49
+ bchr = blas_char(a)
50
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
51
+
52
+ getrf = "#{bchr}getrf".to_sym
53
+ getri = "#{bchr}getri".to_sym
54
+
55
+ lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
56
+ if info.zero?
57
+ Numo::TinyLinalg::Lapack.send(getri, lu, piv)[0]
58
+ elsif info.positive?
59
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
60
+ else
61
+ raise "the #{-info}-th argument of getrf had illegal value"
62
+ end
63
+ end
64
+
65
+ # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
66
+ #
67
+ # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
68
+ # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
69
+ # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
70
+ # @return [Numo::NArray] The pseudo-inverse of `a`.
71
+ def pinv(a, driver: 'svd', rcond: nil)
72
+ s, u, vh = svd(a, driver: driver, job: 'S')
73
+ rcond = a.shape.max * s.class::EPSILON if rcond.nil?
74
+ rank = s.gt(rcond * s[0]).count
75
+
76
+ u = u[true, 0...rank] / s[0...rank]
77
+ u.dot(vh[0...rank, true]).conj.transpose
78
+ end
79
+
80
+ # Compute QR decomposition of a matrix.
81
+ #
82
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
83
+ # @param mode [String] The mode of decomposition.
84
+ # - "reduce" -- returns both Q [m, m] and R [m, n],
85
+ # - "r" -- returns only R,
86
+ # - "economic" -- returns both Q [m, n] and R [n, n],
87
+ # - "raw" -- returns QR and TAU (LAPACK geqrf results).
88
+ # @return [Numo::NArray] if mode='r'
89
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
90
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
91
+ def qr(a, mode: 'reduce')
92
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
93
+ raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
94
+
95
+ bchr = blas_char(a)
96
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
97
+
98
+ geqrf = "#{bchr}geqrf".to_sym
99
+ qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
100
+
101
+ return [qr, tau] if mode == 'raw'
102
+
103
+ m, n = qr.shape
104
+ r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
105
+
106
+ return r if mode == 'r'
107
+
108
+ org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr".to_sym : "#{bchr}ungqr".to_sym
109
+
110
+ q = if m < n
111
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
112
+ elsif mode == 'economic'
113
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr, tau)[0]
114
+ else
115
+ qqr = a.class.zeros(m, m)
116
+ qqr[0...m, 0...n] = qr
117
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qqr, tau)[0]
118
+ end
119
+
120
+ [q, r]
121
+ end
122
+
123
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
124
+ #
125
+ # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
126
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
127
+ # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
128
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
129
+ # @return [Numo::NArray] The solusion vector / matrix `x`.
130
+ def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
131
+ case blas_char(a, b)
132
+ when 'd'
133
+ Lapack.dgesv(a.dup, b.dup)[1]
134
+ when 's'
135
+ Lapack.sgesv(a.dup, b.dup)[1]
136
+ when 'z'
137
+ Lapack.zgesv(a.dup, b.dup)[1]
138
+ when 'c'
139
+ Lapack.cgesv(a.dup, b.dup)[1]
140
+ end
141
+ end
142
+
13
143
  # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
14
144
  #
15
145
  # @param a [Numo::NArray] Matrix to be decomposed.
File without changes
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-07-14 00:00:00.000000000 Z
11
+ date: 2023-08-02 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -35,31 +35,30 @@ extensions:
35
35
  - ext/numo/tiny_linalg/extconf.rb
36
36
  extra_rdoc_files: []
37
37
  files:
38
- - ".clang-format"
39
- - ".husky/commit-msg"
40
- - ".rubocop.yml"
41
38
  - CHANGELOG.md
42
39
  - CODE_OF_CONDUCT.md
43
- - Gemfile
44
40
  - LICENSE.txt
45
41
  - README.md
46
- - Rakefile
47
- - commitlint.config.js
42
+ - ext/numo/tiny_linalg/blas/dot.hpp
43
+ - ext/numo/tiny_linalg/blas/dot_sub.hpp
44
+ - ext/numo/tiny_linalg/blas/gemm.hpp
45
+ - ext/numo/tiny_linalg/blas/gemv.hpp
46
+ - ext/numo/tiny_linalg/blas/nrm2.hpp
48
47
  - ext/numo/tiny_linalg/converter.hpp
49
- - ext/numo/tiny_linalg/dot.hpp
50
- - ext/numo/tiny_linalg/dot_sub.hpp
51
48
  - ext/numo/tiny_linalg/extconf.rb
52
- - ext/numo/tiny_linalg/gemm.hpp
53
- - ext/numo/tiny_linalg/gemv.hpp
54
- - ext/numo/tiny_linalg/gesdd.hpp
55
- - ext/numo/tiny_linalg/gesvd.hpp
56
- - ext/numo/tiny_linalg/nrm2.hpp
49
+ - ext/numo/tiny_linalg/lapack/geqrf.hpp
50
+ - ext/numo/tiny_linalg/lapack/gesdd.hpp
51
+ - ext/numo/tiny_linalg/lapack/gesv.hpp
52
+ - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
+ - ext/numo/tiny_linalg/lapack/getrf.hpp
54
+ - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/orgqr.hpp
56
+ - ext/numo/tiny_linalg/lapack/ungqr.hpp
57
57
  - ext/numo/tiny_linalg/tiny_linalg.cpp
58
58
  - ext/numo/tiny_linalg/tiny_linalg.hpp
59
59
  - lib/numo/tiny_linalg.rb
60
60
  - lib/numo/tiny_linalg/version.rb
61
- - numo-tiny_linalg.gemspec
62
- - package.json
61
+ - vendor/tmp/.gitkeep
63
62
  homepage: https://github.com/yoshoku/numo-tiny_linalg
64
63
  licenses:
65
64
  - BSD-3-Clause
@@ -67,6 +66,7 @@ metadata:
67
66
  homepage_uri: https://github.com/yoshoku/numo-tiny_linalg
68
67
  source_code_uri: https://github.com/yoshoku/numo-tiny_linalg
69
68
  changelog_uri: https://github.com/yoshoku/numo-tiny_linalg/blob/main/CHANGELOG.md
69
+ documentation_uri: https://yoshoku.github.io/numo-tiny_linalg/doc/
70
70
  rubygems_mfa_required: 'true'
71
71
  post_install_message:
72
72
  rdoc_options: []
data/.clang-format DELETED
@@ -1,149 +0,0 @@
1
- ---
2
- Language: Cpp
3
- # BasedOnStyle: LLVM
4
- AccessModifierOffset: -2
5
- AlignAfterOpenBracket: Align
6
- AlignConsecutiveMacros: false
7
- AlignConsecutiveAssignments: false
8
- AlignConsecutiveBitFields: false
9
- AlignConsecutiveDeclarations: false
10
- AlignEscapedNewlines: Right
11
- AlignOperands: Align
12
- AlignTrailingComments: true
13
- AllowAllArgumentsOnNextLine: true
14
- AllowAllConstructorInitializersOnNextLine: true
15
- AllowAllParametersOfDeclarationOnNextLine: true
16
- AllowShortEnumsOnASingleLine: true
17
- AllowShortBlocksOnASingleLine: Never
18
- AllowShortCaseLabelsOnASingleLine: false
19
- AllowShortFunctionsOnASingleLine: All
20
- AllowShortLambdasOnASingleLine: All
21
- AllowShortIfStatementsOnASingleLine: true
22
- AllowShortLoopsOnASingleLine: true
23
- AlwaysBreakAfterDefinitionReturnType: None
24
- AlwaysBreakAfterReturnType: None
25
- AlwaysBreakBeforeMultilineStrings: false
26
- AlwaysBreakTemplateDeclarations: MultiLine
27
- BinPackArguments: true
28
- BinPackParameters: true
29
- BraceWrapping:
30
- AfterCaseLabel: false
31
- AfterClass: false
32
- AfterControlStatement: Never
33
- AfterEnum: false
34
- AfterFunction: false
35
- AfterNamespace: false
36
- AfterObjCDeclaration: false
37
- AfterStruct: false
38
- AfterUnion: false
39
- AfterExternBlock: false
40
- BeforeCatch: false
41
- BeforeElse: false
42
- BeforeLambdaBody: false
43
- BeforeWhile: false
44
- IndentBraces: false
45
- SplitEmptyFunction: true
46
- SplitEmptyRecord: true
47
- SplitEmptyNamespace: true
48
- BreakBeforeBinaryOperators: None
49
- BreakBeforeBraces: Attach
50
- BreakBeforeInheritanceComma: false
51
- BreakInheritanceList: BeforeColon
52
- BreakBeforeTernaryOperators: true
53
- BreakConstructorInitializersBeforeComma: false
54
- BreakConstructorInitializers: BeforeColon
55
- BreakAfterJavaFieldAnnotations: false
56
- BreakStringLiterals: true
57
- ColumnLimit: 0
58
- CommentPragmas: '^ IWYU pragma:'
59
- CompactNamespaces: false
60
- ConstructorInitializerAllOnOneLineOrOnePerLine: false
61
- ConstructorInitializerIndentWidth: 4
62
- ContinuationIndentWidth: 2
63
- Cpp11BracedListStyle: false
64
- DeriveLineEnding: true
65
- DerivePointerAlignment: false
66
- DisableFormat: false
67
- ExperimentalAutoDetectBinPacking: false
68
- FixNamespaceComments: true
69
- ForEachMacros:
70
- - foreach
71
- - Q_FOREACH
72
- - BOOST_FOREACH
73
- IncludeBlocks: Preserve
74
- IncludeCategories:
75
- - Regex: '^"(llvm|llvm-c|clang|clang-c)/'
76
- Priority: 2
77
- SortPriority: 0
78
- - Regex: '^(<|"(gtest|gmock|isl|json)/)'
79
- Priority: 3
80
- SortPriority: 0
81
- - Regex: '.*'
82
- Priority: 1
83
- SortPriority: 0
84
- IncludeIsMainRegex: '(Test)?$'
85
- IncludeIsMainSourceRegex: ''
86
- IndentCaseLabels: false
87
- IndentCaseBlocks: false
88
- IndentGotoLabels: true
89
- IndentPPDirectives: None
90
- IndentExternBlock: AfterExternBlock
91
- IndentWidth: 2
92
- IndentWrappedFunctionNames: false
93
- InsertTrailingCommas: None
94
- JavaScriptQuotes: Leave
95
- JavaScriptWrapImports: true
96
- KeepEmptyLinesAtTheStartOfBlocks: true
97
- MacroBlockBegin: ''
98
- MacroBlockEnd: ''
99
- MaxEmptyLinesToKeep: 1
100
- NamespaceIndentation: None
101
- ObjCBinPackProtocolList: Auto
102
- ObjCBlockIndentWidth: 2
103
- ObjCBreakBeforeNestedBlockParam: true
104
- ObjCSpaceAfterProperty: false
105
- ObjCSpaceBeforeProtocolList: true
106
- PenaltyBreakAssignment: 2
107
- PenaltyBreakBeforeFirstCallParameter: 19
108
- PenaltyBreakComment: 300
109
- PenaltyBreakFirstLessLess: 120
110
- PenaltyBreakString: 1000
111
- PenaltyBreakTemplateDeclaration: 10
112
- PenaltyExcessCharacter: 1000000
113
- PenaltyReturnTypeOnItsOwnLine: 60
114
- PointerAlignment: Left
115
- ReflowComments: true
116
- SortIncludes: true
117
- SortUsingDeclarations: true
118
- SpaceAfterCStyleCast: false
119
- SpaceAfterLogicalNot: false
120
- SpaceAfterTemplateKeyword: true
121
- SpaceBeforeAssignmentOperators: true
122
- SpaceBeforeCpp11BracedList: false
123
- SpaceBeforeCtorInitializerColon: true
124
- SpaceBeforeInheritanceColon: true
125
- SpaceBeforeParens: ControlStatements
126
- SpaceBeforeRangeBasedForLoopColon: true
127
- SpaceInEmptyBlock: false
128
- SpaceInEmptyParentheses: false
129
- SpacesBeforeTrailingComments: 1
130
- SpacesInAngles: false
131
- SpacesInConditionalStatement: false
132
- SpacesInContainerLiterals: true
133
- SpacesInCStyleCastParentheses: false
134
- SpacesInParentheses: false
135
- SpacesInSquareBrackets: false
136
- SpaceBeforeSquareBrackets: false
137
- Standard: Latest
138
- StatementMacros:
139
- - Q_UNUSED
140
- - QT_REQUIRE_VERSION
141
- TabWidth: 8
142
- UseCRLF: false
143
- UseTab: Never
144
- WhitespaceSensitiveMacros:
145
- - STRINGIZE
146
- - PP_STRINGIZE
147
- - BOOST_PP_STRINGIZE
148
- ...
149
-
data/.husky/commit-msg DELETED
@@ -1,4 +0,0 @@
1
- #!/usr/bin/env sh
2
- . "$(dirname -- "$0")/_/husky.sh"
3
-
4
- yarn commitlint --edit "$1"