numo-tiny_linalg 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,115 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZUngQr {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
5
+ lapack_complex_double* a, lapack_int lda, const lapack_complex_double* tau) {
6
+ return LAPACKE_zungqr(matrix_layout, m, n, k, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct CUngQr {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n, lapack_int k,
12
+ lapack_complex_float* a, lapack_int lda, const lapack_complex_float* tau) {
13
+ return LAPACKE_cungqr(matrix_layout, m, n, k, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ template <int nary_dtype_id, typename DType, typename FncType>
18
+ class UngQr {
19
+ public:
20
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
21
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_ungqr), -1);
22
+ }
23
+
24
+ private:
25
+ struct ungqr_opt {
26
+ int matrix_layout;
27
+ };
28
+
29
+ static void iter_ungqr(na_loop_t* const lp) {
30
+ DType* a = (DType*)NDL_PTR(lp, 0);
31
+ DType* tau = (DType*)NDL_PTR(lp, 1);
32
+ int* info = (int*)NDL_PTR(lp, 2);
33
+ ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
34
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
+ const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
+ const lapack_int lda = n;
38
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
+ *info = static_cast<int>(i);
40
+ }
41
+
42
+ static VALUE tiny_linalg_ungqr(int argc, VALUE* argv, VALUE self) {
43
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
44
+
45
+ VALUE a_vnary = Qnil;
46
+ VALUE tau_vnary = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args);
49
+ ID kw_table[1] = { rb_intern("order") };
50
+ VALUE kw_values[1] = { Qundef };
51
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
+
54
+ if (CLASS_OF(a_vnary) != nary_dtype) {
55
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
56
+ }
57
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
58
+ a_vnary = nary_dup(a_vnary);
59
+ }
60
+ if (CLASS_OF(tau_vnary) != nary_dtype) {
61
+ tau_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, tau_vnary);
62
+ }
63
+ if (!RTEST(nary_check_contiguous(tau_vnary))) {
64
+ tau_vnary = nary_dup(tau_vnary);
65
+ }
66
+
67
+ narray_t* a_nary = NULL;
68
+ GetNArray(a_vnary, a_nary);
69
+ if (NA_NDIM(a_nary) != 2) {
70
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
71
+ return Qnil;
72
+ }
73
+ narray_t* tau_nary = NULL;
74
+ GetNArray(tau_vnary, tau_nary);
75
+ if (NA_NDIM(tau_nary) != 1) {
76
+ rb_raise(rb_eArgError, "input array tau must be 1-dimensional");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { nary_dtype, 1 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_ungqr, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
83
+ ungqr_opt opt = { matrix_layout };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary);
85
+
86
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
87
+
88
+ RB_GC_GUARD(a_vnary);
89
+ RB_GC_GUARD(tau_vnary);
90
+
91
+ return ret;
92
+ }
93
+
94
+ static int get_matrix_layout(VALUE val) {
95
+ const char* option_str = StringValueCStr(val);
96
+
97
+ if (std::strlen(option_str) > 0) {
98
+ switch (option_str[0]) {
99
+ case 'r':
100
+ case 'R':
101
+ break;
102
+ case 'c':
103
+ case 'C':
104
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
+ break;
106
+ }
107
+ }
108
+
109
+ RB_GC_GUARD(val);
110
+
111
+ return LAPACK_ROW_MAJOR;
112
+ }
113
+ };
114
+
115
+ } // namespace TinyLinalg
@@ -5,11 +5,20 @@
5
5
  #include "blas/gemv.hpp"
6
6
  #include "blas/nrm2.hpp"
7
7
  #include "converter.hpp"
8
+ #include "lapack/geqrf.hpp"
8
9
  #include "lapack/gesdd.hpp"
9
10
  #include "lapack/gesv.hpp"
10
11
  #include "lapack/gesvd.hpp"
11
12
  #include "lapack/getrf.hpp"
