numo-tiny_linalg 0.1.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 0aeb68c259f8d4c4dcd44c15caef6926cbe3f26f9dd99e9229721d02a939d7c1
4
- data.tar.gz: 5465aebebc07612812b861932fcabf14b66ee0ec389603c9ffcf06694eb69ce4
3
+ metadata.gz: dd5b3d90b9c0bc323420f2eed5363490ec7c2ec4a3b0950573f54f234d534eaf
4
+ data.tar.gz: cd0a71621d1e3e935faba0a40036b3950c0f6ac32a74cdaad46b0e313b5dfa2f
5
5
  SHA512:
6
- metadata.gz: 3e1ddd9a4f43d89635d0fc4e07fc3eab0e302b108648e00506deaec106f147f10132b6c28818f6ef054d5fa0ef52bb2c33287dcdeb413fc12268fb35edfde859
7
- data.tar.gz: 78903f772d37936f73a69624b3241d43703632f08542a2f57e38a619b6dc1ac012d8885c79056459ef4fd50102d416f996ea54269b64b919ac91a3db46330d8c
6
+ metadata.gz: f904406ce9c883d59d93afc061bfc65c96073dc9660cb65e6eefa97e2be32cc8c79ac26970a8b181cd3acbdbd412cf97c08153017956c4cc4f2b29bf28bcda46
7
+ data.tar.gz: 8f3ee9b0e0e9c3789f61148db03d43c245dab66fc4cb428488cbe668dad00b0d7fb73ef825e5dc0c5772ba817d41bb90d5fce843ece1f1adf6fe81857b9917ff
data/CHANGELOG.md CHANGED
@@ -1,5 +1,15 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.3.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.2.0...v0.3.0)] - 2023-08-13
4
+ - Add cholesky and cho_solve module functions to TinyLinalg.
5
+
6
+ **Breaking change**
7
+ - Change to raise NotImplementedError when calling a method not yet implemented in Numo::TinyLinalg.
8
+
9
+ ## [[0.2.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.2...v0.2.0)] - 2023-08-11
10
+ **Breaking change**
11
+ - Change LAPACK function to call when array b is not given to TinyLinalg.eigh method.
12
+
3
13
  ## [[0.1.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.1...v0.1.2)] - 2023-08-09
4
14
  - Add dsyev, ssyev, zheev, and cheev module functions to TinyLinalg::Lapack.
5
15
  - Add dsyevd, ssyevd, zheevd, and cheevd module functions to TinyLinalg::Lapack.
data/README.md CHANGED
@@ -6,7 +6,9 @@
6
6
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/numo-tiny_linalg/doc/)
7
7
 
8
8
  Numo::TinyLinalg is a subset library from [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) consisting only of methods used in Machine Learning algorithms.
9
- The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, and svd.
9
+ The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, cholesky, cho_solve and svd.
10
+
11
+ Note that the version numbering rule of Numo::TinyLinalg is not compatible with that of Numo::Linalg.
10
12
 
11
13
  ## Installation
12
14
  Unlike Numo::Linalg, Numo::TinyLinalg only supports OpenBLAS as a backend library for BLAS and LAPACK.
@@ -93,7 +93,25 @@ private:
93
93
  return Qnil;
94
94
  }
95
95
 
96
+ if (range == 'V' && vu <= vl) {
97
+ rb_raise(rb_eArgError, "vu must be greater than vl");
98
+ return Qnil;
99
+ }
100
+
96
101
  const size_t n = NA_SHAPE(a_nary)[1];
102
+ if (range == 'I' && (il < 1 || il > n)) {
103
+ rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
104
+ return Qnil;
105
+ }
106
+ if (range == 'I' && (iu < 1 || iu > n)) {
107
+ rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
108
+ return Qnil;
109
+ }
110
+ if (range == 'I' && iu < il) {
111
+ rb_raise(rb_eArgError, "iu must be greater than or equal to il");
112
+ return Qnil;
113
+ }
114
+
97
115
  size_t m = range != 'I' ? n : iu - il + 1;
98
116
  size_t w_shape[1] = { m };
99
117
  size_t z_shape[2] = { n, m };
@@ -114,7 +114,25 @@ private:
114
114
  return Qnil;
