numo-tiny_linalg 0.0.1 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 15ddc75f758bbff2ef6887db5d71c13dd517f7eeb0c4d3ab181c5c2db8b8f995
4
- data.tar.gz: 20350f1d084a31e51e05c317dddc23f6285dd27430176a753edd3eea935331be
3
+ metadata.gz: bbebba3b506ab283688f9d0935739e14c09106918a98571b9addb64b6689cc75
4
+ data.tar.gz: 85eaa28da383e21a4503407667baacb95fb5382f20f54804f02e4c2d49c15cd1
5
5
  SHA512:
6
- metadata.gz: 457d487b20bfffb3eade0fc80120f5de7acfa3f0678550f3b844abac139e7b6db469fdcf6b686337755ea2b49b8bad54ea66381368d2e4e554fdb32df6e2d87d
7
- data.tar.gz: 217765d951cf0d790e8620e1a4e4d28a884a5955bdca32cc4386fd4951c8ae2a133be029c24b20e987ce6fb2c7adfa101fe327c55cce327c2e89f4ff4d54a874
6
+ metadata.gz: 262e38a4bbbbf6141cca723830f93c80cc9d33ac8a34753e3af1b9b5b239f4576140f18ad43999e8d0568438380f8bb2f6cf3f78c1f5d35402d57fcf55e4e253
7
+ data.tar.gz: ba6e767f8728022dff634e1a3824a166c5b3f4d17057b65f2e279a2ccda293c2432aea07783708fdd9f1d810e7a9bc9428ae2b9fe559b0bb519d8b1cb828b01e
data/CHANGELOG.md CHANGED
@@ -1,5 +1,20 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
4
+ - Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
5
+ - Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
6
+ - Add det module function to TinyLinalg.
7
+ - Add pinv module function to TinyLinalg.
8
+ - Add qr module function to TinyLinalg.
9
+
10
+ ## [[0.0.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.1...v0.0.2)] - 2023-07-26
11
+ - Add automatic build of OpenBLAS if it is not found.
12
+ - Add dgesv, sgesv, zgesv, and cgesv module functions to TinyLinalg::Lapack.
13
+ - Add dgetrf, sgetrf, zgetrf, and cgetrf module functions to TinyLinalg::Lapack.
14
+ - Add dgetri, sgetri, zgetri, and cgetri module functions to TinyLinalg::Lapack.
15
+ - Add solve module function to TinyLinalg.
16
+ - Add inv module function to TinyLinalg.
17
+
3
18
  ## [0.0.1] - 2023-07-14
4
19
 
5
20
  - Initial release
data/README.md CHANGED
@@ -1,7 +1,9 @@
1
1
  # Numo::TinyLinalg
2
2
 
