numo-tiny_linalg 0.3.5 → 0.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9165d44ad74a388e32e08564402c255b238ac80dc66655d264ce5074b6d9ae66
4
- data.tar.gz: d679aae9036c5da1ebb72854b554438dc201dd7d95376c381647513acb4ce5f7
3
+ metadata.gz: c508a2c2f31a965c26824084f1eeb1a2caf53ea82a3d23e523980d6cf78e6259
4
+ data.tar.gz: 4bffa14a1a73bfb60230c65fa35c98eeb957e176b8af2003f9c3e5b8aaf89b19
5
5
  SHA512:
6
- metadata.gz: dc6c2fdebcd684b07a29dfa3650a33da81f3a0339a90ddc17c872876854cc7c4b8fc42c7feaf932501b212fc4f6a460e871ff23aba776e8e97127e6af2d51f0e
7
- data.tar.gz: 56b0db97799b6e5c41f67503bc20d2fb08868b9a53a64a7a5a919b29791a61f08749c03f54190a238e84dd8e3f3aac44a2ac4bdb8d7e53d50e3491f7c9874f7e
6
+ metadata.gz: 4ad7a9af39fdc0c5ffb9ba7fc51dcf037fe8ffe7d81c33c70a0c669b384127929efee3a9ceb8440b7d1ec10c79ee66b362aed7098befbe02a892d4490961f169
7
+ data.tar.gz: 91c9935714ef5b929718affc9cfae518f6e08f85063f673f26ffa7c137bb7fe62d41c20386704657dccb89feebf57bdaed5b182750c6f0fe3ad5b20e55cc4221
data/CHANGELOG.md CHANGED
@@ -1,4 +1,12 @@
1
1
  ## [Unreleased]
2
+ ## [[0.3.6](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.5...v0.3.6)] - 2024-01-28
3
+
4
+ - Add solve_triangular module function to TinyLinalg.
5
+ - The solve_triangular is not implemented in Numo::Linalg, but I have implemented it because it uses some machine learning algorithms.
6
+ - Add dtrtrs, strtrs, ztrtrs, and ctrtrs module functions to TinyLinalg::Lapack.
7
+ - Add norm module function to TinyLinalg.
8
+ - Add dlange, slange, zlange, and clange module functions to TinyLinalg::Lapack.
9
+
2
10
  ## [[0.3.5](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.4...v0.3.5)] - 2024-01-03
3
11
  - Bump OpenBLAS to be downloaded from 0.3.25 to 0.3.26.
4
12
  - Minor changes using RuboCop.
@@ -83,8 +83,7 @@ if build_openblas
83
83
  end
84
84
  end
85
85
 
86
- libopenblas_dir = on_windows ? TINYLINALG_DIR : "#{VENDOR_DIR}/lib"
87
- abort('libopenblas is not found.') unless find_library('openblas', nil, libopenblas_dir)
86
+ abort('libopenblas is not found.') unless find_library('openblas', nil, "#{VENDOR_DIR}/lib")
88
87
  abort('openblas_config.h is not found.') unless find_header('openblas_config.h', nil, "#{VENDOR_DIR}/include")
89
88
  abort('cblas.h is not found.') unless find_header('cblas.h', nil, "#{VENDOR_DIR}/include")
90
89
  abort('lapacke.h is not found.') unless find_header('lapacke.h', nil, "#{VENDOR_DIR}/include")