115
115
  }
116
116
 
117
+ if (range == 'V' && vu <= vl) {
118
+ rb_raise(rb_eArgError, "vu must be greater than vl");
119
+ return Qnil;
120
+ }
121
+
117
122
  const size_t n = NA_SHAPE(a_nary)[1];
123
+ if (range == 'I' && (il < 1 || il > n)) {
124
+ rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
125
+ return Qnil;
126
+ }
127
+ if (range == 'I' && (iu < 1 || iu > n)) {
128
+ rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
129
+ return Qnil;
130
+ }
131
+ if (range == 'I' && iu < il) {
132
+ rb_raise(rb_eArgError, "il must be less than or equal to iu");
133
+ return Qnil;
134
+ }
135
+
118
136
  size_t m = range != 'I' ? n : iu - il + 1;
119
137
  size_t w_shape[1] = { m };
120
138
  size_t z_shape[2] = { n, m };
@@ -0,0 +1,93 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DPoTrf {
4
+ lapack_int call(int matrix_layout, char uplo, lapack_int n, double* a, lapack_int lda) {
5
+ return LAPACKE_dpotrf(matrix_layout, uplo, n, a, lda);
6
+ }
7
+ };
8
+
9
+ struct SPoTrf {
10
+ lapack_int call(int matrix_layout, char uplo, lapack_int n, float* a, lapack_int lda) {
11
+ return LAPACKE_spotrf(matrix_layout, uplo, n, a, lda);
12
+ }
13
+ };
14
+
15
+ struct ZPoTrf {
16
+ lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda) {
17
+ return LAPACKE_zpotrf(matrix_layout, uplo, n, a, lda);
18
+ }
19
+ };
20
+
21
+ struct CPoTrf {
22
+ lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda) {
23
+ return LAPACKE_cpotrf(matrix_layout, uplo, n, a, lda);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, typename dtype, class LapackFn>
28
+ class PoTrf {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_potrf), -1);
32
+ }
33
+
34
+ private:
35
+ struct potrf_opt {
36
+ int matrix_layout;
37
+ char uplo;
38
+ };
39
+
40
+ static void iter_potrf(na_loop_t* const lp) {
41
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
42
+ int* info = (int*)NDL_PTR(lp, 1);
43
+ potrf_opt* opt = (potrf_opt*)(lp->opt_ptr);
44
+ const lapack_int n = NDL_SHAPE(lp, 0)[0];
45
+ const lapack_int lda = NDL_SHAPE(lp, 0)[1];
46
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, n, a, lda);
47
+ *info = static_cast<int>(i);
48
+ }
49
+
50
+ static VALUE tiny_linalg_potrf(int argc, VALUE* argv, VALUE self) {
51
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
52
+
53
+ VALUE a_vnary = Qnil;
54
+ VALUE kw_args = Qnil;
55
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
56
+ ID kw_table[2] = { rb_intern("order"), rb_intern("uplo") };
57
+ VALUE kw_values[2] = { Qundef, Qundef };
58
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
59
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
60
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+
69
+ narray_t* a_nary = NULL;
70
+ GetNArray(a_vnary, a_nary);
71
+ if (NA_NDIM(a_nary) != 2) {
72
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
73
+ return Qnil;
74
+ }
75
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
76
+ rb_raise(rb_eArgError, "input array a must be square");
77
+ return Qnil;
78
+ }
79
+
80
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
81
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
82
+ ndfunc_t ndf = { iter_potrf, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
83
+ potrf_opt opt = { matrix_layout, uplo };
84
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
85
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
86
+
87
+ RB_GC_GUARD(a_vnary);
88
+
89
+ return ret;
90
+ }
91
+ };
92
+
93
+ } // namespace TinyLinalg
@@ -0,0 +1,121 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DPoTrs {
4
+ lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
5
+ const double* a, lapack_int lda, double* b, lapack_int ldb) {
6
+ return LAPACKE_dpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
7
+ }
8
+ };
9
+
10
+ struct SPoTrs {
11
+ lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
12
+ const float* a, lapack_int lda, float* b, lapack_int ldb) {
13
+ return LAPACKE_spotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
14
+ }
15
+ };
16
+
17
+ struct ZPoTrs {
18
+ lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
19
+ const lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb) {
20
+ return LAPACKE_zpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
21
+ }
22
+ };
23
+
24
+ struct CPoTrs {
25
+ lapack_int call(int matrix_layout, char uplo, lapack_int n, lapack_int nrhs,
26
+ const lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb) {
27
+ return LAPACKE_cpotrs(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename dtype, class LapackFn>
32
+ class PoTrs {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_potrs), -1);
36
+ }
37
+
38
+ private:
39
+ struct potrs_opt {
40
+ int matrix_layout;
41
+ char uplo;
42
+ };
43
+
44
+ static void iter_potrs(na_loop_t* const lp) {
45
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
46
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
47
+ int* info = (int*)NDL_PTR(lp, 2);
48
+ potrs_opt* opt = (potrs_opt*)(lp->opt_ptr);
49
+ const lapack_int n = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
51
+ const lapack_int lda = n;
52
+ const lapack_int ldb = nrhs;
53
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->uplo, n, nrhs, a, lda, b, ldb);
54
+ *info = static_cast<int>(i);
55
+ }
56
+
57
+ static VALUE tiny_linalg_potrs(int argc, VALUE* argv, VALUE self) {
58
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
59
+
60
+ VALUE a_vnary = Qnil;
61
+ VALUE b_vnary = Qnil;
62
+ VALUE kw_args = Qnil;
63
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
64
+ ID kw_table[2] = { rb_intern("order"), rb_intern("uplo") };
65
+ VALUE kw_values[2] = { Qundef, Qundef };
66
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
67
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
68
+ const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
69
+
70
+ if (CLASS_OF(a_vnary) != nary_dtype) {
71
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
72
+ }
73
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
74
+ a_vnary = nary_dup(a_vnary);
75
+ }
76
+ if (CLASS_OF(b_vnary) != nary_dtype) {
77
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
78
+ }
79
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
80
+ b_vnary = nary_dup(b_vnary);
81
+ }
82
+
83
+ narray_t* a_nary = NULL;
84
+ GetNArray(a_vnary, a_nary);
85
+ if (NA_NDIM(a_nary) != 2) {
86
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
87
+ return Qnil;
88
+ }
89
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
90
+ rb_raise(rb_eArgError, "input array a must be square");
91
+ return Qnil;
92
+ }
93
+ narray_t* b_nary = NULL;
94
+ GetNArray(b_vnary, b_nary);
95
+ const int b_n_dims = NA_NDIM(b_nary);
96
+ if (b_n_dims != 1 && b_n_dims != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
98
+ return Qnil;
99
+ }
100
+
101
+ lapack_int n = NA_SHAPE(a_nary)[0];
102
+ lapack_int nb = NA_SHAPE(b_nary)[0];
103
+ if (n != nb) {
104
+ rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb);
105
+ }
106
+
107
+ ndfunc_arg_in_t ain[2] = { { nary_dtype, 2 }, { OVERWRITE, b_n_dims } };
108
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_potrs, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
110
+ potrs_opt opt = { matrix_layout, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(2, b_vnary, res);
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+ };
120
+
121
+ } // namespace TinyLinalg
@@ -90,7 +90,25 @@ private:
90
90
  return Qnil;
