numo-tiny_linalg 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,85 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZDotuSub {
4
+ void call(const int n, const double* x, const int incx, const double* y, const int incy, double* ret) {
5
+ cblas_zdotu_sub(n, x, incx, y, incy, ret);
6
+ }
7
+ };
8
+
9
+ struct CDotuSub {
10
+ void call(const int n, const float* x, const int incx, const float* y, const int incy, float* ret) {
11
+ cblas_cdotu_sub(n, x, incx, y, incy, ret);
12
+ }
13
+ };
14
+
15
+ template <int nary_dtype_id, typename dtype, class BlasFn>
16
+ class DotSub {
17
+ public:
18
+ static void define_module_function(VALUE mBlas, const char* modfn_name) {
19
+ rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_dot_sub), 2);
20
+ };
21
+
22
+ private:
23
+ static void iter_dot_sub(na_loop_t* const lp) {
24
+ dtype* x = (dtype*)NDL_PTR(lp, 0);
25
+ dtype* y = (dtype*)NDL_PTR(lp, 1);
26
+ dtype* d = (dtype*)NDL_PTR(lp, 2);
27
+ const size_t n = NDL_SHAPE(lp, 0)[0];
28
+ BlasFn().call(n, x, 1, y, 1, d);
29
+ };
30
+
31
+ static VALUE tiny_linalg_dot_sub(VALUE self, VALUE x, VALUE y) {
32
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
33
+
34
+ if (CLASS_OF(x) != nary_dtype) {
35
+ x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
36
+ }
37
+ if (!RTEST(nary_check_contiguous(x))) {
38
+ x = nary_dup(x);
39
+ }
40
+ if (CLASS_OF(y) != nary_dtype) {
41
+ y = rb_funcall(nary_dtype, rb_intern("cast"), 1, y);
42
+ }
43
+ if (!RTEST(nary_check_contiguous(y))) {
44
+ y = nary_dup(y);
45
+ }
46
+
47
+ narray_t* x_nary = NULL;
48
+ GetNArray(x, x_nary);
49
+ narray_t* y_nary = NULL;
50
+ GetNArray(y, y_nary);
51
+
52
+ if (NA_NDIM(x_nary) != 1) {
53
+ rb_raise(rb_eArgError, "x must be 1-dimensional");
54
+ return Qnil;
55
+ }
56
+ if (NA_NDIM(y_nary) != 1) {
57
+ rb_raise(rb_eArgError, "y must be 1-dimensional");
58
+ return Qnil;
59
+ }
60
+ if (NA_SIZE(x_nary) == 0) {
61
+ rb_raise(rb_eArgError, "x must not be empty");
62
+ return Qnil;
63
+ }
64
+ if (NA_SIZE(y_nary) == 0) {
65
+ rb_raise(rb_eArgError, "x must not be empty");
66
+ return Qnil;
67
+ }
68
+ if (NA_SIZE(x_nary) != NA_SIZE(y_nary)) {
69
+ rb_raise(rb_eArgError, "x and y must have same size");
70
+ return Qnil;
71
+ }
72
+
73
+ ndfunc_arg_in_t ain[2] = { { nary_dtype, 1 }, { nary_dtype, 1 } };
74
+ size_t shape_out[1] = { 1 };
75
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
76
+ ndfunc_t ndf = { iter_dot_sub, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout };
77
+ VALUE ret = na_ndloop(&ndf, 2, x, y);
78
+
79
+ RB_GC_GUARD(x);
80
+ RB_GC_GUARD(y);
81
+ return ret;
82
+ };
83
+ };
84
+
85
+ } // namespace TinyLinalg
@@ -0,0 +1,54 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'mkmf'
4
+ require 'numo/narray'
5
+
6
+ $LOAD_PATH.each do |lp|
7
+ if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
+ $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
9
+ break
10
+ end
11
+ end
12
+
13
+ abort 'numo/narray.h is not found' unless have_header('numo/narray.h')
14
+
15
+ if RUBY_PLATFORM.match?(/mswin|cygwin|mingw/)
16
+ $LOAD_PATH.each do |lp|
17
+ if File.exist?(File.join(lp, 'numo/libnarray.a'))
18
+ $LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
19
+ break
20
+ end
21
+ end
22
+ abort 'libnarray.a is not found' unless have_library('narray', 'nary_new')
23
+ end
24
+
25
+ if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
26
+ try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
27
+ $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
28
+ end
29
+
30
+ use_accelerate = false
31
+ # NOTE: Accelerate framework does not support LAPACKE.
32
+ # if RUBY_PLATFORM.include?('darwin') && have_framework('Accelerate')
33
+ # $CFLAGS << ' -DTINYLINALG_USE_ACCELERATE'
34
+ # use_accelerate = true
35
+ # end
36
+
37
+ unless use_accelerate
38
+ if have_library('openblas')
39
+ $CFLAGS << ' -DTINYLINALG_USE_OPENBLAS'
40
+ else
41
+ abort 'libblas is not found' unless have_library('blas')
42
+ $CFLAGS << ' -DTINYLINALG_USE_BLAS'
43
+ end
44
+
45
+ abort 'liblapacke is not found' if !have_func('LAPACKE_dsyevr') && !have_library('lapacke')
46
+ abort 'cblas.h is not found' unless have_header('cblas.h')
47
+ abort 'lapacke.h is not found' unless have_header('lapacke.h')
48
+ end
49
+
50
+ abort 'libstdc++ is not found.' unless have_library('stdc++')
51
+
52
+ $CXXFLAGS << ' -std=c++11'
53
+
54
+ create_makefile('numo/tiny_linalg/tiny_linalg')
@@ -0,0 +1,223 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGemm {
4
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
5
+ const blasint m, const blasint n, const blasint k,
6
+ const double alpha, const double* a, const blasint lda, const double* b, const blasint ldb, const double beta,
7
+ double* c, const blasint ldc) {
8
+ cblas_dgemm(order, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
9
+ }
10
+ };
11
+
12
+ struct SGemm {
13
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
14
+ const blasint m, const blasint n, const blasint k,
15
+ const float alpha, const float* a, const blasint lda, const float* b, const blasint ldb, const float beta,
16
+ float* c, const blasint ldc) {
17
+ cblas_sgemm(order, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
18
+ }
19
+ };
20
+
21
+ struct ZGemm {
22
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
23
+ const blasint m, const blasint n, const blasint k,
24
+ const dcomplex alpha, const dcomplex* a, const blasint lda, const dcomplex* b, const blasint ldb, const dcomplex beta,
25
+ dcomplex* c, const blasint ldc) {
26
+ cblas_zgemm(order, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
27
+ }
28
+ };
29
+
30
+ struct CGemm {
31
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_TRANSPOSE transb,
32
+ const blasint m, const blasint n, const blasint k,
33
+ const scomplex alpha, const scomplex* a, const blasint lda, const scomplex* b, const blasint ldb, const scomplex beta,
34
+ scomplex* c, const blasint ldc) {
35
+ cblas_cgemm(order, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
36
+ }
37
+ };
38
+
39
+ template <int nary_dtype_id, typename dtype, class BlasFn, class Converter>
40
+ class Gemm {
41
+ public:
42
+ static void define_module_function(VALUE mBlas, const char* modfn_name) {
43
+ rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_gemm), -1);
44
+ };
45
+
46
+ private:
47
+ struct options {
48
+ dtype alpha;
49
+ dtype beta;
50
+ enum CBLAS_ORDER order;
51
+ enum CBLAS_TRANSPOSE transa;
52
+ enum CBLAS_TRANSPOSE transb;
53
+ blasint m;
54
+ blasint n;
55
+ blasint k;
56
+ };
57
+
58
+ static void iter_gemm(na_loop_t* const lp) {
59
+ const dtype* a = (dtype*)NDL_PTR(lp, 0);
60
+ const dtype* b = (dtype*)NDL_PTR(lp, 1);
61
+ dtype* c = (dtype*)NDL_PTR(lp, 2);
62
+ const options* opt = (options*)(lp->opt_ptr);
63
+ const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m;
64
+ const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k;
65
+ const blasint ldc = opt->n;
66
+ BlasFn().call(opt->order, opt->transa, opt->transb, opt->m, opt->n, opt->k, opt->alpha, a, lda, b, ldb, opt->beta, c, ldc);
67
+ };
68
+
69
+ static VALUE tiny_linalg_gemm(int argc, VALUE* argv, VALUE self) {
70
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
71
+
72
+ VALUE a = Qnil;
73
+ VALUE b = Qnil;
74
+ VALUE c = Qnil;
75
+ VALUE kw_args = Qnil;
76
+ rb_scan_args(argc, argv, "21:", &a, &b, &c, &kw_args);
77
+
78
+ ID kw_table[5] = { rb_intern("alpha"), rb_intern("beta"), rb_intern("order"), rb_intern("transa"), rb_intern("transb") };
79
+ VALUE kw_values[5] = { Qundef, Qundef, Qundef, Qundef, Qundef };
80
+ rb_get_kwargs(kw_args, kw_table, 0, 5, kw_values);
81
+
82
+ if (CLASS_OF(a) != nary_dtype) {
83
+ a = rb_funcall(nary_dtype, rb_intern("cast"), 1, a);
84
+ }
85
+ if (!RTEST(nary_check_contiguous(a))) {
86
+ a = nary_dup(a);
87
+ }
88
+ if (CLASS_OF(b) != nary_dtype) {
89
+ b = rb_funcall(nary_dtype, rb_intern("cast"), 1, b);
90
+ }
91
+ if (!RTEST(nary_check_contiguous(b))) {
92
+ b = nary_dup(b);
93
+ }
94
+ if (!NIL_P(c)) {
95
+ if (CLASS_OF(c) != nary_dtype) {
96
+ c = rb_funcall(nary_dtype, rb_intern("cast"), 1, c);
97
+ }
98
+ if (!RTEST(nary_check_contiguous(c))) {
99
+ c = nary_dup(c);
100
+ }
101
+ }
102
+
103
+ dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
104
+ dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
105
+ enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
106
+ enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
107
+ enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? get_cblas_trans(kw_values[4]) : CblasNoTrans;
108
+
109
+ narray_t* a_nary = NULL;
110
+ GetNArray(a, a_nary);
111
+ narray_t* b_nary = NULL;
112
+ GetNArray(b, b_nary);
113
+
114
+ if (NA_NDIM(a_nary) != 2) {
115
+ rb_raise(rb_eArgError, "a must be 2-dimensional");
116
+ return Qnil;
117
+ }
118
+ if (NA_NDIM(b_nary) != 2) {
119
+ rb_raise(rb_eArgError, "b must be 2-dimensional");
120
+ return Qnil;
121
+ }
122
+ if (NA_SIZE(a_nary) == 0) {
123
+ rb_raise(rb_eArgError, "a must not be empty");
124
+ return Qnil;
125
+ }
126
+ if (NA_SIZE(b_nary) == 0) {
127
+ rb_raise(rb_eArgError, "b must not be empty");
128
+ return Qnil;
129
+ }
130
+
131
+ const blasint ma = NA_SHAPE(a_nary)[0];
132
+ const blasint ka = NA_SHAPE(a_nary)[1];
133
+ const blasint kb = NA_SHAPE(b_nary)[0];
134
+ const blasint nb = NA_SHAPE(b_nary)[1];
135
+ const blasint m = transa == CblasNoTrans ? ma : ka;
136
+ const blasint n = transb == CblasNoTrans ? nb : kb;
137
+ const blasint k = transa == CblasNoTrans ? ka : ma;
138
+ const blasint l = transb == CblasNoTrans ? kb : nb;
139
+
140
+ if (k != l) {
141
+ rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", k, l);
142
+ return Qnil;
143
+ }
144
+
145
+ options opt = { alpha, beta, order, transa, transb, m, n, k };
146
+ size_t shape_out[2] = { static_cast<size_t>(m), static_cast<size_t>(n) };
147
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 2, shape_out } };
148
+ VALUE ret = Qnil;
149
+
150
+ if (!NIL_P(c)) {
151
+ narray_t* c_nary = NULL;
152
+ GetNArray(c, c_nary);
153
+ blasint nc = NA_SHAPE(c_nary)[0];
154
+ if (m > nc) {
155
+ rb_raise(nary_eShapeError, "shape3[0](=%d) >= shape1[0]=%d", nc, m);
156
+ return Qnil;
157
+ }
158
+ ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 2 }, { OVERWRITE, 2 } };
159
+ ndfunc_t ndf = { iter_gemm, NO_LOOP, 3, 0, ain, aout };
160
+ na_ndloop3(&ndf, &opt, 3, a, b, c);
161
+ ret = c;
162
+ } else {
163
+ c = INT2NUM(0);
164
+ ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 2 }, { sym_init, 0 } };
165
+ ndfunc_t ndf = { iter_gemm, NO_LOOP, 3, 1, ain, aout };
166
+ ret = na_ndloop3(&ndf, &opt, 3, a, b, c);
167
+ }
168
+
169
+ RB_GC_GUARD(a);
170
+ RB_GC_GUARD(b);
171
+ RB_GC_GUARD(c);
172
+
173
+ return ret;
174
+ };
175
+
176
+ static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
177
+ const char* option_str = StringValueCStr(val);
178
+ enum CBLAS_TRANSPOSE res = CblasNoTrans;
179
+
180
+ if (std::strlen(option_str) > 0) {
181
+ switch (option_str[0]) {
182
+ case 'n':
183
+ case 'N':
184
+ res = CblasNoTrans;
185
+ break;
186
+ case 't':
187
+ case 'T':
188
+ res = CblasTrans;
189
+ break;
190
+ case 'c':
191
+ case 'C':
192
+ res = CblasConjTrans;
193
+ break;
194
+ }
195
+ }
196
+
197
+ RB_GC_GUARD(val);
198
+
199
+ return res;
200
+ }
201
+
202
+ static enum CBLAS_ORDER get_cblas_order(VALUE val) {
203
+ const char* option_str = StringValueCStr(val);
204
+
205
+ if (std::strlen(option_str) > 0) {
206
+ switch (option_str[0]) {
207
+ case 'r':
208
+ case 'R':
209
+ break;
210
+ case 'c':
211
+ case 'C':
212
+ rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
213
+ break;
214
+ }
215
+ }
216
+
217
+ RB_GC_GUARD(val);
218
+
219
+ return CblasRowMajor;
220
+ }
221
+ };
222
+
223
+ } // namespace TinyLinalg
@@ -0,0 +1,211 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGemv {
4
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
5
+ const double alpha, const double* a, const blasint lda,
6
+ const double* x, const blasint incx, const double beta, double* y, const blasint incy) {
7
+ cblas_dgemv(order, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
8
+ }
9
+ };
10
+
11
+ struct SGemv {
12
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
13
+ const float alpha, const float* a, const blasint lda,
14
+ const float* x, const blasint incx, const float beta, float* y, const blasint incy) {
15
+ cblas_sgemv(order, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
16
+ }
17
+ };
18
+
19
+ struct ZGemv {
20
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
21
+ const dcomplex alpha, const dcomplex* a, const blasint lda,
22
+ const dcomplex* x, const blasint incx, const dcomplex beta, dcomplex* y, const blasint incy) {
23
+ cblas_zgemv(order, trans, m, n, &alpha, a, lda, x, incx, &beta, y, incy);
24
+ }
25
+ };
26
+
27
+ struct CGemv {
28
+ void call(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const blasint m, const blasint n,
29
+ const scomplex alpha, const scomplex* a, const blasint lda,
30
+ const scomplex* x, const blasint incx, const scomplex beta, scomplex* y, const blasint incy) {
31
+ cblas_cgemv(order, trans, m, n, &alpha, a, lda, x, incx, &beta, y, incy);
32
+ }
33
+ };
34
+
35
+ template <int nary_dtype_id, typename dtype, class BlasFn, class Converter>
36
+ class Gemv {
37
+ public:
38
+ static void define_module_function(VALUE mBlas, const char* modfn_name) {
39
+ rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_gemv), -1);
40
+ };
41
+
42
+ private:
43
+ struct options {
44
+ dtype alpha;
45
+ dtype beta;
46
+ enum CBLAS_ORDER order;
47
+ enum CBLAS_TRANSPOSE trans;
48
+ blasint m;
49
+ blasint n;
50
+ };
51
+
52
+ static void iter_gemv(na_loop_t* const lp) {
53
+ const dtype* a = (dtype*)NDL_PTR(lp, 0);
54
+ const dtype* x = (dtype*)NDL_PTR(lp, 1);
55
+ dtype* y = (dtype*)NDL_PTR(lp, 2);
56
+ const options* opt = (options*)(lp->opt_ptr);
57
+ const blasint lda = opt->n;
58
+ BlasFn().call(opt->order, opt->trans, opt->m, opt->n, opt->alpha, a, lda, x, 1, opt->beta, y, 1);
59
+ };
60
+
61
+ static VALUE tiny_linalg_gemv(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+
64
+ VALUE a = Qnil;
65
+ VALUE x = Qnil;
66
+ VALUE y = Qnil;
67
+ VALUE kw_args = Qnil;
68
+ rb_scan_args(argc, argv, "21:", &a, &x, &y, &kw_args);
69
+
70
+ ID kw_table[4] = { rb_intern("alpha"), rb_intern("beta"), rb_intern("order"), rb_intern("trans") };
71
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
72
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
73
+
74
+ if (CLASS_OF(a) != nary_dtype) {
75
+ a = rb_funcall(nary_dtype, rb_intern("cast"), 1, a);
76
+ }
77
+ if (!RTEST(nary_check_contiguous(a))) {
78
+ a = nary_dup(a);
79
+ }
80
+ if (CLASS_OF(x) != nary_dtype) {
81
+ x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
82
+ }
83
+ if (!RTEST(nary_check_contiguous(x))) {
84
+ x = nary_dup(x);
85
+ }
86
+ if (!NIL_P(y)) {
87
+ if (CLASS_OF(y) != nary_dtype) {
88
+ y = rb_funcall(nary_dtype, rb_intern("cast"), 1, y);
89
+ }
90
+ if (!RTEST(nary_check_contiguous(y))) {
91
+ y = nary_dup(y);
92
+ }
93
+ }
94
+
95
+ dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
96
+ dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
97
+ enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
98
+ enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
99
+
100
+ narray_t* a_nary = NULL;
101
+ GetNArray(a, a_nary);
102
+ narray_t* x_nary = NULL;
103
+ GetNArray(x, x_nary);
104
+
105
+ if (NA_NDIM(a_nary) != 2) {
106
+ rb_raise(rb_eArgError, "a must be 2-dimensional");
107
+ return Qnil;
108
+ }
109
+ if (NA_NDIM(x_nary) != 1) {
110
+ rb_raise(rb_eArgError, "x must be 1-dimensional");
111
+ return Qnil;
112
+ }
113
+ if (NA_SIZE(a_nary) == 0) {
114
+ rb_raise(rb_eArgError, "a must not be empty");
115
+ return Qnil;
116
+ }
117
+ if (NA_SIZE(x_nary) == 0) {
118
+ rb_raise(rb_eArgError, "x must not be empty");
119
+ return Qnil;
120
+ }
121
+
122
+ const blasint ma = NA_SHAPE(a_nary)[0];
123
+ const blasint na = NA_SHAPE(a_nary)[1];
124
+ const blasint mx = NA_SHAPE(x_nary)[0];
125
+ const blasint m = trans == CblasNoTrans ? ma : na;
126
+ const blasint n = trans == CblasNoTrans ? na : ma;
127
+
128
+ if (n != mx) {
129
+ rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", n, mx);
130
+ return Qnil;
131
+ }
132
+
133
+ options opt = { alpha, beta, order, trans, ma, na };
134
+ size_t shape_out[1] = { static_cast<size_t>(m) };
135
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 1, shape_out } };
136
+ VALUE ret = Qnil;
137
+
138
+ if (!NIL_P(y)) {
139
+ narray_t* y_nary = NULL;
140
+ GetNArray(y, y_nary);
141
+ blasint my = NA_SHAPE(y_nary)[0];
142
+ if (m > my) {
143
+ rb_raise(nary_eShapeError, "shape3[0](=%d) >= shape1[0]=%d", my, m);
144
+ return Qnil;
145
+ }
146
+ ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 1 }, { OVERWRITE, 1 } };
147
+ ndfunc_t ndf = { iter_gemv, NO_LOOP, 3, 0, ain, aout };
148
+ na_ndloop3(&ndf, &opt, 3, a, x, y);
149
+ ret = y;
150
+ } else {
151
+ y = INT2NUM(0);
152
+ ndfunc_arg_in_t ain[3] = { { nary_dtype, 2 }, { nary_dtype, 1 }, { sym_init, 0 } };
153
+ ndfunc_t ndf = { iter_gemv, NO_LOOP, 3, 1, ain, aout };
154
+ ret = na_ndloop3(&ndf, &opt, 3, a, x, y);
155
+ }
156
+
157
+ RB_GC_GUARD(a);
158
+ RB_GC_GUARD(x);
159
+ RB_GC_GUARD(y);
160
+
161
+ return ret;
162
+ }
163
+
164
+ static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
165
+ const char* option_str = StringValueCStr(val);
166
+ enum CBLAS_TRANSPOSE res = CblasNoTrans;
167
+
168
+ if (std::strlen(option_str) > 0) {
169
+ switch (option_str[0]) {
170
+ case 'n':
171
+ case 'N':
172
+ res = CblasNoTrans;
173
+ break;
174
+ case 't':
175
+ case 'T':
176
+ res = CblasTrans;
177
+ break;
178
+ case 'c':
179
+ case 'C':
180
+ res = CblasConjTrans;
181
+ break;
182
+ }
183
+ }
184
+
185
+ RB_GC_GUARD(val);
186
+
187
+ return res;
188
+ }
189
+
190
+ static enum CBLAS_ORDER get_cblas_order(VALUE val) {
191
+ const char* option_str = StringValueCStr(val);
192
+
193
+ if (std::strlen(option_str) > 0) {
194
+ switch (option_str[0]) {
195
+ case 'r':
196
+ case 'R':
197
+ break;
198
+ case 'c':
199
+ case 'C':
200
+ rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
201
+ break;
202
+ }
203
+ }
204
+
205
+ RB_GC_GUARD(val);
206
+
207
+ return CblasRowMajor;
208
+ }
209
+ };
210
+
211
+ } // namespace TinyLinalg
@@ -0,0 +1,134 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGESDD {
4
+ lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt) {
6
+ return LAPACKE_dgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
7
+ };
8
+ };
9
+
10
+ struct SGESDD {
11
+ lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
12
+ float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt) {
13
+ return LAPACKE_sgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
14
+ };
15
+ };
16
+
17
+ struct ZGESDD {
18
+ lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
19
+ lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt) {
20
+ return LAPACKE_zgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
21
+ };
22
+ };
23
+
24
+ struct CGESDD {
25
+ lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
26
+ lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt) {
27
+ return LAPACKE_cgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
28
+ };
29
+ };
30
+
31
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
32
+ class GESDD {
33
+ public:
34
+ static void define_module_function(VALUE mLapack, const char* mf_name) {
35
+ rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesdd), -1);
36
+ };
37
+
38
+ private:
39
+ struct gesdd_opt {
40
+ int matrix_order;
41
+ char jobz;
42
+ };
43
+
44
+ static void iter_gesdd(na_loop_t* const lp) {
45
+ DType* a = (DType*)NDL_PTR(lp, 0);
46
+ RType* s = (RType*)NDL_PTR(lp, 1);
47
+ DType* u = (DType*)NDL_PTR(lp, 2);
48
+ DType* vt = (DType*)NDL_PTR(lp, 3);
49
+ int* info = (int*)NDL_PTR(lp, 4);
50
+ gesdd_opt* opt = (gesdd_opt*)(lp->opt_ptr);
51
+
52
+ const size_t m = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[0] : NDL_SHAPE(lp, 0)[1];
53
+ const size_t n = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[1] : NDL_SHAPE(lp, 0)[0];
54
+ const size_t min_mn = m < n ? m : n;
55
+ const lapack_int lda = n;
56
+ const lapack_int ldu = opt->jobz == 'S' ? min_mn : m;
57
+ const lapack_int ldvt = opt->jobz == 'S' ? min_mn : n;
58
+
59
+ lapack_int i = FncType().call(opt->matrix_order, opt->jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
60
+ *info = static_cast<int>(i);
61
+ };
62
+
63
+ static VALUE tiny_linalg_gesdd(int argc, VALUE* argv, VALUE self) {
64
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
65
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
66
+ VALUE a_vnary = Qnil;
67
+ VALUE kw_args = Qnil;
68
+
69
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
70
+
71
+ ID kw_table[2] = { rb_intern("jobz"), rb_intern("order") };
72
+ VALUE kw_values[2] = { Qundef, Qundef };
73
+
74
+ rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
75
+
76
+ const char jobz = kw_values[0] == Qundef ? 'A' : StringValueCStr(kw_values[0])[0];
77
+ const char order = kw_values[1] == Qundef ? 'R' : StringValueCStr(kw_values[1])[0];
78
+
79
+ if (CLASS_OF(a_vnary) != nary_dtype) {
80
+ rb_raise(rb_eTypeError, "type of input array is invalid for overwriting");
81
+ return Qnil;
82
+ }
83
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
84
+ a_vnary = nary_dup(a_vnary);
85
+ }
86
+
87
+ narray_t* a_nary = NULL;
88
+ GetNArray(a_vnary, a_nary);
89
+ const int n_dims = NA_NDIM(a_nary);
90
+ if (n_dims != 2) {
91
+ rb_raise(rb_eArgError, "input array must be 2-dimensional");
92
+ return Qnil;
93
+ }
94
+
95
+ const int matrix_order = order == 'C' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
96
+ const size_t m = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[0] : NA_SHAPE(a_nary)[1];
97
+ const size_t n = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[1] : NA_SHAPE(a_nary)[0];
98
+
99
+ const size_t min_mn = m < n ? m : n;
100
+ size_t shape_s[1] = { min_mn };
101
+ size_t shape_u[2] = { m, m };
102
+ size_t shape_vt[2] = { n, n };
103
+
104
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
105
+ ndfunc_arg_out_t aout[4] = { { nary_rtype, 1, shape_s }, { nary_dtype, 2, shape_u }, { nary_dtype, 2, shape_vt }, { numo_cInt32, 0 } };
106
+
107
+ switch (jobz) {
108
+ case 'A':
109
+ break;
110
+ case 'S':
111
+ shape_u[matrix_order == LAPACK_ROW_MAJOR ? 1 : 0] = min_mn;
112
+ shape_vt[matrix_order == LAPACK_ROW_MAJOR ? 0 : 1] = min_mn;
113
+ break;
114
+ case 'O':
115
+ break;
116
+ case 'N':
117
+ aout[1].dim = 0;
118
+ aout[2].dim = 0;
119
+ break;
120
+ default:
121
+ rb_raise(rb_eArgError, "jobz must be one of 'A', 'S', 'O', or 'N'");
122
+ return Qnil;
123
+ }
124
+
125
+ ndfunc_t ndf = { iter_gesdd, NO_LOOP | NDF_EXTRACT, 1, 4, ain, aout };
126
+ gesdd_opt opt = { matrix_order, jobz };
127
+ VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
128
+
129
+ RB_GC_GUARD(a_vnary);
130
+ return ret;
131
+ };
132
+ };
133
+
134
+ } // namespace TinyLinalg