@@ -0,0 +1,89 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DLanGe {
4
+ double call(int matrix_layout, char norm, lapack_int m, lapack_int n, const double* a, lapack_int lda) {
5
+ return LAPACKE_dlange(matrix_layout, norm, m, n, a, lda);
6
+ }
7
+ };
8
+
9
+ struct SLanGe {
10
+ float call(int matrix_layout, char norm, lapack_int m, lapack_int n, const float* a, lapack_int lda) {
11
+ return LAPACKE_slange(matrix_layout, norm, m, n, a, lda);
12
+ }
13
+ };
14
+
15
+ struct ZLanGe {
16
+ double call(int matrix_layout, char norm, lapack_int m, lapack_int n, const lapack_complex_double* a, lapack_int lda) {
17
+ return LAPACKE_zlange(matrix_layout, norm, m, n, a, lda);
18
+ }
19
+ };
20
+
21
+ struct CLanGe {
22
+ float call(int matrix_layout, char norm, lapack_int m, lapack_int n, const lapack_complex_float* a, lapack_int lda) {
23
+ return LAPACKE_clange(matrix_layout, norm, m, n, a, lda);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, typename dtype, class LapackFn>
28
+ class LanGe {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_lange), -1);
32
+ }
33
+
34
+ private:
35
+ struct lange_opt {
36
+ int matrix_layout;
37
+ char norm;
38
+ };
39
+
40
+ static void iter_lange(na_loop_t* const lp) {
41
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
42
+ dtype* d = (dtype*)NDL_PTR(lp, 1);
43
+ lange_opt* opt = (lange_opt*)(lp->opt_ptr);
44
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
45
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
46
+ const lapack_int lda = n;
47
+ *d = LapackFn().call(opt->matrix_layout, opt->norm, m, n, a, lda);
48
+ }
49
+
50
+ static VALUE tiny_linalg_lange(int argc, VALUE* argv, VALUE self) {
51
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
52
+
53
+ VALUE a_vnary = Qnil;
54
+ VALUE kw_args = Qnil;
55
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
56
+ ID kw_table[2] = { rb_intern("order"), rb_intern("norm") };
57
+ VALUE kw_values[2] = { Qundef, Qundef };
58
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
59
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
60
+ const char norm = kw_values[1] != Qundef ? NUM2CHR(kw_values[1]) : 'F';
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+
69
+ narray_t* a_nary = NULL;
70
+ GetNArray(a_vnary, a_nary);
71
+ if (NA_NDIM(a_nary) != 2) {
72
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
73
+ return Qnil;
74
+ }
75
+
76
+ ndfunc_arg_in_t ain[1] = { { nary_dtype, 2 } };
77
+ size_t shape_out[1] = { 1 };
78
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
79
+ ndfunc_t ndf = { iter_lange, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
80
+ lange_opt opt = { matrix_layout, norm };
81
+ VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
82
+
83
+ RB_GC_GUARD(a_vnary);
84
+
85
+ return ret;
86
+ }
87
+ };
88
+
89
+ } // namespace TinyLinalg
@@ -0,0 +1,126 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DTrTrs {
4
+ lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
5
+ const double* a, lapack_int lda, double* b, lapack_int ldb) {
6
+ return LAPACKE_dtrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
7
+ }
8
+ };
9
+
10
+ struct STrTrs {
11
+ lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
12
+ const float* a, lapack_int lda, float* b, lapack_int ldb) {
13
+ return LAPACKE_strtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
14
+ }
15
+ };
16
+
17
+ struct ZTrTrs {
18
+ lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
19
+ const lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb) {
20
+ return LAPACKE_ztrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
21
+ }
22
+ };
23
+
24
+ struct CTrTrs {
25
+ lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
26
+ const lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb) {
27
+ return LAPACKE_ctrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename dtype, class LapackFn>
32
+ class TrTrs {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_trtrs), -1);
36
+ }
37
+
38
+ private:
39
+ struct trtrs_opt {
40
+ int matrix_layout;
41
+ char uplo;
42
+ char trans;
43
+ char diag;
44
+ };
45
+
46
+ static void iter_trtrs(na_loop_t* const lp) {
47
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
48
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
49
+ int* info = (int*)NDL_PTR(lp, 2);
50
+ trtrs_opt* opt = (trtrs_opt*)(lp->opt_ptr);
51
+ const lapack_int n = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
53
+ const lapack_int lda = n;
54
+ const lapack_int ldb = nrhs;
55
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, opt->trans, opt->diag, n, nrhs, a, lda, b, ldb);
56
+ *info = static_cast<int>(i);
57
+ }
58
+
59
+ static VALUE tiny_linalg_trtrs(int argc, VALUE* argv, VALUE self) {
60
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
61
+
62
+ VALUE a_vnary = Qnil;
63
+ VALUE b_vnary = Qnil;
64
+ VALUE kw_args = Qnil;
65
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
66
+ ID kw_table[4] = { rb_intern("order"), rb_intern("uplo"), rb_intern("trans"), rb_intern("diag") };
67
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
68
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
69
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
70
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
71
+ const char trans = kw_values[2] != Qundef ? NUM2CHR(kw_values[2]) : 'N';
72
+ const char diag = kw_values[3] != Qundef ? NUM2CHR(kw_values[3]) : 'N';
73
+
74
+ if (CLASS_OF(a_vnary) != nary_dtype) {
75
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
76
+ }
77
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
78
+ a_vnary = nary_dup(a_vnary);
79
+ }
80
+ if (CLASS_OF(b_vnary) != nary_dtype) {
81
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
82
+ }
83
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
84
+ b_vnary = nary_dup(b_vnary);
85
+ }
86
+
87
+ narray_t* a_nary = NULL;
88
+ GetNArray(a_vnary, a_nary);
89
+ if (NA_NDIM(a_nary) != 2) {
90
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
91
+ return Qnil;
92
+ }
93
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
94
+ rb_raise(rb_eArgError, "input array a must be square");
95
+ return Qnil;
96
+ }
97
+
98
+ narray_t* b_nary = NULL;
99
+ GetNArray(b_vnary, b_nary);
100
+ const int b_n_dims = NA_NDIM(b_nary);
101
+ if (b_n_dims != 1 && b_n_dims != 2) {
102
+ rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
103
+ return Qnil;
104
+ }
105
+
106
+ lapack_int n = NA_SHAPE(a_nary)[0];
107
+ lapack_int nb = NA_SHAPE(b_nary)[0];
108
+ if (n != nb) {
109
+ rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb);
110
+ }
111
+
112
+ ndfunc_arg_in_t ain[2] = { { nary_dtype, 2 }, { OVERWRITE, b_n_dims } };
113
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
114
+ ndfunc_t ndf = { iter_trtrs, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
115
+ trtrs_opt opt = { matrix_layout, uplo, trans, diag };
116
+ VALUE info = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
117
+ VALUE ret = rb_ary_new3(2, b_vnary, info);
118
+
119
+ RB_GC_GUARD(a_vnary);
120
+ RB_GC_GUARD(b_vnary);
121
+
122
+ return ret;
123
+ }
124
+ };
125
+
126
+ } // namespace TinyLinalg
@@ -50,6 +50,7 @@
50
50
  #include "lapack/hegv.hpp"
