numo-tiny_linalg 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,85 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZDotuSub {
4
+ void call(const int n, const double* x, const int incx, const double* y, const int incy, double* ret) {
5
+ cblas_zdotu_sub(n, x, incx, y, incy, ret);
6
+ }
7
+ };
8
+
9
+ struct CDotuSub {
10
+ void call(const int n, const float* x, const int incx, const float* y, const int incy, float* ret) {
11
+ cblas_cdotu_sub(n, x, incx, y, incy, ret);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, typename dtype, class BlasFn>
16
+ class DotSub {
17
+ public:
18
+ static void define_module_function(VALUE mBlas, const char* modfn_name) {
19
+ rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_dot_sub), 2);
20
+ };
21
+
22
+ private:
23
+ static void iter_dot_sub(na_loop_t* const lp) {
24
+ dtype* x = (dtype*)NDL_PTR(lp, 0);
25
+ dtype* y = (dtype*)NDL_PTR(lp, 1);
26
+ dtype* d = (dtype*)NDL_PTR(lp, 2);
27
+ const size_t n = NDL_SHAPE(lp, 0)[0];
28
+ BlasFn().call(n, x, 1, y, 1, d);
29
+ };
30
+
31
+ static VALUE tiny_linalg_dot_sub(VALUE self, VALUE x, VALUE y) {
32
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
33
+
34
+ if (CLASS_OF(x) != nary_dtype) {
35
+ x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
36
+ }
37
+ if (!RTEST(nary_check_contiguous(x))) {
38
+ x = nary_dup(x);
39
+ }
40
+ if (CLASS_OF(y) != nary_dtype) {
41
+ y = rb_funcall(nary_dtype, rb_intern("cast"), 1, y);
42
+ }
43
+ if (!RTEST(nary_check_contiguous(y))) {
44
+ y = nary_dup(y);
45
+ }
46
+
47
+ narray_t* x_nary = NULL;
48
+ GetNArray(x, x_nary);
49
+ narray_t* y_nary = NULL;
50
+ GetNArray(y, y_nary);
51
+
52
+ if (NA_NDIM(x_nary) != 1) {
53
+ rb_raise(rb_eArgError, "x must be 1-dimensional");
54
+ return Qnil;
55
+ }
56
+ if (NA_NDIM(y_nary) != 1) {
57
+ rb_raise(rb_eArgError, "y must be 1-dimensional");
58
+ return Qnil;
59
+ }
60
+ if (NA_SIZE(x_nary) == 0) {
61
+ rb_raise(rb_eArgError, "x must not be empty");
62
+ return Qnil;
63
+ }
64
+ if (NA_SIZE(y_nary) == 0) {
65
+ rb_raise(rb_eArgError, "x must not be empty");
66
+ return Qnil;
67
+ }
68
+ if (NA_SIZE(x_nary) != NA_SIZE(y_nary)) {
69
+ rb_raise(rb_eArgError, "x and y must have same size");
70
+ return Qnil;
71
+ }
72
+
73
+ ndfunc_arg_in_t ain[2] = { { nary_dtype, 1 }, { nary_dtype, 1 } };
74
+ size_t shape_out[1] = { 1 };
75
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
76
+ ndfunc_t ndf = { iter_dot_sub, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
77
+ VALUE ret = na_ndloop(&ndf, 2, x, y);
78
+
79
+ RB_GC_GUARD(x);
80
+ RB_GC_GUARD(y);
81
+ return ret;
82
+ };
83
+ };
84
+
85
+ } // namespace TinyLinalg
@@ -0,0 +1,54 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'mkmf'
4
+ require 'numo/narray'
5
+
6
+ $LOAD_PATH.each do |lp|
7
+ if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
+ $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
9
+ break
10
+ end
11
+ end
12
+
13
+ abort 'numo/narray.h is not found' unless have_header('numo/narray.h')
14
+
15
+ if RUBY_PLATFORM.match?(/mswin|cygwin|mingw/)
16
+ $LOAD_PATH.each do |lp|
17
+ if File.exist?(File.join(lp, 'numo/libnarray.a'))
18
+ $LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
19
+ break
20
+ end
21
+ end
22
+ abort 'libnarray.a is not found' unless have_library('narray', 'nary_new')
23
+ end
24
+
25
+ if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
26
+ try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
27
+ $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
28
+ end
29
+
30
+ use_accelerate = false
31
+ # NOTE: Accelerate framework does not support LAPACKE.
32
+ # if RUBY_PLATFORM.include?('darwin') && have_framework('Accelerate')
33
+ # $CFLAGS << ' -DTINYLINALG_USE_ACCELERATE'
34
+ # use_accelerate = true
35
+ # end
36
+
37
+ unless use_accelerate
38
+ if have_library('openblas')
39
+ $CFLAGS << ' -DTINYLINALG_USE_OPENBLAS'
40
+ else
41
+ abort 'libblas is not found' unless have_library('blas')
42
+ $CFLAGS << ' -DTINYLINALG_USE_BLAS'
43
+ end
44
+
45
+ abort 'liblapacke is not found' if !have_func('LAPACKE_dsyevr') && !have_library('lapacke')
46
+ abort 'cblas.h is not found' unless have_header('cblas.h')
47
+ abort 'lapacke.h is not found' unless have_header('lapacke.h')
48
+ end
49
+
50
+ abort 'libstdc++ is not found.' unless have_library('stdc++')
51
+
52
+ $CXXFLAGS << ' -std=c++11'
53
+
54
+ create_makefile('numo/tiny_linalg/tiny_linalg')
@@ -0,0 +1,223 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGemm {
4
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
5
+ const blasint m, const blasint n, const blasint k,
6
+ const double alpha, const double* a, const blasint lda, const double* b, const blasint ldb, const double beta,
7
+ double* c, const blasint ldc) {
8
+ cblas_dgemm(order, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
9
+ }
10
+ };
11
+
12
+ struct SGemm {
13
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
14
+ const blasint m, const blasint n, const blasint k,
15
+ const float alpha, const float* a, const blasint lda, const float* b, const blasint ldb, const float beta,
16
+ float* c, const blasint ldc) {
17
+ cblas_sgemm(order, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
18
+ }
19
+ };
20
+
21
+ struct ZGemm {
22
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
23
+ const blasint m, const blasint n, const blasint k,
24
+ const dcomplex alpha, const dcomplex* a, const blasint lda, const dcomplex* b, const blasint ldb, const dcomplex beta,
25
+ dcomplex* c, const blasint ldc) {
26
+ cblas_zgemm(order, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
27
+ }
28
+ };
29
+
30
+ struct CGemm {
31
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
32
+ const blasint m, const blasint n, const blasint k,
33
+ const scomplex alpha, const scomplex* a, const blasint lda, const scomplex* b, const blasint ldb, const scomplex beta,
34
+ scomplex* c, const blasint ldc) {
35
+ cblas_cgemm(order, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
36
+ }
37
+ };
38
+
39
+ template <int nary_dtype_id, typename dtype, class BlasFn, class Converter>
40
+ class Gemm {
41
+ public:
42
+ static void define_module_function(VALUE mBlas, const char* modfn_name) {
43
+ rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_gemm), -1);
44
+ };
45
+
46
+ private:
47
+ struct options {
48
+ dtype alpha;
49
+ dtype beta;
50
+ enum CBLAS_ORDER order;
51
+ enum CBLAS_TRANSPOSE transa;
52
+ enum CBLAS_TRANSPOSE transb;
53
+ blasint m;
54
+ blasint n;
55
+ blasint k;
56
+ };
57
+
58
+ static void iter_gemm(na_loop_t* const lp) {
59
+ const dtype* a = (dtype*)NDL_PTR(lp, 0);
60
+ const dtype* b = (dtype*)NDL_PTR(lp, 1);
61
+ dtype* c = (dtype*)NDL_PTR(lp, 2);
62
+ const options* opt = (options*)(lp->opt_ptr);
63
+ const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m;
64
+ const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k;
65
+ const blasint ldc = opt->n;
66
+ BlasFn().call(opt->order, opt->transa, opt->transb, opt->m, opt->n, opt->k, opt->alpha, a, lda, b, ldb, opt->beta, c, ldc);
67
+ };
68
+
69
+ static VALUE tiny_linalg_gemm(int argc, VALUE* argv, VALUE self) {
70
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
71
+
72
+ VALUE a = Qnil;
73
+ VALUE b = Qnil;
74
+ VALUE c = Qnil;
75
+ VALUE kw_args = Qnil;
76
+ rb_scan_args(argc, argv, "21:", &a, &b, &c, &kw_args);
77
+
78
+ ID kw_table[5] = { rb_intern("alpha"), rb_intern("beta"), rb_intern("order"), rb_intern("transa"), rb_intern("transb") };
79
+ VALUE kw_values[5] = { Qundef, Qundef, Qundef, Qundef, Qundef };
80
+ rb_get_kwargs(kw_args, kw_table, 0, 5, kw_values);
81
+
82
+ if (CLASS_OF(a) != nary_dtype) {
83
+ a = rb_funcall(nary_dtype, rb_intern("cast"), 1, a);
84
+ }
85
+ if (!RTEST(nary_check_contiguous(a))) {
86
+ a = nary_dup(a);
87
+ }
88
+ if (CLASS_OF(b) != nary_dtype) {
89
+ b = rb_funcall(nary_dtype, rb_intern("cast"), 1, b);
90
+ }
91
+ if (!RTEST(nary_check_contiguous(b))) {
92
+ b = nary_dup(b);
93
+ }
94
+ if (!NIL_P(c)) {
95
+ if (CLASS_OF(c) != nary_dtype) {
96
+ c = rb_funcall(nary_dtype, rb_intern("cast"), 1, c);
97
+ }
98
+ if (!RTEST(nary_check_contiguous(c))) {
99
+ c = nary_dup(c);
100
+ }
101
+ }
102
+
103
+ dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
104
+ dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
105
+ enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
106
+ enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
107
+ enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? get_cblas_trans(kw_values[4]) : CblasNoTrans;
108
+
109
+ narray_t* a_nary = NULL;
110
+ GetNArray(a, a_nary);
111
+ narray_t* b_nary = NULL;
112
+ GetNArray(b, b_nary);
113
+
114
+ if (NA_NDIM(a_nary) != 2) {
115
+ rb_raise(rb_eArgError, "a must be 2-dimensional");
116
+ return Qnil;
117
+ }
118
+ if (NA_NDIM(b_nary) != 2) {
119
+ rb_raise(rb_eArgError, "b must be 2-dimensional");
120
+ return Qnil;
121
+ }
122
+ if (NA_SIZE(a_nary) == 0) {
123
+ rb_raise(rb_eArgError, "a must not be empty");
124
+ return Qnil;
125
+ }
126
+ if (NA_SIZE(b_nary) == 0) {
127
+ rb_raise(rb_eArgError, "b must not be empty");
128
+ return Qnil;
129
+ }
130
+
131
+ const blasint ma = NA_SHAPE(a_nary)[0];
132
+ const blasint ka = NA_SHAPE(a_nary)[1];
133
+ const blasint kb = NA_SHAPE(b_nary)[0];
134
+ const blasint nb = NA_SHAPE(b_nary)[1];
135
+ const blasint m = transa == CblasNoTrans ? ma : ka;
136
+ const blasint n = transb == CblasNoTrans ? nb : kb;
137
+ const blasint k = transa == CblasNoTrans ? ka : ma;
138
+ const blasint l = transb == CblasNoTrans ? kb : nb;
139
+
140
+ if (k != l) {
141
+ rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", k, l);
142
+ return Qnil;
143
+ }
144
+
145
+ options opt = { alpha, beta, order, transa, transb, m, n, k };
146
+ size_t shape_out[2] = { static_cast<size_t>(m), static_cast<size_t>(n) };
147
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 2, shape_out } };
148
+ VALUE ret = Qnil;
149
+
150
+ if (!NIL_P(c)) {
151
+ narray_t* c_nary = NULL;
152
+ GetNArray(c, c_nary);
153
+ blasint nc = NA_SHAPE(c_nary)[0];
154
+ if (m > nc) {
155
+ rb_raise(nary_eShapeError, "shape3[0](=%d) >= shape1[0]=%d", nc, m);
156
+ return Qnil;
157
+ }
158
+ ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 2 }, { OVERWRITE, 2 } };
159
+ ndfunc_t ndf = { iter_gemm, NO_LOOP, 3, 0, ain, aout };
160
+ na_ndloop3(&ndf, &opt, 3, a, b, c);
161
+ ret = c;
162
+ } else {
163
+ c = INT2NUM(0);
164
+ ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 2 }, { sym_init, 0 } };
165
+ ndfunc_t ndf = { iter_gemm, NO_LOOP, 3, 1, ain, aout };
166
+ ret = na_ndloop3(&ndf, &opt, 3, a, b, c);
167
+ }
168
+
169
+ RB_GC_GUARD(a);
170
+ RB_GC_GUARD(b);
171
+ RB_GC_GUARD(c);
172
+
173
+ return ret;
174
+ };
175
+
176
+ static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
177
+ const char* option_str = StringValueCStr(val);
178
+ enum CBLAS_TRANSPOSE res = CblasNoTrans;
179
+
180
+ if (std::strlen(option_str) > 0) {
181
+ switch (option_str[0]) {
182
+ case 'n':
183
+ case 'N':
184
+ res = CblasNoTrans;
185
+ break;
186
+ case 't':
187
+ case 'T':
188
+ res = CblasTrans;
189
+ break;
190
+ case 'c':
191
+ case 'C':
192
+ res = CblasConjTrans;
193
+ break;
194
+ }
195
+ }
196
+
197
+ RB_GC_GUARD(val);
198
+
199
+ return res;
200
+ }
201
+
202
+ static enum CBLAS_ORDER get_cblas_order(VALUE val) {
203
+ const char* option_str = StringValueCStr(val);
204
+
205
+ if (std::strlen(option_str) > 0) {
206
+ switch (option_str[0]) {
207
+ case 'r':
208
+ case 'R':
209
+ break;
210
+ case 'c':
211
+ case 'C':
212
+ rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
213
+ break;
214
+ }
215
+ }
216
+
217
+ RB_GC_GUARD(val);
218
+
219
+ return CblasRowMajor;
220
+ }
221
+ };
222
+
223
+ } // namespace TinyLinalg
@@ -0,0 +1,211 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGemv {
4
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
5
+ const double alpha, const double* a, const blasint lda,
6
+ const double* x, const blasint incx, const double beta, double* y, const blasint incy) {
7
+ cblas_dgemv(order, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
8
+ }
9
+ };
10
+
11
+ struct SGemv {
12
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
13
+ const float alpha, const float* a, const blasint lda,
14
+ const float* x, const blasint incx, const float beta, float* y, const blasint incy) {
15
+ cblas_sgemv(order, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
16
+ }
17
+ };
18
+
19
+ struct ZGemv {
20
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
21
+ const dcomplex alpha, const dcomplex* a, const blasint lda,
22
+ const dcomplex* x, const blasint incx, const dcomplex beta, dcomplex* y, const blasint incy) {
23
+ cblas_zgemv(order, trans, m, n, &alpha, a, lda, x, incx, &beta, y, incy);
24
+ }
25
+ };
26
+
27
+ struct CGemv {
28
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
29
+ const scomplex alpha, const scomplex* a, const blasint lda,
30
+ const scomplex* x, const blasint incx, const scomplex beta, scomplex* y, const blasint incy) {
31
+ cblas_cgemv(order, trans, m, n, &alpha, a, lda, x, incx, &beta, y, incy);
32
+ }
33
+ };
34
+
35
+ template <int nary_dtype_id, typename dtype, class BlasFn, class Converter>
36
+ class Gemv {
37
+ public:
38
+ static void define_module_function(VALUE mBlas, const char* modfn_name) {
39
+ rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_gemv), -1);
40
+ };
41
+
42
+ private:
43
+ struct options {
44
+ dtype alpha;
45
+ dtype beta;
46
+ enum CBLAS_ORDER order;
47
+ enum CBLAS_TRANSPOSE trans;
48
+ blasint m;
49
+ blasint n;
50
+ };
51
+
52
+ static void iter_gemv(na_loop_t* const lp) {
53
+ const dtype* a = (dtype*)NDL_PTR(lp, 0);
54
+ const dtype* x = (dtype*)NDL_PTR(lp, 1);
55
+ dtype* y = (dtype*)NDL_PTR(lp, 2);
56
+ const options* opt = (options*)(lp->opt_ptr);
57
+ const blasint lda = opt->n;
58
+ BlasFn().call(opt->order, opt->trans, opt->m, opt->n, opt->alpha, a, lda, x, 1, opt->beta, y, 1);
59
+ };
60
+
61
+ static VALUE tiny_linalg_gemv(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+
64
+ VALUE a = Qnil;
65
+ VALUE x = Qnil;
66
+ VALUE y = Qnil;
67
+ VALUE kw_args = Qnil;
68
+ rb_scan_args(argc, argv, "21:", &a, &x, &y, &kw_args);
69
+
70
+ ID kw_table[4] = { rb_intern("alpha"), rb_intern("beta"), rb_intern("order"), rb_intern("trans") };
71
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
72
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
73
+
74
+ if (CLASS_OF(a) != nary_dtype) {
75
+ a = rb_funcall(nary_dtype, rb_intern("cast"), 1, a);
76
+ }
77
+ if (!RTEST(nary_check_contiguous(a))) {
78
+ a = nary_dup(a);
79
+ }
80
+ if (CLASS_OF(x) != nary_dtype) {
81
+ x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
82
+ }
83
+ if (!RTEST(nary_check_contiguous(x))) {
84
+ x = nary_dup(x);
85
+ }
86
+ if (!NIL_P(y)) {
87
+ if (CLASS_OF(y) != nary_dtype) {
88
+ y = rb_funcall(nary_dtype, rb_intern("cast"), 1, y);
89
+ }
90
+ if (!RTEST(nary_check_contiguous(y))) {
91
+ y = nary_dup(y);
92
+ }
93
+ }
94
+
95
+ dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
96
+ dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
97
+ enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
98
+ enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
99
+
100
+ narray_t* a_nary = NULL;
101
+ GetNArray(a, a_nary);
102
+ narray_t* x_nary = NULL;
103
+ GetNArray(x, x_nary);
104
+
105
+ if (NA_NDIM(a_nary) != 2) {
106
+ rb_raise(rb_eArgError, "a must be 2-dimensional");
107
+ return Qnil;
108
+ }
109
+ if (NA_NDIM(x_nary) != 1) {
110
+ rb_raise(rb_eArgError, "x must be 1-dimensional");
111
+ return Qnil;
112
+ }
113
+ if (NA_SIZE(a_nary) == 0) {
114
+ rb_raise(rb_eArgError, "a must not be empty");
115
+ return Qnil;
116
+ }
117
+ if (NA_SIZE(x_nary) == 0) {
118
+ rb_raise(rb_eArgError, "x must not be empty");
119
+ return Qnil;
120
+ }
121
+
122
+ const blasint ma = NA_SHAPE(a_nary)[0];
123
+ const blasint na = NA_SHAPE(a_nary)[1];
124
+ const blasint mx = NA_SHAPE(x_nary)[0];
125
+ const blasint m = trans == CblasNoTrans ? ma : na;
126
+ const blasint n = trans == CblasNoTrans ? na : ma;
127
+
128
+ if (n != mx) {
129
+ rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", n, mx);
130
+ return Qnil;
131
+ }
132
+
133
+ options opt = { alpha, beta, order, trans, ma, na };
134
+ size_t shape_out[1] = { static_cast<size_t>(m) };
135
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 1, shape_out } };
136
+ VALUE ret = Qnil;
137
+
138
+ if (!NIL_P(y)) {
139
+ narray_t* y_nary = NULL;
140
+ GetNArray(y, y_nary);
141
+ blasint my = NA_SHAPE(y_nary)[0];
142
+ if (m > my) {
143
+ rb_raise(nary_eShapeError, "shape3[0](=%d) >= shape1[0]=%d", my, m);
144
+ return Qnil;
145
+ }
146
+ ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 1 }, { OVERWRITE, 1 } };
147
+ ndfunc_t ndf = { iter_gemv, NO_LOOP, 3, 0, ain, aout };
148
+ na_ndloop3(&ndf, &opt, 3, a, x, y);
149
+ ret = y;
150
+ } else {
151
+ y = INT2NUM(0);
152
+ ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 1 }, { sym_init, 0 } };
153
+ ndfunc_t ndf = { iter_gemv, NO_LOOP, 3, 1, ain, aout };
154
+ ret = na_ndloop3(&ndf, &opt, 3, a, x, y);
155
+ }
156
+
157
+ RB_GC_GUARD(a);
158
+ RB_GC_GUARD(x);
159
+ RB_GC_GUARD(y);
160
+
161
+ return ret;
162
+ }
163
+
164
+ static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
165
+ const char* option_str = StringValueCStr(val);
166
+ enum CBLAS_TRANSPOSE res = CblasNoTrans;
167
+
168
+ if (std::strlen(option_str) > 0) {
169
+ switch (option_str[0]) {
170
+ case 'n':
171
+ case 'N':
172
+ res = CblasNoTrans;
173
+ break;
174
+ case 't':
175
+ case 'T':
176
+ res = CblasTrans;
177
+ break;
178
+ case 'c':
179
+ case 'C':
180
+ res = CblasConjTrans;
181
+ break;
182
+ }
183
+ }
184
+
185
+ RB_GC_GUARD(val);
186
+
187
+ return res;
188
+ }
189
+
190
+ static enum CBLAS_ORDER get_cblas_order(VALUE val) {
191
+ const char* option_str = StringValueCStr(val);
192
+
193
+ if (std::strlen(option_str) > 0) {
194
+ switch (option_str[0]) {
195
+ case 'r':
196
+ case 'R':
197
+ break;
198
+ case 'c':
199
+ case 'C':
200
+ rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
201
+ break;
202
+ }
203
+ }
204
+
205
+ RB_GC_GUARD(val);
206
+
207
+ return CblasRowMajor;
208
+ }
209
+ };
210
+
211
+ } // namespace TinyLinalg
@@ -0,0 +1,134 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGESDD {
4
+ lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt) {
6
+ return LAPACKE_dgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
7
+ };
8
+ };
9
+
10
+ struct SGESDD {
11
+ lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt) {
13
+ return LAPACKE_sgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
14
+ };
15
+ };
16
+
17
+ struct ZGESDD {
18
+ lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt) {
20
+ return LAPACKE_zgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
21
+ };
22
+ };
23
+
24
+ struct CGESDD {
25
+ lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt) {
27
+ return LAPACKE_cgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
28
+ };
29
+ };
30
+
31
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
32
+ class GESDD {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* mf_name) {
35
+ rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesdd), -1);
36
+ };
37
+
38
+ private:
39
+ struct gesdd_opt {
40
+ int matrix_order;
41
+ char jobz;
42
+ };
43
+
44
+ static void iter_gesdd(na_loop_t* const lp) {
45
+ DType* a = (DType*)NDL_PTR(lp, 0);
46
+ RType* s = (RType*)NDL_PTR(lp, 1);
47
+ DType* u = (DType*)NDL_PTR(lp, 2);
48
+ DType* vt = (DType*)NDL_PTR(lp, 3);
49
+ int* info = (int*)NDL_PTR(lp, 4);
50
+ gesdd_opt* opt = (gesdd_opt*)(lp->opt_ptr);
51
+
52
+ const size_t m = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[0] : NDL_SHAPE(lp, 0)[1];
53
+ const size_t n = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[1] : NDL_SHAPE(lp, 0)[0];
54
+ const size_t min_mn = m < n ? m : n;
55
+ const lapack_int lda = n;
56
+ const lapack_int ldu = opt->jobz == 'S' ? min_mn : m;
57
+ const lapack_int ldvt = opt->jobz == 'S' ? min_mn : n;
58
+
59
+ lapack_int i = FncType().call(opt->matrix_order, opt->jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
60
+ *info = static_cast<int>(i);
61
+ };
62
+
63
+ static VALUE tiny_linalg_gesdd(int argc, VALUE* argv, VALUE self) {
64
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
65
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
66
+ VALUE a_vnary = Qnil;
67
+ VALUE kw_args = Qnil;
68
+
69
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
70
+
71
+ ID kw_table[2] = { rb_intern("jobz"), rb_intern("order") };
72
+ VALUE kw_values[2] = { Qundef, Qundef };
73
+
74
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
75
+
76
+ const char jobz = kw_values[0] == Qundef ? 'A' : StringValueCStr(kw_values[0])[0];
77
+ const char order = kw_values[1] == Qundef ? 'R' : StringValueCStr(kw_values[1])[0];
78
+
79
+ if (CLASS_OF(a_vnary) != nary_dtype) {
80
+ rb_raise(rb_eTypeError, "type of input array is invalid for overwriting");
81
+ return Qnil;
82
+ }
83
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
84
+ a_vnary = nary_dup(a_vnary);
85
+ }
86
+
87
+ narray_t* a_nary = NULL;
88
+ GetNArray(a_vnary, a_nary);
89
+ const int n_dims = NA_NDIM(a_nary);
90
+ if (n_dims != 2) {
91
+ rb_raise(rb_eArgError, "input array must be 2-dimensional");
92
+ return Qnil;
93
+ }
94
+
95
+ const int matrix_order = order == 'C' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
96
+ const size_t m = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[0] : NA_SHAPE(a_nary)[1];
97
+ const size_t n = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[1] : NA_SHAPE(a_nary)[0];
98
+
99
+ const size_t min_mn = m < n ? m : n;
100
+ size_t shape_s[1] = { min_mn };
101
+ size_t shape_u[2] = { m, m };
102
+ size_t shape_vt[2] = { n, n };
103
+
104
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
105
+ ndfunc_arg_out_t aout[4] = { { nary_rtype, 1, shape_s }, { nary_dtype, 2, shape_u }, { nary_dtype, 2, shape_vt }, { numo_cInt32, 0 } };
106
+
107
+ switch (jobz) {
108
+ case 'A':
109
+ break;
110
+ case 'S':
111
+ shape_u[matrix_order == LAPACK_ROW_MAJOR ? 1 : 0] = min_mn;
112
+ shape_vt[matrix_order == LAPACK_ROW_MAJOR ? 0 : 1] = min_mn;
113
+ break;
114
+ case 'O':
115
+ break;
116
+ case 'N':
117
+ aout[1].dim = 0;
118
+ aout[2].dim = 0;
119
+ break;
120
+ default:
121
+ rb_raise(rb_eArgError, "jobz must be one of 'A', 'S', 'O', or 'N'");
122
+ return Qnil;
123
+ }
124
+
125
+ ndfunc_t ndf = { iter_gesdd, NO_LOOP | NDF_EXTRACT, 1, 4, ain, aout };
126
+ gesdd_opt opt = { matrix_order, jobz };
127
+ VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
128
+
129
+ RB_GC_GUARD(a_vnary);
130
+ return ret;
131
+ };
132
+ };
133
+
134
+ } // namespace TinyLinalg