numo-tiny_linalg 0.3.4 → 0.3.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d640eb16ee97cc31ec5a5c15d9600b9cd3b01395954e5fab80110251445ff56b
4
- data.tar.gz: 5133899439b92170f1cffd13a3ce10f792ce5e4434e733d7767b88a7ce3f30a2
3
+ metadata.gz: c508a2c2f31a965c26824084f1eeb1a2caf53ea82a3d23e523980d6cf78e6259
4
+ data.tar.gz: 4bffa14a1a73bfb60230c65fa35c98eeb957e176b8af2003f9c3e5b8aaf89b19
5
5
  SHA512:
6
- metadata.gz: d53553a0e5aabb16012377acaf2b1b3ff47e039dad2f44ec7c70a8e62a038153366b1b5a6d57a7d2c6e44c5ffb991d29132f3eda3a0cb6ff38987b117f7f6ccb
7
- data.tar.gz: e8fa35b07eac5f69d4a1a98118cb34bc2cbcfd506852659b0c4a59990a047326348b5d7a612940fb861b635d03ec4efa4420302f01f616f979b58b63b965defa
6
+ metadata.gz: 4ad7a9af39fdc0c5ffb9ba7fc51dcf037fe8ffe7d81c33c70a0c669b384127929efee3a9ceb8440b7d1ec10c79ee66b362aed7098befbe02a892d4490961f169
7
+ data.tar.gz: 91c9935714ef5b929718affc9cfae518f6e08f85063f673f26ffa7c137bb7fe62d41c20386704657dccb89feebf57bdaed5b182750c6f0fe3ad5b20e55cc4221
data/CHANGELOG.md CHANGED
@@ -1,4 +1,15 @@
1
1
  ## [Unreleased]
2
+ ## [[0.3.6](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.5...v0.3.6)] - 2024-01-28
3
+
4
+ - Add solve_triangular module function to TinyLinalg.
5
+ - The solve_triangular is not implemented in Numo::Linalg, but I have implemented it because it uses some machine learning algorithms.
6
+ - Add dtrtrs, strtrs, ztrtrs, and ctrtrs module functions to TinyLinalg::Lapack.
7
+ - Add norm module function to TinyLinalg.
8
+ - Add dlange, slange, zlange, and clange module functions to TinyLinalg::Lapack.
9
+
10
+ ## [[0.3.5](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.4...v0.3.5)] - 2024-01-03
11
+ - Bump OpenBLAS to be downloaded from 0.3.25 to 0.3.26.
12
+ - Minor changes using RuboCop.
2
13
 
3
14
  ## [[0.3.4](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.3.3...v0.3.4)] - 2023-11-19
4
15
  - Bump OpenBLAS to be downloaded from 0.3.24 to 0.3.25.
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2023 Atsushi Tatsuma
1
+ Copyright (c) 2023-2024 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
@@ -46,8 +46,8 @@ if build_openblas
46
46
 
47
47
  VENDOR_DIR = File.expand_path("#{__dir__}/../../../vendor")
48
48
  TINYLINALG_DIR = File.expand_path("#{__dir__}/../../../lib/numo/tiny_linalg")
49
- OPENBLAS_VER = '0.3.25'
50
- OPENBLAS_KEY = '48384e324cd1cdcfbdb0d2e16ca55327'
49
+ OPENBLAS_VER = '0.3.26'
50
+ OPENBLAS_KEY = 'bd496a1c81769ed19a161c1f8f904ccd'
51
51
  OPENBLAS_URI = "https://github.com/OpenMathLib/OpenBLAS/archive/v#{OPENBLAS_VER}.tar.gz"
52
52
  OPENBLAS_TGZ = "#{VENDOR_DIR}/tmp/openblas.tgz"
53
53
 
@@ -83,8 +83,7 @@ if build_openblas
83
83
  end
84
84
  end
85
85
 
86
- libopenblas_dir = on_windows ? TINYLINALG_DIR : "#{VENDOR_DIR}/lib"
87
- abort('libopenblas is not found.') unless find_library('openblas', nil, libopenblas_dir)
86
+ abort('libopenblas is not found.') unless find_library('openblas', nil, "#{VENDOR_DIR}/lib")
88
87
  abort('openblas_config.h is not found.') unless find_header('openblas_config.h', nil, "#{VENDOR_DIR}/include")