51
51
  #include "lapack/hegvd.hpp"
52
52
  #include "lapack/hegvx.hpp"
53
+ #include "lapack/lange.hpp"
53
54
  #include "lapack/orgqr.hpp"
54
55
  #include "lapack/potrf.hpp"
55
56
  #include "lapack/potrs.hpp"
@@ -59,6 +60,7 @@
59
60
  #include "lapack/sygv.hpp"
60
61
  #include "lapack/sygvd.hpp"
61
62
  #include "lapack/sygvx.hpp"
63
+ #include "lapack/trtrs.hpp"
62
64
  #include "lapack/ungqr.hpp"
63
65
 
64
66
  VALUE rb_mTinyLinalg;
@@ -314,6 +316,10 @@ extern "C" void Init_tiny_linalg(void) {
314
316
  TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
315
317
  TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
316
318
  TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
319
+ TinyLinalg::TrTrs<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DTrTrs>::define_module_function(rb_mTinyLinalgLapack, "dtrtrs");
320
+ TinyLinalg::TrTrs<TinyLinalg::numo_cSFloatId, float, TinyLinalg::STrTrs>::define_module_function(rb_mTinyLinalgLapack, "strtrs");
321
+ TinyLinalg::TrTrs<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZTrTrs>::define_module_function(rb_mTinyLinalgLapack, "ztrtrs");
322
+ TinyLinalg::TrTrs<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CTrTrs>::define_module_function(rb_mTinyLinalgLapack, "ctrtrs");
317
323
  TinyLinalg::PoTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrf>::define_module_function(rb_mTinyLinalgLapack, "dpotrf");
318
324
  TinyLinalg::PoTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrf>::define_module_function(rb_mTinyLinalgLapack, "spotrf");
319
325
  TinyLinalg::PoTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrf>::define_module_function(rb_mTinyLinalgLapack, "zpotrf");
@@ -354,6 +360,10 @@ extern "C" void Init_tiny_linalg(void) {
354
360
  TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
355
361
  TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
356
362
  TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
363
+ TinyLinalg::LanGe<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DLanGe>::define_module_function(rb_mTinyLinalgLapack, "dlange");
364
+ TinyLinalg::LanGe<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SLanGe>::define_module_function(rb_mTinyLinalgLapack, "slange");
365
+ TinyLinalg::LanGe<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZLanGe>::define_module_function(rb_mTinyLinalgLapack, "zlange");
366
+ TinyLinalg::LanGe<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CLanGe>::define_module_function(rb_mTinyLinalgLapack, "clange");
357
367
 
358
368
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
359
369
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.3.5'
8
+ VERSION = '0.3.6'
9
9
  end
10
10
  end
@@ -95,6 +95,172 @@ module Numo
95
95
  [vals, vecs]
96
96
  end
97
97
 
98
+ # Computes the matrix or vector norm.
99
+ #
100
+ # | ord | matrix norm | vector norm |
101
+ # | ----- | ---------------------- | --------------------------- |
102
+ # | nil | Frobenius norm | 2-norm |
103
+ # | 'fro' | Frobenius norm | - |
104
+ # | 'nuc' | nuclear norm | - |
105
+ # | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
106
+ # | 0 | - | (x.ne 0).sum |
107
+ # | 1 | x.abs.sum(axis:-2).max | same as below |
108
+ # | 2 | 2-norm (max sing_vals) | same as below |
109
+ # | other | - | (x.abs**ord).sum**(1.0/ord) |
110
+ #
111
+ # @example
112
+ # require 'numo/tiny_linalg'
113
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
114
+ #
115
+ # # matrix norm
116
+ # x = Numo::DFloat[[1, 2, -3, 1], [-4, 1, 8, 2]]
117
+ # pp Numo::Linalg.norm(x)
118
+ # # => 10
119
+ #
120
+ # # vector norm
121
+ # x = Numo::DFloat[3, -4]
122
+ # pp Numo::Linalg.norm(x)
123
+ # # => 5
124
+ #
125
+ # @param a [Numo::NArray] The matrix or vector (>= 1-dimensinal NArray)
126
+ # @param ord [String/Numeric] The order of the norm.
127
+ # @param axis [Integer/Array] The applied axes.
128
+ # @param keepdims [Bool] The flag indicating whether to leave the normed axes in the result as dimensions with size one.
129
+ # @return [Numo::NArray/Numeric] The norm of the matrix or vectors.
130
+ def norm(a, ord = nil, axis: nil, keepdims: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
131
+ a = Numo::NArray.asarray(a) unless a.is_a?(Numo::NArray)
132
+
133
+ return 0.0 if a.empty?
134
+
135
+ # for compatibility with Numo::Linalg.norm
136
+ if ord.is_a?(String)
137
+ if ord == 'inf'
138
+ ord = Float::INFINITY
139
+ elsif ord == '-inf'
140
+ ord = -Float::INFINITY
141
+ end
142
+ end
143
+
144
+ if axis.nil?
145
+ norm = case a.ndim
146
+ when 1
147
+ Numo::TinyLinalg::Blas.send(:"#{blas_char(a)}nrm2", a) if ord.nil? || ord == 2
148
+ when 2
149
+ if ord.nil? || ord == 'fro'
150
+ Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'F')
151
+ elsif ord.is_a?(Numeric)
152
+ if ord == 1
153
+ Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: '1')
154
+ elsif !ord.infinite?.nil? && ord.infinite?.positive?
155
+ Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'I')
156
+ end
157
+ end
158
+ else
159
+ if ord.nil?
160
+ b = a.flatten.dup
161
+ Numo::TinyLinalg::Blas.send(:"#{blas_char(b)}nrm2", b)
162
+ end
163
+ end
164
+ unless norm.nil?
165
+ norm = Numo::NArray.asarray(norm).reshape(*([1] * a.ndim)) if keepdims
166
+ return norm
167
+ end
168
+ end
169
+
170
+ if axis.nil?
171
+ axis = Array.new(a.ndim) { |d| d }
172
+ else
173
+ case axis
174
+ when Integer
175
+ axis = [axis]
176
+ when Array, Numo::NArray
177
+ axis = axis.flatten.to_a
178
+ else
179
+ raise ArgumentError, "invalid axis: #{axis}"
180
+ end
181
+ end
182
+
183
+ raise ArgumentError, "the number of dimensions of axis is inappropriate for the norm: #{axis.size}" unless axis.size == 1 || axis.size == 2
184
+ raise ArgumentError, "axis is out of range: #{axis}" unless axis.all? { |ax| (-a.ndim...a.ndim).cover?(ax) }
185
+
186
+ if axis.size == 1
187
+ ord ||= 2
188
+ raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(Numeric)
189
+
190
+ ord_inf = ord.infinite?
191
+ if ord_inf.nil?
192
+ case ord
193
+ when 0
194
+ a.class.cast(a.ne(0)).sum(axis: axis, keepdims: keepdims)
195
+ when 1
196
+ a.abs.sum(axis: axis, keepdims: keepdims)
197
+ else
198
+ (a.abs**ord).sum(axis: axis, keepdims: keepdims)**1.fdiv(ord)
199
+ end
200
+ elsif ord_inf.positive?
201
+ a.abs.max(axis: axis, keepdims: keepdims)
202
+ else
203
+ a.abs.min(axis: axis, keepdims: keepdims)
204
+ end
205
+ else
206
+ ord ||= 'fro'
207
+ raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(String) || ord.is_a?(Numeric)
208
+ raise ArgumentError, "invalid axis: #{axis}" if axis.uniq.size == 1
209
+
210
+ r_axis, c_axis = axis.map { |ax| ax.negative? ? ax + a.ndim : ax }
211
+
212
+ norm = if ord.is_a?(String)
213
+ raise ArgumentError, "invalid ord: #{ord}" unless %w[fro nuc].include?(ord)
214
+
215
+ if ord == 'fro'
216
+ Numo::NMath.sqrt((a.abs**2).sum(axis: axis))
217
+ else
218
+ b = a.transpose(c_axis, r_axis).dup
219
+ gesvd = :"#{blas_char(b)}gesvd"
220
+ s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
221
+ s.sum(axis: -1)
222
+ end
223
+ else
224
+ ord_inf = ord.infinite?
225
+ if ord_inf.nil?
226
+ case ord
227
+ when -2
228
+ b = a.transpose(c_axis, r_axis).dup
229
+ gesvd = :"#{blas_char(b)}gesvd"
230
+ s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
231
+ s.min(axis: -1)
232
+ when -1
233
+ c_axis -= 1 if c_axis > r_axis
234
+ a.abs.sum(axis: r_axis).min(axis: c_axis)
235
+ when 1
236
+ c_axis -= 1 if c_axis > r_axis
237
+ a.abs.sum(axis: r_axis).max(axis: c_axis)
238
+ when 2
239
+ b = a.transpose(c_axis, r_axis).dup
240
+ gesvd = :"#{blas_char(b)}gesvd"
241
+ s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
242
+ s.max(axis: -1)
243
+ else
244
+ raise ArgumentError, "invalid ord: #{ord}"
245
+ end
246
+ else
247
+ r_axis -= 1 if r_axis > c_axis
248
+ if ord_inf.positive?
249
+ a.abs.sum(axis: c_axis).max(axis: r_axis)
250
+ else
251
+ a.abs.sum(axis: c_axis).min(axis: r_axis)
252
+ end
253
+ end
254
+ end
255
+ if keepdims
256
+ norm = Numo::NArray.asarray(norm) unless norm.is_a?(Numo::NArray)
257
+ norm = norm.reshape(*([1] * a.ndim))
258
+ end
259
+
260
+ norm
261
+ end
262
+ end
263
+
98
264
  # Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
