numo-tiny_linalg 0.0.1 → 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 15ddc75f758bbff2ef6887db5d71c13dd517f7eeb0c4d3ab181c5c2db8b8f995
4
- data.tar.gz: 20350f1d084a31e51e05c317dddc23f6285dd27430176a753edd3eea935331be
3
+ metadata.gz: bbebba3b506ab283688f9d0935739e14c09106918a98571b9addb64b6689cc75
4
+ data.tar.gz: 85eaa28da383e21a4503407667baacb95fb5382f20f54804f02e4c2d49c15cd1
5
5
  SHA512:
6
- metadata.gz: 457d487b20bfffb3eade0fc80120f5de7acfa3f0678550f3b844abac139e7b6db469fdcf6b686337755ea2b49b8bad54ea66381368d2e4e554fdb32df6e2d87d
7
- data.tar.gz: 217765d951cf0d790e8620e1a4e4d28a884a5955bdca32cc4386fd4951c8ae2a133be029c24b20e987ce6fb2c7adfa101fe327c55cce327c2e89f4ff4d54a874
6
+ metadata.gz: 262e38a4bbbbf6141cca723830f93c80cc9d33ac8a34753e3af1b9b5b239f4576140f18ad43999e8d0568438380f8bb2f6cf3f78c1f5d35402d57fcf55e4e253
7
+ data.tar.gz: ba6e767f8728022dff634e1a3824a166c5b3f4d17057b65f2e279a2ccda293c2432aea07783708fdd9f1d810e7a9bc9428ae2b9fe559b0bb519d8b1cb828b01e
data/CHANGELOG.md CHANGED
@@ -1,5 +1,20 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
4
+ - Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
5
+ - Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
6
+ - Add det module function to TinyLinalg.
7
+ - Add pinv module function to TinyLinalg.
8
+ - Add qr module function to TinyLinalg.
9
+
10
+ ## [[0.0.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.1...v0.0.2)] - 2023-07-26
11
+ - Add automatic build of OpenBLAS if it is not found.
12
+ - Add dgesv, sgesv, zgesv, and cgesv module functions to TinyLinalg::Lapack.
13
+ - Add dgetrf, sgetrf, zgetrf, and cgetrf module functions to TinyLinalg::Lapack.
14
+ - Add dgetri, sgetri, zgetri, and cgetri module functions to TinyLinalg::Lapack.
15
+ - Add solve module function to TinyLinalg.
16
+ - Add inv module function to TinyLinalg.
17
+
3
18
  ## [0.0.1] - 2023-07-14
4
19
 
5
20
  - Initial release
data/README.md CHANGED
@@ -1,7 +1,9 @@
1
1
  # Numo::TinyLinalg
2
2
 