12
13
  #include "lapack/getri.hpp"
14
+ #include "lapack/hegv.hpp"
15
+ #include "lapack/hegvd.hpp"
16
+ #include "lapack/hegvx.hpp"
17
+ #include "lapack/orgqr.hpp"
18
+ #include "lapack/sygv.hpp"
19
+ #include "lapack/sygvd.hpp"
20
+ #include "lapack/sygvx.hpp"
21
+ #include "lapack/ungqr.hpp"
13
22
 
14
23
  VALUE rb_mTinyLinalg;
15
24
  VALUE rb_mTinyLinalgBlas;
@@ -260,6 +269,26 @@ extern "C" void Init_tiny_linalg(void) {
260
269
  TinyLinalg::GETRI<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRI>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
261
270
  TinyLinalg::GETRI<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRI>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
262
271
  TinyLinalg::GETRI<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRI>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
272
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
273
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
274
+ TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
275
+ TinyLinalg::GeQrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeQrf>::define_module_function(rb_mTinyLinalgLapack, "cgeqrf");
276
+ TinyLinalg::OrgQr<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DOrgQr>::define_module_function(rb_mTinyLinalgLapack, "dorgqr");
277
+ TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
278
+ TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
279
+ TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
280
+ TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
281
+ TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
282
+ TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
283
+ TinyLinalg::HeGv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGv>::define_module_function(rb_mTinyLinalgLapack, "chegv");
284
+ TinyLinalg::SyGvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvd>::define_module_function(rb_mTinyLinalgLapack, "dsygvd");
285
+ TinyLinalg::SyGvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvd>::define_module_function(rb_mTinyLinalgLapack, "ssygvd");
286
+ TinyLinalg::HeGvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvd>::define_module_function(rb_mTinyLinalgLapack, "zhegvd");
287
+ TinyLinalg::HeGvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvd>::define_module_function(rb_mTinyLinalgLapack, "chegvd");
288
+ TinyLinalg::SyGvx<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvx>::define_module_function(rb_mTinyLinalgLapack, "dsygvx");
289
+ TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
290
+ TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
291
+ TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
263
292
 
264
293
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
265
294
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.0.2'
8
+ VERSION = '0.0.4'
9
9
  end
10
10
  end
@@ -7,9 +7,79 @@ require_relative 'tiny_linalg/tiny_linalg'
7
7
  # Ruby/Numo (NUmerical MOdules)
8
8
  module Numo
9
9
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
10
- module TinyLinalg
10
+ module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
+ # by solving an ordinary or generalized eigenvalue problem.
15
+ #
16
+ # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
17
+ # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
18
+ # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
19
+ # @param vals_range [Range/Array]
20
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
21
+ # If nil, all eigenvalues and eigenvectors are computed.
22
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
23
+ # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
24
+ # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
25
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
26
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
27
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
28
+
29
+ bchr = blas_char(a)
30
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
31
+
32
+ unless b.nil?
33
+ raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
34
+ raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
35
+ raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
36
+ end
37
+
38
+ jobz = vals_only ? 'N' : 'V'
39
+ b = a.class.eye(a.shape[0]) if b.nil?
40
+ sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
41
+
42
+ if vals_range.nil?
43
+ sy_he_gv << 'd' if turbo
44
+ vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
45
+ else
46
+ sy_he_gv << 'x'
47
+ il = vals_range.first + 1
48
+ iu = vals_range.last + 1
49
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
50
+ sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
51
+ )
52
+ end
53
+ vecs = nil if vals_only
54
+ [vals, vecs]
55
+ end
56
+
57
+ # Computes the determinant of matrix.
58
+ #
59
+ # @param a [Numo::NArray] n-by-n square matrix.
60
+ # @return [Float/Complex] The determinant of `a`.
61
+ def det(a)
62
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
63
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
64
+
65
+ bchr = blas_char(a)
66
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
67
+
68
+ getrf = "#{bchr}getrf".to_sym
69
+ lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
70
+
71
+ if info.zero?
72
+ det_l = 1
73
+ det_u = lu.diagonal.prod
74
+ det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
75
+ det_l * det_u * det_p
76
+ elsif info.positive?
77
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
78
+ else
79
+ raise "the #{-info}-th argument of getrf had illegal value"
80
+ end
81
+ end
82
+
13
83
  # Computes the inverse matrix of a square matrix.