99
265
  #
100
266
  # @example
@@ -399,6 +565,47 @@ module Numo
399
565
  Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
400
566
  end
401
567
 
568
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix.
569
+ #
570
+ # @example
571
+ # require 'numo/tiny_linalg'
572
+ #
573
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
574
+ #
575
+ # a = Numo::DFloat.new(3, 3).rand.triu
576
+ # b = Numo::DFloat.eye(3)
577
+ #
578
+ # x = Numo::Linalg.solve(a, b)
579
+ #
580
+ # pp x
581
+ # # =>
582
+ # # Numo::DFloat#shape=[3,3]
583
+ # # [[16.1932, -52.0604, 30.5283],
584
+ # # [0, 8.61765, -17.9585],
585
+ # # [0, 0, 6.05735]]
586
+ #
587
+ # pp (b - a.dot(x)).abs.max
588
+ # # => 4.071100642430302e-16
589
+ #
590
+ # @param a [Numo::NArray] The n-by-n triangular matrix.
591
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
592
+ # @param lower [Boolean] The flag indicating whether to use the lower-triangular part of `a`.
593
+ # @return [Numo::NArray] The solusion vector / matrix `X`.
594
+ def solve_triangular(a, b, lower: false)
595
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
596
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
597
+
598
+ bchr = blas_char(a, b)
599
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
600
+
601
+ trtrs = :"#{bchr}trtrs"
602
+ uplo = lower ? 'L' : 'U'
603
+ x, info = Numo::TinyLinalg::Lapack.send(trtrs, a, b.dup, uplo: uplo)
604
+ raise "wrong value is given to the #{info}-th argument of #{trtrs} used internally" if info.negative?
605
+
606
+ x
607
+ end
608
+
402
609
  # Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