91
91
  }
92
92
 
93
+ if (range == 'V' && vu <= vl) {
94
+ rb_raise(rb_eArgError, "vu must be greater than vl");
95
+ return Qnil;
96
+ }
97
+
93
98
  const size_t n = NA_SHAPE(a_nary)[1];
99
+ if (range == 'I' && (il < 1 || il > n)) {
100
+ rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
101
+ return Qnil;
102
+ }
103
+ if (range == 'I' && (iu < 1 || iu > n)) {
104
+ rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
105
+ return Qnil;
106
+ }
107
+ if (range == 'I' && iu < il) {
108
+ rb_raise(rb_eArgError, "iu must be greater than or equal to il");
109
+ return Qnil;
110
+ }
111
+
94
112
  size_t m = range != 'I' ? n : iu - il + 1;
95
113
  size_t w_shape[1] = { m };
96
114
  size_t z_shape[2] = { n, m };
@@ -113,7 +113,25 @@ private:
113
113
  return Qnil;
114
114
  }
115
115
 
116
+ if (range == 'V' && vu <= vl) {
117
+ rb_raise(rb_eArgError, "vu must be greater than vl");
118
+ return Qnil;
119
+ }
120
+
116
121
  const size_t n = NA_SHAPE(a_nary)[1];
122
+ if (range == 'I' && (il < 1 || il > n)) {
123
+ rb_raise(rb_eArgError, "il must satisfy 1 <= il <= n");
124
+ return Qnil;
125
+ }
126
+ if (range == 'I' && (iu < 1 || iu > n)) {
127
+ rb_raise(rb_eArgError, "iu must satisfy 1 <= iu <= n");
128
+ return Qnil;
129
+ }
130
+ if (range == 'I' && iu < il) {
131
+ rb_raise(rb_eArgError, "iu must be greater than or equal to il");
132
+ return Qnil;
133
+ }
134
+
117
135
  size_t m = range != 'I' ? n : iu - il + 1;