14
84
  #
15
85
  # @param a [Numo::NArray] n-by-n square matrix.
@@ -36,6 +106,64 @@ module Numo
36
106
  end
37
107
  end
38
108
 
109
+ # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
110
+ #
111
+ # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
112
+ # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
113
+ # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
114
+ # @return [Numo::NArray] The pseudo-inverse of `a`.
115
+ def pinv(a, driver: 'svd', rcond: nil)
116
+ s, u, vh = svd(a, driver: driver, job: 'S')
117
+ rcond = a.shape.max * s.class::EPSILON if rcond.nil?
118
+ rank = s.gt(rcond * s[0]).count
119
+
120
+ u = u[true, 0...rank] / s[0...rank]
121
+ u.dot(vh[0...rank, true]).conj.transpose
122
+ end
123
+
124
+ # Compute QR decomposition of a matrix.
125
+ #
126
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
127
+ # @param mode [String] The mode of decomposition.
128
+ # - "reduce" -- returns both Q [m, m] and R [m, n],
129
+ # - "r" -- returns only R,
130
+ # - "economic" -- returns both Q [m, n] and R [n, n],
131
+ # - "raw" -- returns QR and TAU (LAPACK geqrf results).
132
+ # @return [Numo::NArray] if mode='r'
133
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
134
+ # @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
135
+ def qr(a, mode: 'reduce')
136
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
137
+ raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
138
+
139
+ bchr = blas_char(a)
140
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
141
+
142
+ geqrf = "#{bchr}geqrf".to_sym
143
+ qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
144
+
145
+ return [qr, tau] if mode == 'raw'
146
+
147
+ m, n = qr.shape
148
+ r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
149
+
150
+ return r if mode == 'r'
151
+
152
+ org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr".to_sym : "#{bchr}ungqr".to_sym
153
+
154
+ q = if m < n
155
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
156
+ elsif mode == 'economic'
157
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qr, tau)[0]
158
+ else
159
+ qqr = a.class.zeros(m, m)
160
+ qqr[0...m, 0...n] = qr
161
+ Numo::TinyLinalg::Lapack.send(org_ung_qr, qqr, tau)[0]
162
+ end
163
+
164
+ [q, r]
165
+ end
166
+
39
167
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
40
168
  #
41
169
  # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.2
4
+ version: 0.0.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-07-26 00:00:00.000000000 Z
11
+ date: 2023-08-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -46,11 +46,20 @@ files:
46
46
  - ext/numo/tiny_linalg/blas/nrm2.hpp
47
47
  - ext/numo/tiny_linalg/converter.hpp
48
48
  - ext/numo/tiny_linalg/extconf.rb
49
+ - ext/numo/tiny_linalg/lapack/geqrf.hpp
49
50
  - ext/numo/tiny_linalg/lapack/gesdd.hpp
50
51
  - ext/numo/tiny_linalg/lapack/gesv.hpp
51
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
52
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
53
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/hegv.hpp
56
+ - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
+ - ext/numo/tiny_linalg/lapack/hegvx.hpp
58
+ - ext/numo/tiny_linalg/lapack/orgqr.hpp
59
+ - ext/numo/tiny_linalg/lapack/sygv.hpp
60
+ - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
+ - ext/numo/tiny_linalg/lapack/sygvx.hpp
62
+ - ext/numo/tiny_linalg/lapack/ungqr.hpp
54
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
55
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp
56
65
  - lib/numo/tiny_linalg.rb