403
610
  #
404
611
  # @example
@@ -534,11 +741,6 @@ module Numo
534
741
  raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
535
742
  end
536
743
 
537
- # @!visibility private
538
- def norm(*args)
539
- raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
540
- end
541
-
542
744
  # @!visibility private
543
745
  def cond(*args)
544
746
  raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.5
4
+ version: 0.3.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-01-03 00:00:00.000000000 Z
11
+ date: 2024-01-28 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -57,6 +57,7 @@ files:
57
57
  - ext/numo/tiny_linalg/lapack/hegv.hpp
58
58
  - ext/numo/tiny_linalg/lapack/hegvd.hpp
59
59
  - ext/numo/tiny_linalg/lapack/hegvx.hpp
60
+ - ext/numo/tiny_linalg/lapack/lange.hpp
60
61
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
61
62
  - ext/numo/tiny_linalg/lapack/potrf.hpp
62
63
  - ext/numo/tiny_linalg/lapack/potrs.hpp
@@ -66,6 +67,7 @@ files:
66
67
  - ext/numo/tiny_linalg/lapack/sygv.hpp
67
68
  - ext/numo/tiny_linalg/lapack/sygvd.hpp
68
69
  - ext/numo/tiny_linalg/lapack/sygvx.hpp
70
+ - ext/numo/tiny_linalg/lapack/trtrs.hpp
69
71
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
70
72
  - ext/numo/tiny_linalg/tiny_linalg.cpp
71
73
  - ext/numo/tiny_linalg/tiny_linalg.hpp
@@ -97,7 +99,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
97
99
  - !ruby/object:Gem::Version
98
100
  version: '0'
99
101
  requirements: []
100
- rubygems_version: 3.4.22
102
+ rubygems_version: 3.4.19
101
103
  signing_key:
102
104
  specification_version: 4
103
105
  summary: Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of