3
+ [![Gem Version](https://badge.fury.io/rb/numo-tiny_linalg.svg)](https://badge.fury.io/rb/numo-tiny_linalg)
3
4
  [![Build Status](https://github.com/yoshoku/numo-tiny_linalg/actions/workflows/main.yml/badge.svg)](https://github.com/yoshoku/numo-tiny_linalg/actions/workflows/main.yml)
4
5
  [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/numo-tiny_linalg/blob/main/LICENSE.txt)
6
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/numo-tiny_linalg/doc/)
5
7
 
6
8
  Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
7
9
 
@@ -2,6 +2,12 @@
2
2
 
3
3
  require 'mkmf'
4
4
  require 'numo/narray'
5
+ require 'open-uri'
6
+ require 'etc'
7
+ require 'fileutils'
8
+ require 'open3'
9
+ require 'digest/md5'
10
+ require 'rubygems/package'
5
11
 
6
12
  $LOAD_PATH.each do |lp|
7
13
  if File.exist?(File.join(lp, 'numo/numo/narray.h'))
@@ -22,33 +28,75 @@ if RUBY_PLATFORM.match?(/mswin|cygwin|mingw/)
22
28
  abort 'libnarray.a is not found' unless have_library('narray', 'nary_new')
23
29
  end
24
30
 
25
- if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
26
- try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
27
- $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
28
- end
31
+ abort 'libstdc++ is not found.' unless have_library('stdc++')
32
+ $CXXFLAGS << ' -std=c++11'
29
33
 
30
- use_accelerate = false
31
34
  # NOTE: Accelerate framework does not support LAPACKE.
35
+ # use_accelerate = false
32
36
  # if RUBY_PLATFORM.include?('darwin') && have_framework('Accelerate')
33
37
  # $CFLAGS << ' -DTINYLINALG_USE_ACCELERATE'
34
38
  # use_accelerate = true
35
39
  # end
36
40
 
37
- unless use_accelerate
38
- if have_library('openblas')
39
- $CFLAGS << ' -DTINYLINALG_USE_OPENBLAS'
40
- else
41
- abort 'libblas is not found' unless have_library('blas')
42
- $CFLAGS << ' -DTINYLINALG_USE_BLAS'
41
+ build_openblas = false
42
+ unless find_library('openblas', 'LAPACKE_dsyevr')
43
+ build_openblas = true unless have_library('openblas')
44
+ build_openblas = true unless have_library('lapacke')
45
+ end
46
+ build_openblas = true unless have_header('cblas.h')
47
+ build_openblas = true unless have_header('lapacke.h')
48
+ build_openblas = true unless have_header('openblas_config.h')
49
+
50
+ if build_openblas
51
+ warn 'BLAS and LAPACKE APIs are not found. Downloading and Building OpenBLAS...'
52
+
53
+ VENDOR_DIR = File.expand_path("#{__dir__}/../../../vendor")
54
+ OPENBLAS_VER = '0.3.23'
55
+ OPENBLAS_KEY = '115634b39007de71eb7e75cf7591dfb2'
56
+ OPENBLAS_URI = "https://github.com/xianyi/OpenBLAS/archive/v#{OPENBLAS_VER}.tar.gz"
57
+ OPENBLAS_TGZ = "#{VENDOR_DIR}/tmp/openblas.tgz"
58
+
59
+ unless File.exist?("#{VENDOR_DIR}/installed_#{OPENBLAS_VER}")
60
+ URI.parse(OPENBLAS_URI).open { |f| File.binwrite(OPENBLAS_TGZ, f.read) }
61
+ abort('MD5 digest of downloaded OpenBLAS does not match.') if Digest::MD5.file(OPENBLAS_TGZ).to_s != OPENBLAS_KEY
62
+
63
+ Gem::Package::TarReader.new(Zlib::GzipReader.open(OPENBLAS_TGZ)) do |tar|
64
+ tar.each do |entry|
65
+ next unless entry.file?
66
+
67
+ filename = "#{VENDOR_DIR}/tmp/#{entry.full_name}"
68
+ next if filename == File.dirname(filename)
69
+
70
+ FileUtils.mkdir_p("#{VENDOR_DIR}/tmp/#{File.dirname(entry.full_name)}")
71
+ File.binwrite(filename, entry.read)
72
+ File.chmod(entry.header.mode, filename)
73
+ end
74
+ end
75
+
76
+ Dir.chdir("#{VENDOR_DIR}/tmp/OpenBLAS-#{OPENBLAS_VER}") do
77
+ mkstdout, _mkstderr, mkstatus = Open3.capture3("make -j#{Etc.nprocessors}")
78
+ File.open("#{VENDOR_DIR}/tmp/openblas.log", 'w') { |f| f.puts(mkstdout) }
79
+ abort("Failed to build OpenBLAS. Check the openblas.log file for more details: #{VENDOR_DIR}/tmp/openblas.log") unless mkstatus.success?
80
+
81
+ insstdout, _insstderr, insstatus = Open3.capture3("make install PREFIX=#{VENDOR_DIR}")
82
+ File.open("#{VENDOR_DIR}/tmp/openblas.log", 'a') { |f| f.puts(insstdout) }
83
+ abort("Failed to install OpenBLAS. Check the openblas.log file for more details: #{VENDOR_DIR}/tmp/openblas.log") unless insstatus.success?
84
+
85
+ FileUtils.touch("#{VENDOR_DIR}/installed_#{OPENBLAS_VER}")
86
+ end
43
87
  end
44
88
 
45
- abort 'liblapacke is not found' if !have_func('LAPACKE_dsyevr') && !have_library('lapacke')
46
- abort 'cblas.h is not found' unless have_header('cblas.h')
47
- abort 'lapacke.h is not found' unless have_header('lapacke.h')
89
+ abort('libopenblas is not found.') unless find_library('openblas', nil, "#{VENDOR_DIR}/lib")
90
+ abort('openblas_config.h is not found.') unless find_header('openblas_config.h', nil, "#{VENDOR_DIR}/include")
91
+ abort('cblas.h is not found.') unless find_header('cblas.h', nil, "#{VENDOR_DIR}/include")
92
+ abort('lapacke.h is not found.') unless find_header('lapacke.h', nil, "#{VENDOR_DIR}/include")
48
93
  end
49
94
 
50
- abort 'libstdc++ is not found.' unless have_library('stdc++')
95
+ $CFLAGS << ' -DNUMO_TINY_LINALG_USE_OPENBLAS'
51
96
 
52
- $CXXFLAGS << ' -std=c++11'
97
+ if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
98
+ try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
99
+ $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
100
+ end
53
101
 
54
102
  create_makefile('numo/tiny_linalg/tiny_linalg')
@@ -0,0 +1,118 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGeQrf {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* tau) {
6
+ return LAPACKE_dgeqrf(matrix_layout, m, n, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SGeQrf {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, float* tau) {
13
+ return LAPACKE_sgeqrf(matrix_layout, m, n, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ struct ZGeQrf {
18
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, lapack_complex_double* tau) {
20
+ return LAPACKE_zgeqrf(matrix_layout, m, n, a, lda, tau);
21
+ }
22
+ };
23
+
24
+ struct CGeQrf {
25
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, lapack_complex_float* tau) {
27
+ return LAPACKE_cgeqrf(matrix_layout, m, n, a, lda, tau);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename DType, typename FncType>
32
+ class GeQrf {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_geqrf), -1);
36
+ }
37
+
38
+ private:
39
+ struct geqrf_opt {
40
+ int matrix_layout;
41
+ };
42
+
43
+ static void iter_geqrf(na_loop_t* const lp) {
44
+ DType* a = (DType*)NDL_PTR(lp, 0);
45
+ DType* tau = (DType*)NDL_PTR(lp, 1);
46
+ int* info = (int*)NDL_PTR(lp, 2);
47
+ geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
48
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
49
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
50
+ const lapack_int lda = n;
51
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, tau);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_geqrf(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+
58
+ VALUE a_vnary = Qnil;
59
+ VALUE kw_args = Qnil;
60
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
61
+ ID kw_table[1] = { rb_intern("order") };
62
+ VALUE kw_values[1] = { Qundef };
63
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
64
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
65
+
66
+ if (CLASS_OF(a_vnary) != nary_dtype) {
67
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
68
+ }
69
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
70
+ a_vnary = nary_dup(a_vnary);
71
+ }
72
+
73
+ narray_t* a_nary = NULL;
74
+ GetNArray(a_vnary, a_nary);
75
+ const int n_dims = NA_NDIM(a_nary);
76
+ if (n_dims != 2) {
77
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
78
+ return Qnil;
79
+ }
80
+
81
+ size_t m = NA_SHAPE(a_nary)[0];
82
+ size_t n = NA_SHAPE(a_nary)[1];
83
+ size_t shape[1] = { m < n ? m : n };
84
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
85
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
86
+ ndfunc_t ndf = { iter_geqrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
87
+ geqrf_opt opt = { matrix_layout };
88
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
89
+
90
+ VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
91
+
92
+ RB_GC_GUARD(a_vnary);
93
+
94
+ return ret;
95
+ }
96
+
97
+ static int get_matrix_layout(VALUE val) {
98
+ const char* option_str = StringValueCStr(val);
99
+
100
+ if (std::strlen(option_str) > 0) {
101
+ switch (option_str[0]) {
102
+ case 'r':
103
+ case 'R':
104
+ break;
105
+ case 'c':
106
+ case 'C':
107
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
108
+ break;
109
+ }
110
+ }
111
+
112
+ RB_GC_GUARD(val);
113
+
114
+ return LAPACK_ROW_MAJOR;
115
+ }
116
+ };
117
+
118
+ } // namespace TinyLinalg
@@ -0,0 +1,148 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGESV {
4
+ lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
5
+ double* a, lapack_int lda, lapack_int* ipiv,
6
+ double* b, lapack_int ldb) {
7
+ return LAPACKE_dgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
8
+ }
9
+ };
10
+
11
+ struct SGESV {
12
+ lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
13
+ float* a, lapack_int lda, lapack_int* ipiv,
14
+ float* b, lapack_int ldb) {
15
+ return LAPACKE_sgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
16
+ }
17
+ };
18
+
19
+ struct ZGESV {
20
+ lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
21
+ lapack_complex_double* a, lapack_int lda, lapack_int* ipiv,
22
+ lapack_complex_double* b, lapack_int ldb) {
23
+ return LAPACKE_zgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
24
+ }
25
+ };
26
+
27
+ struct CGESV {
28
+ lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
29
+ lapack_complex_float* a, lapack_int lda, lapack_int* ipiv,
30
+ lapack_complex_float* b, lapack_int ldb) {
31
+ return LAPACKE_cgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
32
+ }
33
+ };
34
+
35
+ template <int nary_dtype_id, typename DType, typename FncType>
36
+ class GESV {
37
+ public:
38
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
39
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_gesv), -1);
40
+ }
41
+
42
+ private:
43
+ struct gesv_opt {
44
+ int matrix_layout;
45
+ };
46
+
47
+ static void iter_gesv(na_loop_t* const lp) {
48
+ DType* a = (DType*)NDL_PTR(lp, 0);
49
+ DType* b = (DType*)NDL_PTR(lp, 1);
50
+ int* ipiv = (int*)NDL_PTR(lp, 2);
51
+ int* info = (int*)NDL_PTR(lp, 3);
52
+ gesv_opt* opt = (gesv_opt*)(lp->opt_ptr);
53
+ const lapack_int n = NDL_SHAPE(lp, 0)[0];
54
+ const lapack_int nhrs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
55
+ const lapack_int lda = n;
56
+ const lapack_int ldb = nhrs;
57
+ const lapack_int i = FncType().call(opt->matrix_layout, n, nhrs, a, lda, ipiv, b, ldb);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_gesv(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+ VALUE a_vnary = Qnil;
64
+ VALUE b_vnary = Qnil;
65
+ VALUE kw_args = Qnil;
66
+
67
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
68
+
69
+ ID kw_table[1] = { rb_intern("order") };
70
+ VALUE kw_values[1] = { Qundef };
71
+
72
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
73
+
74
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
75
+
76
+ if (CLASS_OF(a_vnary) != nary_dtype) {
77
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
78
+ }
79
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
80
+ a_vnary = nary_dup(a_vnary);
81
+ }
82
+ if (CLASS_OF(b_vnary) != nary_dtype) {
83
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
84
+ }
85
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
86
+ b_vnary = nary_dup(b_vnary);
87
+ }
88
+
89
+ narray_t* a_nary = NULL;
90
+ narray_t* b_nary = NULL;
91
+ GetNArray(a_vnary, a_nary);
92
+ GetNArray(b_vnary, b_nary);
93
+ const int a_n_dims = NA_NDIM(a_nary);
94
+ const int b_n_dims = NA_NDIM(b_nary);
95
+ if (a_n_dims != 2) {
96
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
97
+ return Qnil;
98
+ }
99
+ if (b_n_dims != 1 && b_n_dims != 2) {
100
+ rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
101
+ return Qnil;
102
+ }
103
+
104
+ lapack_int n = NA_SHAPE(a_nary)[0];
105
+ lapack_int nb = b_n_dims == 1 ? NA_SHAPE(b_nary)[0] : NA_SHAPE(b_nary)[0];
106
+ if (n != nb) {
107
+ rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", n, nb);
108
+ }
109
+
110
+ lapack_int nhrs = b_n_dims == 1 ? 1 : NA_SHAPE(b_nary)[1];
111
+ size_t shape[2] = { static_cast<size_t>(n), static_cast<size_t>(nhrs) };
112
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, b_n_dims } };
113
+ ndfunc_arg_out_t aout[2] = { { numo_cInt32, 1, shape }, { numo_cInt32, 0 } };
114
+
115
+ ndfunc_t ndf = { iter_gesv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
116
+ gesv_opt opt = { matrix_layout };
117
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
118
+
119
+ VALUE ret = rb_ary_concat(rb_assoc_new(a_vnary, b_vnary), res);
120
+
121
+ RB_GC_GUARD(a_vnary);
122
+ RB_GC_GUARD(b_vnary);
123
+
124
+ return ret;
125
+ }
126
+
127
+ static int get_matrix_layout(VALUE val) {
128
+ const char* option_str = StringValueCStr(val);
129
+
130
+ if (std::strlen(option_str) > 0) {
131
+ switch (option_str[0]) {
132
+ case 'r':
133
+ case 'R':
134
+ break;
135
+ case 'c':
136
+ case 'C':
137
+ rb_warn("Numo::TinyLinalg::Lapack.gesv does not support column major.");
138
+ break;
139
+ }
140
+ }
141
+
142
+ RB_GC_GUARD(val);
143
+
144
+ return LAPACK_ROW_MAJOR;
145
+ }
146
+ };
147
+
148
+ } // namespace TinyLinalg
@@ -0,0 +1,118 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGETRF {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, lapack_int* ipiv) {
6
+ return LAPACKE_dgetrf(matrix_layout, m, n, a, lda, ipiv);
7
+ }
8
+ };
9
+
10
+ struct SGETRF {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, lapack_int* ipiv) {
13
+ return LAPACKE_sgetrf(matrix_layout, m, n, a, lda, ipiv);
14
+ }
15
+ };
16
+
17
+ struct ZGETRF {
18
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, lapack_int* ipiv) {
20
+ return LAPACKE_zgetrf(matrix_layout, m, n, a, lda, ipiv);
21
+ }
22
+ };
23
+
24
+ struct CGETRF {
25
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, lapack_int* ipiv) {
27
+ return LAPACKE_cgetrf(matrix_layout, m, n, a, lda, ipiv);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename DType, typename FncType>
32
+ class GETRF {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getrf), -1);
36
+ }
37
+
38
+ private:
39
+ struct getrf_opt {
40
+ int matrix_layout;
41
+ };
42
+
43
+ static void iter_getrf(na_loop_t* const lp) {
44
+ DType* a = (DType*)NDL_PTR(lp, 0);
45
+ int* ipiv = (int*)NDL_PTR(lp, 1);
46
+ int* info = (int*)NDL_PTR(lp, 2);
47
+ getrf_opt* opt = (getrf_opt*)(lp->opt_ptr);
48
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
49
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
50
+ const lapack_int lda = n;
51
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, ipiv);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_getrf(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+
58
+ VALUE a_vnary = Qnil;
59
+ VALUE kw_args = Qnil;
60
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
61
+ ID kw_table[1] = { rb_intern("order") };
62
+ VALUE kw_values[1] = { Qundef };
63
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
64
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
65
+
66
+ if (CLASS_OF(a_vnary) != nary_dtype) {
67
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
68
+ }
69
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
70
+ a_vnary = nary_dup(a_vnary);
71
+ }
72
+
73
+ narray_t* a_nary = NULL;
74
+ GetNArray(a_vnary, a_nary);
75
+ const int n_dims = NA_NDIM(a_nary);
76
+ if (n_dims != 2) {
77
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
78
+ return Qnil;
79
+ }
80
+
81
+ size_t m = NA_SHAPE(a_nary)[0];
82
+ size_t n = NA_SHAPE(a_nary)[1];
83
+ size_t shape[1] = { m < n ? m : n };
84
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
85
+ ndfunc_arg_out_t aout[2] = { { numo_cInt32, 1, shape }, { numo_cInt32, 0 } };
86
+ ndfunc_t ndf = { iter_getrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
87
+ getrf_opt opt = { matrix_layout };
88
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
89
+
90
+ VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
91
+
92
+ RB_GC_GUARD(a_vnary);
93
+
94
+ return ret;
95
+ }
96
+
97
+ static int get_matrix_layout(VALUE val) {
98
+ const char* option_str = StringValueCStr(val);
99
+
100
+ if (std::strlen(option_str) > 0) {
101
+ switch (option_str[0]) {
102
+ case 'r':
103
+ case 'R':
104
+ break;
105
+ case 'c':
106
+ case 'C':
107
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
108
+ break;
109
+ }
110
+ }
111
+
112
+ RB_GC_GUARD(val);
113
+
114
+ return LAPACK_ROW_MAJOR;
115
+ }
116
+ };
117
+
118
+ } // namespace TinyLinalg
@@ -0,0 +1,127 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGETRI {
4
+ lapack_int call(int matrix_layout, lapack_int n, double* a, lapack_int lda, const lapack_int* ipiv) {
5
+ return LAPACKE_dgetri(matrix_layout, n, a, lda, ipiv);
6
+ }
7
+ };
8
+
9
+ struct SGETRI {
10
+ lapack_int call(int matrix_layout, lapack_int n, float* a, lapack_int lda, const lapack_int* ipiv) {
11
+ return LAPACKE_sgetri(matrix_layout, n, a, lda, ipiv);
12
+ }
13
+ };
14
+
15
+ struct ZGETRI {
16
+ lapack_int call(int matrix_layout, lapack_int n, lapack_complex_double* a, lapack_int lda, const lapack_int* ipiv) {
17
+ return LAPACKE_zgetri(matrix_layout, n, a, lda, ipiv);
18
+ }
19
+ };
20
+
21
+ struct CGETRI {
22
+ lapack_int call(int matrix_layout, lapack_int n, lapack_complex_float* a, lapack_int lda, const lapack_int* ipiv) {
23
+ return LAPACKE_cgetri(matrix_layout, n, a, lda, ipiv);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, typename DType, typename FncType>
28
+ class GETRI {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getri), -1);
32
+ }
33
+
34
+ private:
35
+ struct getri_opt {
36
+ int matrix_layout;
37
+ };
38
+
39
+ static void iter_getri(na_loop_t* const lp) {
40
+ DType* a = (DType*)NDL_PTR(lp, 0);
41
+ lapack_int* ipiv = (lapack_int*)NDL_PTR(lp, 1);
42
+ int* info = (int*)NDL_PTR(lp, 2);
43
+ getri_opt* opt = (getri_opt*)(lp->opt_ptr);
44
+ const lapack_int n = NDL_SHAPE(lp, 0)[0];
45
+ const lapack_int lda = n;
46
+ const lapack_int i = FncType().call(opt->matrix_layout, n, a, lda, ipiv);
47
+ *info = static_cast<int>(i);
48
+ }
49
+
50
+ static VALUE tiny_linalg_getri(int argc, VALUE* argv, VALUE self) {
51
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
52
+
53
+ VALUE a_vnary = Qnil;
54
+ VALUE ipiv_vnary = Qnil;
55
+ VALUE kw_args = Qnil;
56
+ rb_scan_args(argc, argv, "2:", &a_vnary, &ipiv_vnary, &kw_args);
57
+ ID kw_table[1] = { rb_intern("order") };
58
+ VALUE kw_values[1] = { Qundef };
59
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
60
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(ipiv_vnary) != numo_cInt32) {
69
+ ipiv_vnary = rb_funcall(numo_cInt32, rb_intern("cast"), 1, ipiv_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(ipiv_vnary))) {
72
+ ipiv_vnary = nary_dup(ipiv_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = NULL;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* ipiv_nary = NULL;
86
+ GetNArray(ipiv_vnary, ipiv_nary);
87
+ if (NA_NDIM(ipiv_nary) != 1) {
88
+ rb_raise(rb_eArgError, "input array ipiv must be 1-dimensional");
89
+ return Qnil;
90
+ }
91
+
92
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { numo_cInt32, 1 } };
93
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
94
+ ndfunc_t ndf = { iter_getri, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
95
+ getri_opt opt = { matrix_layout };
96
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, ipiv_vnary);
97
+
98
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
99
+
100
+ RB_GC_GUARD(a_vnary);
101
+ RB_GC_GUARD(ipiv_vnary);
102
+
103
+ return ret;
104
+ }
105
+
106
+ static int get_matrix_layout(VALUE val) {
107
+ const char* option_str = StringValueCStr(val);
108
+
109
+ if (std::strlen(option_str) > 0) {
110
+ switch (option_str[0]) {
111
+ case 'r':
112
+ case 'R':
113
+ break;
114
+ case 'c':
115
+ case 'C':
116
+ rb_warn("Numo::TinyLinalg::Lapack.getri does not support column major.");
117
+ break;
118
+ }
119
+ }
120
+
121
+ RB_GC_GUARD(val);
122
+
123
+ return LAPACK_ROW_MAJOR;
124
+ }
125
+ };
126
+
127
+ } // namespace TinyLinalg