89
88
  abort('cblas.h is not found.') unless find_header('cblas.h', nil, "#{VENDOR_DIR}/include")
90
89
  abort('lapacke.h is not found.') unless find_header('lapacke.h', nil, "#{VENDOR_DIR}/include")
@@ -0,0 +1,89 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DLanGe {
4
+ double call(int matrix_layout, char norm, lapack_int m, lapack_int n, const double* a, lapack_int lda) {
5
+ return LAPACKE_dlange(matrix_layout, norm, m, n, a, lda);
6
+ }
7
+ };
8
+
9
+ struct SLanGe {
10
+ float call(int matrix_layout, char norm, lapack_int m, lapack_int n, const float* a, lapack_int lda) {
11
+ return LAPACKE_slange(matrix_layout, norm, m, n, a, lda);
12
+ }
13
+ };
14
+
15
+ struct ZLanGe {
16
+ double call(int matrix_layout, char norm, lapack_int m, lapack_int n, const lapack_complex_double* a, lapack_int lda) {
17
+ return LAPACKE_zlange(matrix_layout, norm, m, n, a, lda);
18
+ }
19
+ };
20
+
21
+ struct CLanGe {
22
+ float call(int matrix_layout, char norm, lapack_int m, lapack_int n, const lapack_complex_float* a, lapack_int lda) {
23
+ return LAPACKE_clange(matrix_layout, norm, m, n, a, lda);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, typename dtype, class LapackFn>
28
+ class LanGe {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_lange), -1);
32
+ }
33
+
34
+ private:
35
+ struct lange_opt {
36
+ int matrix_layout;
37
+ char norm;
38
+ };
39
+
40
+ static void iter_lange(na_loop_t* const lp) {
41
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
42
+ dtype* d = (dtype*)NDL_PTR(lp, 1);
43
+ lange_opt* opt = (lange_opt*)(lp->opt_ptr);
44
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
45
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
46
+ const lapack_int lda = n;
47
+ *d = LapackFn().call(opt->matrix_layout, opt->norm, m, n, a, lda);
48
+ }
49
+
50
+ static VALUE tiny_linalg_lange(int argc, VALUE* argv, VALUE self) {
51
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
52
+
53
+ VALUE a_vnary = Qnil;
54
+ VALUE kw_args = Qnil;
55
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
56
+ ID kw_table[2] = { rb_intern("order"), rb_intern("norm") };
57
+ VALUE kw_values[2] = { Qundef, Qundef };
58
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
59
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
60
+ const char norm = kw_values[1] != Qundef ? NUM2CHR(kw_values[1]) : 'F';
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+
69
+ narray_t* a_nary = NULL;
70
+ GetNArray(a_vnary, a_nary);
71
+ if (NA_NDIM(a_nary) != 2) {
72
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
73
+ return Qnil;
74
+ }
75
+
76
+ ndfunc_arg_in_t ain[1] = { { nary_dtype, 2 } };
77
+ size_t shape_out[1] = { 1 };
78
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
79
+ ndfunc_t ndf = { iter_lange, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
80
+ lange_opt opt = { matrix_layout, norm };
81
+ VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
82
+
83
+ RB_GC_GUARD(a_vnary);
84
+
85
+ return ret;
86
+ }
87
+ };
88
+
89
+ } // namespace TinyLinalg
@@ -0,0 +1,126 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DTrTrs {
4
+ lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
5
+ const double* a, lapack_int lda, double* b, lapack_int ldb) {
6
+ return LAPACKE_dtrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
7
+ }
8
+ };
9
+
10
+ struct STrTrs {
11
+ lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
12
+ const float* a, lapack_int lda, float* b, lapack_int ldb) {
13
+ return LAPACKE_strtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
14
+ }
15
+ };
16
+
17
+ struct ZTrTrs {
18
+ lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
19
+ const lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb) {
20
+ return LAPACKE_ztrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
21
+ }
22
+ };
23
+
24
+ struct CTrTrs {
25
+ lapack_int call(int matrix_layout, char uplo, char trans, char diag, lapack_int n, lapack_int nrhs,
26
+ const lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb) {
27
+ return LAPACKE_ctrtrs(matrix_layout, uplo, trans, diag, n, nrhs, a, lda, b, ldb);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename dtype, class LapackFn>
32
+ class TrTrs {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_trtrs), -1);
36
+ }
37
+
38
+ private:
39
+ struct trtrs_opt {
40
+ int matrix_layout;
41
+ char uplo;
42
+ char trans;
43
+ char diag;
44
+ };
45
+
46
+ static void iter_trtrs(na_loop_t* const lp) {
47
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
48
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
49
+ int* info = (int*)NDL_PTR(lp, 2);
50
+ trtrs_opt* opt = (trtrs_opt*)(lp->opt_ptr);
51
+ const lapack_int n = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
53
+ const lapack_int lda = n;
54
+ const lapack_int ldb = nrhs;
55
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, opt->trans, opt->diag, n, nrhs, a, lda, b, ldb);
56
+ *info = static_cast<int>(i);
57
+ }
58
+
59
+ static VALUE tiny_linalg_trtrs(int argc, VALUE* argv, VALUE self) {
60
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
61
+
62
+ VALUE a_vnary = Qnil;
63
+ VALUE b_vnary = Qnil;
64
+ VALUE kw_args = Qnil;
65
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
66
+ ID kw_table[4] = { rb_intern("order"), rb_intern("uplo"), rb_intern("trans"), rb_intern("diag") };
67
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
68
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
69
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
70
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
71
+ const char trans = kw_values[2] != Qundef ? NUM2CHR(kw_values[2]) : 'N';
72
+ const char diag = kw_values[3] != Qundef ? NUM2CHR(kw_values[3]) : 'N';
73
+
74
+ if (CLASS_OF(a_vnary) != nary_dtype) {
75
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
76
+ }
77
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
78
+ a_vnary = nary_dup(a_vnary);
79
+ }
80
+ if (CLASS_OF(b_vnary) != nary_dtype) {
81
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
82
+ }
83
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
84
+ b_vnary = nary_dup(b_vnary);
85
+ }
86
+
87
+ narray_t* a_nary = NULL;
88
+ GetNArray(a_vnary, a_nary);
89
+ if (NA_NDIM(a_nary) != 2) {
90
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
91
+ return Qnil;
92
+ }
93
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
94
+ rb_raise(rb_eArgError, "input array a must be square");
95
+ return Qnil;
96
+ }
97
+
98
+ narray_t* b_nary = NULL;
99
+ GetNArray(b_vnary, b_nary);
100
+ const int b_n_dims = NA_NDIM(b_nary);
101
+ if (b_n_dims != 1 && b_n_dims != 2) {
102
+ rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
103
+ return Qnil;
104
+ }
105
+
106
+ lapack_int n = NA_SHAPE(a_nary)[0];
107
+ lapack_int nb = NA_SHAPE(b_nary)[0];
108
+ if (n != nb) {
109
+ rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb);
110
+ }
111
+
112
+ ndfunc_arg_in_t ain[2] = { { nary_dtype, 2 }, { OVERWRITE, b_n_dims } };
113
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
114
+ ndfunc_t ndf = { iter_trtrs, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
115
+ trtrs_opt opt = { matrix_layout, uplo, trans, diag };
116
+ VALUE info = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
117
+ VALUE ret = rb_ary_new3(2, b_vnary, info);
118
+
119
+ RB_GC_GUARD(a_vnary);
120
+ RB_GC_GUARD(b_vnary);
121
+
122
+ return ret;
123
+ }
124
+ };
125
+
126
+ } // namespace TinyLinalg
@@ -1,5 +1,5 @@
1
1
  /**
2
- * Copyright (c) 2023 Atsushi Tatsuma
2
+ * Copyright (c) 2023-2024 Atsushi Tatsuma
3
3
  * All rights reserved.
4
4
  *
5
5
  * Redistribution and use in source and binary forms, with or without
@@ -50,6 +50,7 @@
50
50
  #include "lapack/hegv.hpp"
51
51
  #include "lapack/hegvd.hpp"
52
52
  #include "lapack/hegvx.hpp"
53
+ #include "lapack/lange.hpp"
53
54
  #include "lapack/orgqr.hpp"
54
55
  #include "lapack/potrf.hpp"
55
56
  #include "lapack/potrs.hpp"
@@ -59,6 +60,7 @@
59
60
  #include "lapack/sygv.hpp"
60
61
  #include "lapack/sygvd.hpp"
61
62
  #include "lapack/sygvx.hpp"
63
+ #include "lapack/trtrs.hpp"
62
64
  #include "lapack/ungqr.hpp"
63
65
 
64
66
  VALUE rb_mTinyLinalg;
@@ -314,6 +316,10 @@ extern "C" void Init_tiny_linalg(void) {
314
316
  TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
315
317
  TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
316
318
  TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
319
+ TinyLinalg::TrTrs<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DTrTrs>::define_module_function(rb_mTinyLinalgLapack, "dtrtrs");
320
+ TinyLinalg::TrTrs<TinyLinalg::numo_cSFloatId, float, TinyLinalg::STrTrs>::define_module_function(rb_mTinyLinalgLapack, "strtrs");
321
+ TinyLinalg::TrTrs<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZTrTrs>::define_module_function(rb_mTinyLinalgLapack, "ztrtrs");
322
+ TinyLinalg::TrTrs<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CTrTrs>::define_module_function(rb_mTinyLinalgLapack, "ctrtrs");
317
323
  TinyLinalg::PoTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrf>::define_module_function(rb_mTinyLinalgLapack, "dpotrf");
318
324
  TinyLinalg::PoTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrf>::define_module_function(rb_mTinyLinalgLapack, "spotrf");
319
325
  TinyLinalg::PoTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrf>::define_module_function(rb_mTinyLinalgLapack, "zpotrf");
@@ -354,6 +360,10 @@ extern "C" void Init_tiny_linalg(void) {
354
360
  TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
355
361
  TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
356
362
  TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
363
+ TinyLinalg::LanGe<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DLanGe>::define_module_function(rb_mTinyLinalgLapack, "dlange");
364
+ TinyLinalg::LanGe<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SLanGe>::define_module_function(rb_mTinyLinalgLapack, "slange");
365
+ TinyLinalg::LanGe<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZLanGe>::define_module_function(rb_mTinyLinalgLapack, "zlange");
366
+ TinyLinalg::LanGe<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CLanGe>::define_module_function(rb_mTinyLinalgLapack, "clange");
357
367
 
358
368
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
359
369
  rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
@@ -1,5 +1,5 @@
1
1
  /**
2
- * Copyright (c) 2023 Atsushi Tatsuma
2
+ * Copyright (c) 2023-2024 Atsushi Tatsuma
3
3
  * All rights reserved.
4
4
  *
5
5
  * Redistribution and use in source and binary forms, with or without
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.3.4'
8
+ VERSION = '0.3.6'
9
9
  end
10
10
  end
@@ -95,6 +95,172 @@ module Numo
95
95
  [vals, vecs]
96
96
  end
97
97
 
98
+ # Computes the matrix or vector norm.
99
+ #
100
+ # | ord | matrix norm | vector norm |
101
+ # | ----- | ---------------------- | --------------------------- |
102
+ # | nil | Frobenius norm | 2-norm |
103
+ # | 'fro' | Frobenius norm | - |
104
+ # | 'nuc' | nuclear norm | - |
105
+ # | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
106
+ # | 0 | - | (x.ne 0).sum |
107
+ # | 1 | x.abs.sum(axis:-2).max | same as below |
108
+ # | 2 | 2-norm (max sing_vals) | same as below |
109
+ # | other | - | (x.abs**ord).sum**(1.0/ord) |
110
+ #
111
+ # @example
112
+ # require 'numo/tiny_linalg'
113
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
114
+ #
115
+ # # matrix norm
116
+ # x = Numo::DFloat[[1, 2, -3, 1], [-4, 1, 8, 2]]
117
+ # pp Numo::Linalg.norm(x)
118
+ # # => 10
119
+ #
120
+ # # vector norm
121
+ # x = Numo::DFloat[3, -4]
122
+ # pp Numo::Linalg.norm(x)
123
+ # # => 5
124
+ #
125
+ # @param a [Numo::NArray] The matrix or vector (>= 1-dimensinal NArray)
126
+ # @param ord [String/Numeric] The order of the norm.
127
+ # @param axis [Integer/Array] The applied axes.
128
+ # @param keepdims [Bool] The flag indicating whether to leave the normed axes in the result as dimensions with size one.
129
+ # @return [Numo::NArray/Numeric] The norm of the matrix or vectors.
130
+ def norm(a, ord = nil, axis: nil, keepdims: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
131
+ a = Numo::NArray.asarray(a) unless a.is_a?(Numo::NArray)
132
+
133
+ return 0.0 if a.empty?
134
+
135
+ # for compatibility with Numo::Linalg.norm
136
+ if ord.is_a?(String)
137
+ if ord == 'inf'
138
+ ord = Float::INFINITY
139
+ elsif ord == '-inf'
140
+ ord = -Float::INFINITY
141
+ end
142
+ end
143
+
144
+ if axis.nil?
145
+ norm = case a.ndim
146
+ when 1
147
+ Numo::TinyLinalg::Blas.send(:"#{blas_char(a)}nrm2", a) if ord.nil? || ord == 2
148
+ when 2
149
+ if ord.nil? || ord == 'fro'
150
+ Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'F')
151
+ elsif ord.is_a?(Numeric)
152
+ if ord == 1
153
+ Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: '1')
154
+ elsif !ord.infinite?.nil? && ord.infinite?.positive?
155
+ Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'I')
156
+ end
157
+ end
158
+ else
159
+ if ord.nil?
160
+ b = a.flatten.dup
161
+ Numo::TinyLinalg::Blas.send(:"#{blas_char(b)}nrm2", b)
162
+ end
163
+ end
164
+ unless norm.nil?
165
+ norm = Numo::NArray.asarray(norm).reshape(*([1] * a.ndim)) if keepdims
166
+ return norm
167
+ end
168
+ end
169
+
170
+ if axis.nil?
171
+ axis = Array.new(a.ndim) { |d| d }
172
+ else
173
+ case axis
174
+ when Integer
175
+ axis = [axis]
176
+ when Array, Numo::NArray
177
+ axis = axis.flatten.to_a
178
+ else
179
+ raise ArgumentError, "invalid axis: #{axis}"
180
+ end
181
+ end
182
+
183
+ raise ArgumentError, "the number of dimensions of axis is inappropriate for the norm: #{axis.size}" unless axis.size == 1 || axis.size == 2
184
+ raise ArgumentError, "axis is out of range: #{axis}" unless axis.all? { |ax| (-a.ndim...a.ndim).cover?(ax) }
185
+
186
+ if axis.size == 1
187
+ ord ||= 2
188
+ raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(Numeric)
189
+
190
+ ord_inf = ord.infinite?
191
+ if ord_inf.nil?
192
+ case ord
193
+ when 0
194
+ a.class.cast(a.ne(0)).sum(axis: axis, keepdims: keepdims)
195
+ when 1
196
+ a.abs.sum(axis: axis, keepdims: keepdims)
197
+ else
198
+ (a.abs**ord).sum(axis: axis, keepdims: keepdims)**1.fdiv(ord)
199
+ end
200
+ elsif ord_inf.positive?
201
+ a.abs.max(axis: axis, keepdims: keepdims)
202
+ else
203
+ a.abs.min(axis: axis, keepdims: keepdims)
204
+ end
205
+ else
206
+ ord ||= 'fro'
207
+ raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(String) || ord.is_a?(Numeric)
208
+ raise ArgumentError, "invalid axis: #{axis}" if axis.uniq.size == 1
209
+
210
+ r_axis, c_axis = axis.map { |ax| ax.negative? ? ax + a.ndim : ax }
211
+
212
+ norm = if ord.is_a?(String)
213
+ raise ArgumentError, "invalid ord: #{ord}" unless %w[fro nuc].include?(ord)
214
+
215
+ if ord == 'fro'
216
+ Numo::NMath.sqrt((a.abs**2).sum(axis: axis))
217
+ else
218
+ b = a.transpose(c_axis, r_axis).dup
219
+ gesvd = :"#{blas_char(b)}gesvd"
220
+ s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
221
+ s.sum(axis: -1)
222
+ end
223
+ else
224
+ ord_inf = ord.infinite?
225
+ if ord_inf.nil?
226
+ case ord
227
+ when -2
228
+ b = a.transpose(c_axis, r_axis).dup
229
+ gesvd = :"#{blas_char(b)}gesvd"
230
+ s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
231
+ s.min(axis: -1)
232
+ when -1
233
+ c_axis -= 1 if c_axis > r_axis
234
+ a.abs.sum(axis: r_axis).min(axis: c_axis)
235
+ when 1
236
+ c_axis -= 1 if c_axis > r_axis
237
+ a.abs.sum(axis: r_axis).max(axis: c_axis)
238
+ when 2
239
+ b = a.transpose(c_axis, r_axis).dup
240
+ gesvd = :"#{blas_char(b)}gesvd"
241
+ s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
242
+ s.max(axis: -1)
243
+ else
244
+ raise ArgumentError, "invalid ord: #{ord}"
245
+ end
246
+ else
247
+ r_axis -= 1 if r_axis > c_axis
248
+ if ord_inf.positive?
249
+ a.abs.sum(axis: c_axis).max(axis: r_axis)
250
+ else
251
+ a.abs.sum(axis: c_axis).min(axis: r_axis)
252
+ end
253
+ end
254
+ end
255
+ if keepdims
256
+ norm = Numo::NArray.asarray(norm) unless norm.is_a?(Numo::NArray)
257
+ norm = norm.reshape(*([1] * a.ndim))
258
+ end
259
+
260
+ norm
261
+ end
262
+ end
263
+
98
264
  # Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
99
265
  #
100
266
  # @example
@@ -138,7 +304,7 @@ module Numo
138
304
  bchr = blas_char(a)
139
305
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
140
306
 
141
- fnc = "#{bchr}potrf".to_sym
307
+ fnc = :"#{bchr}potrf"
142
308
  c, _info = Numo::TinyLinalg::Lapack.send(fnc, a.dup, uplo: uplo)
143
309
 
144
310
  case uplo
@@ -180,7 +346,7 @@ module Numo
180
346
  bchr = blas_char(a, b)
181
347
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
182
348
 
183
- fnc = "#{bchr}potrs".to_sym
349
+ fnc = :"#{bchr}potrs"
184
350
  x, _info = Numo::TinyLinalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
185
351
  x
186
352
  end
@@ -205,7 +371,7 @@ module Numo
205
371
  bchr = blas_char(a)
206
372
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
207
373
 
208
- getrf = "#{bchr}getrf".to_sym
374
+ getrf = :"#{bchr}getrf"
209
375
  lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
210
376
 
211
377
  if info.zero?
@@ -248,8 +414,8 @@ module Numo
248
414
  bchr = blas_char(a)
249
415
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
250
416
 
251
- getrf = "#{bchr}getrf".to_sym
252
- getri = "#{bchr}getri".to_sym
417
+ getrf = :"#{bchr}getrf"
418
+ getri = :"#{bchr}getri"
253
419
 
254
420
  lu, piv, info = Numo::TinyLinalg::Lapack.send(getrf, a.dup)
255
421
  if info.zero?
@@ -336,7 +502,7 @@ module Numo
336
502
  bchr = blas_char(a)
337
503
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
338
504
 
339
- geqrf = "#{bchr}geqrf".to_sym
505
+ geqrf = :"#{bchr}geqrf"
340
506
  qr, tau, = Numo::TinyLinalg::Lapack.send(geqrf, a.dup)
341
507
 
342
508
  return [qr, tau] if mode == 'raw'
@@ -346,7 +512,7 @@ module Numo
346
512
 
347
513
  return r if mode == 'r'
348
514
 
349
- org_ung_qr = %w[d s].include?(bchr) ? "#{bchr}orgqr".to_sym : "#{bchr}ungqr".to_sym
515
+ org_ung_qr = %w[d s].include?(bchr) ? :"#{bchr}orgqr" : :"#{bchr}ungqr"
350
516
 
351
517
  q = if m < n
352
518
  Numo::TinyLinalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
@@ -395,10 +561,51 @@ module Numo
395
561
  bchr = blas_char(a, b)
396
562
  raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
397
563
 
398
- gesv = "#{bchr}gesv".to_sym
564
+ gesv = :"#{bchr}gesv"
399
565
  Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
400
566
  end
401
567
 
568
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix.
569
+ #
570
+ # @example
571
+ # require 'numo/tiny_linalg'
572
+ #
573
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
574
+ #
575
+ # a = Numo::DFloat.new(3, 3).rand.triu
576
+ # b = Numo::DFloat.eye(3)
577
+ #
578
+ # x = Numo::Linalg.solve(a, b)
579
+ #
580
+ # pp x
581
+ # # =>
582
+ # # Numo::DFloat#shape=[3,3]
583
+ # # [[16.1932, -52.0604, 30.5283],
584
+ # # [0, 8.61765, -17.9585],
585
+ # # [0, 0, 6.05735]]
586
+ #
587
+ # pp (b - a.dot(x)).abs.max
588
+ # # => 4.071100642430302e-16
589
+ #
590
+ # @param a [Numo::NArray] The n-by-n triangular matrix.
591
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
592
+ # @param lower [Boolean] The flag indicating whether to use the lower-triangular part of `a`.
593
+ # @return [Numo::NArray] The solusion vector / matrix `X`.
594
+ def solve_triangular(a, b, lower: false)
595
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
596
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
597
+
598
+ bchr = blas_char(a, b)
599
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
600
+
601
+ trtrs = :"#{bchr}trtrs"
602
+ uplo = lower ? 'L' : 'U'
603
+ x, info = Numo::TinyLinalg::Lapack.send(trtrs, a, b.dup, uplo: uplo)
604
+ raise "wrong value is given to the #{info}-th argument of #{trtrs} used internally" if info.negative?
605
+
606
+ x
607
+ end
608
+
402
609
  # Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
403
610
  #
404
611
  # @example
@@ -443,10 +650,10 @@ module Numo
443
650
 
444
651
  case driver.to_s
445
652
  when 'sdd'
446
- gesdd = "#{bchr}gesdd".to_sym
653
+ gesdd = :"#{bchr}gesdd"
447
654
  s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
448
655
  when 'svd'
449
- gesvd = "#{bchr}gesvd".to_sym
656
+ gesvd = :"#{bchr}gesvd"
450
657
  s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
451
658
  else
452
659
  raise ArgumentError, "invalid driver: #{driver}"
@@ -534,11 +741,6 @@ module Numo
534
741
  raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
535
742
  end
536
743
 
537
- # @!visibility private
538
- def norm(*args)
539
- raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
540
- end
541
-
542
744
  # @!visibility private
543
745
  def cond(*args)
544
746
  raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.4
4
+ version: 0.3.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-11-19 00:00:00.000000000 Z
11
+ date: 2024-01-28 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -57,6 +57,7 @@ files:
57
57
  - ext/numo/tiny_linalg/lapack/hegv.hpp
58
58
  - ext/numo/tiny_linalg/lapack/hegvd.hpp
59
59
  - ext/numo/tiny_linalg/lapack/hegvx.hpp
60
+ - ext/numo/tiny_linalg/lapack/lange.hpp
60
61
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
61
62
  - ext/numo/tiny_linalg/lapack/potrf.hpp
62
63
  - ext/numo/tiny_linalg/lapack/potrs.hpp
@@ -66,6 +67,7 @@ files:
66
67
  - ext/numo/tiny_linalg/lapack/sygv.hpp
67
68
  - ext/numo/tiny_linalg/lapack/sygvd.hpp
68
69
  - ext/numo/tiny_linalg/lapack/sygvx.hpp
70
+ - ext/numo/tiny_linalg/lapack/trtrs.hpp
69
71
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
70
72
  - ext/numo/tiny_linalg/tiny_linalg.cpp
71
73
  - ext/numo/tiny_linalg/tiny_linalg.hpp
@@ -97,7 +99,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
97
99
  - !ruby/object:Gem::Version
98
100
  version: '0'
99
101
  requirements: []
100
- rubygems_version: 3.4.20
102
+ rubygems_version: 3.4.19
101
103
  signing_key:
102
104
  specification_version: 4
103
105
  summary: Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of