numo-tiny_linalg 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.clang-format +149 -0
- data/.husky/commit-msg +4 -0
- data/.rubocop.yml +47 -0
- data/CHANGELOG.md +5 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/Gemfile +15 -0
- data/LICENSE.txt +27 -0
- data/README.md +99 -0
- data/Rakefile +30 -0
- data/commitlint.config.js +1 -0
- data/ext/numo/tiny_linalg/converter.hpp +75 -0
- data/ext/numo/tiny_linalg/dot.hpp +86 -0
- data/ext/numo/tiny_linalg/dot_sub.hpp +85 -0
- data/ext/numo/tiny_linalg/extconf.rb +54 -0
- data/ext/numo/tiny_linalg/gemm.hpp +223 -0
- data/ext/numo/tiny_linalg/gemv.hpp +211 -0
- data/ext/numo/tiny_linalg/gesdd.hpp +134 -0
- data/ext/numo/tiny_linalg/gesvd.hpp +182 -0
- data/ext/numo/tiny_linalg/nrm2.hpp +89 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +251 -0
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +38 -0
- data/lib/numo/tiny_linalg/version.rb +10 -0
- data/lib/numo/tiny_linalg.rb +60 -0
- data/numo-tiny_linalg.gemspec +42 -0
- data/package.json +15 -0
- metadata +91 -0
@@ -0,0 +1,85 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct ZDotuSub {
|
4
|
+
void call(const int n, const double* x, const int incx, const double* y, const int incy, double* ret) {
|
5
|
+
cblas_zdotu_sub(n, x, incx, y, incy, ret);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct CDotuSub {
|
10
|
+
void call(const int n, const float* x, const int incx, const float* y, const int incy, float* ret) {
|
11
|
+
cblas_cdotu_sub(n, x, incx, y, incy, ret);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
template <int nary_dtype_id, typename dtype, class BlasFn>
|
16
|
+
class DotSub {
|
17
|
+
public:
|
18
|
+
static void define_module_function(VALUE mBlas, const char* modfn_name) {
|
19
|
+
rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_dot_sub), 2);
|
20
|
+
};
|
21
|
+
|
22
|
+
private:
|
23
|
+
static void iter_dot_sub(na_loop_t* const lp) {
|
24
|
+
dtype* x = (dtype*)NDL_PTR(lp, 0);
|
25
|
+
dtype* y = (dtype*)NDL_PTR(lp, 1);
|
26
|
+
dtype* d = (dtype*)NDL_PTR(lp, 2);
|
27
|
+
const size_t n = NDL_SHAPE(lp, 0)[0];
|
28
|
+
BlasFn().call(n, x, 1, y, 1, d);
|
29
|
+
};
|
30
|
+
|
31
|
+
static VALUE tiny_linalg_dot_sub(VALUE self, VALUE x, VALUE y) {
|
32
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
33
|
+
|
34
|
+
if (CLASS_OF(x) != nary_dtype) {
|
35
|
+
x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
|
36
|
+
}
|
37
|
+
if (!RTEST(nary_check_contiguous(x))) {
|
38
|
+
x = nary_dup(x);
|
39
|
+
}
|
40
|
+
if (CLASS_OF(y) != nary_dtype) {
|
41
|
+
y = rb_funcall(nary_dtype, rb_intern("cast"), 1, y);
|
42
|
+
}
|
43
|
+
if (!RTEST(nary_check_contiguous(y))) {
|
44
|
+
y = nary_dup(y);
|
45
|
+
}
|
46
|
+
|
47
|
+
narray_t* x_nary = NULL;
|
48
|
+
GetNArray(x, x_nary);
|
49
|
+
narray_t* y_nary = NULL;
|
50
|
+
GetNArray(y, y_nary);
|
51
|
+
|
52
|
+
if (NA_NDIM(x_nary) != 1) {
|
53
|
+
rb_raise(rb_eArgError, "x must be 1-dimensional");
|
54
|
+
return Qnil;
|
55
|
+
}
|
56
|
+
if (NA_NDIM(y_nary) != 1) {
|
57
|
+
rb_raise(rb_eArgError, "y must be 1-dimensional");
|
58
|
+
return Qnil;
|
59
|
+
}
|
60
|
+
if (NA_SIZE(x_nary) == 0) {
|
61
|
+
rb_raise(rb_eArgError, "x must not be empty");
|
62
|
+
return Qnil;
|
63
|
+
}
|
64
|
+
if (NA_SIZE(y_nary) == 0) {
|
65
|
+
rb_raise(rb_eArgError, "x must not be empty");
|
66
|
+
return Qnil;
|
67
|
+
}
|
68
|
+
if (NA_SIZE(x_nary) != NA_SIZE(y_nary)) {
|
69
|
+
rb_raise(rb_eArgError, "x and y must have same size");
|
70
|
+
return Qnil;
|
71
|
+
}
|
72
|
+
|
73
|
+
ndfunc_arg_in_t ain[2] = { { nary_dtype, 1 }, { nary_dtype, 1 } };
|
74
|
+
size_t shape_out[1] = { 1 };
|
75
|
+
ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
|
76
|
+
ndfunc_t ndf = { iter_dot_sub, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
|
77
|
+
VALUE ret = na_ndloop(&ndf, 2, x, y);
|
78
|
+
|
79
|
+
RB_GC_GUARD(x);
|
80
|
+
RB_GC_GUARD(y);
|
81
|
+
return ret;
|
82
|
+
};
|
83
|
+
};
|
84
|
+
|
85
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'mkmf'
|
4
|
+
require 'numo/narray'
|
5
|
+
|
6
|
+
$LOAD_PATH.each do |lp|
|
7
|
+
if File.exist?(File.join(lp, 'numo/numo/narray.h'))
|
8
|
+
$INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
|
9
|
+
break
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
abort 'numo/narray.h is not found' unless have_header('numo/narray.h')
|
14
|
+
|
15
|
+
if RUBY_PLATFORM.match?(/mswin|cygwin|mingw/)
|
16
|
+
$LOAD_PATH.each do |lp|
|
17
|
+
if File.exist?(File.join(lp, 'numo/libnarray.a'))
|
18
|
+
$LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
|
19
|
+
break
|
20
|
+
end
|
21
|
+
end
|
22
|
+
abort 'libnarray.a is not found' unless have_library('narray', 'nary_new')
|
23
|
+
end
|
24
|
+
|
25
|
+
if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
|
26
|
+
try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
|
27
|
+
$LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
|
28
|
+
end
|
29
|
+
|
30
|
+
use_accelerate = false
|
31
|
+
# NOTE: Accelerate framework does not support LAPACKE.
|
32
|
+
# if RUBY_PLATFORM.include?('darwin') && have_framework('Accelerate')
|
33
|
+
# $CFLAGS << ' -DTINYLINALG_USE_ACCELERATE'
|
34
|
+
# use_accelerate = true
|
35
|
+
# end
|
36
|
+
|
37
|
+
unless use_accelerate
|
38
|
+
if have_library('openblas')
|
39
|
+
$CFLAGS << ' -DTINYLINALG_USE_OPENBLAS'
|
40
|
+
else
|
41
|
+
abort 'libblas is not found' unless have_library('blas')
|
42
|
+
$CFLAGS << ' -DTINYLINALG_USE_BLAS'
|
43
|
+
end
|
44
|
+
|
45
|
+
abort 'liblapacke is not found' if !have_func('LAPACKE_dsyevr') && !have_library('lapacke')
|
46
|
+
abort 'cblas.h is not found' unless have_header('cblas.h')
|
47
|
+
abort 'lapacke.h is not found' unless have_header('lapacke.h')
|
48
|
+
end
|
49
|
+
|
50
|
+
abort 'libstdc++ is not found.' unless have_library('stdc++')
|
51
|
+
|
52
|
+
$CXXFLAGS << ' -std=c++11'
|
53
|
+
|
54
|
+
create_makefile('numo/tiny_linalg/tiny_linalg')
|
@@ -0,0 +1,223 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGemm {
|
4
|
+
void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
|
5
|
+
const blasint m, const blasint n, const blasint k,
|
6
|
+
const double alpha, const double* a, const blasint lda, const double* b, const blasint ldb, const double beta,
|
7
|
+
double* c, const blasint ldc) {
|
8
|
+
cblas_dgemm(order, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct SGemm {
|
13
|
+
void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
|
14
|
+
const blasint m, const blasint n, const blasint k,
|
15
|
+
const float alpha, const float* a, const blasint lda, const float* b, const blasint ldb, const float beta,
|
16
|
+
float* c, const blasint ldc) {
|
17
|
+
cblas_sgemm(order, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
struct ZGemm {
|
22
|
+
void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
|
23
|
+
const blasint m, const blasint n, const blasint k,
|
24
|
+
const dcomplex alpha, const dcomplex* a, const blasint lda, const dcomplex* b, const blasint ldb, const dcomplex beta,
|
25
|
+
dcomplex* c, const blasint ldc) {
|
26
|
+
cblas_zgemm(order, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
|
27
|
+
}
|
28
|
+
};
|
29
|
+
|
30
|
+
struct CGemm {
|
31
|
+
void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
|
32
|
+
const blasint m, const blasint n, const blasint k,
|
33
|
+
const scomplex alpha, const scomplex* a, const blasint lda, const scomplex* b, const blasint ldb, const scomplex beta,
|
34
|
+
scomplex* c, const blasint ldc) {
|
35
|
+
cblas_cgemm(order, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
|
36
|
+
}
|
37
|
+
};
|
38
|
+
|
39
|
+
template <int nary_dtype_id, typename dtype, class BlasFn, class Converter>
|
40
|
+
class Gemm {
|
41
|
+
public:
|
42
|
+
static void define_module_function(VALUE mBlas, const char* modfn_name) {
|
43
|
+
rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_gemm), -1);
|
44
|
+
};
|
45
|
+
|
46
|
+
private:
|
47
|
+
struct options {
|
48
|
+
dtype alpha;
|
49
|
+
dtype beta;
|
50
|
+
enum CBLAS_ORDER order;
|
51
|
+
enum CBLAS_TRANSPOSE transa;
|
52
|
+
enum CBLAS_TRANSPOSE transb;
|
53
|
+
blasint m;
|
54
|
+
blasint n;
|
55
|
+
blasint k;
|
56
|
+
};
|
57
|
+
|
58
|
+
static void iter_gemm(na_loop_t* const lp) {
|
59
|
+
const dtype* a = (dtype*)NDL_PTR(lp, 0);
|
60
|
+
const dtype* b = (dtype*)NDL_PTR(lp, 1);
|
61
|
+
dtype* c = (dtype*)NDL_PTR(lp, 2);
|
62
|
+
const options* opt = (options*)(lp->opt_ptr);
|
63
|
+
const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m;
|
64
|
+
const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k;
|
65
|
+
const blasint ldc = opt->n;
|
66
|
+
BlasFn().call(opt->order, opt->transa, opt->transb, opt->m, opt->n, opt->k, opt->alpha, a, lda, b, ldb, opt->beta, c, ldc);
|
67
|
+
};
|
68
|
+
|
69
|
+
static VALUE tiny_linalg_gemm(int argc, VALUE* argv, VALUE self) {
|
70
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
71
|
+
|
72
|
+
VALUE a = Qnil;
|
73
|
+
VALUE b = Qnil;
|
74
|
+
VALUE c = Qnil;
|
75
|
+
VALUE kw_args = Qnil;
|
76
|
+
rb_scan_args(argc, argv, "21:", &a, &b, &c, &kw_args);
|
77
|
+
|
78
|
+
ID kw_table[5] = { rb_intern("alpha"), rb_intern("beta"), rb_intern("order"), rb_intern("transa"), rb_intern("transb") };
|
79
|
+
VALUE kw_values[5] = { Qundef, Qundef, Qundef, Qundef, Qundef };
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 0, 5, kw_values);
|
81
|
+
|
82
|
+
if (CLASS_OF(a) != nary_dtype) {
|
83
|
+
a = rb_funcall(nary_dtype, rb_intern("cast"), 1, a);
|
84
|
+
}
|
85
|
+
if (!RTEST(nary_check_contiguous(a))) {
|
86
|
+
a = nary_dup(a);
|
87
|
+
}
|
88
|
+
if (CLASS_OF(b) != nary_dtype) {
|
89
|
+
b = rb_funcall(nary_dtype, rb_intern("cast"), 1, b);
|
90
|
+
}
|
91
|
+
if (!RTEST(nary_check_contiguous(b))) {
|
92
|
+
b = nary_dup(b);
|
93
|
+
}
|
94
|
+
if (!NIL_P(c)) {
|
95
|
+
if (CLASS_OF(c) != nary_dtype) {
|
96
|
+
c = rb_funcall(nary_dtype, rb_intern("cast"), 1, c);
|
97
|
+
}
|
98
|
+
if (!RTEST(nary_check_contiguous(c))) {
|
99
|
+
c = nary_dup(c);
|
100
|
+
}
|
101
|
+
}
|
102
|
+
|
103
|
+
dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
|
104
|
+
dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
|
105
|
+
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
|
106
|
+
enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
107
|
+
enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? get_cblas_trans(kw_values[4]) : CblasNoTrans;
|
108
|
+
|
109
|
+
narray_t* a_nary = NULL;
|
110
|
+
GetNArray(a, a_nary);
|
111
|
+
narray_t* b_nary = NULL;
|
112
|
+
GetNArray(b, b_nary);
|
113
|
+
|
114
|
+
if (NA_NDIM(a_nary) != 2) {
|
115
|
+
rb_raise(rb_eArgError, "a must be 2-dimensional");
|
116
|
+
return Qnil;
|
117
|
+
}
|
118
|
+
if (NA_NDIM(b_nary) != 2) {
|
119
|
+
rb_raise(rb_eArgError, "b must be 2-dimensional");
|
120
|
+
return Qnil;
|
121
|
+
}
|
122
|
+
if (NA_SIZE(a_nary) == 0) {
|
123
|
+
rb_raise(rb_eArgError, "a must not be empty");
|
124
|
+
return Qnil;
|
125
|
+
}
|
126
|
+
if (NA_SIZE(b_nary) == 0) {
|
127
|
+
rb_raise(rb_eArgError, "b must not be empty");
|
128
|
+
return Qnil;
|
129
|
+
}
|
130
|
+
|
131
|
+
const blasint ma = NA_SHAPE(a_nary)[0];
|
132
|
+
const blasint ka = NA_SHAPE(a_nary)[1];
|
133
|
+
const blasint kb = NA_SHAPE(b_nary)[0];
|
134
|
+
const blasint nb = NA_SHAPE(b_nary)[1];
|
135
|
+
const blasint m = transa == CblasNoTrans ? ma : ka;
|
136
|
+
const blasint n = transb == CblasNoTrans ? nb : kb;
|
137
|
+
const blasint k = transa == CblasNoTrans ? ka : ma;
|
138
|
+
const blasint l = transb == CblasNoTrans ? kb : nb;
|
139
|
+
|
140
|
+
if (k != l) {
|
141
|
+
rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", k, l);
|
142
|
+
return Qnil;
|
143
|
+
}
|
144
|
+
|
145
|
+
options opt = { alpha, beta, order, transa, transb, m, n, k };
|
146
|
+
size_t shape_out[2] = { static_cast<size_t>(m), static_cast<size_t>(n) };
|
147
|
+
ndfunc_arg_out_t aout[1] = { { nary_dtype, 2, shape_out } };
|
148
|
+
VALUE ret = Qnil;
|
149
|
+
|
150
|
+
if (!NIL_P(c)) {
|
151
|
+
narray_t* c_nary = NULL;
|
152
|
+
GetNArray(c, c_nary);
|
153
|
+
blasint nc = NA_SHAPE(c_nary)[0];
|
154
|
+
if (m > nc) {
|
155
|
+
rb_raise(nary_eShapeError, "shape3[0](=%d) >= shape1[0]=%d", nc, m);
|
156
|
+
return Qnil;
|
157
|
+
}
|
158
|
+
ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 2 }, { OVERWRITE, 2 } };
|
159
|
+
ndfunc_t ndf = { iter_gemm, NO_LOOP, 3, 0, ain, aout };
|
160
|
+
na_ndloop3(&ndf, &opt, 3, a, b, c);
|
161
|
+
ret = c;
|
162
|
+
} else {
|
163
|
+
c = INT2NUM(0);
|
164
|
+
ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 2 }, { sym_init, 0 } };
|
165
|
+
ndfunc_t ndf = { iter_gemm, NO_LOOP, 3, 1, ain, aout };
|
166
|
+
ret = na_ndloop3(&ndf, &opt, 3, a, b, c);
|
167
|
+
}
|
168
|
+
|
169
|
+
RB_GC_GUARD(a);
|
170
|
+
RB_GC_GUARD(b);
|
171
|
+
RB_GC_GUARD(c);
|
172
|
+
|
173
|
+
return ret;
|
174
|
+
};
|
175
|
+
|
176
|
+
static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
|
177
|
+
const char* option_str = StringValueCStr(val);
|
178
|
+
enum CBLAS_TRANSPOSE res = CblasNoTrans;
|
179
|
+
|
180
|
+
if (std::strlen(option_str) > 0) {
|
181
|
+
switch (option_str[0]) {
|
182
|
+
case 'n':
|
183
|
+
case 'N':
|
184
|
+
res = CblasNoTrans;
|
185
|
+
break;
|
186
|
+
case 't':
|
187
|
+
case 'T':
|
188
|
+
res = CblasTrans;
|
189
|
+
break;
|
190
|
+
case 'c':
|
191
|
+
case 'C':
|
192
|
+
res = CblasConjTrans;
|
193
|
+
break;
|
194
|
+
}
|
195
|
+
}
|
196
|
+
|
197
|
+
RB_GC_GUARD(val);
|
198
|
+
|
199
|
+
return res;
|
200
|
+
}
|
201
|
+
|
202
|
+
static enum CBLAS_ORDER get_cblas_order(VALUE val) {
|
203
|
+
const char* option_str = StringValueCStr(val);
|
204
|
+
|
205
|
+
if (std::strlen(option_str) > 0) {
|
206
|
+
switch (option_str[0]) {
|
207
|
+
case 'r':
|
208
|
+
case 'R':
|
209
|
+
break;
|
210
|
+
case 'c':
|
211
|
+
case 'C':
|
212
|
+
rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
|
213
|
+
break;
|
214
|
+
}
|
215
|
+
}
|
216
|
+
|
217
|
+
RB_GC_GUARD(val);
|
218
|
+
|
219
|
+
return CblasRowMajor;
|
220
|
+
}
|
221
|
+
};
|
222
|
+
|
223
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,211 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGemv {
|
4
|
+
void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
|
5
|
+
const double alpha, const double* a, const blasint lda,
|
6
|
+
const double* x, const blasint incx, const double beta, double* y, const blasint incy) {
|
7
|
+
cblas_dgemv(order, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
|
8
|
+
}
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SGemv {
|
12
|
+
void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
|
13
|
+
const float alpha, const float* a, const blasint lda,
|
14
|
+
const float* x, const blasint incx, const float beta, float* y, const blasint incy) {
|
15
|
+
cblas_sgemv(order, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
|
16
|
+
}
|
17
|
+
};
|
18
|
+
|
19
|
+
struct ZGemv {
|
20
|
+
void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
|
21
|
+
const dcomplex alpha, const dcomplex* a, const blasint lda,
|
22
|
+
const dcomplex* x, const blasint incx, const dcomplex beta, dcomplex* y, const blasint incy) {
|
23
|
+
cblas_zgemv(order, trans, m, n, &alpha, a, lda, x, incx, &beta, y, incy);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
struct CGemv {
|
28
|
+
void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
|
29
|
+
const scomplex alpha, const scomplex* a, const blasint lda,
|
30
|
+
const scomplex* x, const blasint incx, const scomplex beta, scomplex* y, const blasint incy) {
|
31
|
+
cblas_cgemv(order, trans, m, n, &alpha, a, lda, x, incx, &beta, y, incy);
|
32
|
+
}
|
33
|
+
};
|
34
|
+
|
35
|
+
template <int nary_dtype_id, typename dtype, class BlasFn, class Converter>
|
36
|
+
class Gemv {
|
37
|
+
public:
|
38
|
+
static void define_module_function(VALUE mBlas, const char* modfn_name) {
|
39
|
+
rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_gemv), -1);
|
40
|
+
};
|
41
|
+
|
42
|
+
private:
|
43
|
+
struct options {
|
44
|
+
dtype alpha;
|
45
|
+
dtype beta;
|
46
|
+
enum CBLAS_ORDER order;
|
47
|
+
enum CBLAS_TRANSPOSE trans;
|
48
|
+
blasint m;
|
49
|
+
blasint n;
|
50
|
+
};
|
51
|
+
|
52
|
+
static void iter_gemv(na_loop_t* const lp) {
|
53
|
+
const dtype* a = (dtype*)NDL_PTR(lp, 0);
|
54
|
+
const dtype* x = (dtype*)NDL_PTR(lp, 1);
|
55
|
+
dtype* y = (dtype*)NDL_PTR(lp, 2);
|
56
|
+
const options* opt = (options*)(lp->opt_ptr);
|
57
|
+
const blasint lda = opt->n;
|
58
|
+
BlasFn().call(opt->order, opt->trans, opt->m, opt->n, opt->alpha, a, lda, x, 1, opt->beta, y, 1);
|
59
|
+
};
|
60
|
+
|
61
|
+
static VALUE tiny_linalg_gemv(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
63
|
+
|
64
|
+
VALUE a = Qnil;
|
65
|
+
VALUE x = Qnil;
|
66
|
+
VALUE y = Qnil;
|
67
|
+
VALUE kw_args = Qnil;
|
68
|
+
rb_scan_args(argc, argv, "21:", &a, &x, &y, &kw_args);
|
69
|
+
|
70
|
+
ID kw_table[4] = { rb_intern("alpha"), rb_intern("beta"), rb_intern("order"), rb_intern("trans") };
|
71
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
72
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
|
73
|
+
|
74
|
+
if (CLASS_OF(a) != nary_dtype) {
|
75
|
+
a = rb_funcall(nary_dtype, rb_intern("cast"), 1, a);
|
76
|
+
}
|
77
|
+
if (!RTEST(nary_check_contiguous(a))) {
|
78
|
+
a = nary_dup(a);
|
79
|
+
}
|
80
|
+
if (CLASS_OF(x) != nary_dtype) {
|
81
|
+
x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
|
82
|
+
}
|
83
|
+
if (!RTEST(nary_check_contiguous(x))) {
|
84
|
+
x = nary_dup(x);
|
85
|
+
}
|
86
|
+
if (!NIL_P(y)) {
|
87
|
+
if (CLASS_OF(y) != nary_dtype) {
|
88
|
+
y = rb_funcall(nary_dtype, rb_intern("cast"), 1, y);
|
89
|
+
}
|
90
|
+
if (!RTEST(nary_check_contiguous(y))) {
|
91
|
+
y = nary_dup(y);
|
92
|
+
}
|
93
|
+
}
|
94
|
+
|
95
|
+
dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
|
96
|
+
dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
|
97
|
+
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
|
98
|
+
enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
99
|
+
|
100
|
+
narray_t* a_nary = NULL;
|
101
|
+
GetNArray(a, a_nary);
|
102
|
+
narray_t* x_nary = NULL;
|
103
|
+
GetNArray(x, x_nary);
|
104
|
+
|
105
|
+
if (NA_NDIM(a_nary) != 2) {
|
106
|
+
rb_raise(rb_eArgError, "a must be 2-dimensional");
|
107
|
+
return Qnil;
|
108
|
+
}
|
109
|
+
if (NA_NDIM(x_nary) != 1) {
|
110
|
+
rb_raise(rb_eArgError, "x must be 1-dimensional");
|
111
|
+
return Qnil;
|
112
|
+
}
|
113
|
+
if (NA_SIZE(a_nary) == 0) {
|
114
|
+
rb_raise(rb_eArgError, "a must not be empty");
|
115
|
+
return Qnil;
|
116
|
+
}
|
117
|
+
if (NA_SIZE(x_nary) == 0) {
|
118
|
+
rb_raise(rb_eArgError, "x must not be empty");
|
119
|
+
return Qnil;
|
120
|
+
}
|
121
|
+
|
122
|
+
const blasint ma = NA_SHAPE(a_nary)[0];
|
123
|
+
const blasint na = NA_SHAPE(a_nary)[1];
|
124
|
+
const blasint mx = NA_SHAPE(x_nary)[0];
|
125
|
+
const blasint m = trans == CblasNoTrans ? ma : na;
|
126
|
+
const blasint n = trans == CblasNoTrans ? na : ma;
|
127
|
+
|
128
|
+
if (n != mx) {
|
129
|
+
rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", n, mx);
|
130
|
+
return Qnil;
|
131
|
+
}
|
132
|
+
|
133
|
+
options opt = { alpha, beta, order, trans, ma, na };
|
134
|
+
size_t shape_out[1] = { static_cast<size_t>(m) };
|
135
|
+
ndfunc_arg_out_t aout[1] = { { nary_dtype, 1, shape_out } };
|
136
|
+
VALUE ret = Qnil;
|
137
|
+
|
138
|
+
if (!NIL_P(y)) {
|
139
|
+
narray_t* y_nary = NULL;
|
140
|
+
GetNArray(y, y_nary);
|
141
|
+
blasint my = NA_SHAPE(y_nary)[0];
|
142
|
+
if (m > my) {
|
143
|
+
rb_raise(nary_eShapeError, "shape3[0](=%d) >= shape1[0]=%d", my, m);
|
144
|
+
return Qnil;
|
145
|
+
}
|
146
|
+
ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 1 }, { OVERWRITE, 1 } };
|
147
|
+
ndfunc_t ndf = { iter_gemv, NO_LOOP, 3, 0, ain, aout };
|
148
|
+
na_ndloop3(&ndf, &opt, 3, a, x, y);
|
149
|
+
ret = y;
|
150
|
+
} else {
|
151
|
+
y = INT2NUM(0);
|
152
|
+
ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 1 }, { sym_init, 0 } };
|
153
|
+
ndfunc_t ndf = { iter_gemv, NO_LOOP, 3, 1, ain, aout };
|
154
|
+
ret = na_ndloop3(&ndf, &opt, 3, a, x, y);
|
155
|
+
}
|
156
|
+
|
157
|
+
RB_GC_GUARD(a);
|
158
|
+
RB_GC_GUARD(x);
|
159
|
+
RB_GC_GUARD(y);
|
160
|
+
|
161
|
+
return ret;
|
162
|
+
}
|
163
|
+
|
164
|
+
static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
|
165
|
+
const char* option_str = StringValueCStr(val);
|
166
|
+
enum CBLAS_TRANSPOSE res = CblasNoTrans;
|
167
|
+
|
168
|
+
if (std::strlen(option_str) > 0) {
|
169
|
+
switch (option_str[0]) {
|
170
|
+
case 'n':
|
171
|
+
case 'N':
|
172
|
+
res = CblasNoTrans;
|
173
|
+
break;
|
174
|
+
case 't':
|
175
|
+
case 'T':
|
176
|
+
res = CblasTrans;
|
177
|
+
break;
|
178
|
+
case 'c':
|
179
|
+
case 'C':
|
180
|
+
res = CblasConjTrans;
|
181
|
+
break;
|
182
|
+
}
|
183
|
+
}
|
184
|
+
|
185
|
+
RB_GC_GUARD(val);
|
186
|
+
|
187
|
+
return res;
|
188
|
+
}
|
189
|
+
|
190
|
+
static enum CBLAS_ORDER get_cblas_order(VALUE val) {
|
191
|
+
const char* option_str = StringValueCStr(val);
|
192
|
+
|
193
|
+
if (std::strlen(option_str) > 0) {
|
194
|
+
switch (option_str[0]) {
|
195
|
+
case 'r':
|
196
|
+
case 'R':
|
197
|
+
break;
|
198
|
+
case 'c':
|
199
|
+
case 'C':
|
200
|
+
rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
|
201
|
+
break;
|
202
|
+
}
|
203
|
+
}
|
204
|
+
|
205
|
+
RB_GC_GUARD(val);
|
206
|
+
|
207
|
+
return CblasRowMajor;
|
208
|
+
}
|
209
|
+
};
|
210
|
+
|
211
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,134 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGESDD {
|
4
|
+
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
5
|
+
double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt) {
|
6
|
+
return LAPACKE_dgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
7
|
+
};
|
8
|
+
};
|
9
|
+
|
10
|
+
struct SGESDD {
|
11
|
+
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
12
|
+
float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt) {
|
13
|
+
return LAPACKE_sgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
14
|
+
};
|
15
|
+
};
|
16
|
+
|
17
|
+
struct ZGESDD {
|
18
|
+
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
19
|
+
lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt) {
|
20
|
+
return LAPACKE_zgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
21
|
+
};
|
22
|
+
};
|
23
|
+
|
24
|
+
struct CGESDD {
|
25
|
+
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
26
|
+
lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt) {
|
27
|
+
return LAPACKE_cgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
28
|
+
};
|
29
|
+
};
|
30
|
+
|
31
|
+
template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
|
32
|
+
class GESDD {
|
33
|
+
public:
|
34
|
+
static void define_module_function(VALUE mLapack, const char* mf_name) {
|
35
|
+
rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesdd), -1);
|
36
|
+
};
|
37
|
+
|
38
|
+
private:
|
39
|
+
struct gesdd_opt {
|
40
|
+
int matrix_order;
|
41
|
+
char jobz;
|
42
|
+
};
|
43
|
+
|
44
|
+
static void iter_gesdd(na_loop_t* const lp) {
|
45
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
46
|
+
RType* s = (RType*)NDL_PTR(lp, 1);
|
47
|
+
DType* u = (DType*)NDL_PTR(lp, 2);
|
48
|
+
DType* vt = (DType*)NDL_PTR(lp, 3);
|
49
|
+
int* info = (int*)NDL_PTR(lp, 4);
|
50
|
+
gesdd_opt* opt = (gesdd_opt*)(lp->opt_ptr);
|
51
|
+
|
52
|
+
const size_t m = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[0] : NDL_SHAPE(lp, 0)[1];
|
53
|
+
const size_t n = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[1] : NDL_SHAPE(lp, 0)[0];
|
54
|
+
const size_t min_mn = m < n ? m : n;
|
55
|
+
const lapack_int lda = n;
|
56
|
+
const lapack_int ldu = opt->jobz == 'S' ? min_mn : m;
|
57
|
+
const lapack_int ldvt = opt->jobz == 'S' ? min_mn : n;
|
58
|
+
|
59
|
+
lapack_int i = FncType().call(opt->matrix_order, opt->jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
60
|
+
*info = static_cast<int>(i);
|
61
|
+
};
|
62
|
+
|
63
|
+
static VALUE tiny_linalg_gesdd(int argc, VALUE* argv, VALUE self) {
|
64
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
65
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
66
|
+
VALUE a_vnary = Qnil;
|
67
|
+
VALUE kw_args = Qnil;
|
68
|
+
|
69
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
70
|
+
|
71
|
+
ID kw_table[2] = { rb_intern("jobz"), rb_intern("order") };
|
72
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
73
|
+
|
74
|
+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
|
75
|
+
|
76
|
+
const char jobz = kw_values[0] == Qundef ? 'A' : StringValueCStr(kw_values[0])[0];
|
77
|
+
const char order = kw_values[1] == Qundef ? 'R' : StringValueCStr(kw_values[1])[0];
|
78
|
+
|
79
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
80
|
+
rb_raise(rb_eTypeError, "type of input array is invalid for overwriting");
|
81
|
+
return Qnil;
|
82
|
+
}
|
83
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
84
|
+
a_vnary = nary_dup(a_vnary);
|
85
|
+
}
|
86
|
+
|
87
|
+
narray_t* a_nary = NULL;
|
88
|
+
GetNArray(a_vnary, a_nary);
|
89
|
+
const int n_dims = NA_NDIM(a_nary);
|
90
|
+
if (n_dims != 2) {
|
91
|
+
rb_raise(rb_eArgError, "input array must be 2-dimensional");
|
92
|
+
return Qnil;
|
93
|
+
}
|
94
|
+
|
95
|
+
const int matrix_order = order == 'C' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
|
96
|
+
const size_t m = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[0] : NA_SHAPE(a_nary)[1];
|
97
|
+
const size_t n = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[1] : NA_SHAPE(a_nary)[0];
|
98
|
+
|
99
|
+
const size_t min_mn = m < n ? m : n;
|
100
|
+
size_t shape_s[1] = { min_mn };
|
101
|
+
size_t shape_u[2] = { m, m };
|
102
|
+
size_t shape_vt[2] = { n, n };
|
103
|
+
|
104
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
105
|
+
ndfunc_arg_out_t aout[4] = { { nary_rtype, 1, shape_s }, { nary_dtype, 2, shape_u }, { nary_dtype, 2, shape_vt }, { numo_cInt32, 0 } };
|
106
|
+
|
107
|
+
switch (jobz) {
|
108
|
+
case 'A':
|
109
|
+
break;
|
110
|
+
case 'S':
|
111
|
+
shape_u[matrix_order == LAPACK_ROW_MAJOR ? 1 : 0] = min_mn;
|
112
|
+
shape_vt[matrix_order == LAPACK_ROW_MAJOR ? 0 : 1] = min_mn;
|
113
|
+
break;
|
114
|
+
case 'O':
|
115
|
+
break;
|
116
|
+
case 'N':
|
117
|
+
aout[1].dim = 0;
|
118
|
+
aout[2].dim = 0;
|
119
|
+
break;
|
120
|
+
default:
|
121
|
+
rb_raise(rb_eArgError, "jobz must be one of 'A', 'S', 'O', or 'N'");
|
122
|
+
return Qnil;
|
123
|
+
}
|
124
|
+
|
125
|
+
ndfunc_t ndf = { iter_gesdd, NO_LOOP | NDF_EXTRACT, 1, 4, ain, aout };
|
126
|
+
gesdd_opt opt = { matrix_order, jobz };
|
127
|
+
VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
128
|
+
|
129
|
+
RB_GC_GUARD(a_vnary);
|
130
|
+
return ret;
|
131
|
+
};
|
132
|
+
};
|
133
|
+
|
134
|
+
} // namespace TinyLinalg
|