numo-tiny_linalg 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,182 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGESVD {
4
+ lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt,
6
+ double* superb) {
7
+ return LAPACKE_dgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
8
+ };
9
+ };
10
+
11
+ struct SGESVD {
12
+ lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
13
+ float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt,
14
+ float* superb) {
15
+ return LAPACKE_sgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
16
+ };
17
+ };
18
+
19
+ struct ZGESVD {
20
+ lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
21
+ lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt,
22
+ double* superb) {
23
+ return LAPACKE_zgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
24
+ };
25
+ };
26
+
27
+ struct CGESVD {
28
+ lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
29
+ lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt,
30
+ float* superb) {
31
+ return LAPACKE_cgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
32
+ };
33
+ };
34
+
35
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
36
+ class GESVD {
37
+ public:
38
+ static void define_module_function(VALUE mLapack, const char* mf_name) {
39
+ rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesvd), -1);
40
+ };
41
+
42
+ private:
43
+ struct gesvd_opt {
44
+ int matrix_order;
45
+ char jobu;
46
+ char jobvt;
47
+ };
48
+
49
+ static void iter_gesvd(na_loop_t* const lp) {
50
+ DType* a = (DType*)NDL_PTR(lp, 0);
51
+ RType* s = (RType*)NDL_PTR(lp, 1);
52
+ DType* u = (DType*)NDL_PTR(lp, 2);
53
+ DType* vt = (DType*)NDL_PTR(lp, 3);
54
+ int* info = (int*)NDL_PTR(lp, 4);
55
+ gesvd_opt* opt = (gesvd_opt*)(lp->opt_ptr);
56
+
57
+ const size_t m = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[0] : NDL_SHAPE(lp, 0)[1];
58
+ const size_t n = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[1] : NDL_SHAPE(lp, 0)[0];
59
+ const size_t min_mn = m < n ? m : n;
60
+ const lapack_int lda = n;
61
+ const lapack_int ldu = opt->jobu == 'A' ? m : min_mn;
62
+ const lapack_int ldvt = n;
63
+
64
+ RType* superb = (RType*)ruby_xmalloc(min_mn * sizeof(RType));
65
+
66
+ lapack_int i = FncType().call(opt->matrix_order, opt->jobu, opt->jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
67
+ *info = static_cast<int>(i);
68
+
69
+ ruby_xfree(superb);
70
+ };
71
+
72
+ static VALUE tiny_linalg_gesvd(int argc, VALUE* argv, VALUE self) {
73
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
74
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
75
+ VALUE a_vnary = Qnil;
76
+ VALUE kw_args = Qnil;
77
+
78
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
79
+
80
+ ID kw_table[3] = { rb_intern("jobu"), rb_intern("jobvt"), rb_intern("order") };
81
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
82
+
83
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
84
+
85
+ const char jobu = kw_values[0] == Qundef ? 'A' : StringValueCStr(kw_values[0])[0];
86
+ const char jobvt = kw_values[1] == Qundef ? 'A' : StringValueCStr(kw_values[1])[0];
87
+ const char order = kw_values[2] == Qundef ? 'R' : StringValueCStr(kw_values[2])[0];
88
+
89
+ if (jobu == 'O' && jobvt == 'O') {
90
+ rb_raise(rb_eArgError, "jobu and jobvt cannot be both 'O'");
91
+ return Qnil;
92
+ }
93
+ if (CLASS_OF(a_vnary) != nary_dtype) {
94
+ rb_raise(rb_eTypeError, "type of input array is invalid for overwriting");
95
+ return Qnil;
96
+ }
97
+
98
+ if (CLASS_OF(a_vnary) != nary_dtype) {
99
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
100
+ }
101
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
102
+ a_vnary = nary_dup(a_vnary);
103
+ }
104
+
105
+ narray_t* a_nary = NULL;
106
+ GetNArray(a_vnary, a_nary);
107
+ const int n_dims = NA_NDIM(a_nary);
108
+ if (n_dims != 2) {
109
+ rb_raise(rb_eArgError, "input array must be 2-dimensional");
110
+ return Qnil;
111
+ }
112
+
113
+ const int matrix_order = order == 'C' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
114
+ const size_t m = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[0] : NA_SHAPE(a_nary)[1];
115
+ const size_t n = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[1] : NA_SHAPE(a_nary)[0];
116
+
117
+ const size_t min_mn = m < n ? m : n;
118
+ size_t shape_s[1] = { min_mn };
119
+ size_t shape_u[2] = { m, m };
120
+ size_t shape_vt[2] = { n, n };
121
+
122
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
123
+ ndfunc_arg_out_t aout[4] = { { nary_rtype, 1, shape_s }, { nary_dtype, 2, shape_u }, { nary_dtype, 2, shape_vt }, { numo_cInt32, 0 } };
124
+
125
+ switch (jobu) {
126
+ case 'A':
127
+ break;
128
+ case 'S':
129
+ shape_u[matrix_order == LAPACK_ROW_MAJOR ? 1 : 0] = min_mn;
130
+ break;
131
+ case 'O':
132
+ case 'N':
133
+ aout[1].dim = 0;
134
+ break;
135
+ default:
136
+ rb_raise(rb_eArgError, "jobu must be 'A', 'S', 'O', or 'N'");
137
+ return Qnil;
138
+ }
139
+
140
+ switch (jobvt) {
141
+ case 'A':
142
+ break;
143
+ case 'S':
144
+ shape_vt[matrix_order == LAPACK_ROW_MAJOR ? 0 : 1] = min_mn;
145
+ break;
146
+ case 'O':
147
+ case 'N':
148
+ aout[2].dim = 0;
149
+ break;
150
+ default:
151
+ rb_raise(rb_eArgError, "jobvt must be 'A', 'S', 'O', or 'N'");
152
+ return Qnil;
153
+ }
154
+
155
+ ndfunc_t ndf = { iter_gesvd, NO_LOOP | NDF_EXTRACT, 1, 4, ain, aout };
156
+ gesvd_opt opt = { matrix_order, jobu, jobvt };
157
+ VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
158
+
159
+ switch (jobu) {
160
+ case 'O':
161
+ rb_ary_store(ret, 1, a_vnary);
162
+ break;
163
+ case 'N':
164
+ rb_ary_store(ret, 1, Qnil);
165
+ break;
166
+ }
167
+
168
+ switch (jobvt) {
169
+ case 'O':
170
+ rb_ary_store(ret, 2, a_vnary);
171
+ break;
172
+ case 'N':
173
+ rb_ary_store(ret, 2, Qnil);
174
+ break;
175
+ }
176
+
177
+ RB_GC_GUARD(a_vnary);
178
+ return ret;
179
+ };
180
+ };
181
+
182
+ } // namespace TinyLinalg
@@ -0,0 +1,89 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DNrm2 {
4
+ double call(const int n, const double* x, const int incx) {
5
+ return cblas_dnrm2(n, x, incx);
6
+ }
7
+ };
8
+
9
+ struct SNrm2 {
10
+ float call(const int n, const float* x, const int incx) {
11
+ return cblas_snrm2(n, x, incx);
12
+ }
13
+ };
14
+
15
+ struct DZNrm2 {
16
+ double call(const int n, const void* x, const int incx) {
17
+ return cblas_dznrm2(n, x, incx);
18
+ }
19
+ };
20
+
21
+ struct SCNrm2 {
22
+ double call(const int n, const void* x, const int incx) {
23
+ return cblas_scnrm2(n, x, incx);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, typename dtype, class BlasFn>
28
+ class Nrm2 {
29
+ public:
30
+ static void define_module_function(VALUE mBlas, const char* modfn_name) {
31
+ rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_dot), -1);
32
+ }
33
+
34
+ private:
35
+ static void iter_nrm2(na_loop_t* const lp) {
36
+ dtype* x = (dtype*)NDL_PTR(lp, 0);
37
+ dtype* d = (dtype*)NDL_PTR(lp, 1);
38
+ const size_t n = NDL_SHAPE(lp, 0)[0];
39
+ dtype ret = BlasFn().call(n, x, 1);
40
+ *d = ret;
41
+ }
42
+
43
+ static VALUE tiny_linalg_dot(int argc, VALUE* argv, VALUE self) {
44
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
45
+
46
+ VALUE x = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "1:", &x, &kw_args);
49
+
50
+ ID kw_table[1] = { rb_intern("keepdims") };
51
+ VALUE kw_values[1] = { Qundef };
52
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
53
+ const bool keepdims = kw_values[0] != Qundef ? RTEST(kw_values[0]) : false;
54
+
55
+ if (CLASS_OF(x) != nary_dtype) {
56
+ x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
57
+ }
58
+ if (!RTEST(nary_check_contiguous(x))) {
59
+ x = nary_dup(x);
60
+ }
61
+
62
+ narray_t* x_nary = NULL;
63
+ GetNArray(x, x_nary);
64
+
65
+ if (NA_NDIM(x_nary) != 1) {
66
+ rb_raise(rb_eArgError, "x must be 1-dimensional");
67
+ return Qnil;
68
+ }
69
+ if (NA_SIZE(x_nary) == 0) {
70
+ rb_raise(rb_eArgError, "x must not be empty");
71
+ return Qnil;
72
+ }
73
+
74
+ ndfunc_arg_in_t ain[1] = { { nary_dtype, 1 } };
75
+ size_t shape_out[1] = { 1 };
76
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
77
+ ndfunc_t ndf = { iter_nrm2, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
78
+ if (keepdims) {
79
+ ndf.flag |= NDF_KEEP_DIM;
80
+ }
81
+
82
+ VALUE ret = na_ndloop(&ndf, 1, x);
83
+
84
+ RB_GC_GUARD(x);
85
+ return ret;
86
+ }
87
+ };
88
+
89
+ } // namespace TinyLinalg
@@ -0,0 +1,251 @@
1
+ #include "tiny_linalg.hpp"
2
+ #include "converter.hpp"
3
+ #include "dot.hpp"
4
+ #include "dot_sub.hpp"
5
+ #include "gemm.hpp"
6
+ #include "gemv.hpp"
7
+ #include "gesdd.hpp"
8
+ #include "gesvd.hpp"
9
+ #include "nrm2.hpp"
10
+
11
+ VALUE rb_mTinyLinalg;
12
+ VALUE rb_mTinyLinalgBlas;
13
+ VALUE rb_mTinyLinalgLapack;
14
+
15
+ char blas_char(VALUE nary_arr) {
16
+ char type = 'n';
17
+ const size_t n = RARRAY_LEN(nary_arr);
18
+ for (size_t i = 0; i < n; i++) {
19
+ VALUE arg = rb_ary_entry(nary_arr, i);
20
+ if (RB_TYPE_P(arg, T_ARRAY)) {
21
+ arg = rb_funcall(numo_cNArray, rb_intern("asarray"), 1, arg);
22
+ }
23
+ if (CLASS_OF(arg) == numo_cBit || CLASS_OF(arg) == numo_cInt64 || CLASS_OF(arg) == numo_cInt32 ||
24
+ CLASS_OF(arg) == numo_cInt16 || CLASS_OF(arg) == numo_cInt8 || CLASS_OF(arg) == numo_cUInt64 ||
25
+ CLASS_OF(arg) == numo_cUInt32 || CLASS_OF(arg) == numo_cUInt16 || CLASS_OF(arg) == numo_cUInt8) {
26
+ if (type == 'n') {
27
+ type = 'd';
28
+ }
29
+ } else if (CLASS_OF(arg) == numo_cDFloat) {
30
+ if (type == 'c' || type == 'z') {
31
+ type = 'z';
32
+ } else {
33
+ type = 'd';
34
+ }
35
+ } else if (CLASS_OF(arg) == numo_cSFloat) {
36
+ if (type == 'n') {
37
+ type = 's';
38
+ }
39
+ } else if (CLASS_OF(arg) == numo_cDComplex) {
40
+ type = 'z';
41
+ } else if (CLASS_OF(arg) == numo_cSComplex) {
42
+ if (type == 'n' || type == 's') {
43
+ type = 'c';
44
+ } else if (type == 'd') {
45
+ type = 'z';
46
+ }
47
+ }
48
+ }
49
+ return type;
50
+ }
51
+
52
+ static VALUE tiny_linalg_blas_char(int argc, VALUE* argv, VALUE self) {
53
+ VALUE nary_arr = Qnil;
54
+ rb_scan_args(argc, argv, "*", &nary_arr);
55
+
56
+ const char type = blas_char(nary_arr);
57
+ if (type == 'n') {
58
+ rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
59
+ return Qnil;
60
+ }
61
+
62
+ return rb_str_new(&type, 1);
63
+ }
64
+
65
+ static VALUE tiny_linalg_blas_call(int argc, VALUE* argv, VALUE self) {
66
+ VALUE fn_name = Qnil;
67
+ VALUE nary_arr = Qnil;
68
+ VALUE kw_args = Qnil;
69
+ rb_scan_args(argc, argv, "1*:", &fn_name, &nary_arr, &kw_args);
70
+
71
+ const char type = blas_char(nary_arr);
72
+ if (type == 'n') {
73
+ rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
74
+ return Qnil;
75
+ }
76
+
77
+ std::string type_str{ type };
78
+ std::string fn_str = type_str + std::string(rb_id2name(rb_to_id(rb_to_symbol(fn_name))));
79
+ ID fn_id = rb_intern(fn_str.c_str());
80
+ size_t n = RARRAY_LEN(nary_arr);
81
+ VALUE ret = Qnil;
82
+
83
+ if (NIL_P(kw_args)) {
84
+ VALUE* args = ALLOCA_N(VALUE, n);
85
+ for (size_t i = 0; i < n; i++) {
86
+ args[i] = rb_ary_entry(nary_arr, i);
87
+ }
88
+ ret = rb_funcallv(self, fn_id, n, args);
89
+ } else {
90
+ VALUE* args = ALLOCA_N(VALUE, n + 1);
91
+ for (size_t i = 0; i < n; i++) {
92
+ args[i] = rb_ary_entry(nary_arr, i);
93
+ }
94
+ args[n] = kw_args;
95
+ ret = rb_funcallv_kw(self, fn_id, n + 1, args, RB_PASS_KEYWORDS);
96
+ }
97
+
98
+ return ret;
99
+ }
100
+
101
+ static VALUE tiny_linalg_dot(VALUE self, VALUE a_, VALUE b_) {
102
+ VALUE a = IsNArray(a_) ? a_ : rb_funcall(numo_cNArray, rb_intern("asarray"), 1, a_);
103
+ VALUE b = IsNArray(b_) ? b_ : rb_funcall(numo_cNArray, rb_intern("asarray"), 1, b_);
104
+
105
+ VALUE arg_arr = rb_ary_new3(2, a, b);
106
+ const char type = blas_char(arg_arr);
107
+ if (type == 'n') {
108
+ rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
109
+ return Qnil;
110
+ }
111
+
112
+ VALUE ret = Qnil;
113
+ narray_t* a_nary = NULL;
114
+ narray_t* b_nary = NULL;
115
+ GetNArray(a, a_nary);
116
+ GetNArray(b, b_nary);
117
+ const int a_ndim = NA_NDIM(a_nary);
118
+ const int b_ndim = NA_NDIM(b_nary);
119
+
120
+ if (a_ndim == 1) {
121
+ if (b_ndim == 1) {
122
+ ID fn_id = type == 'c' || type == 'z' ? rb_intern("dotu") : rb_intern("dot");
123
+ ret = rb_funcall(rb_mTinyLinalgBlas, rb_intern("call"), 3, ID2SYM(fn_id), a, b);
124
+ } else {
125
+ VALUE kw_args = rb_hash_new();
126
+ if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
127
+ b = rb_funcall(b, rb_intern("transpose"), 0);
128
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
129
+ } else {
130
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
131
+ }
132
+ char fn_name[] = "xgemv";
133
+ fn_name[0] = type;
134
+ VALUE argv[3] = { b, a, kw_args };
135
+ ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
136
+ }
137
+ } else {
138
+ if (b_ndim == 1) {
139
+ VALUE kw_args = rb_hash_new();
140
+ if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
141
+ a = rb_funcall(a, rb_intern("transpose"), 0);
142
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
143
+ } else {
144
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
145
+ }
146
+ char fn_name[] = "xgemv";
147
+ fn_name[0] = type;
148
+ VALUE argv[3] = { a, b, kw_args };
149
+ ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
150
+ } else {
151
+ VALUE kw_args = rb_hash_new();
152
+ if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
153
+ a = rb_funcall(a, rb_intern("transpose"), 0);
154
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("T"));
155
+ } else {
156
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("N"));
157
+ }
158
+ if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
159
+ b = rb_funcall(b, rb_intern("transpose"), 0);
160
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("T"));
161
+ } else {
162
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("N"));
163
+ }
164
+ char fn_name[] = "xgemm";
165
+ fn_name[0] = type;
166
+ VALUE argv[3] = { a, b, kw_args };
167
+ ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
168
+ }
169
+ }
170
+
171
+ RB_GC_GUARD(a);
172
+ RB_GC_GUARD(b);
173
+
174
+ return ret;
175
+ }
176
+
177
+ extern "C" void Init_tiny_linalg(void) {
178
+ rb_require("numo/narray");
179
+
180
+ /**
181
+ * Document-module: Numo::TinyLinalg
182
+ * Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
183
+ */
184
+ rb_mTinyLinalg = rb_define_module_under(rb_mNumo, "TinyLinalg");
185
+ /**
186
+ * Document-module: Numo::TinyLinalg::Blas
187
+ * Numo::TinyLinalg::Blas is wrapper module of BLAS functions.
188
+ */
189
+ rb_mTinyLinalgBlas = rb_define_module_under(rb_mTinyLinalg, "Blas");
190
+ /**
191
+ * Document-module: Numo::TinyLinalg::Lapack
192
+ * Numo::TinyLinalg::Lapack is wrapper module of LAPACK functions.
193
+ */
194
+ rb_mTinyLinalgLapack = rb_define_module_under(rb_mTinyLinalg, "Lapack");
195
+
196
+ /**
197
+ * Returns BLAS char ([sdcz]) defined by data-type of arguments.
198
+ *
199
+ * @overload blas_char(a, ...) -> String
200
+ * @param [Numo::NArray] a
201
+ * @return [String]
202
+ */
203
+ rb_define_module_function(rb_mTinyLinalg, "blas_char", RUBY_METHOD_FUNC(tiny_linalg_blas_char), -1);
204
+ /**
205
+ * Calculates dot product of two vectors / matrices.
206
+ *
207
+ * @overload dot(a, b) -> [Float|Complex|Numo::NArray]
208
+ * @param [Numo::NArray] a
209
+ * @param [Numo::NArray] b
210
+ * @return [Float|Complex|Numo::NArray]
211
+ */
212
+ rb_define_module_function(rb_mTinyLinalg, "dot", RUBY_METHOD_FUNC(tiny_linalg_dot), 2);
213
+ /**
214
+ * Calls BLAS function prefixed with BLAS char.
215
+ *
216
+ * @overload call(func, *args)
217
+ * @param func [Symbol] BLAS function name without BLAS char.
218
+ * @param args arguments of BLAS function.
219
+ * @example
220
+ * Numo::TinyLinalg::Blas.call(:gemv, a, b)
221
+ */
222
+ rb_define_singleton_method(rb_mTinyLinalgBlas, "call", RUBY_METHOD_FUNC(tiny_linalg_blas_call), -1);
223
+
224
+ TinyLinalg::Dot<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DDot>::define_module_function(rb_mTinyLinalgBlas, "ddot");
225
+ TinyLinalg::Dot<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SDot>::define_module_function(rb_mTinyLinalgBlas, "sdot");
226
+ TinyLinalg::DotSub<TinyLinalg::numo_cDComplexId, double, TinyLinalg::ZDotuSub>::define_module_function(rb_mTinyLinalgBlas, "zdotu");
227
+ TinyLinalg::DotSub<TinyLinalg::numo_cSComplexId, float, TinyLinalg::CDotuSub>::define_module_function(rb_mTinyLinalgBlas, "cdotu");
228
+ TinyLinalg::Gemm<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGemm, TinyLinalg::DConverter>::define_module_function(rb_mTinyLinalgBlas, "dgemm");
229
+ TinyLinalg::Gemm<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGemm, TinyLinalg::SConverter>::define_module_function(rb_mTinyLinalgBlas, "sgemm");
230
+ TinyLinalg::Gemm<TinyLinalg::numo_cDComplexId, dcomplex, TinyLinalg::ZGemm, TinyLinalg::ZConverter>::define_module_function(rb_mTinyLinalgBlas, "zgemm");
231
+ TinyLinalg::Gemm<TinyLinalg::numo_cSComplexId, scomplex, TinyLinalg::CGemm, TinyLinalg::CConverter>::define_module_function(rb_mTinyLinalgBlas, "cgemm");
232
+ TinyLinalg::Gemv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGemv, TinyLinalg::DConverter>::define_module_function(rb_mTinyLinalgBlas, "dgemv");
233
+ TinyLinalg::Gemv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGemv, TinyLinalg::SConverter>::define_module_function(rb_mTinyLinalgBlas, "sgemv");
234
+ TinyLinalg::Gemv<TinyLinalg::numo_cDComplexId, dcomplex, TinyLinalg::ZGemv, TinyLinalg::ZConverter>::define_module_function(rb_mTinyLinalgBlas, "zgemv");
235
+ TinyLinalg::Gemv<TinyLinalg::numo_cSComplexId, scomplex, TinyLinalg::CGemv, TinyLinalg::CConverter>::define_module_function(rb_mTinyLinalgBlas, "cgemv");
236
+ TinyLinalg::Nrm2<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DNrm2>::define_module_function(rb_mTinyLinalgBlas, "dnrm2");
237
+ TinyLinalg::Nrm2<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SNrm2>::define_module_function(rb_mTinyLinalgBlas, "snrm2");
238
+ TinyLinalg::Nrm2<TinyLinalg::numo_cDComplexId, double, TinyLinalg::DZNrm2>::define_module_function(rb_mTinyLinalgBlas, "dznrm2");
239
+ TinyLinalg::Nrm2<TinyLinalg::numo_cSComplexId, float, TinyLinalg::SCNrm2>::define_module_function(rb_mTinyLinalgBlas, "scnrm2");
240
+ TinyLinalg::GESVD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESVD>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
241
+ TinyLinalg::GESVD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESVD>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
242
+ TinyLinalg::GESVD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESVD>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
243
+ TinyLinalg::GESVD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESVD>::define_module_function(rb_mTinyLinalgLapack, "cgesvd");
244
+ TinyLinalg::GESDD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESDD>::define_module_function(rb_mTinyLinalgLapack, "dgesdd");
245
+ TinyLinalg::GESDD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESDD>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
246
+ TinyLinalg::GESDD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESDD>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
247
+ TinyLinalg::GESDD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESDD>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
248
+
249
+ rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
250
+ rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
251
+ }
@@ -0,0 +1,38 @@
1
+ #ifndef NUMO_TINY_LINALG_H
2
+ #define NUMO_TINY_LINALG_H 1
3
+
4
+ #if defined(TINYLINALG_USE_ACCELERATE)
5
+ #include <Accelerate/Accelerate.h>
6
+ #else
7
+ #include <cblas.h>
8
+ #include <lapacke.h>
9
+ #endif
10
+
11
+ #include <string>
12
+
13
+ #include <cstring>
14
+
15
+ #include <ruby.h>
16
+
17
+ #include <numo/narray.h>
18
+ #include <numo/template.h>
19
+
20
+ namespace TinyLinalg {
21
+
22
+ const VALUE NaryTypes[4] = {
23
+ numo_cDFloat,
24
+ numo_cSFloat,
25
+ numo_cDComplex,
26
+ numo_cSComplex
27
+ };
28
+
29
+ enum NaryType {
30
+ numo_cDFloatId,
31
+ numo_cSFloatId,
32
+ numo_cDComplexId,
33
+ numo_cSComplexId
34
+ };
35
+
36
+ } // namespace TinyLinalg
37
+
38
+ #endif /* NUMO_TINY_LINALG_H */
@@ -0,0 +1,10 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Ruby/Numo (NUmerical MOdules)
4
+ module Numo
5
+ # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
+ module TinyLinalg
7
+ # The version of Numo::TinyLinalg you install.
8
+ VERSION = '0.0.1'
9
+ end
10
+ end
@@ -0,0 +1,60 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+ require_relative 'tiny_linalg/version'
5
+ require_relative 'tiny_linalg/tiny_linalg'
6
+
7
+ # Ruby/Numo (NUmerical MOdules)
8
+ module Numo
9
+ # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
10
+ module TinyLinalg
11
+ module_function
12
+
13
+ # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
14
+ #
15
+ # @param a [Numo::NArray] Matrix to be decomposed.
16
+ # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
17
+ # @param job [String] Job option ('A', 'S', or 'N').
18
+ # @return [Array<Numo::NArray>] Singular values and singular vectors ([s, u, vt]).
19
+ def svd(a, driver: 'svd', job: 'A')
20
+ raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
21
+
22
+ case driver.to_s
23
+ when 'sdd'
24
+ s, u, vt, info = case a
25
+ when Numo::DFloat
26
+ Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
27
+ when Numo::SFloat
28
+ Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
29
+ when Numo::DComplex
30
+ Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
31
+ when Numo::SComplex
32
+ Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
33
+ else
34
+ raise ArgumentError, "invalid array type: #{a.class}"
35
+ end
36
+ when 'svd'
37
+ s, u, vt, info = case a
38
+ when Numo::DFloat
39
+ Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
40
+ when Numo::SFloat
41
+ Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
42
+ when Numo::DComplex
43
+ Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
44
+ when Numo::SComplex
45
+ Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
46
+ else
47
+ raise ArgumentError, "invalid array type: #{a.class}"
48
+ end
49
+ else
50
+ raise ArgumentError, "invalid driver: #{driver}"
51
+ end
52
+
53
+ raise "the #{info.abs}-th argument had illegal value" if info.negative?
54
+ raise 'input array has a NAN entry' if info == -4
55
+ raise 'svd did not converge' if info.positive?
56
+
57
+ [s, u, vt]
58
+ end
59
+ end
60
+ end
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'lib/numo/tiny_linalg/version'
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = 'numo-tiny_linalg'
7
+ spec.version = Numo::TinyLinalg::VERSION
8
+ spec.authors = ['yoshoku']
9
+ spec.email = ['yoshoku@outlook.com']
10
+
11
+ spec.summary = <<~MSG
12
+ Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
13
+ MSG
14
+ spec.description = <<~MSG
15
+ Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
16
+ MSG
17
+ spec.homepage = 'https://github.com/yoshoku/numo-tiny_linalg'
18
+ spec.license = 'BSD-3-Clause'
19
+
20
+ spec.metadata['homepage_uri'] = spec.homepage
21
+ spec.metadata['source_code_uri'] = spec.homepage
22
+ spec.metadata['changelog_uri'] = "#{spec.homepage}/blob/main/CHANGELOG.md"
23
+
24
+ # Specify which files should be added to the gem when it is released.
25
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
26
+ spec.files = Dir.chdir(__dir__) do
27
+ `git ls-files -z`.split("\x0").reject do |f|
28
+ (f == __FILE__) || f.match(%r{\A(?:(?:bin|test|spec|features)/|\.(?:git|circleci)|appveyor)})
29
+ end
30
+ end
31
+ spec.bindir = 'exe'
32
+ spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
33
+ spec.require_paths = ['lib']
34
+ spec.extensions = ['ext/numo/tiny_linalg/extconf.rb']
35
+
36
+ # Uncomment to register a new dependency of your gem
37
+ spec.add_dependency 'numo-narray', '>= 0.9.1'
38
+
39
+ # For more information and examples about making a new gem, check out our
40
+ # guide at: https://bundler.io/guides/creating_gem.html
41
+ spec.metadata['rubygems_mfa_required'] = 'true'
42
+ end
data/package.json ADDED
@@ -0,0 +1,15 @@
1
+ {
2
+ "name": "numo-tiny_linalg",
3
+ "repository": "git@github.com:yoshoku/numo-tiny_linalg.git",
4
+ "author": "Atsushi Tatsuma <yoshoku@outlook.com>",
5
+ "license": "BSD-3-Clause",
6
+ "private": true,
7
+ "scripts": {
8
+ "prepare": "husky install"
9
+ },
10
+ "devDependencies": {
11
+ "@commitlint/cli": "^17.6.1",
12
+ "@commitlint/config-conventional": "^17.6.1",
13
+ "husky": "^8.0.3"
14
+ }
15
+ }