numo-tiny_linalg 0.0.1 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DOrgQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ double* a, lapack_int lda, const double* tau) {
6
+ return LAPACKE_dorgqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SOrgQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ float* a, lapack_int lda, const float* tau) {
13
+ return LAPACKE_sorgqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class OrgQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_orgqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct orgqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_orgqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_orgqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_orgqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ orgqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZUngQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ lapack_complex_double* a, lapack_int lda, const lapack_complex_double* tau) {
6
+ return LAPACKE_zungqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct CUngQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ lapack_complex_float* a, lapack_int lda, const lapack_complex_float* tau) {
13
+ return LAPACKE_cungqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class UngQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_ungqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct ungqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_ungqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_ungqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_ungqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ ungqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -1,12 +1,18 @@
1
1
  #include "tiny_linalg.hpp"
2
+ #include "blas/dot.hpp"
3
+ #include "blas/dot_sub.hpp"
4
+ #include "blas/gemm.hpp"
5
+ #include "blas/gemv.hpp"
6
+ #include "blas/nrm2.hpp"
2
7
  #include "converter.hpp"
3
- #include "dot.hpp"
4
- #include "dot_sub.hpp"
5
- #include "gemm.hpp"
6
- #include "gemv.hpp"
7
- #include "gesdd.hpp"
8
- #include "gesvd.hpp"
9
- #include "nrm2.hpp"
8
+ #include "lapack/geqrf.hpp"
9
+ #include "lapack/gesdd.hpp"
10
+ #include "lapack/gesv.hpp"
11
+ #include "lapack/gesvd.hpp"
12
+ #include "lapack/getrf.hpp"
13
+ #include "lapack/getri.hpp"
14
+ #include "lapack/orgqr.hpp"
15
+ #include "lapack/ungqr.hpp"
10
16
 
11
17
  VALUE rb_mTinyLinalg;
12
18
  VALUE rb_mTinyLinalgBlas;
@@ -237,6 +243,10 @@ extern "C" void Init_tiny_linalg(void) {
237
243
  TinyLinalg::Nrm2<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SNrm2>::define_module_function(rb_mTinyLinalgBlas, "snrm2");
238
244
  TinyLinalg::Nrm2<TinyLinalg::numo_cDComplexId, double, TinyLinalg::DZNrm2>::define_module_function(rb_mTinyLinalgBlas, "dznrm2");
239
245
  TinyLinalg::Nrm2<TinyLinalg::numo_cSComplexId, float, TinyLinalg::SCNrm2>::define_module_function(rb_mTinyLinalgBlas, "scnrm2");
246
+ TinyLinalg::GESV<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGESV>::define_module_function(rb_mTinyLinalgLapack, "dgesv");
247
+ TinyLinalg::GESV<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGESV>::define_module_function(rb_mTinyLinalgLapack, "sgesv");
248
+ TinyLinalg::GESV<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGESV>::define_module_function(rb_mTinyLinalgLapack, "zgesv");
249
+ TinyLinalg::GESV<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGESV>::define_module_function(rb_mTinyLinalgLapack, "cgesv");
240
250
  TinyLinalg::GESVD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESVD>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
241
251
  TinyLinalg::GESVD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESVD>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
242
252
  TinyLinalg::GESVD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESVD>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
@@ -245,6 +255,22 @@ extern "C" void Init_tiny_linalg(void) {
245
255
  TinyLinalg::GESDD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESDD>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
246
256
  TinyLinalg::GESDD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESDD>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
247
257
  TinyLinalg::GESDD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESDD>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
258
+ TinyLinalg::GETRF<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGETRF>::define_module_function(rb_mTinyLinalgLapack, "dgetrf");
259
+ TinyLinalg::GETRF<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRF>::define_module_function(rb_mTinyLinalgLapack, "sgetrf");
260
+ TinyLinalg::GETRF<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRF>::define_module_function(rb_mTinyLinalgLapack, "zgetrf");
261
+ TinyLinalg::GETRF<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRF>::define_module_function(rb_mTinyLinalgLapack, "cgetrf");
262
+ TinyLinalg::GETRI<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGETRI>::define_module_function(rb_mTinyLinalgLapack, "dgetri");
263
+ TinyLinalg::GETRI<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRI>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
264
+ TinyLinalg::GETRI<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRI>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
265
+ TinyLinalg::GETRI<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRI>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
266
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
267
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
268
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
269
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeQrf>::define_module_function(rb_mTinyLinalgLapack, "cgeqrf");
270
+ TinyLinalg::OrgQr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DOrgQr>::define_module_function(rb_mTinyLinalgLapack, "dorgqr");
271
+ TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
272
+ TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
273
+ TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
248
274
 
249
275
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
250
276
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.0.1'
8
+ VERSION = '0.0.3'
9
9
  end
10
10
  end
@@ -7,9 +7,139 @@ require_relative 'tiny_linalg/tiny_linalg'
7
7
  # Ruby/Numo (NUmerical MOdules)
8
8
  module Numo
9
9
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
10
- module TinyLinalg
10
+ module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the determinant of matrix.
14
+ #
15
+ # @param a [Numo::NArray] n-by-n square matrix.
16
+ # @return [Float/Complex] The determinant of `a`.
17
+ def det(a)
18
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
19
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
20
+
21
+ bchr = blas_char(a)
22
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
23
+
24
+ getrf = "#{bchr}getrf".to_sym
25
+ lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
26
+
27
+ if info.zero?
28
+ det_l = 1
29
+ det_u = lu.diagonal.prod
30
+ det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
31
+ det_l * det_u * det_p
32
+ elsif info.positive?
33
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
34
+ else
35
+ raise "the #{-info}-th argument of getrf had illegal value"
36
+ end
37
+ end
38
+
39
+ # Computes the inverse matrix of a square matrix.
40
+ #
41
+ # @param a [Numo::NArray] n-by-n square matrix.
42
+ # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
43
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
44
+ # @return [Numo::NArray] The inverse matrix of `a`.
45
+ def inv(a, driver: 'getrf', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
46
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
47
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
48
+
49
+ bchr = blas_char(a)
50
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
51
+
52
+ getrf = "#{bchr}getrf".to_sym
53
+ getri = "#{bchr}getri".to_sym
54
+
55
+ lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
56
+ if info.zero?
57
+ Numo::TinyLinalg::Lapack.send(getri, lu, piv)[0]
58
+ elsif info.positive?
59
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
60
+ else
61
+ raise "the #{-info}-th argument of getrf had illegal value"
62
+ end
63
+ end
64
+
65
+ # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
66
+ #
67
+ # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
68
+ # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
69
+ # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
70
+ # @return [Numo::NArray] The pseudo-inverse of `a`.
71
+ def pinv(a, driver: 'svd', rcond: nil)
72
+ s, u, vh = svd(a, driver: driver, job: 'S')
73
+ rcond = a.shape.max * s.class::EPSILON if rcond.nil?
74
+ rank = s.gt(rcond * s[0]).count
75
+
76
+ u = u[true, 0...rank] / s[0...rank]
77
+ u.dot(vh[0...rank, true]).conj.transpose
78
+ end
79
+
80
+ # Compute QR decomposition of a matrix.
81
+ #
82
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
83
+ # @param mode [String] The mode of decomposition.
84
+ # - "reduce" -- returns both Q [m, m] and R [m, n],
85
+ # - "r" -- returns only R,
86
+ # - "economic" -- returns both Q [m, n] and R [n, n],
87
+ # - "raw" -- returns QR and TAU (LAPACK geqrf results).
88
+ # @return [Numo::NArray] if mode='r'
89
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
90
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
91
+ def qr(a, mode: 'reduce')
92
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
93
+ raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
94
+
95
+ bchr = blas_char(a)
96
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
97
+
98
+ geqrf = "#{bchr}geqrf".to_sym
99
+ qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
100
+
101
+ return [qr, tau] if mode == 'raw'
102
+
103
+ m, n = qr.shape
104
+ r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
105
+
106
+ return r if mode == 'r'
107
+
108
+ org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr".to_sym : "#{bchr}ungqr".to_sym
109
+
110
+ q = if m < n
111
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
112
+ elsif mode == 'economic'
113
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr, tau)[0]
114
+ else
115
+ qqr = a.class.zeros(m, m)
116
+ qqr[0...m, 0...n] = qr
117
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qqr, tau)[0]
118
+ end
119
+
120
+ [q, r]
121
+ end
122
+
123
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
124
+ #
125
+ # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
126
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
127
+ # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
128
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
129
+ # @return [Numo::NArray] The solusion vector / matrix `x`.
130
+ def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
131
+ case blas_char(a, b)
132
+ when 'd'
133
+ Lapack.dgesv(a.dup, b.dup)[1]
134
+ when 's'
135
+ Lapack.sgesv(a.dup, b.dup)[1]
136
+ when 'z'
137
+ Lapack.zgesv(a.dup, b.dup)[1]
138
+ when 'c'
139
+ Lapack.cgesv(a.dup, b.dup)[1]
140
+ end
141
+ end
142
+
13
143
  # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
14
144
  #
15
145
  # @param a [Numo::NArray] Matrix to be decomposed.
File without changes
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-07-14 00:00:00.000000000 Z
11
+ date: 2023-08-02 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -35,31 +35,30 @@ extensions:
35
35
  - ext/numo/tiny_linalg/extconf.rb
36
36
  extra_rdoc_files: []
37
37
  files:
38
- - ".clang-format"
39
- - ".husky/commit-msg"
40
- - ".rubocop.yml"
41
38
  - CHANGELOG.md
42
39
  - CODE_OF_CONDUCT.md
43
- - Gemfile
44
40
  - LICENSE.txt
45
41
  - README.md
46
- - Rakefile
47
- - commitlint.config.js
42
+ - ext/numo/tiny_linalg/blas/dot.hpp
43
+ - ext/numo/tiny_linalg/blas/dot_sub.hpp
44
+ - ext/numo/tiny_linalg/blas/gemm.hpp
45
+ - ext/numo/tiny_linalg/blas/gemv.hpp
46
+ - ext/numo/tiny_linalg/blas/nrm2.hpp
48
47
  - ext/numo/tiny_linalg/converter.hpp
49
- - ext/numo/tiny_linalg/dot.hpp
50
- - ext/numo/tiny_linalg/dot_sub.hpp
51
48
  - ext/numo/tiny_linalg/extconf.rb
52
- - ext/numo/tiny_linalg/gemm.hpp
53
- - ext/numo/tiny_linalg/gemv.hpp
54
- - ext/numo/tiny_linalg/gesdd.hpp
55
- - ext/numo/tiny_linalg/gesvd.hpp
56
- - ext/numo/tiny_linalg/nrm2.hpp
49
+ - ext/numo/tiny_linalg/lapack/geqrf.hpp
50
+ - ext/numo/tiny_linalg/lapack/gesdd.hpp
51
+ - ext/numo/tiny_linalg/lapack/gesv.hpp
52
+ - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
+ - ext/numo/tiny_linalg/lapack/getrf.hpp
54
+ - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/orgqr.hpp
56
+ - ext/numo/tiny_linalg/lapack/ungqr.hpp
57
57
  - ext/numo/tiny_linalg/tiny_linalg.cpp
58
58
  - ext/numo/tiny_linalg/tiny_linalg.hpp
59
59
  - lib/numo/tiny_linalg.rb
60
60
  - lib/numo/tiny_linalg/version.rb
61
- - numo-tiny_linalg.gemspec
62
- - package.json
61
+ - vendor/tmp/.gitkeep
63
62
  homepage: https://github.com/yoshoku/numo-tiny_linalg
64
63
  licenses:
65
64
  - BSD-3-Clause
@@ -67,6 +66,7 @@ metadata:
67
66
  homepage_uri: https://github.com/yoshoku/numo-tiny_linalg
68
67
  source_code_uri: https://github.com/yoshoku/numo-tiny_linalg
69
68
  changelog_uri: https://github.com/yoshoku/numo-tiny_linalg/blob/main/CHANGELOG.md
69
+ documentation_uri: https://yoshoku.github.io/numo-tiny_linalg/doc/
70
70
  rubygems_mfa_required: 'true'
71
71
  post_install_message:
72
72
  rdoc_options: []
data/.clang-format DELETED
@@ -1,149 +0,0 @@
1
- ---
2
- Language: Cpp
3
- # BasedOnStyle: LLVM
4
- AccessModifierOffset: -2
5
- AlignAfterOpenBracket: Align
6
- AlignConsecutiveMacros: false
7
- AlignConsecutiveAssignments: false
8
- AlignConsecutiveBitFields: false
9
- AlignConsecutiveDeclarations: false
10
- AlignEscapedNewlines: Right
11
- AlignOperands: Align
12
- AlignTrailingComments: true
13
- AllowAllArgumentsOnNextLine: true
14
- AllowAllConstructorInitializersOnNextLine: true
15
- AllowAllParametersOfDeclarationOnNextLine: true
16
- AllowShortEnumsOnASingleLine: true
17
- AllowShortBlocksOnASingleLine: Never
18
- AllowShortCaseLabelsOnASingleLine: false
19
- AllowShortFunctionsOnASingleLine: All
20
- AllowShortLambdasOnASingleLine: All
21
- AllowShortIfStatementsOnASingleLine: true
22
- AllowShortLoopsOnASingleLine: true
23
- AlwaysBreakAfterDefinitionReturnType: None
24
- AlwaysBreakAfterReturnType: None
25
- AlwaysBreakBeforeMultilineStrings: false
26
- AlwaysBreakTemplateDeclarations: MultiLine
27
- BinPackArguments: true
28
- BinPackParameters: true
29
- BraceWrapping:
30
- AfterCaseLabel: false
31
- AfterClass: false
32
- AfterControlStatement: Never
33
- AfterEnum: false
34
- AfterFunction: false
35
- AfterNamespace: false
36
- AfterObjCDeclaration: false
37
- AfterStruct: false
38
- AfterUnion: false
39
- AfterExternBlock: false
40
- BeforeCatch: false
41
- BeforeElse: false
42
- BeforeLambdaBody: false
43
- BeforeWhile: false
44
- IndentBraces: false
45
- SplitEmptyFunction: true
46
- SplitEmptyRecord: true
47
- SplitEmptyNamespace: true
48
- BreakBeforeBinaryOperators: None
49
- BreakBeforeBraces: Attach
50
- BreakBeforeInheritanceComma: false
51
- BreakInheritanceList: BeforeColon
52
- BreakBeforeTernaryOperators: true
53
- BreakConstructorInitializersBeforeComma: false
54
- BreakConstructorInitializers: BeforeColon
55
- BreakAfterJavaFieldAnnotations: false
56
- BreakStringLiterals: true
57
- ColumnLimit: 0
58
- CommentPragmas: '^ IWYU pragma:'
59
- CompactNamespaces: false
60
- ConstructorInitializerAllOnOneLineOrOnePerLine: false
61
- ConstructorInitializerIndentWidth: 4
62
- ContinuationIndentWidth: 2
63
- Cpp11BracedListStyle: false
64
- DeriveLineEnding: true
65
- DerivePointerAlignment: false
66
- DisableFormat: false
67
- ExperimentalAutoDetectBinPacking: false
68
- FixNamespaceComments: true
69
- ForEachMacros:
70
- - foreach
71
- - Q_FOREACH
72
- - BOOST_FOREACH
73
- IncludeBlocks: Preserve
74
- IncludeCategories:
75
- - Regex: '^"(llvm|llvm-c|clang|clang-c)/'
76
- Priority: 2
77
- SortPriority: 0
78
- - Regex: '^(<|"(gtest|gmock|isl|json)/)'
79
- Priority: 3
80
- SortPriority: 0
81
- - Regex: '.*'
82
- Priority: 1
83
- SortPriority: 0
84
- IncludeIsMainRegex: '(Test)?$'
85
- IncludeIsMainSourceRegex: ''
86
- IndentCaseLabels: false
87
- IndentCaseBlocks: false
88
- IndentGotoLabels: true
89
- IndentPPDirectives: None
90
- IndentExternBlock: AfterExternBlock
91
- IndentWidth: 2
92
- IndentWrappedFunctionNames: false
93
- InsertTrailingCommas: None
94
- JavaScriptQuotes: Leave
95
- JavaScriptWrapImports: true
96
- KeepEmptyLinesAtTheStartOfBlocks: true
97
- MacroBlockBegin: ''
98
- MacroBlockEnd: ''
99
- MaxEmptyLinesToKeep: 1
100
- NamespaceIndentation: None
101
- ObjCBinPackProtocolList: Auto
102
- ObjCBlockIndentWidth: 2
103
- ObjCBreakBeforeNestedBlockParam: true
104
- ObjCSpaceAfterProperty: false
105
- ObjCSpaceBeforeProtocolList: true
106
- PenaltyBreakAssignment: 2
107
- PenaltyBreakBeforeFirstCallParameter: 19
108
- PenaltyBreakComment: 300
109
- PenaltyBreakFirstLessLess: 120
110
- PenaltyBreakString: 1000
111
- PenaltyBreakTemplateDeclaration: 10
112
- PenaltyExcessCharacter: 1000000
113
- PenaltyReturnTypeOnItsOwnLine: 60
114
- PointerAlignment: Left
115
- ReflowComments: true
116
- SortIncludes: true
117
- SortUsingDeclarations: true
118
- SpaceAfterCStyleCast: false
119
- SpaceAfterLogicalNot: false
120
- SpaceAfterTemplateKeyword: true
121
- SpaceBeforeAssignmentOperators: true
122
- SpaceBeforeCpp11BracedList: false
123
- SpaceBeforeCtorInitializerColon: true
124
- SpaceBeforeInheritanceColon: true
125
- SpaceBeforeParens: ControlStatements
126
- SpaceBeforeRangeBasedForLoopColon: true
127
- SpaceInEmptyBlock: false
128
- SpaceInEmptyParentheses: false
129
- SpacesBeforeTrailingComments: 1
130
- SpacesInAngles: false
131
- SpacesInConditionalStatement: false
132
- SpacesInContainerLiterals: true
133
- SpacesInCStyleCastParentheses: false
134
- SpacesInParentheses: false
135
- SpacesInSquareBrackets: false
136
- SpaceBeforeSquareBrackets: false
137
- Standard: Latest
138
- StatementMacros:
139
- - Q_UNUSED
140
- - QT_REQUIRE_VERSION
141
- TabWidth: 8
142
- UseCRLF: false
143
- UseTab: Never
144
- WhitespaceSensitiveMacros:
145
- - STRINGIZE
146
- - PP_STRINGIZE
147
- - BOOST_PP_STRINGIZE
148
- ...
149
-
data/.husky/commit-msg DELETED
@@ -1,4 +0,0 @@
1
- #!/usr/bin/env sh
2
- . "$(dirname -- "$0")/_/husky.sh"
3
-
4
- yarn commitlint --edit "$1"