118
136
  size_t w_shape[1] = { m };
119
137
  size_t z_shape[2] = { n, m };
@@ -51,6 +51,8 @@
51
51
  #include "lapack/hegvd.hpp"
52
52
  #include "lapack/hegvx.hpp"
53
53
  #include "lapack/orgqr.hpp"
54
+ #include "lapack/potrf.hpp"
55
+ #include "lapack/potrs.hpp"
54
56
  #include "lapack/syev.hpp"
55
57
  #include "lapack/syevd.hpp"
56
58
  #include "lapack/syevr.hpp"
@@ -310,6 +312,14 @@ extern "C" void Init_tiny_linalg(void) {
310
312
  TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
311
313
  TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
312
314
  TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
315
+ TinyLinalg::PoTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrf>::define_module_function(rb_mTinyLinalgLapack, "dpotrf");
316
+ TinyLinalg::PoTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrf>::define_module_function(rb_mTinyLinalgLapack, "spotrf");
317
+ TinyLinalg::PoTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrf>::define_module_function(rb_mTinyLinalgLapack, "zpotrf");
318
+ TinyLinalg::PoTrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CPoTrf>::define_module_function(rb_mTinyLinalgLapack, "cpotrf");
319
+ TinyLinalg::PoTrs<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DPoTrs>::define_module_function(rb_mTinyLinalgLapack, "dpotrs");
320
+ TinyLinalg::PoTrs<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SPoTrs>::define_module_function(rb_mTinyLinalgLapack, "spotrs");
321
+ TinyLinalg::PoTrs<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZPoTrs>::define_module_function(rb_mTinyLinalgLapack, "zpotrs");
322
+ TinyLinalg::PoTrs<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CPoTrs>::define_module_function(rb_mTinyLinalgLapack, "cpotrs");
313
323
  TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
314
324
  TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
315
325
  TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.1.2'
8
+ VERSION = '0.3.0'
9
9
  end
10
10
  end
@@ -48,38 +48,143 @@ module Numo
48
48
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
49
49
  # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
50
50
  # @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
51
- def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
51
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/ParameterLists, Metrics/PerceivedComplexity, Lint/UnusedMethodArgument
52
52
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
53
53
  raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
54
54
 
55
+ b_given = !b.nil?
56
+ raise ArgumentError, 'input array b must be 2-dimensional' if b_given && b.ndim != 2
57
+ raise ArgumentError, 'input array b must be square' if b_given && b.shape[0] != b.shape[1]
58
+ raise ArgumentError, "invalid array type: #{b.class}" if b_given && blas_char(b) == 'n'
59
+
55
60
  bchr = blas_char(a)
56
61
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
57
62
 
58
- unless b.nil?
59
- raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
60
- raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
61
- raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
62
- end
63
-
64
63
  jobz = vals_only ? 'N' : 'V'
65
- b = a.class.eye(a.shape[0]) if b.nil?
66
- sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
67
64
 
68
- if vals_range.nil?
69
- sy_he_gv << 'd' if turbo
70
- vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
65
+ if b_given
66
+ fnc = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
67
+ if vals_range.nil?
68
+ fnc << 'd' if turbo
69
+ vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(fnc.to_sym, a.dup, b.dup, jobz: jobz)
70
+ else
71
+ fnc << 'x'
72
+ il = vals_range.first(1)[0] + 1
73
+ iu = vals_range.last(1)[0] + 1
74
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
75
+ fnc.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
76
+ )
77
+ end
71
78
  else