3
+ [![Gem Version](https://badge.fury.io/rb/numo-tiny_linalg.svg)](https://badge.fury.io/rb/numo-tiny_linalg)
3
4
  [![Build Status](https://github.com/yoshoku/numo-tiny_linalg/actions/workflows/main.yml/badge.svg)](https://github.com/yoshoku/numo-tiny_linalg/actions/workflows/main.yml)
4
5
  [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/numo-tiny_linalg/blob/main/LICENSE.txt)
6
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/numo-tiny_linalg/doc/)
5
7
 
6
8
  Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
7
9
 
@@ -2,6 +2,12 @@
2
2
 
3
3
  require 'mkmf'
4
4
  require 'numo/narray'
5
+ require 'open-uri'
6
+ require 'etc'
7
+ require 'fileutils'
8
+ require 'open3'
9
+ require 'digest/md5'
10
+ require 'rubygems/package'
5
11
 
6
12
  $LOAD_PATH.each do |lp|
7
13
  if File.exist?(File.join(lp, 'numo/numo/narray.h'))
@@ -22,33 +28,75 @@ if RUBY_PLATFORM.match?(/mswin|cygwin|mingw/)
22
28
  abort 'libnarray.a is not found' unless have_library('narray', 'nary_new')
23
29
  end
24
30
 
25
- if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
26
- try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
27
- $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
28
- end
31
+ abort 'libstdc++ is not found.' unless have_library('stdc++')
32
+ $CXXFLAGS << ' -std=c++11'
29
33
 
30
- use_accelerate = false
31
34
  # NOTE: Accelerate framework does not support LAPACKE.
35
+ # use_accelerate = false
32
36
  # if RUBY_PLATFORM.include?('darwin') && have_framework('Accelerate')
33
37
  # $CFLAGS << ' -DTINYLINALG_USE_ACCELERATE'
34
38
  # use_accelerate = true
35
39
  # end
36
40
 
37
- unless use_accelerate
38
- if have_library('openblas')
39
- $CFLAGS << ' -DTINYLINALG_USE_OPENBLAS'
40
- else
41
- abort 'libblas is not found' unless have_library('blas')
42
- $CFLAGS << ' -DTINYLINALG_USE_BLAS'
41
+ build_openblas = false
42
+ unless find_library('openblas', 'LAPACKE_dsyevr')
43
+ build_openblas = true unless have_library('openblas')
44
+ build_openblas = true unless have_library('lapacke')
45
+ end
46
+ build_openblas = true unless have_header('cblas.h')
47
+ build_openblas = true unless have_header('lapacke.h')
48
+ build_openblas = true unless have_header('openblas_config.h')
49
+
50
+ if build_openblas
51
+ warn 'BLAS and LAPACKE APIs are not found. Downloading and Building OpenBLAS...'
52
+
53
+ VENDOR_DIR = File.expand_path("#{__dir__}/../../../vendor")
54
+ OPENBLAS_VER = '0.3.23'
55
+ OPENBLAS_KEY = '115634b39007de71eb7e75cf7591dfb2'
56
+ OPENBLAS_URI = "https://github.com/xianyi/OpenBLAS/archive/v#{OPENBLAS_VER}.tar.gz"
57
+ OPENBLAS_TGZ = "#{VENDOR_DIR}/tmp/openblas.tgz"
58
+
59
+ unless File.exist?("#{VENDOR_DIR}/installed_#{OPENBLAS_VER}")
60
+ URI.parse(OPENBLAS_URI).open { |f| File.binwrite(OPENBLAS_TGZ, f.read) }
61
+ abort('MD5 digest of downloaded OpenBLAS does not match.') if Digest::MD5.file(OPENBLAS_TGZ).to_s != OPENBLAS_KEY
62
+
63
+ Gem::Package::TarReader.new(Zlib::GzipReader.open(OPENBLAS_TGZ)) do |tar|
64
+ tar.each do |entry|
65
+ next unless entry.file?
66
+
67
+ filename = "#{VENDOR_DIR}/tmp/#{entry.full_name}"
68
+ next if filename == File.dirname(filename)
69
+
70
+ FileUtils.mkdir_p("#{VENDOR_DIR}/tmp/#{File.dirname(entry.full_name)}")
71
+ File.binwrite(filename, entry.read)
72
+ File.chmod(entry.header.mode, filename)
73
+ end
74
+ end
75
+
76
+ Dir.chdir("#{VENDOR_DIR}/tmp/OpenBLAS-#{OPENBLAS_VER}") do
77
+ mkstdout, _mkstderr, mkstatus = Open3.capture3("make -j#{Etc.nprocessors}")
78
+ File.open("#{VENDOR_DIR}/tmp/openblas.log", 'w') { |f| f.puts(mkstdout) }
79
+ abort("Failed to build OpenBLAS. Check the openblas.log file for more details: #{VENDOR_DIR}/tmp/openblas.log") unless mkstatus.success?
80
+
81
+ insstdout, _insstderr, insstatus = Open3.capture3("make install PREFIX=#{VENDOR_DIR}")
82
+ File.open("#{VENDOR_DIR}/tmp/openblas.log", 'a') { |f| f.puts(insstdout) }
83
+ abort("Failed to install OpenBLAS. Check the openblas.log file for more details: #{VENDOR_DIR}/tmp/openblas.log") unless insstatus.success?
84
+
85
+ FileUtils.touch("#{VENDOR_DIR}/installed_#{OPENBLAS_VER}")
86
+ end
43
87
  end
44
88
 
45
- abort 'liblapacke is not found' if !have_func('LAPACKE_dsyevr') && !have_library('lapacke')
46
- abort 'cblas.h is not found' unless have_header('cblas.h')
47
- abort 'lapacke.h is not found' unless have_header('lapacke.h')
89
+ abort('libopenblas is not found.') unless find_library('openblas', nil, "#{VENDOR_DIR}/lib")
90
+ abort('openblas_config.h is not found.') unless find_header('openblas_config.h', nil, "#{VENDOR_DIR}/include")
91
+ abort('cblas.h is not found.') unless find_header('cblas.h', nil, "#{VENDOR_DIR}/include")
92
+ abort('lapacke.h is not found.') unless find_header('lapacke.h', nil, "#{VENDOR_DIR}/include")
48
93
  end
49
94
 
50
- abort 'libstdc++ is not found.' unless have_library('stdc++')
95
+ $CFLAGS << ' -DNUMO_TINY_LINALG_USE_OPENBLAS'
51
96
 
52
- $CXXFLAGS << ' -std=c++11'
97
+ if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
98
+ try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
99
+ $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
100
+ end
53
101
 
54
102
  create_makefile('numo/tiny_linalg/tiny_linalg')
@@ -0,0 +1,118 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGeQrf {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* tau) {
6
+ return LAPACKE_dgeqrf(matrix_layout, m, n, a, lda, tau);
7
+ }
8
+ };
9
+
10
+ struct SGeQrf {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, float* tau) {
13
+ return LAPACKE_sgeqrf(matrix_layout, m, n, a, lda, tau);
14
+ }
15
+ };
16
+
17
+ struct ZGeQrf {
18
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, lapack_complex_double* tau) {
20
+ return LAPACKE_zgeqrf(matrix_layout, m, n, a, lda, tau);
21
+ }
22
+ };
23
+
24
+ struct CGeQrf {
25
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, lapack_complex_float* tau) {
27
+ return LAPACKE_cgeqrf(matrix_layout, m, n, a, lda, tau);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename DType, typename FncType>
32
+ class GeQrf {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_geqrf), -1);
36
+ }
37
+
38
+ private:
39
+ struct geqrf_opt {
40
+ int matrix_layout;
41
+ };
42
+
43
+ static void iter_geqrf(na_loop_t* const lp) {
44
+ DType* a = (DType*)NDL_PTR(lp, 0);
45
+ DType* tau = (DType*)NDL_PTR(lp, 1);
46
+ int* info = (int*)NDL_PTR(lp, 2);
47
+ geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
48
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
49
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
50
+ const lapack_int lda = n;
51
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, tau);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_geqrf(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+
58
+ VALUE a_vnary = Qnil;
59
+ VALUE kw_args = Qnil;
60
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
61
+ ID kw_table[1] = { rb_intern("order") };
62
+ VALUE kw_values[1] = { Qundef };
63
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
64
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
65
+
66
+ if (CLASS_OF(a_vnary) != nary_dtype) {
67
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
68
+ }
69
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
70
+ a_vnary = nary_dup(a_vnary);
71
+ }
72
+
73
+ narray_t* a_nary = NULL;
74
+ GetNArray(a_vnary, a_nary);
75
+ const int n_dims = NA_NDIM(a_nary);
76
+ if (n_dims != 2) {
77
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
78
+ return Qnil;
79
+ }
80
+
81
+ size_t m = NA_SHAPE(a_nary)[0];
82
+ size_t n = NA_SHAPE(a_nary)[1];
83
+ size_t shape[1] = { m < n ? m : n };
84
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
85
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
86
+ ndfunc_t ndf = { iter_geqrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
87
+ geqrf_opt opt = { matrix_layout };
88
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
89
+
90
+ VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
91
+
92
+ RB_GC_GUARD(a_vnary);
93
+
94
+ return ret;
95
+ }
96
+
97
+ static int get_matrix_layout(VALUE val) {
98
+ const char* option_str = StringValueCStr(val);
99
+
100
+ if (std::strlen(option_str) > 0) {
101
+ switch (option_str[0]) {
102
+ case 'r':
103
+ case 'R':
104
+ break;
105
+ case 'c':
106
+ case 'C':
107
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
108
+ break;
109
+ }
110
+ }
111
+
112
+ RB_GC_GUARD(val);
113
+
114
+ return LAPACK_ROW_MAJOR;
115
+ }
116
+ };
117
+
118
+ } // namespace TinyLinalg
@@ -0,0 +1,148 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGESV {
4
+ lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
5
+ double* a, lapack_int lda, lapack_int* ipiv,
6
+ double* b, lapack_int ldb) {
7
+ return LAPACKE_dgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
8
+ }
9
+ };
10
+
11
+ struct SGESV {
12
+ lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
13
+ float* a, lapack_int lda, lapack_int* ipiv,
14
+ float* b, lapack_int ldb) {
15
+ return LAPACKE_sgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
16
+ }
17
+ };
18
+
19
+ struct ZGESV {
20
+ lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
21
+ lapack_complex_double* a, lapack_int lda, lapack_int* ipiv,
22
+ lapack_complex_double* b, lapack_int ldb) {
23
+ return LAPACKE_zgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
24
+ }
25
+ };
26
+
27
+ struct CGESV {
28
+ lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
29
+ lapack_complex_float* a, lapack_int lda, lapack_int* ipiv,
30
+ lapack_complex_float* b, lapack_int ldb) {
31
+ return LAPACKE_cgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
32
+ }
33
+ };
34
+
35
+ template <int nary_dtype_id, typename DType, typename FncType>
36
+ class GESV {
37
+ public:
38
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
39
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_gesv), -1);
40
+ }
41
+
42
+ private:
43
+ struct gesv_opt {
44
+ int matrix_layout;
45
+ };
46
+
47
+ static void iter_gesv(na_loop_t* const lp) {
48
+ DType* a = (DType*)NDL_PTR(lp, 0);
49
+ DType* b = (DType*)NDL_PTR(lp, 1);
50
+ int* ipiv = (int*)NDL_PTR(lp, 2);
51
+ int* info = (int*)NDL_PTR(lp, 3);
52
+ gesv_opt* opt = (gesv_opt*)(lp->opt_ptr);
53
+ const lapack_int n = NDL_SHAPE(lp, 0)[0];
54
+ const lapack_int nhrs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
55
+ const lapack_int lda = n;
56
+ const lapack_int ldb = nhrs;
57
+ const lapack_int i = FncType().call(opt->matrix_layout, n, nhrs, a, lda, ipiv, b, ldb);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_gesv(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+ VALUE a_vnary = Qnil;
64
+ VALUE b_vnary = Qnil;
65
+ VALUE kw_args = Qnil;
66
+
67
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
68
+
69
+ ID kw_table[1] = { rb_intern("order") };
70
+ VALUE kw_values[1] = { Qundef };
71
+
72
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
73
+
74
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
75
+
76
+ if (CLASS_OF(a_vnary) != nary_dtype) {
77
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
78
+ }
79
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
80
+ a_vnary = nary_dup(a_vnary);
81
+ }
82
+ if (CLASS_OF(b_vnary) != nary_dtype) {
83
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
84
+ }
85
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
86
+ b_vnary = nary_dup(b_vnary);
87
+ }
88
+
89
+ narray_t* a_nary = NULL;
90
+ narray_t* b_nary = NULL;
91
+ GetNArray(a_vnary, a_nary);
92
+ GetNArray(b_vnary, b_nary);
93
+ const int a_n_dims = NA_NDIM(a_nary);
94
+ const int b_n_dims = NA_NDIM(b_nary);
95
+ if (a_n_dims != 2) {
96
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
97
+ return Qnil;
98
+ }
99
+ if (b_n_dims != 1 && b_n_dims != 2) {
100
+ rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
101
+ return Qnil;
102
+ }
103
+
104
+ lapack_int n = NA_SHAPE(a_nary)[0];
105
+ lapack_int nb = b_n_dims == 1 ? NA_SHAPE(b_nary)[0] : NA_SHAPE(b_nary)[0];
106
+ if (n != nb) {
107
+ rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", n, nb);
108
+ }
109
+
110
+ lapack_int nhrs = b_n_dims == 1 ? 1 : NA_SHAPE(b_nary)[1];
111
+ size_t shape[2] = { static_cast<size_t>(n), static_cast<size_t>(nhrs) };
112
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, b_n_dims } };
113
+ ndfunc_arg_out_t aout[2] = { { numo_cInt32, 1, shape }, { numo_cInt32, 0 } };
114
+
115
+ ndfunc_t ndf = { iter_gesv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
116
+ gesv_opt opt = { matrix_layout };
117
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
118
+
119
+ VALUE ret = rb_ary_concat(rb_assoc_new(a_vnary, b_vnary), res);
120
+
121
+ RB_GC_GUARD(a_vnary);
122
+ RB_GC_GUARD(b_vnary);
123
+
124
+ return ret;
125
+ }
126
+
127
+ static int get_matrix_layout(VALUE val) {
128
+ const char* option_str = StringValueCStr(val);
129
+
130
+ if (std::strlen(option_str) > 0) {
131
+ switch (option_str[0]) {
132
+ case 'r':
133
+ case 'R':
134
+ break;
135
+ case 'c':
136
+ case 'C':
137
+ rb_warn("Numo::TinyLinalg::Lapack.gesv does not support column major.");
138
+ break;
139
+ }
140
+ }
141
+
142
+ RB_GC_GUARD(val);
143
+
144
+ return LAPACK_ROW_MAJOR;
145
+ }
146
+ };
147
+
148
+ } // namespace TinyLinalg
@@ -0,0 +1,118 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGETRF {
4
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, lapack_int* ipiv) {
6
+ return LAPACKE_dgetrf(matrix_layout, m, n, a, lda, ipiv);
7
+ }
8
+ };
9
+
10
+ struct SGETRF {
11
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, lapack_int* ipiv) {
13
+ return LAPACKE_sgetrf(matrix_layout, m, n, a, lda, ipiv);
14
+ }
15
+ };
16
+
17
+ struct ZGETRF {
18
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, lapack_int* ipiv) {
20
+ return LAPACKE_zgetrf(matrix_layout, m, n, a, lda, ipiv);
21
+ }
22
+ };
23
+
24
+ struct CGETRF {
25
+ lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, lapack_int* ipiv) {
27
+ return LAPACKE_cgetrf(matrix_layout, m, n, a, lda, ipiv);
28
+ }
29
+ };
30
+
31
+ template <int nary_dtype_id, typename DType, typename FncType>
32
+ class GETRF {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
35
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getrf), -1);
36
+ }
37
+
38
+ private:
39
+ struct getrf_opt {
40
+ int matrix_layout;
41
+ };
42
+
43
+ static void iter_getrf(na_loop_t* const lp) {
44
+ DType* a = (DType*)NDL_PTR(lp, 0);
45
+ int* ipiv = (int*)NDL_PTR(lp, 1);
46
+ int* info = (int*)NDL_PTR(lp, 2);
47
+ getrf_opt* opt = (getrf_opt*)(lp->opt_ptr);
48
+ const lapack_int m = NDL_SHAPE(lp, 0)[0];
49
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
50
+ const lapack_int lda = n;
51
+ const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, ipiv);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_getrf(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+
58
+ VALUE a_vnary = Qnil;
59
+ VALUE kw_args = Qnil;
60
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
61
+ ID kw_table[1] = { rb_intern("order") };
62
+ VALUE kw_values[1] = { Qundef };
63
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
64
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
65
+
66
+ if (CLASS_OF(a_vnary) != nary_dtype) {
67
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
68
+ }
69
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
70
+ a_vnary = nary_dup(a_vnary);
71
+ }
72
+
73
+ narray_t* a_nary = NULL;
74
+ GetNArray(a_vnary, a_nary);
75
+ const int n_dims = NA_NDIM(a_nary);
76
+ if (n_dims != 2) {
77
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
78
+ return Qnil;
79
+ }
80
+
81
+ size_t m = NA_SHAPE(a_nary)[0];
82
+ size_t n = NA_SHAPE(a_nary)[1];
83
+ size_t shape[1] = { m < n ? m : n };
84
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
85
+ ndfunc_arg_out_t aout[2] = { { numo_cInt32, 1, shape }, { numo_cInt32, 0 } };
86
+ ndfunc_t ndf = { iter_getrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
87
+ getrf_opt opt = { matrix_layout };
88
+ VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
89
+
90
+ VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
91
+
92
+ RB_GC_GUARD(a_vnary);
93
+
94
+ return ret;
95
+ }
96
+
97
+ static int get_matrix_layout(VALUE val) {
98
+ const char* option_str = StringValueCStr(val);
99
+
100
+ if (std::strlen(option_str) > 0) {
101
+ switch (option_str[0]) {
102
+ case 'r':
103
+ case 'R':
104
+ break;
105
+ case 'c':
106
+ case 'C':
107
+ rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
108
+ break;
109
+ }
110
+ }
111
+
112
+ RB_GC_GUARD(val);
113
+
114
+ return LAPACK_ROW_MAJOR;
115
+ }
116
+ };
117
+
118
+ } // namespace TinyLinalg
@@ -0,0 +1,127 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGETRI {
4
+ lapack_int call(int matrix_layout, lapack_int n, double* a, lapack_int lda, const lapack_int* ipiv) {
5
+ return LAPACKE_dgetri(matrix_layout, n, a, lda, ipiv);
6
+ }
7
+ };
8
+
9
+ struct SGETRI {
10
+ lapack_int call(int matrix_layout, lapack_int n, float* a, lapack_int lda, const lapack_int* ipiv) {
11
+ return LAPACKE_sgetri(matrix_layout, n, a, lda, ipiv);
12
+ }
13
+ };
14
+
15
+ struct ZGETRI {
16
+ lapack_int call(int matrix_layout, lapack_int n, lapack_complex_double* a, lapack_int lda, const lapack_int* ipiv) {
17
+ return LAPACKE_zgetri(matrix_layout, n, a, lda, ipiv);
18
+ }
19
+ };
20
+
21
+ struct CGETRI {
22
+ lapack_int call(int matrix_layout, lapack_int n, lapack_complex_float* a, lapack_int lda, const lapack_int* ipiv) {
23
+ return LAPACKE_cgetri(matrix_layout, n, a, lda, ipiv);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, typename DType, typename FncType>
28
+ class GETRI {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getri), -1);
32
+ }
33
+
34
+ private:
35
+ struct getri_opt {
36
+ int matrix_layout;
37
+ };
38
+
39
+ static void iter_getri(na_loop_t* const lp) {
40
+ DType* a = (DType*)NDL_PTR(lp, 0);
41
+ lapack_int* ipiv = (lapack_int*)NDL_PTR(lp, 1);
42
+ int* info = (int*)NDL_PTR(lp, 2);
43
+ getri_opt* opt = (getri_opt*)(lp->opt_ptr);
44
+ const lapack_int n = NDL_SHAPE(lp, 0)[0];
45
+ const lapack_int lda = n;
46
+ const lapack_int i = FncType().call(opt->matrix_layout, n, a, lda, ipiv);
47
+ *info = static_cast<int>(i);
48
+ }
49
+
50
+ static VALUE tiny_linalg_getri(int argc, VALUE* argv, VALUE self) {
51
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
52
+
53
+ VALUE a_vnary = Qnil;
54
+ VALUE ipiv_vnary = Qnil;
55
+ VALUE kw_args = Qnil;
56
+ rb_scan_args(argc, argv, "2:", &a_vnary, &ipiv_vnary, &kw_args);
57
+ ID kw_table[1] = { rb_intern("order") };
58
+ VALUE kw_values[1] = { Qundef };
59
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
60
+ const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(ipiv_vnary) != numo_cInt32) {
69
+ ipiv_vnary = rb_funcall(numo_cInt32, rb_intern("cast"), 1, ipiv_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(ipiv_vnary))) {
72
+ ipiv_vnary = nary_dup(ipiv_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = NULL;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* ipiv_nary = NULL;
86
+ GetNArray(ipiv_vnary, ipiv_nary);
87
+ if (NA_NDIM(ipiv_nary) != 1) {
88
+ rb_raise(rb_eArgError, "input array ipiv must be 1-dimensional");
89
+ return Qnil;
90
+ }
91
+
92
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { numo_cInt32, 1 } };
93
+ ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
94
+ ndfunc_t ndf = { iter_getri, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
95
+ getri_opt opt = { matrix_layout };
96
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, ipiv_vnary);
97
+
98
+ VALUE ret = rb_ary_new3(2, a_vnary, res);
99
+
100
+ RB_GC_GUARD(a_vnary);
101
+ RB_GC_GUARD(ipiv_vnary);
102
+
103
+ return ret;
104
+ }
105
+
106
+ static int get_matrix_layout(VALUE val) {
107
+ const char* option_str = StringValueCStr(val);
108
+
109
+ if (std::strlen(option_str) > 0) {
110
+ switch (option_str[0]) {
111
+ case 'r':
112
+ case 'R':
113
+ break;
114
+ case 'c':
115
+ case 'C':
116
+ rb_warn("Numo::TinyLinalg::Lapack.getri does not support column major.");
117
+ break;
118
+ }
119
+ }
120
+
121
+ RB_GC_GUARD(val);
122
+
123
+ return LAPACK_ROW_MAJOR;
124
+ }
125
+ };
126
+
127
+ } // namespace TinyLinalg