numo-tiny_linalg 0.0.2 → 0.0.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZUngQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ lapack_complex_double* a, lapack_int lda, const lapack_complex_double* tau) {
6
+ return LAPACKE_zungqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct CUngQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ lapack_complex_float* a, lapack_int lda, const lapack_complex_float* tau) {
13
+ return LAPACKE_cungqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class UngQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_ungqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct ungqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_ungqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_ungqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_ungqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ ungqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -5,11 +5,20 @@
5
5
  #include "blas/gemv.hpp"
6
6
  #include "blas/nrm2.hpp"
7
7
  #include "converter.hpp"
8
+ #include "lapack/geqrf.hpp"
8
9
  #include "lapack/gesdd.hpp"
9
10
  #include "lapack/gesv.hpp"
10
11
  #include "lapack/gesvd.hpp"
11
12
  #include "lapack/getrf.hpp"
12
13
  #include "lapack/getri.hpp"
14
+ #include "lapack/hegv.hpp"
15
+ #include "lapack/hegvd.hpp"
16
+ #include "lapack/hegvx.hpp"
17
+ #include "lapack/orgqr.hpp"
18
+ #include "lapack/sygv.hpp"
19
+ #include "lapack/sygvd.hpp"
20
+ #include "lapack/sygvx.hpp"
21
+ #include "lapack/ungqr.hpp"
13
22
 
14
23
  VALUE rb_mTinyLinalg;
15
24
  VALUE rb_mTinyLinalgBlas;
@@ -260,6 +269,26 @@ extern "C" void Init_tiny_linalg(void) {
260
269
  TinyLinalg::GETRI<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRI>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
261
270
  TinyLinalg::GETRI<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRI>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
262
271
  TinyLinalg::GETRI<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRI>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
272
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
273
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
274
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
275
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeQrf>::define_module_function(rb_mTinyLinalgLapack, "cgeqrf");
276
+ TinyLinalg::OrgQr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DOrgQr>::define_module_function(rb_mTinyLinalgLapack, "dorgqr");
277
+ TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
278
+ TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
279
+ TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
280
+ TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
281
+ TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
282
+ TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
283
+ TinyLinalg::HeGv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGv>::define_module_function(rb_mTinyLinalgLapack, "chegv");
284
+ TinyLinalg::SyGvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvd>::define_module_function(rb_mTinyLinalgLapack, "dsygvd");
285
+ TinyLinalg::SyGvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvd>::define_module_function(rb_mTinyLinalgLapack, "ssygvd");
286
+ TinyLinalg::HeGvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvd>::define_module_function(rb_mTinyLinalgLapack, "zhegvd");
287
+ TinyLinalg::HeGvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvd>::define_module_function(rb_mTinyLinalgLapack, "chegvd");
288
+ TinyLinalg::SyGvx<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvx>::define_module_function(rb_mTinyLinalgLapack, "dsygvx");
289
+ TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
290
+ TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
291
+ TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
263
292
 
264
293
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
265
294
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.0.2'
8
+ VERSION = '0.0.4'
9
9
  end
10
10
  end
@@ -7,9 +7,79 @@ require_relative 'tiny_linalg/tiny_linalg'
7
7
  # Ruby/Numo (NUmerical MOdules)
8
8
  module Numo
9
9
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
10
- module TinyLinalg
10
+ module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
+ # by solving an ordinary or generalized eigenvalue problem.
15
+ #
16
+ # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
17
+ # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
18
+ # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
19
+ # @param vals_range [Range/Array]
20
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
21
+ # If nil, all eigenvalues and eigenvectors are computed.
22
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
23
+ # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
24
+ # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
25
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
26
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
27
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
28
+
29
+ bchr = blas_char(a)
30
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
31
+
32
+ unless b.nil?
33
+ raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
34
+ raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
35
+ raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
36
+ end
37
+
38
+ jobz = vals_only ? 'N' : 'V'
39
+ b = a.class.eye(a.shape[0]) if b.nil?
40
+ sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
41
+
42
+ if vals_range.nil?
43
+ sy_he_gv << 'd' if turbo
44
+ vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
45
+ else
46
+ sy_he_gv << 'x'
47
+ il = vals_range.first + 1
48
+ iu = vals_range.last + 1
49
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
50
+ sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
51
+ )
52
+ end
53
+ vecs = nil if vals_only
54
+ [vals, vecs]
55
+ end
56
+
57
+ # Computes the determinant of matrix.
58
+ #
59
+ # @param a [Numo::NArray] n-by-n square matrix.
60
+ # @return [Float/Complex] The determinant of `a`.
61
+ def det(a)
62
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
63
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
64
+
65
+ bchr = blas_char(a)
66
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
67
+
68
+ getrf = "#{bchr}getrf".to_sym
69
+ lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
70
+
71
+ if info.zero?
72
+ det_l = 1
73
+ det_u = lu.diagonal.prod
74
+ det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
75
+ det_l * det_u * det_p
76
+ elsif info.positive?
77
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
78
+ else
79
+ raise "the #{-info}-th argument of getrf had illegal value"
80
+ end
81
+ end
82
+
13
83
  # Computes the inverse matrix of a square matrix.
14
84
  #
15
85
  # @param a [Numo::NArray] n-by-n square matrix.
@@ -36,6 +106,64 @@ module Numo
36
106
  end
37
107
  end
38
108
 
109
+ # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
110
+ #
111
+ # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
112
+ # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
113
+ # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
114
+ # @return [Numo::NArray] The pseudo-inverse of `a`.
115
+ def pinv(a, driver: 'svd', rcond: nil)
116
+ s, u, vh = svd(a, driver: driver, job: 'S')
117
+ rcond = a.shape.max * s.class::EPSILON if rcond.nil?
118
+ rank = s.gt(rcond * s[0]).count
119
+
120
+ u = u[true, 0...rank] / s[0...rank]
121
+ u.dot(vh[0...rank, true]).conj.transpose
122
+ end
123
+
124
+ # Compute QR decomposition of a matrix.
125
+ #
126
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
127
+ # @param mode [String] The mode of decomposition.
128
+ # - "reduce" -- returns both Q [m, m] and R [m, n],
129
+ # - "r" -- returns only R,
130
+ # - "economic" -- returns both Q [m, n] and R [n, n],
131
+ # - "raw" -- returns QR and TAU (LAPACK geqrf results).
132
+ # @return [Numo::NArray] if mode='r'
133
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
134
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
135
+ def qr(a, mode: 'reduce')
136
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
137
+ raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
138
+
139
+ bchr = blas_char(a)
140
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
141
+
142
+ geqrf = "#{bchr}geqrf".to_sym
143
+ qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
144
+
145
+ return [qr, tau] if mode == 'raw'
146
+
147
+ m, n = qr.shape
148
+ r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
149
+
150
+ return r if mode == 'r'
151
+
152
+ org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr".to_sym : "#{bchr}ungqr".to_sym
153
+
154
+ q = if m < n
155
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
156
+ elsif mode == 'economic'
157
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr, tau)[0]
158
+ else
159
+ qqr = a.class.zeros(m, m)
160
+ qqr[0...m, 0...n] = qr
161
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qqr, tau)[0]
162
+ end
163
+
164
+ [q, r]
165
+ end
166
+
39
167
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
40
168
  #
41
169
  # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.2
4
+ version: 0.0.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-07-26 00:00:00.000000000 Z
11
+ date: 2023-08-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -46,11 +46,20 @@ files:
46
46
  - ext/numo/tiny_linalg/blas/nrm2.hpp
47
47
  - ext/numo/tiny_linalg/converter.hpp
48
48
  - ext/numo/tiny_linalg/extconf.rb
49
+ - ext/numo/tiny_linalg/lapack/geqrf.hpp
49
50
  - ext/numo/tiny_linalg/lapack/gesdd.hpp
50
51
  - ext/numo/tiny_linalg/lapack/gesv.hpp
51
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
52
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
53
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/hegv.hpp
56
+ - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
+ - ext/numo/tiny_linalg/lapack/hegvx.hpp
58
+ - ext/numo/tiny_linalg/lapack/orgqr.hpp
59
+ - ext/numo/tiny_linalg/lapack/sygv.hpp
60
+ - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
+ - ext/numo/tiny_linalg/lapack/sygvx.hpp
62
+ - ext/numo/tiny_linalg/lapack/ungqr.hpp
54
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
55
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp
56
65
  - lib/numo/tiny_linalg.rb