72
- sy_he_gv << 'x'
73
- il = vals_range.first(1)[0] + 1
74
- iu = vals_range.last(1)[0] + 1
75
- _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
76
- sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
77
- )
79
+ fnc = %w[d s].include?(bchr) ? "#{bchr}syev" : "#{bchr}heev"
80
+ if vals_range.nil?
81
+ fnc << 'd' if turbo
82
+ vecs, vals, _info = Numo::TinyLinalg::Lapack.send(fnc.to_sym, a.dup, jobz: jobz)
83
+ else
84
+ fnc << 'r'
85
+ il = vals_range.first(1)[0] + 1
86
+ iu = vals_range.last(1)[0] + 1
87
+ _a, _m, vals, vecs, _isuppz, _info = Numo::TinyLinalg::Lapack.send(
88
+ fnc.to_sym, a.dup, jobz: jobz, range: 'I', il: il, iu: iu
89
+ )
90
+ end
78
91
  end
92
+
79
93
  vecs = nil if vals_only
94
+
80
95
  [vals, vecs]
81
96
  end
82
97
 
98
+ # Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
99
+ #
100
+ # @example
101
+ # require 'numo/tiny_linalg'
102
+ #
103
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
104
+ #
105
+ # s = Numo::DFloat.new(3, 3).rand - 0.5
106
+ # a = s.transpose.dot(s)
107
+ # u = Numo::Linalg.cholesky(a)
108
+ #
109
+ # pp u
110
+ # # =>
111
+ # # Numo::DFloat#shape=[3,3]
112
+ # # [[0.532006, 0.338183, -0.18036],
113
+ # # [0, 0.325153, 0.011721],
114
+ # # [0, 0, 0.436738]]
115
+ #
116
+ # pp (a - u.transpose.dot(u)).abs.max
117
+ # # => 1.3877787807814457e-17
118
+ #
119
+ # l = Numo::Linalg.cholesky(a, uplo: 'L')
120
+ #
121
+ # pp l
122
+ # # =>
123
+ # # Numo::DFloat#shape=[3,3]
124
+ # # [[0.532006, 0, 0],
125
+ # # [0.338183, 0.325153, 0],
126
+ # # [-0.18036, 0.011721, 0.436738]]
127
+ #
128
+ # pp (a - l.dot(l.transpose)).abs.max
129
+ # # => 1.3877787807814457e-17
130
+ #
131
+ # @param a [Numo::NArray] The n-by-n symmetric matrix.
132
+ # @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
133
+ # @return [Numo::NArray] The upper- or lower-triangular Cholesky factor of a.
134
+ def cholesky(a, uplo: 'U')
135
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
136
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
137
+
138
+ bchr = blas_char(a)
139
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
140
+
141
+ fnc = "#{bchr}potrf".to_sym
142
+ c, _info = Numo::TinyLinalg::Lapack.send(fnc, a.dup, uplo: uplo)
143
+
144
+ case uplo
145
+ when 'U'
146
+ c.triu
147
+ when 'L'
148
+ c.tril
149
+ else
150
+ raise ArgumentError, "invalid uplo: #{uplo}"
151
+ end
152
+ end
153
+
154
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` with the Cholesky factorization of `A`.
155
+ #
156
+ # @example
157
+ # require 'numo/tiny_linalg'
158
+ #
159
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
160
+ #
161
+ # s = Numo::DFloat.new(3, 3).rand - 0.5
162
+ # a = s.transpose.dot(s)
163
+ # u = Numo::Linalg.cholesky(a)
164
+ #
165
+ # b = Numo::DFloat.new(3).rand
166
+ # x = Numo::Linalg.cho_solve(u, b)
167
+ #
168
+ # puts (b - a.dot(x)).abs.max
169
+ # => 0.0
170
+ #
171
+ # @param a [Numo::NArray] The n-by-n cholesky factor.
172
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
173
+ # @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
174
+ # @return [Numo::NArray] The solution vector or matrix `X`.
175
+ def cho_solve(a, b, uplo: 'U')
176
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
177
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
178
+ raise ArgumentError, "incompatible dimensions: a.shape[0] = #{a.shape[0]} != b.shape[0] = #{b.shape[0]}" if a.shape[0] != b.shape[0]
179
+
180
+ bchr = blas_char(a, b)
181
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
182
+
183
+ fnc = "#{bchr}potrs".to_sym
184
+ x, _info = Numo::TinyLinalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
185
+ x
186
+ end
187
+
83
188
  # Computes the determinant of matrix.
84
189
  #
85
190
  # @example
@@ -256,7 +361,7 @@ module Numo
256
361
  [q, r]
257
362
  end
258
363
 
259
- # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
364
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `A`.
260
365
  #
261
366
  # @example
262
367
  # require 'numo/tiny_linalg'
@@ -279,10 +384,10 @@ module Numo
279
384
  # # => 2.1081041547796492e-16
280
385
  #
281
386
  # @param a [Numo::NArray] The n-by-n square matrix.
282
- # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
387
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
283
388
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
284
389
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
285
- # @return [Numo::NArray] The solusion vector / matrix `x`.
390
+ # @return [Numo::NArray] The solusion vector / matrix `X`.
286
391
  def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
287
392
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
288
393
  raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
@@ -353,5 +458,110 @@ module Numo
353
458
 
354
459
  [s, u, vt]
355
460
  end
461
+
462
+ # @!visibility private
463
+ def matmul(*args)
464
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
465
+ end
466
+
467
+ # @!visibility private
468
+ def matrix_power(*args)
469
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
470
+ end
471
+
472
+ # @!visibility private
473
+ def svdvals(*args)
474
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
475
+ end
476
+
477
+ # @!visibility private
478
+ def orth(*args)
479
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
480
+ end
481
+
482
+ # @!visibility private
483
+ def null_space(*args)
484
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
485
+ end
486
+
487
+ # @!visibility private
488
+ def lu(*args)
489
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
490
+ end
491
+
492
+ # @!visibility private
493
+ def lu_fact(*args)
494
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
495
+ end
496
+
497
+ # @!visibility private
498
+ def lu_inv(*args)
499
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
500
+ end
501
+
502
+ # @!visibility private
503
+ def lu_solve(*args)
504
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
505
+ end
506
+
507
+ # @!visibility private
508
+ def ldl(*args)
509
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
510
+ end
511
+
512
+ # @!visibility private
513
+ def cho_fact(*args)
514
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
515
+ end
516
+
517
+ # @!visibility private
518
+ def cho_inv(*args)
519
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
520
+ end
521
+
522
+ # @!visibility private
523
+ def eig(*args)
524
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
525
+ end
526
+
527
+ # @!visibility private
528
+ def eigvals(*args)
529
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
530
+ end
531
+
532
+ # @!visibility private
533
+ def eigvalsh(*args)
534
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
535
+ end
536
+
537
+ # @!visibility private
538
+ def norm(*args)
539
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
540
+ end
541
+
542
+ # @!visibility private
543
+ def cond(*args)
544
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
545
+ end
546
+
547
+ # @!visibility private
548
+ def slogdet(*args)
549
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
550
+ end
551
+
552
+ # @!visibility private
553
+ def matrix_rank(*args)
554
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
555
+ end
556
+
557
+ # @!visibility private
558
+ def lstsq(*args)
559
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
560
+ end
561
+
562
+ # @!visibility private
563
+ def expm(*args)
564
+ raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
565
+ end
356
566
  end
357
567
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-08 00:00:00.000000000 Z
11
+ date: 2023-08-13 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -59,6 +59,8 @@ files:
59
59
  - ext/numo/tiny_linalg/lapack/hegvd.hpp
60
60
  - ext/numo/tiny_linalg/lapack/hegvx.hpp
61
61
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
62
+ - ext/numo/tiny_linalg/lapack/potrf.hpp
63
+ - ext/numo/tiny_linalg/lapack/potrs.hpp
62
64
  - ext/numo/tiny_linalg/lapack/syev.hpp
63
65
  - ext/numo/tiny_linalg/lapack/syevd.hpp
64
66
  - ext/numo/tiny_linalg/lapack/syevr.hpp