numo-tiny_linalg 0.0.1 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -0
- data/README.md +2 -0
- data/ext/numo/tiny_linalg/extconf.rb +64 -16
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +118 -0
- data/ext/numo/tiny_linalg/lapack/gesv.hpp +148 -0
- data/ext/numo/tiny_linalg/lapack/getrf.hpp +118 -0
- data/ext/numo/tiny_linalg/lapack/getri.hpp +127 -0
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +115 -0
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +115 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +33 -7
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +131 -1
- data/vendor/tmp/.gitkeep +0 -0
- metadata +17 -17
- data/.clang-format +0 -149
- data/.husky/commit-msg +0 -4
- data/.rubocop.yml +0 -47
- data/Gemfile +0 -15
- data/Rakefile +0 -30
- data/commitlint.config.js +0 -1
- data/numo-tiny_linalg.gemspec +0 -42
- data/package.json +0 -15
- /data/ext/numo/tiny_linalg/{dot.hpp → blas/dot.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{dot_sub.hpp → blas/dot_sub.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{gemm.hpp → blas/gemm.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{gemv.hpp → blas/gemv.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{nrm2.hpp → blas/nrm2.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{gesdd.hpp → lapack/gesdd.hpp} +0 -0
- /data/ext/numo/tiny_linalg/{gesvd.hpp → lapack/gesvd.hpp} +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: bbebba3b506ab283688f9d0935739e14c09106918a98571b9addb64b6689cc75
|
4
|
+
data.tar.gz: 85eaa28da383e21a4503407667baacb95fb5382f20f54804f02e4c2d49c15cd1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 262e38a4bbbbf6141cca723830f93c80cc9d33ac8a34753e3af1b9b5b239f4576140f18ad43999e8d0568438380f8bb2f6cf3f78c1f5d35402d57fcf55e4e253
|
7
|
+
data.tar.gz: ba6e767f8728022dff634e1a3824a166c5b3f4d17057b65f2e279a2ccda293c2432aea07783708fdd9f1d810e7a9bc9428ae2b9fe559b0bb519d8b1cb828b01e
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,20 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.0.3](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.2...v0.0.3)] - 2023-08-02
|
4
|
+
- Add dgeqrf, sgeqrf, zgeqrf, and cgeqrf module functions to TinyLinalg::Lapack.
|
5
|
+
- Add dorgqr, sorgqr, zungqr, and cungqr module functions to TinyLinalg::Lapack.
|
6
|
+
- Add det module function to TinyLinalg.
|
7
|
+
- Add pinv module function to TinyLinalg.
|
8
|
+
- Add qr module function to TinyLinalg.
|
9
|
+
|
10
|
+
## [[0.0.2](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.1...v0.0.2)] - 2023-07-26
|
11
|
+
- Add automatic build of OpenBLAS if it is not found.
|
12
|
+
- Add dgesv, sgesv, zgesv, and cgesv module functions to TinyLinalg::Lapack.
|
13
|
+
- Add dgetrf, sgetrf, zgetrf, and cgetrf module functions to TinyLinalg::Lapack.
|
14
|
+
- Add dgetri, sgetri, zgetri, and cgetri module functions to TinyLinalg::Lapack.
|
15
|
+
- Add solve module function to TinyLinalg.
|
16
|
+
- Add inv module function to TinyLinalg.
|
17
|
+
|
3
18
|
## [0.0.1] - 2023-07-14
|
4
19
|
|
5
20
|
- Initial release
|
data/README.md
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
# Numo::TinyLinalg
|
2
2
|
|
3
|
+
[](https://badge.fury.io/rb/numo-tiny_linalg)
|
3
4
|
[](https://github.com/yoshoku/numo-tiny_linalg/actions/workflows/main.yml)
|
4
5
|
[](https://github.com/yoshoku/numo-tiny_linalg/blob/main/LICENSE.txt)
|
6
|
+
[](https://yoshoku.github.io/numo-tiny_linalg/doc/)
|
5
7
|
|
6
8
|
Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
7
9
|
|
@@ -2,6 +2,12 @@
|
|
2
2
|
|
3
3
|
require 'mkmf'
|
4
4
|
require 'numo/narray'
|
5
|
+
require 'open-uri'
|
6
|
+
require 'etc'
|
7
|
+
require 'fileutils'
|
8
|
+
require 'open3'
|
9
|
+
require 'digest/md5'
|
10
|
+
require 'rubygems/package'
|
5
11
|
|
6
12
|
$LOAD_PATH.each do |lp|
|
7
13
|
if File.exist?(File.join(lp, 'numo/numo/narray.h'))
|
@@ -22,33 +28,75 @@ if RUBY_PLATFORM.match?(/mswin|cygwin|mingw/)
|
|
22
28
|
abort 'libnarray.a is not found' unless have_library('narray', 'nary_new')
|
23
29
|
end
|
24
30
|
|
25
|
-
|
26
|
-
|
27
|
-
$LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
|
28
|
-
end
|
31
|
+
abort 'libstdc++ is not found.' unless have_library('stdc++')
|
32
|
+
$CXXFLAGS << ' -std=c++11'
|
29
33
|
|
30
|
-
use_accelerate = false
|
31
34
|
# NOTE: Accelerate framework does not support LAPACKE.
|
35
|
+
# use_accelerate = false
|
32
36
|
# if RUBY_PLATFORM.include?('darwin') && have_framework('Accelerate')
|
33
37
|
# $CFLAGS << ' -DTINYLINALG_USE_ACCELERATE'
|
34
38
|
# use_accelerate = true
|
35
39
|
# end
|
36
40
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
41
|
+
build_openblas = false
|
42
|
+
unless find_library('openblas', 'LAPACKE_dsyevr')
|
43
|
+
build_openblas = true unless have_library('openblas')
|
44
|
+
build_openblas = true unless have_library('lapacke')
|
45
|
+
end
|
46
|
+
build_openblas = true unless have_header('cblas.h')
|
47
|
+
build_openblas = true unless have_header('lapacke.h')
|
48
|
+
build_openblas = true unless have_header('openblas_config.h')
|
49
|
+
|
50
|
+
if build_openblas
|
51
|
+
warn 'BLAS and LAPACKE APIs are not found. Downloading and Building OpenBLAS...'
|
52
|
+
|
53
|
+
VENDOR_DIR = File.expand_path("#{__dir__}/../../../vendor")
|
54
|
+
OPENBLAS_VER = '0.3.23'
|
55
|
+
OPENBLAS_KEY = '115634b39007de71eb7e75cf7591dfb2'
|
56
|
+
OPENBLAS_URI = "https://github.com/xianyi/OpenBLAS/archive/v#{OPENBLAS_VER}.tar.gz"
|
57
|
+
OPENBLAS_TGZ = "#{VENDOR_DIR}/tmp/openblas.tgz"
|
58
|
+
|
59
|
+
unless File.exist?("#{VENDOR_DIR}/installed_#{OPENBLAS_VER}")
|
60
|
+
URI.parse(OPENBLAS_URI).open { |f| File.binwrite(OPENBLAS_TGZ, f.read) }
|
61
|
+
abort('MD5 digest of downloaded OpenBLAS does not match.') if Digest::MD5.file(OPENBLAS_TGZ).to_s != OPENBLAS_KEY
|
62
|
+
|
63
|
+
Gem::Package::TarReader.new(Zlib::GzipReader.open(OPENBLAS_TGZ)) do |tar|
|
64
|
+
tar.each do |entry|
|
65
|
+
next unless entry.file?
|
66
|
+
|
67
|
+
filename = "#{VENDOR_DIR}/tmp/#{entry.full_name}"
|
68
|
+
next if filename == File.dirname(filename)
|
69
|
+
|
70
|
+
FileUtils.mkdir_p("#{VENDOR_DIR}/tmp/#{File.dirname(entry.full_name)}")
|
71
|
+
File.binwrite(filename, entry.read)
|
72
|
+
File.chmod(entry.header.mode, filename)
|
73
|
+
end
|
74
|
+
end
|
75
|
+
|
76
|
+
Dir.chdir("#{VENDOR_DIR}/tmp/OpenBLAS-#{OPENBLAS_VER}") do
|
77
|
+
mkstdout, _mkstderr, mkstatus = Open3.capture3("make -j#{Etc.nprocessors}")
|
78
|
+
File.open("#{VENDOR_DIR}/tmp/openblas.log", 'w') { |f| f.puts(mkstdout) }
|
79
|
+
abort("Failed to build OpenBLAS. Check the openblas.log file for more details: #{VENDOR_DIR}/tmp/openblas.log") unless mkstatus.success?
|
80
|
+
|
81
|
+
insstdout, _insstderr, insstatus = Open3.capture3("make install PREFIX=#{VENDOR_DIR}")
|
82
|
+
File.open("#{VENDOR_DIR}/tmp/openblas.log", 'a') { |f| f.puts(insstdout) }
|
83
|
+
abort("Failed to install OpenBLAS. Check the openblas.log file for more details: #{VENDOR_DIR}/tmp/openblas.log") unless insstatus.success?
|
84
|
+
|
85
|
+
FileUtils.touch("#{VENDOR_DIR}/installed_#{OPENBLAS_VER}")
|
86
|
+
end
|
43
87
|
end
|
44
88
|
|
45
|
-
abort
|
46
|
-
abort
|
47
|
-
abort
|
89
|
+
abort('libopenblas is not found.') unless find_library('openblas', nil, "#{VENDOR_DIR}/lib")
|
90
|
+
abort('openblas_config.h is not found.') unless find_header('openblas_config.h', nil, "#{VENDOR_DIR}/include")
|
91
|
+
abort('cblas.h is not found.') unless find_header('cblas.h', nil, "#{VENDOR_DIR}/include")
|
92
|
+
abort('lapacke.h is not found.') unless find_header('lapacke.h', nil, "#{VENDOR_DIR}/include")
|
48
93
|
end
|
49
94
|
|
50
|
-
|
95
|
+
$CFLAGS << ' -DNUMO_TINY_LINALG_USE_OPENBLAS'
|
51
96
|
|
52
|
-
|
97
|
+
if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
|
98
|
+
try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
|
99
|
+
$LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
|
100
|
+
end
|
53
101
|
|
54
102
|
create_makefile('numo/tiny_linalg/tiny_linalg')
|
@@ -0,0 +1,118 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGeQrf {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
5
|
+
double* a, lapack_int lda, double* tau) {
|
6
|
+
return LAPACKE_dgeqrf(matrix_layout, m, n, a, lda, tau);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct SGeQrf {
|
11
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
12
|
+
float* a, lapack_int lda, float* tau) {
|
13
|
+
return LAPACKE_sgeqrf(matrix_layout, m, n, a, lda, tau);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
struct ZGeQrf {
|
18
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
19
|
+
lapack_complex_double* a, lapack_int lda, lapack_complex_double* tau) {
|
20
|
+
return LAPACKE_zgeqrf(matrix_layout, m, n, a, lda, tau);
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
struct CGeQrf {
|
25
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
26
|
+
lapack_complex_float* a, lapack_int lda, lapack_complex_float* tau) {
|
27
|
+
return LAPACKE_cgeqrf(matrix_layout, m, n, a, lda, tau);
|
28
|
+
}
|
29
|
+
};
|
30
|
+
|
31
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
32
|
+
class GeQrf {
|
33
|
+
public:
|
34
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_geqrf), -1);
|
36
|
+
}
|
37
|
+
|
38
|
+
private:
|
39
|
+
struct geqrf_opt {
|
40
|
+
int matrix_layout;
|
41
|
+
};
|
42
|
+
|
43
|
+
static void iter_geqrf(na_loop_t* const lp) {
|
44
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
45
|
+
DType* tau = (DType*)NDL_PTR(lp, 1);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
47
|
+
geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
49
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
50
|
+
const lapack_int lda = n;
|
51
|
+
const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, tau);
|
52
|
+
*info = static_cast<int>(i);
|
53
|
+
}
|
54
|
+
|
55
|
+
static VALUE tiny_linalg_geqrf(int argc, VALUE* argv, VALUE self) {
|
56
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
57
|
+
|
58
|
+
VALUE a_vnary = Qnil;
|
59
|
+
VALUE kw_args = Qnil;
|
60
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
61
|
+
ID kw_table[1] = { rb_intern("order") };
|
62
|
+
VALUE kw_values[1] = { Qundef };
|
63
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
64
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
65
|
+
|
66
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
67
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
68
|
+
}
|
69
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
70
|
+
a_vnary = nary_dup(a_vnary);
|
71
|
+
}
|
72
|
+
|
73
|
+
narray_t* a_nary = NULL;
|
74
|
+
GetNArray(a_vnary, a_nary);
|
75
|
+
const int n_dims = NA_NDIM(a_nary);
|
76
|
+
if (n_dims != 2) {
|
77
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
78
|
+
return Qnil;
|
79
|
+
}
|
80
|
+
|
81
|
+
size_t m = NA_SHAPE(a_nary)[0];
|
82
|
+
size_t n = NA_SHAPE(a_nary)[1];
|
83
|
+
size_t shape[1] = { m < n ? m : n };
|
84
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
85
|
+
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
|
86
|
+
ndfunc_t ndf = { iter_geqrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
87
|
+
geqrf_opt opt = { matrix_layout };
|
88
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
89
|
+
|
90
|
+
VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
|
91
|
+
|
92
|
+
RB_GC_GUARD(a_vnary);
|
93
|
+
|
94
|
+
return ret;
|
95
|
+
}
|
96
|
+
|
97
|
+
static int get_matrix_layout(VALUE val) {
|
98
|
+
const char* option_str = StringValueCStr(val);
|
99
|
+
|
100
|
+
if (std::strlen(option_str) > 0) {
|
101
|
+
switch (option_str[0]) {
|
102
|
+
case 'r':
|
103
|
+
case 'R':
|
104
|
+
break;
|
105
|
+
case 'c':
|
106
|
+
case 'C':
|
107
|
+
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
108
|
+
break;
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
RB_GC_GUARD(val);
|
113
|
+
|
114
|
+
return LAPACK_ROW_MAJOR;
|
115
|
+
}
|
116
|
+
};
|
117
|
+
|
118
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,148 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGESV {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
5
|
+
double* a, lapack_int lda, lapack_int* ipiv,
|
6
|
+
double* b, lapack_int ldb) {
|
7
|
+
return LAPACKE_dgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
|
8
|
+
}
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SGESV {
|
12
|
+
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
13
|
+
float* a, lapack_int lda, lapack_int* ipiv,
|
14
|
+
float* b, lapack_int ldb) {
|
15
|
+
return LAPACKE_sgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
|
16
|
+
}
|
17
|
+
};
|
18
|
+
|
19
|
+
struct ZGESV {
|
20
|
+
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
21
|
+
lapack_complex_double* a, lapack_int lda, lapack_int* ipiv,
|
22
|
+
lapack_complex_double* b, lapack_int ldb) {
|
23
|
+
return LAPACKE_zgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
struct CGESV {
|
28
|
+
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
29
|
+
lapack_complex_float* a, lapack_int lda, lapack_int* ipiv,
|
30
|
+
lapack_complex_float* b, lapack_int ldb) {
|
31
|
+
return LAPACKE_cgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
|
32
|
+
}
|
33
|
+
};
|
34
|
+
|
35
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
36
|
+
class GESV {
|
37
|
+
public:
|
38
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
39
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_gesv), -1);
|
40
|
+
}
|
41
|
+
|
42
|
+
private:
|
43
|
+
struct gesv_opt {
|
44
|
+
int matrix_layout;
|
45
|
+
};
|
46
|
+
|
47
|
+
static void iter_gesv(na_loop_t* const lp) {
|
48
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
49
|
+
DType* b = (DType*)NDL_PTR(lp, 1);
|
50
|
+
int* ipiv = (int*)NDL_PTR(lp, 2);
|
51
|
+
int* info = (int*)NDL_PTR(lp, 3);
|
52
|
+
gesv_opt* opt = (gesv_opt*)(lp->opt_ptr);
|
53
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
54
|
+
const lapack_int nhrs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
|
55
|
+
const lapack_int lda = n;
|
56
|
+
const lapack_int ldb = nhrs;
|
57
|
+
const lapack_int i = FncType().call(opt->matrix_layout, n, nhrs, a, lda, ipiv, b, ldb);
|
58
|
+
*info = static_cast<int>(i);
|
59
|
+
}
|
60
|
+
|
61
|
+
static VALUE tiny_linalg_gesv(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
63
|
+
VALUE a_vnary = Qnil;
|
64
|
+
VALUE b_vnary = Qnil;
|
65
|
+
VALUE kw_args = Qnil;
|
66
|
+
|
67
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
68
|
+
|
69
|
+
ID kw_table[1] = { rb_intern("order") };
|
70
|
+
VALUE kw_values[1] = { Qundef };
|
71
|
+
|
72
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
73
|
+
|
74
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
75
|
+
|
76
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
77
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
78
|
+
}
|
79
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
80
|
+
a_vnary = nary_dup(a_vnary);
|
81
|
+
}
|
82
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
83
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
84
|
+
}
|
85
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
86
|
+
b_vnary = nary_dup(b_vnary);
|
87
|
+
}
|
88
|
+
|
89
|
+
narray_t* a_nary = NULL;
|
90
|
+
narray_t* b_nary = NULL;
|
91
|
+
GetNArray(a_vnary, a_nary);
|
92
|
+
GetNArray(b_vnary, b_nary);
|
93
|
+
const int a_n_dims = NA_NDIM(a_nary);
|
94
|
+
const int b_n_dims = NA_NDIM(b_nary);
|
95
|
+
if (a_n_dims != 2) {
|
96
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
97
|
+
return Qnil;
|
98
|
+
}
|
99
|
+
if (b_n_dims != 1 && b_n_dims != 2) {
|
100
|
+
rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional");
|
101
|
+
return Qnil;
|
102
|
+
}
|
103
|
+
|
104
|
+
lapack_int n = NA_SHAPE(a_nary)[0];
|
105
|
+
lapack_int nb = b_n_dims == 1 ? NA_SHAPE(b_nary)[0] : NA_SHAPE(b_nary)[0];
|
106
|
+
if (n != nb) {
|
107
|
+
rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", n, nb);
|
108
|
+
}
|
109
|
+
|
110
|
+
lapack_int nhrs = b_n_dims == 1 ? 1 : NA_SHAPE(b_nary)[1];
|
111
|
+
size_t shape[2] = { static_cast<size_t>(n), static_cast<size_t>(nhrs) };
|
112
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, b_n_dims } };
|
113
|
+
ndfunc_arg_out_t aout[2] = { { numo_cInt32, 1, shape }, { numo_cInt32, 0 } };
|
114
|
+
|
115
|
+
ndfunc_t ndf = { iter_gesv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
|
116
|
+
gesv_opt opt = { matrix_layout };
|
117
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
118
|
+
|
119
|
+
VALUE ret = rb_ary_concat(rb_assoc_new(a_vnary, b_vnary), res);
|
120
|
+
|
121
|
+
RB_GC_GUARD(a_vnary);
|
122
|
+
RB_GC_GUARD(b_vnary);
|
123
|
+
|
124
|
+
return ret;
|
125
|
+
}
|
126
|
+
|
127
|
+
static int get_matrix_layout(VALUE val) {
|
128
|
+
const char* option_str = StringValueCStr(val);
|
129
|
+
|
130
|
+
if (std::strlen(option_str) > 0) {
|
131
|
+
switch (option_str[0]) {
|
132
|
+
case 'r':
|
133
|
+
case 'R':
|
134
|
+
break;
|
135
|
+
case 'c':
|
136
|
+
case 'C':
|
137
|
+
rb_warn("Numo::TinyLinalg::Lapack.gesv does not support column major.");
|
138
|
+
break;
|
139
|
+
}
|
140
|
+
}
|
141
|
+
|
142
|
+
RB_GC_GUARD(val);
|
143
|
+
|
144
|
+
return LAPACK_ROW_MAJOR;
|
145
|
+
}
|
146
|
+
};
|
147
|
+
|
148
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,118 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGETRF {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
5
|
+
double* a, lapack_int lda, lapack_int* ipiv) {
|
6
|
+
return LAPACKE_dgetrf(matrix_layout, m, n, a, lda, ipiv);
|
7
|
+
}
|
8
|
+
};
|
9
|
+
|
10
|
+
struct SGETRF {
|
11
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
12
|
+
float* a, lapack_int lda, lapack_int* ipiv) {
|
13
|
+
return LAPACKE_sgetrf(matrix_layout, m, n, a, lda, ipiv);
|
14
|
+
}
|
15
|
+
};
|
16
|
+
|
17
|
+
struct ZGETRF {
|
18
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
19
|
+
lapack_complex_double* a, lapack_int lda, lapack_int* ipiv) {
|
20
|
+
return LAPACKE_zgetrf(matrix_layout, m, n, a, lda, ipiv);
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
struct CGETRF {
|
25
|
+
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
26
|
+
lapack_complex_float* a, lapack_int lda, lapack_int* ipiv) {
|
27
|
+
return LAPACKE_cgetrf(matrix_layout, m, n, a, lda, ipiv);
|
28
|
+
}
|
29
|
+
};
|
30
|
+
|
31
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
32
|
+
class GETRF {
|
33
|
+
public:
|
34
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getrf), -1);
|
36
|
+
}
|
37
|
+
|
38
|
+
private:
|
39
|
+
struct getrf_opt {
|
40
|
+
int matrix_layout;
|
41
|
+
};
|
42
|
+
|
43
|
+
static void iter_getrf(na_loop_t* const lp) {
|
44
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
45
|
+
int* ipiv = (int*)NDL_PTR(lp, 1);
|
46
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
47
|
+
getrf_opt* opt = (getrf_opt*)(lp->opt_ptr);
|
48
|
+
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
49
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
50
|
+
const lapack_int lda = n;
|
51
|
+
const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, ipiv);
|
52
|
+
*info = static_cast<int>(i);
|
53
|
+
}
|
54
|
+
|
55
|
+
static VALUE tiny_linalg_getrf(int argc, VALUE* argv, VALUE self) {
|
56
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
57
|
+
|
58
|
+
VALUE a_vnary = Qnil;
|
59
|
+
VALUE kw_args = Qnil;
|
60
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
61
|
+
ID kw_table[1] = { rb_intern("order") };
|
62
|
+
VALUE kw_values[1] = { Qundef };
|
63
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
64
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
65
|
+
|
66
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
67
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
68
|
+
}
|
69
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
70
|
+
a_vnary = nary_dup(a_vnary);
|
71
|
+
}
|
72
|
+
|
73
|
+
narray_t* a_nary = NULL;
|
74
|
+
GetNArray(a_vnary, a_nary);
|
75
|
+
const int n_dims = NA_NDIM(a_nary);
|
76
|
+
if (n_dims != 2) {
|
77
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
78
|
+
return Qnil;
|
79
|
+
}
|
80
|
+
|
81
|
+
size_t m = NA_SHAPE(a_nary)[0];
|
82
|
+
size_t n = NA_SHAPE(a_nary)[1];
|
83
|
+
size_t shape[1] = { m < n ? m : n };
|
84
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
85
|
+
ndfunc_arg_out_t aout[2] = { { numo_cInt32, 1, shape }, { numo_cInt32, 0 } };
|
86
|
+
ndfunc_t ndf = { iter_getrf, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
|
87
|
+
getrf_opt opt = { matrix_layout };
|
88
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
89
|
+
|
90
|
+
VALUE ret = rb_ary_concat(rb_ary_new3(1, a_vnary), res);
|
91
|
+
|
92
|
+
RB_GC_GUARD(a_vnary);
|
93
|
+
|
94
|
+
return ret;
|
95
|
+
}
|
96
|
+
|
97
|
+
static int get_matrix_layout(VALUE val) {
|
98
|
+
const char* option_str = StringValueCStr(val);
|
99
|
+
|
100
|
+
if (std::strlen(option_str) > 0) {
|
101
|
+
switch (option_str[0]) {
|
102
|
+
case 'r':
|
103
|
+
case 'R':
|
104
|
+
break;
|
105
|
+
case 'c':
|
106
|
+
case 'C':
|
107
|
+
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
108
|
+
break;
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
RB_GC_GUARD(val);
|
113
|
+
|
114
|
+
return LAPACK_ROW_MAJOR;
|
115
|
+
}
|
116
|
+
};
|
117
|
+
|
118
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,127 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGETRI {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int n, double* a, lapack_int lda, const lapack_int* ipiv) {
|
5
|
+
return LAPACKE_dgetri(matrix_layout, n, a, lda, ipiv);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SGETRI {
|
10
|
+
lapack_int call(int matrix_layout, lapack_int n, float* a, lapack_int lda, const lapack_int* ipiv) {
|
11
|
+
return LAPACKE_sgetri(matrix_layout, n, a, lda, ipiv);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
struct ZGETRI {
|
16
|
+
lapack_int call(int matrix_layout, lapack_int n, lapack_complex_double* a, lapack_int lda, const lapack_int* ipiv) {
|
17
|
+
return LAPACKE_zgetri(matrix_layout, n, a, lda, ipiv);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
struct CGETRI {
|
22
|
+
lapack_int call(int matrix_layout, lapack_int n, lapack_complex_float* a, lapack_int lda, const lapack_int* ipiv) {
|
23
|
+
return LAPACKE_cgetri(matrix_layout, n, a, lda, ipiv);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, typename DType, typename FncType>
|
28
|
+
class GETRI {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getri), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
struct getri_opt {
|
36
|
+
int matrix_layout;
|
37
|
+
};
|
38
|
+
|
39
|
+
static void iter_getri(na_loop_t* const lp) {
|
40
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
41
|
+
lapack_int* ipiv = (lapack_int*)NDL_PTR(lp, 1);
|
42
|
+
int* info = (int*)NDL_PTR(lp, 2);
|
43
|
+
getri_opt* opt = (getri_opt*)(lp->opt_ptr);
|
44
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
45
|
+
const lapack_int lda = n;
|
46
|
+
const lapack_int i = FncType().call(opt->matrix_layout, n, a, lda, ipiv);
|
47
|
+
*info = static_cast<int>(i);
|
48
|
+
}
|
49
|
+
|
50
|
+
static VALUE tiny_linalg_getri(int argc, VALUE* argv, VALUE self) {
|
51
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
52
|
+
|
53
|
+
VALUE a_vnary = Qnil;
|
54
|
+
VALUE ipiv_vnary = Qnil;
|
55
|
+
VALUE kw_args = Qnil;
|
56
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &ipiv_vnary, &kw_args);
|
57
|
+
ID kw_table[1] = { rb_intern("order") };
|
58
|
+
VALUE kw_values[1] = { Qundef };
|
59
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
60
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
61
|
+
|
62
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
64
|
+
}
|
65
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
66
|
+
a_vnary = nary_dup(a_vnary);
|
67
|
+
}
|
68
|
+
if (CLASS_OF(ipiv_vnary) != numo_cInt32) {
|
69
|
+
ipiv_vnary = rb_funcall(numo_cInt32, rb_intern("cast"), 1, ipiv_vnary);
|
70
|
+
}
|
71
|
+
if (!RTEST(nary_check_contiguous(ipiv_vnary))) {
|
72
|
+
ipiv_vnary = nary_dup(ipiv_vnary);
|
73
|
+
}
|
74
|
+
|
75
|
+
narray_t* a_nary = NULL;
|
76
|
+
GetNArray(a_vnary, a_nary);
|
77
|
+
if (NA_NDIM(a_nary) != 2) {
|
78
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
79
|
+
return Qnil;
|
80
|
+
}
|
81
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
82
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
83
|
+
return Qnil;
|
84
|
+
}
|
85
|
+
narray_t* ipiv_nary = NULL;
|
86
|
+
GetNArray(ipiv_vnary, ipiv_nary);
|
87
|
+
if (NA_NDIM(ipiv_nary) != 1) {
|
88
|
+
rb_raise(rb_eArgError, "input array ipiv must be 1-dimensional");
|
89
|
+
return Qnil;
|
90
|
+
}
|
91
|
+
|
92
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { numo_cInt32, 1 } };
|
93
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } };
|
94
|
+
ndfunc_t ndf = { iter_getri, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
95
|
+
getri_opt opt = { matrix_layout };
|
96
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, ipiv_vnary);
|
97
|
+
|
98
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res);
|
99
|
+
|
100
|
+
RB_GC_GUARD(a_vnary);
|
101
|
+
RB_GC_GUARD(ipiv_vnary);
|
102
|
+
|
103
|
+
return ret;
|
104
|
+
}
|
105
|
+
|
106
|
+
static int get_matrix_layout(VALUE val) {
|
107
|
+
const char* option_str = StringValueCStr(val);
|
108
|
+
|
109
|
+
if (std::strlen(option_str) > 0) {
|
110
|
+
switch (option_str[0]) {
|
111
|
+
case 'r':
|
112
|
+
case 'R':
|
113
|
+
break;
|
114
|
+
case 'c':
|
115
|
+
case 'C':
|
116
|
+
rb_warn("Numo::TinyLinalg::Lapack.getri does not support column major.");
|
117
|
+
break;
|
118
|
+
}
|
119
|
+
}
|
120
|
+
|
121
|
+
RB_GC_GUARD(val);
|
122
|
+
|
123
|
+
return LAPACK_ROW_MAJOR;
|
124
|
+
}
|
125
|
+
};
|
126
|
+
|
127
|
+
} // namespace TinyLinalg
|