numo-tiny_linalg 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,182 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DGESVD {
4
+ lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
5
+ double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt,
6
+ double* superb) {
7
+ return LAPACKE_dgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
8
+ };
9
+ };
10
+
11
+ struct SGESVD {
12
+ lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
13
+ float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt,
14
+ float* superb) {
15
+ return LAPACKE_sgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
16
+ };
17
+ };
18
+
19
+ struct ZGESVD {
20
+ lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
21
+ lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt,
22
+ double* superb) {
23
+ return LAPACKE_zgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
24
+ };
25
+ };
26
+
27
+ struct CGESVD {
28
+ lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
29
+ lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt,
30
+ float* superb) {
31
+ return LAPACKE_cgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
32
+ };
33
+ };
34
+
35
+ template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
36
+ class GESVD {
37
+ public:
38
+ static void define_module_function(VALUE mLapack, const char* mf_name) {
39
+ rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesvd), -1);
40
+ };
41
+
42
+ private:
43
+ struct gesvd_opt {
44
+ int matrix_order;
45
+ char jobu;
46
+ char jobvt;
47
+ };
48
+
49
+ static void iter_gesvd(na_loop_t* const lp) {
50
+ DType* a = (DType*)NDL_PTR(lp, 0);
51
+ RType* s = (RType*)NDL_PTR(lp, 1);
52
+ DType* u = (DType*)NDL_PTR(lp, 2);
53
+ DType* vt = (DType*)NDL_PTR(lp, 3);
54
+ int* info = (int*)NDL_PTR(lp, 4);
55
+ gesvd_opt* opt = (gesvd_opt*)(lp->opt_ptr);
56
+
57
+ const size_t m = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[0] : NDL_SHAPE(lp, 0)[1];
58
+ const size_t n = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[1] : NDL_SHAPE(lp, 0)[0];
59
+ const size_t min_mn = m < n ? m : n;
60
+ const lapack_int lda = n;
61
+ const lapack_int ldu = opt->jobu == 'A' ? m : min_mn;
62
+ const lapack_int ldvt = n;
63
+
64
+ RType* superb = (RType*)ruby_xmalloc(min_mn * sizeof(RType));
65
+
66
+ lapack_int i = FncType().call(opt->matrix_order, opt->jobu, opt->jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
67
+ *info = static_cast<int>(i);
68
+
69
+ ruby_xfree(superb);
70
+ };
71
+
72
+ static VALUE tiny_linalg_gesvd(int argc, VALUE* argv, VALUE self) {
73
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
74
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
75
+ VALUE a_vnary = Qnil;
76
+ VALUE kw_args = Qnil;
77
+
78
+ rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
79
+
80
+ ID kw_table[3] = { rb_intern("jobu"), rb_intern("jobvt"), rb_intern("order") };
81
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
82
+
83
+ rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
84
+
85
+ const char jobu = kw_values[0] == Qundef ? 'A' : StringValueCStr(kw_values[0])[0];
86
+ const char jobvt = kw_values[1] == Qundef ? 'A' : StringValueCStr(kw_values[1])[0];
87
+ const char order = kw_values[2] == Qundef ? 'R' : StringValueCStr(kw_values[2])[0];
88
+
89
+ if (jobu == 'O' && jobvt == 'O') {
90
+ rb_raise(rb_eArgError, "jobu and jobvt cannot be both 'O'");
91
+ return Qnil;
92
+ }
93
+ if (CLASS_OF(a_vnary) != nary_dtype) {
94
+ rb_raise(rb_eTypeError, "type of input array is invalid for overwriting");
95
+ return Qnil;
96
+ }
97
+
98
+ if (CLASS_OF(a_vnary) != nary_dtype) {
99
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
100
+ }
101
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
102
+ a_vnary = nary_dup(a_vnary);
103
+ }
104
+
105
+ narray_t* a_nary = NULL;
106
+ GetNArray(a_vnary, a_nary);
107
+ const int n_dims = NA_NDIM(a_nary);
108
+ if (n_dims != 2) {
109
+ rb_raise(rb_eArgError, "input array must be 2-dimensional");
110
+ return Qnil;
111
+ }
112
+
113
+ const int matrix_order = order == 'C' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
114
+ const size_t m = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[0] : NA_SHAPE(a_nary)[1];
115
+ const size_t n = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[1] : NA_SHAPE(a_nary)[0];
116
+
117
+ const size_t min_mn = m < n ? m : n;
118
+ size_t shape_s[1] = { min_mn };
119
+ size_t shape_u[2] = { m, m };
120
+ size_t shape_vt[2] = { n, n };
121
+
122
+ ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
123
+ ndfunc_arg_out_t aout[4] = { { nary_rtype, 1, shape_s }, { nary_dtype, 2, shape_u }, { nary_dtype, 2, shape_vt }, { numo_cInt32, 0 } };
124
+
125
+ switch (jobu) {
126
+ case 'A':
127
+ break;
128
+ case 'S':
129
+ shape_u[matrix_order == LAPACK_ROW_MAJOR ? 1 : 0] = min_mn;
130
+ break;
131
+ case 'O':
132
+ case 'N':
133
+ aout[1].dim = 0;
134
+ break;
135
+ default:
136
+ rb_raise(rb_eArgError, "jobu must be 'A', 'S', 'O', or 'N'");
137
+ return Qnil;
138
+ }
139
+
140
+ switch (jobvt) {
141
+ case 'A':
142
+ break;
143
+ case 'S':
144
+ shape_vt[matrix_order == LAPACK_ROW_MAJOR ? 0 : 1] = min_mn;
145
+ break;
146
+ case 'O':
147
+ case 'N':
148
+ aout[2].dim = 0;
149
+ break;
150
+ default:
151
+ rb_raise(rb_eArgError, "jobvt must be 'A', 'S', 'O', or 'N'");
152
+ return Qnil;
153
+ }
154
+
155
+ ndfunc_t ndf = { iter_gesvd, NO_LOOP | NDF_EXTRACT, 1, 4, ain, aout };
156
+ gesvd_opt opt = { matrix_order, jobu, jobvt };
157
+ VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
158
+
159
+ switch (jobu) {
160
+ case 'O':
161
+ rb_ary_store(ret, 1, a_vnary);
162
+ break;
163
+ case 'N':
164
+ rb_ary_store(ret, 1, Qnil);
165
+ break;
166
+ }
167
+
168
+ switch (jobvt) {
169
+ case 'O':
170
+ rb_ary_store(ret, 2, a_vnary);
171
+ break;
172
+ case 'N':
173
+ rb_ary_store(ret, 2, Qnil);
174
+ break;
175
+ }
176
+
177
+ RB_GC_GUARD(a_vnary);
178
+ return ret;
179
+ };
180
+ };
181
+
182
+ } // namespace TinyLinalg
@@ -0,0 +1,89 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DNrm2 {
4
+ double call(const int n, const double* x, const int incx) {
5
+ return cblas_dnrm2(n, x, incx);
6
+ }
7
+ };
8
+
9
+ struct SNrm2 {
10
+ float call(const int n, const float* x, const int incx) {
11
+ return cblas_snrm2(n, x, incx);
12
+ }
13
+ };
14
+
15
+ struct DZNrm2 {
16
+ double call(const int n, const void* x, const int incx) {
17
+ return cblas_dznrm2(n, x, incx);
18
+ }
19
+ };
20
+
21
+ struct SCNrm2 {
22
+ double call(const int n, const void* x, const int incx) {
23
+ return cblas_scnrm2(n, x, incx);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, typename dtype, class BlasFn>
28
+ class Nrm2 {
29
+ public:
30
+ static void define_module_function(VALUE mBlas, const char* modfn_name) {
31
+ rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_dot), -1);
32
+ }
33
+
34
+ private:
35
+ static void iter_nrm2(na_loop_t* const lp) {
36
+ dtype* x = (dtype*)NDL_PTR(lp, 0);
37
+ dtype* d = (dtype*)NDL_PTR(lp, 1);
38
+ const size_t n = NDL_SHAPE(lp, 0)[0];
39
+ dtype ret = BlasFn().call(n, x, 1);
40
+ *d = ret;
41
+ }
42
+
43
+ static VALUE tiny_linalg_dot(int argc, VALUE* argv, VALUE self) {
44
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
45
+
46
+ VALUE x = Qnil;
47
+ VALUE kw_args = Qnil;
48
+ rb_scan_args(argc, argv, "1:", &x, &kw_args);
49
+
50
+ ID kw_table[1] = { rb_intern("keepdims") };
51
+ VALUE kw_values[1] = { Qundef };
52
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
53
+ const bool keepdims = kw_values[0] != Qundef ? RTEST(kw_values[0]) : false;
54
+
55
+ if (CLASS_OF(x) != nary_dtype) {
56
+ x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
57
+ }
58
+ if (!RTEST(nary_check_contiguous(x))) {
59
+ x = nary_dup(x);
60
+ }
61
+
62
+ narray_t* x_nary = NULL;
63
+ GetNArray(x, x_nary);
64
+
65
+ if (NA_NDIM(x_nary) != 1) {
66
+ rb_raise(rb_eArgError, "x must be 1-dimensional");
67
+ return Qnil;
68
+ }
69
+ if (NA_SIZE(x_nary) == 0) {
70
+ rb_raise(rb_eArgError, "x must not be empty");
71
+ return Qnil;
72
+ }
73
+
74
+ ndfunc_arg_in_t ain[1] = { { nary_dtype, 1 } };
75
+ size_t shape_out[1] = { 1 };
76
+ ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
77
+ ndfunc_t ndf = { iter_nrm2, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
78
+ if (keepdims) {
79
+ ndf.flag |= NDF_KEEP_DIM;
80
+ }
81
+
82
+ VALUE ret = na_ndloop(&ndf, 1, x);
83
+
84
+ RB_GC_GUARD(x);
85
+ return ret;
86
+ }
87
+ };
88
+
89
+ } // namespace TinyLinalg
@@ -0,0 +1,251 @@
1
+ #include "tiny_linalg.hpp"
2
+ #include "converter.hpp"
3
+ #include "dot.hpp"
4
+ #include "dot_sub.hpp"
5
+ #include "gemm.hpp"
6
+ #include "gemv.hpp"
7
+ #include "gesdd.hpp"
8
+ #include "gesvd.hpp"
9
+ #include "nrm2.hpp"
10
+
11
+ VALUE rb_mTinyLinalg;
12
+ VALUE rb_mTinyLinalgBlas;
13
+ VALUE rb_mTinyLinalgLapack;
14
+
15
+ char blas_char(VALUE nary_arr) {
16
+ char type = 'n';
17
+ const size_t n = RARRAY_LEN(nary_arr);
18
+ for (size_t i = 0; i < n; i++) {
19
+ VALUE arg = rb_ary_entry(nary_arr, i);
20
+ if (RB_TYPE_P(arg, T_ARRAY)) {
21
+ arg = rb_funcall(numo_cNArray, rb_intern("asarray"), 1, arg);
22
+ }
23
+ if (CLASS_OF(arg) == numo_cBit || CLASS_OF(arg) == numo_cInt64 || CLASS_OF(arg) == numo_cInt32 ||
24
+ CLASS_OF(arg) == numo_cInt16 || CLASS_OF(arg) == numo_cInt8 || CLASS_OF(arg) == numo_cUInt64 ||
25
+ CLASS_OF(arg) == numo_cUInt32 || CLASS_OF(arg) == numo_cUInt16 || CLASS_OF(arg) == numo_cUInt8) {
26
+ if (type == 'n') {
27
+ type = 'd';
28
+ }
29
+ } else if (CLASS_OF(arg) == numo_cDFloat) {
30
+ if (type == 'c' || type == 'z') {
31
+ type = 'z';
32
+ } else {
33
+ type = 'd';
34
+ }
35
+ } else if (CLASS_OF(arg) == numo_cSFloat) {
36
+ if (type == 'n') {
37
+ type = 's';
38
+ }
39
+ } else if (CLASS_OF(arg) == numo_cDComplex) {
40
+ type = 'z';
41
+ } else if (CLASS_OF(arg) == numo_cSComplex) {
42
+ if (type == 'n' || type == 's') {
43
+ type = 'c';
44
+ } else if (type == 'd') {
45
+ type = 'z';
46
+ }
47
+ }
48
+ }
49
+ return type;
50
+ }
51
+
52
+ static VALUE tiny_linalg_blas_char(int argc, VALUE* argv, VALUE self) {
53
+ VALUE nary_arr = Qnil;
54
+ rb_scan_args(argc, argv, "*", &nary_arr);
55
+
56
+ const char type = blas_char(nary_arr);
57
+ if (type == 'n') {
58
+ rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
59
+ return Qnil;
60
+ }
61
+
62
+ return rb_str_new(&type, 1);
63
+ }
64
+
65
+ static VALUE tiny_linalg_blas_call(int argc, VALUE* argv, VALUE self) {
66
+ VALUE fn_name = Qnil;
67
+ VALUE nary_arr = Qnil;
68
+ VALUE kw_args = Qnil;
69
+ rb_scan_args(argc, argv, "1*:", &fn_name, &nary_arr, &kw_args);
70
+
71
+ const char type = blas_char(nary_arr);
72
+ if (type == 'n') {
73
+ rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
74
+ return Qnil;
75
+ }
76
+
77
+ std::string type_str{ type };
78
+ std::string fn_str = type_str + std::string(rb_id2name(rb_to_id(rb_to_symbol(fn_name))));
79
+ ID fn_id = rb_intern(fn_str.c_str());
80
+ size_t n = RARRAY_LEN(nary_arr);
81
+ VALUE ret = Qnil;
82
+
83
+ if (NIL_P(kw_args)) {
84
+ VALUE* args = ALLOCA_N(VALUE, n);
85
+ for (size_t i = 0; i < n; i++) {
86
+ args[i] = rb_ary_entry(nary_arr, i);
87
+ }
88
+ ret = rb_funcallv(self, fn_id, n, args);
89
+ } else {
90
+ VALUE* args = ALLOCA_N(VALUE, n + 1);
91
+ for (size_t i = 0; i < n; i++) {
92
+ args[i] = rb_ary_entry(nary_arr, i);
93
+ }
94
+ args[n] = kw_args;
95
+ ret = rb_funcallv_kw(self, fn_id, n + 1, args, RB_PASS_KEYWORDS);
96
+ }
97
+
98
+ return ret;
99
+ }
100
+
101
+ static VALUE tiny_linalg_dot(VALUE self, VALUE a_, VALUE b_) {
102
+ VALUE a = IsNArray(a_) ? a_ : rb_funcall(numo_cNArray, rb_intern("asarray"), 1, a_);
103
+ VALUE b = IsNArray(b_) ? b_ : rb_funcall(numo_cNArray, rb_intern("asarray"), 1, b_);
104
+
105
+ VALUE arg_arr = rb_ary_new3(2, a, b);
106
+ const char type = blas_char(arg_arr);
107
+ if (type == 'n') {
108
+ rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
109
+ return Qnil;
110
+ }
111
+
112
+ VALUE ret = Qnil;
113
+ narray_t* a_nary = NULL;
114
+ narray_t* b_nary = NULL;
115
+ GetNArray(a, a_nary);
116
+ GetNArray(b, b_nary);
117
+ const int a_ndim = NA_NDIM(a_nary);
118
+ const int b_ndim = NA_NDIM(b_nary);
119
+
120
+ if (a_ndim == 1) {
121
+ if (b_ndim == 1) {
122
+ ID fn_id = type == 'c' || type == 'z' ? rb_intern("dotu") : rb_intern("dot");
123
+ ret = rb_funcall(rb_mTinyLinalgBlas, rb_intern("call"), 3, ID2SYM(fn_id), a, b);
124
+ } else {
125
+ VALUE kw_args = rb_hash_new();
126
+ if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
127
+ b = rb_funcall(b, rb_intern("transpose"), 0);
128
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
129
+ } else {
130
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
131
+ }
132
+ char fn_name[] = "xgemv";
133
+ fn_name[0] = type;
134
+ VALUE argv[3] = { b, a, kw_args };
135
+ ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
136
+ }
137
+ } else {
138
+ if (b_ndim == 1) {
139
+ VALUE kw_args = rb_hash_new();
140
+ if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
141
+ a = rb_funcall(a, rb_intern("transpose"), 0);
142
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
143
+ } else {
144
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
145
+ }
146
+ char fn_name[] = "xgemv";
147
+ fn_name[0] = type;
148
+ VALUE argv[3] = { a, b, kw_args };
149
+ ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
150
+ } else {
151
+ VALUE kw_args = rb_hash_new();
152
+ if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
153
+ a = rb_funcall(a, rb_intern("transpose"), 0);
154
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("T"));
155
+ } else {
156
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("N"));
157
+ }
158
+ if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
159
+ b = rb_funcall(b, rb_intern("transpose"), 0);
160
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("T"));
161
+ } else {
162
+ rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("N"));
163
+ }
164
+ char fn_name[] = "xgemm";
165
+ fn_name[0] = type;
166
+ VALUE argv[3] = { a, b, kw_args };
167
+ ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
168
+ }
169
+ }
170
+
171
+ RB_GC_GUARD(a);
172
+ RB_GC_GUARD(b);
173
+
174
+ return ret;
175
+ }
176
+
177
+ extern "C" void Init_tiny_linalg(void) {
178
+ rb_require("numo/narray");
179
+
180
+ /**
181
+ * Document-module: Numo::TinyLinalg
182
+ * Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
183
+ */
184
+ rb_mTinyLinalg = rb_define_module_under(rb_mNumo, "TinyLinalg");
185
+ /**
186
+ * Document-module: Numo::TinyLinalg::Blas
187
+ * Numo::TinyLinalg::Blas is wrapper module of BLAS functions.
188
+ */
189
+ rb_mTinyLinalgBlas = rb_define_module_under(rb_mTinyLinalg, "Blas");
190
+ /**
191
+ * Document-module: Numo::TinyLinalg::Lapack
192
+ * Numo::TinyLinalg::Lapack is wrapper module of LAPACK functions.
193
+ */
194
+ rb_mTinyLinalgLapack = rb_define_module_under(rb_mTinyLinalg, "Lapack");
195
+
196
+ /**
197
+ * Returns BLAS char ([sdcz]) defined by data-type of arguments.
198
+ *
199
+ * @overload blas_char(a, ...) -> String
200
+ * @param [Numo::NArray] a
201
+ * @return [String]
202
+ */
203
+ rb_define_module_function(rb_mTinyLinalg, "blas_char", RUBY_METHOD_FUNC(tiny_linalg_blas_char), -1);
204
+ /**
205
+ * Calculates dot product of two vectors / matrices.
206
+ *
207
+ * @overload dot(a, b) -> [Float|Complex|Numo::NArray]
208
+ * @param [Numo::NArray] a
209
+ * @param [Numo::NArray] b
210
+ * @return [Float|Complex|Numo::NArray]
211
+ */
212
+ rb_define_module_function(rb_mTinyLinalg, "dot", RUBY_METHOD_FUNC(tiny_linalg_dot), 2);
213
+ /**
214
+ * Calls BLAS function prefixed with BLAS char.
215
+ *
216
+ * @overload call(func, *args)
217
+ * @param func [Symbol] BLAS function name without BLAS char.
218
+ * @param args arguments of BLAS function.
219
+ * @example
220
+ * Numo::TinyLinalg::Blas.call(:gemv, a, b)
221
+ */
222
+ rb_define_singleton_method(rb_mTinyLinalgBlas, "call", RUBY_METHOD_FUNC(tiny_linalg_blas_call), -1);
223
+
224
+ TinyLinalg::Dot<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DDot>::define_module_function(rb_mTinyLinalgBlas, "ddot");
225
+ TinyLinalg::Dot<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SDot>::define_module_function(rb_mTinyLinalgBlas, "sdot");
226
+ TinyLinalg::DotSub<TinyLinalg::numo_cDComplexId, double, TinyLinalg::ZDotuSub>::define_module_function(rb_mTinyLinalgBlas, "zdotu");
227
+ TinyLinalg::DotSub<TinyLinalg::numo_cSComplexId, float, TinyLinalg::CDotuSub>::define_module_function(rb_mTinyLinalgBlas, "cdotu");
228
+ TinyLinalg::Gemm<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGemm, TinyLinalg::DConverter>::define_module_function(rb_mTinyLinalgBlas, "dgemm");
229
+ TinyLinalg::Gemm<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGemm, TinyLinalg::SConverter>::define_module_function(rb_mTinyLinalgBlas, "sgemm");
230
+ TinyLinalg::Gemm<TinyLinalg::numo_cDComplexId, dcomplex, TinyLinalg::ZGemm, TinyLinalg::ZConverter>::define_module_function(rb_mTinyLinalgBlas, "zgemm");
231
+ TinyLinalg::Gemm<TinyLinalg::numo_cSComplexId, scomplex, TinyLinalg::CGemm, TinyLinalg::CConverter>::define_module_function(rb_mTinyLinalgBlas, "cgemm");
232
+ TinyLinalg::Gemv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGemv, TinyLinalg::DConverter>::define_module_function(rb_mTinyLinalgBlas, "dgemv");
233
+ TinyLinalg::Gemv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGemv, TinyLinalg::SConverter>::define_module_function(rb_mTinyLinalgBlas, "sgemv");
234
+ TinyLinalg::Gemv<TinyLinalg::numo_cDComplexId, dcomplex, TinyLinalg::ZGemv, TinyLinalg::ZConverter>::define_module_function(rb_mTinyLinalgBlas, "zgemv");
235
+ TinyLinalg::Gemv<TinyLinalg::numo_cSComplexId, scomplex, TinyLinalg::CGemv, TinyLinalg::CConverter>::define_module_function(rb_mTinyLinalgBlas, "cgemv");
236
+ TinyLinalg::Nrm2<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DNrm2>::define_module_function(rb_mTinyLinalgBlas, "dnrm2");
237
+ TinyLinalg::Nrm2<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SNrm2>::define_module_function(rb_mTinyLinalgBlas, "snrm2");
238
+ TinyLinalg::Nrm2<TinyLinalg::numo_cDComplexId, double, TinyLinalg::DZNrm2>::define_module_function(rb_mTinyLinalgBlas, "dznrm2");
239
+ TinyLinalg::Nrm2<TinyLinalg::numo_cSComplexId, float, TinyLinalg::SCNrm2>::define_module_function(rb_mTinyLinalgBlas, "scnrm2");
240
+ TinyLinalg::GESVD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESVD>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
241
+ TinyLinalg::GESVD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESVD>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
242
+ TinyLinalg::GESVD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESVD>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
243
+ TinyLinalg::GESVD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESVD>::define_module_function(rb_mTinyLinalgLapack, "cgesvd");
244
+ TinyLinalg::GESDD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESDD>::define_module_function(rb_mTinyLinalgLapack, "dgesdd");
245
+ TinyLinalg::GESDD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESDD>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
246
+ TinyLinalg::GESDD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESDD>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
247
+ TinyLinalg::GESDD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESDD>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
248
+
249
+ rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
250
+ rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
251
+ }
@@ -0,0 +1,38 @@
1
+ #ifndef NUMO_TINY_LINALG_H
2
+ #define NUMO_TINY_LINALG_H 1
3
+
4
+ #if defined(TINYLINALG_USE_ACCELERATE)
5
+ #include <Accelerate/Accelerate.h>
6
+ #else
7
+ #include <cblas.h>
8
+ #include <lapacke.h>
9
+ #endif
10
+
11
+ #include <string>
12
+
13
+ #include <cstring>
14
+
15
+ #include <ruby.h>
16
+
17
+ #include <numo/narray.h>
18
+ #include <numo/template.h>
19
+
20
+ namespace TinyLinalg {
21
+
22
+ const VALUE NaryTypes[4] = {
23
+ numo_cDFloat,
24
+ numo_cSFloat,
25
+ numo_cDComplex,
26
+ numo_cSComplex
27
+ };
28
+
29
+ enum NaryType {
30
+ numo_cDFloatId,
31
+ numo_cSFloatId,
32
+ numo_cDComplexId,
33
+ numo_cSComplexId
34
+ };
35
+
36
+ } // namespace TinyLinalg
37
+
38
+ #endif /* NUMO_TINY_LINALG_H */
@@ -0,0 +1,10 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Ruby/Numo (NUmerical MOdules)
4
+ module Numo
5
+ # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
+ module TinyLinalg
7
+ # The version of Numo::TinyLinalg you install.
8
+ VERSION = '0.0.1'
9
+ end
10
+ end
@@ -0,0 +1,60 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+ require_relative 'tiny_linalg/version'
5
+ require_relative 'tiny_linalg/tiny_linalg'
6
+
7
+ # Ruby/Numo (NUmerical MOdules)
8
+ module Numo
9
+ # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
10
+ module TinyLinalg
11
+ module_function
12
+
13
+ # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
14
+ #
15
+ # @param a [Numo::NArray] Matrix to be decomposed.
16
+ # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
17
+ # @param job [String] Job option ('A', 'S', or 'N').
18
+ # @return [Array<Numo::NArray>] Singular values and singular vectors ([s, u, vt]).
19
+ def svd(a, driver: 'svd', job: 'A')
20
+ raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
21
+
22
+ case driver.to_s
23
+ when 'sdd'
24
+ s, u, vt, info = case a
25
+ when Numo::DFloat
26
+ Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
27
+ when Numo::SFloat
28
+ Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
29
+ when Numo::DComplex
30
+ Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
31
+ when Numo::SComplex
32
+ Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
33
+ else
34
+ raise ArgumentError, "invalid array type: #{a.class}"
35
+ end
36
+ when 'svd'
37
+ s, u, vt, info = case a
38
+ when Numo::DFloat
39
+ Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
40
+ when Numo::SFloat
41
+ Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
42
+ when Numo::DComplex
43
+ Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
44
+ when Numo::SComplex
45
+ Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
46
+ else
47
+ raise ArgumentError, "invalid array type: #{a.class}"
48
+ end
49
+ else
50
+ raise ArgumentError, "invalid driver: #{driver}"
51
+ end
52
+
53
+ raise "the #{info.abs}-th argument had illegal value" if info.negative?
54
+ raise 'input array has a NAN entry' if info == -4
55
+ raise 'svd did not converge' if info.positive?
56
+
57
+ [s, u, vt]
58
+ end
59
+ end
60
+ end
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'lib/numo/tiny_linalg/version'
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = 'numo-tiny_linalg'
7
+ spec.version = Numo::TinyLinalg::VERSION
8
+ spec.authors = ['yoshoku']
9
+ spec.email = ['yoshoku@outlook.com']
10
+
11
+ spec.summary = <<~MSG
12
+ Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
13
+ MSG
14
+ spec.description = <<~MSG
15
+ Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
16
+ MSG
17
+ spec.homepage = 'https://github.com/yoshoku/numo-tiny_linalg'
18
+ spec.license = 'BSD-3-Clause'
19
+
20
+ spec.metadata['homepage_uri'] = spec.homepage
21
+ spec.metadata['source_code_uri'] = spec.homepage
22
+ spec.metadata['changelog_uri'] = "#{spec.homepage}/blob/main/CHANGELOG.md"
23
+
24
+ # Specify which files should be added to the gem when it is released.
25
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
26
+ spec.files = Dir.chdir(__dir__) do
27
+ `git ls-files -z`.split("\x0").reject do |f|
28
+ (f == __FILE__) || f.match(%r{\A(?:(?:bin|test|spec|features)/|\.(?:git|circleci)|appveyor)})
29
+ end
30
+ end
31
+ spec.bindir = 'exe'
32
+ spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
33
+ spec.require_paths = ['lib']
34
+ spec.extensions = ['ext/numo/tiny_linalg/extconf.rb']
35
+
36
+ # Uncomment to register a new dependency of your gem
37
+ spec.add_dependency 'numo-narray', '>= 0.9.1'
38
+
39
+ # For more information and examples about making a new gem, check out our
40
+ # guide at: https://bundler.io/guides/creating_gem.html
41
+ spec.metadata['rubygems_mfa_required'] = 'true'
42
+ end
data/package.json ADDED
@@ -0,0 +1,15 @@
1
+ {
2
+ "name": "numo-tiny_linalg",
3
+ "repository": "git@github.com:yoshoku/numo-tiny_linalg.git",
4
+ "author": "Atsushi Tatsuma <yoshoku@outlook.com>",
5
+ "license": "BSD-3-Clause",
6
+ "private": true,
7
+ "scripts": {
8
+ "prepare": "husky install"
9
+ },
10
+ "devDependencies": {
11
+ "@commitlint/cli": "^17.6.1",
12
+ "@commitlint/config-conventional": "^17.6.1",
13
+ "husky": "^8.0.3"
14
+ }
15
+ }