numo-tiny_linalg 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.clang-format +149 -0
- data/.husky/commit-msg +4 -0
- data/.rubocop.yml +47 -0
- data/CHANGELOG.md +5 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/Gemfile +15 -0
- data/LICENSE.txt +27 -0
- data/README.md +99 -0
- data/Rakefile +30 -0
- data/commitlint.config.js +1 -0
- data/ext/numo/tiny_linalg/converter.hpp +75 -0
- data/ext/numo/tiny_linalg/dot.hpp +86 -0
- data/ext/numo/tiny_linalg/dot_sub.hpp +85 -0
- data/ext/numo/tiny_linalg/extconf.rb +54 -0
- data/ext/numo/tiny_linalg/gemm.hpp +223 -0
- data/ext/numo/tiny_linalg/gemv.hpp +211 -0
- data/ext/numo/tiny_linalg/gesdd.hpp +134 -0
- data/ext/numo/tiny_linalg/gesvd.hpp +182 -0
- data/ext/numo/tiny_linalg/nrm2.hpp +89 -0
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +251 -0
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +38 -0
- data/lib/numo/tiny_linalg/version.rb +10 -0
- data/lib/numo/tiny_linalg.rb +60 -0
- data/numo-tiny_linalg.gemspec +42 -0
- data/package.json +15 -0
- metadata +91 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DGESVD {
|
4
|
+
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
5
|
+
double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt,
|
6
|
+
double* superb) {
|
7
|
+
return LAPACKE_dgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
|
8
|
+
};
|
9
|
+
};
|
10
|
+
|
11
|
+
struct SGESVD {
|
12
|
+
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
13
|
+
float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt,
|
14
|
+
float* superb) {
|
15
|
+
return LAPACKE_sgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
|
16
|
+
};
|
17
|
+
};
|
18
|
+
|
19
|
+
struct ZGESVD {
|
20
|
+
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
21
|
+
lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt,
|
22
|
+
double* superb) {
|
23
|
+
return LAPACKE_zgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
|
24
|
+
};
|
25
|
+
};
|
26
|
+
|
27
|
+
struct CGESVD {
|
28
|
+
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
29
|
+
lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt,
|
30
|
+
float* superb) {
|
31
|
+
return LAPACKE_cgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
|
32
|
+
};
|
33
|
+
};
|
34
|
+
|
35
|
+
template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
|
36
|
+
class GESVD {
|
37
|
+
public:
|
38
|
+
static void define_module_function(VALUE mLapack, const char* mf_name) {
|
39
|
+
rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesvd), -1);
|
40
|
+
};
|
41
|
+
|
42
|
+
private:
|
43
|
+
struct gesvd_opt {
|
44
|
+
int matrix_order;
|
45
|
+
char jobu;
|
46
|
+
char jobvt;
|
47
|
+
};
|
48
|
+
|
49
|
+
static void iter_gesvd(na_loop_t* const lp) {
|
50
|
+
DType* a = (DType*)NDL_PTR(lp, 0);
|
51
|
+
RType* s = (RType*)NDL_PTR(lp, 1);
|
52
|
+
DType* u = (DType*)NDL_PTR(lp, 2);
|
53
|
+
DType* vt = (DType*)NDL_PTR(lp, 3);
|
54
|
+
int* info = (int*)NDL_PTR(lp, 4);
|
55
|
+
gesvd_opt* opt = (gesvd_opt*)(lp->opt_ptr);
|
56
|
+
|
57
|
+
const size_t m = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[0] : NDL_SHAPE(lp, 0)[1];
|
58
|
+
const size_t n = opt->matrix_order == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[1] : NDL_SHAPE(lp, 0)[0];
|
59
|
+
const size_t min_mn = m < n ? m : n;
|
60
|
+
const lapack_int lda = n;
|
61
|
+
const lapack_int ldu = opt->jobu == 'A' ? m : min_mn;
|
62
|
+
const lapack_int ldvt = n;
|
63
|
+
|
64
|
+
RType* superb = (RType*)ruby_xmalloc(min_mn * sizeof(RType));
|
65
|
+
|
66
|
+
lapack_int i = FncType().call(opt->matrix_order, opt->jobu, opt->jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
|
67
|
+
*info = static_cast<int>(i);
|
68
|
+
|
69
|
+
ruby_xfree(superb);
|
70
|
+
};
|
71
|
+
|
72
|
+
static VALUE tiny_linalg_gesvd(int argc, VALUE* argv, VALUE self) {
|
73
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
74
|
+
VALUE nary_rtype = NaryTypes[nary_rtype_id];
|
75
|
+
VALUE a_vnary = Qnil;
|
76
|
+
VALUE kw_args = Qnil;
|
77
|
+
|
78
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
|
79
|
+
|
80
|
+
ID kw_table[3] = { rb_intern("jobu"), rb_intern("jobvt"), rb_intern("order") };
|
81
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
82
|
+
|
83
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
|
84
|
+
|
85
|
+
const char jobu = kw_values[0] == Qundef ? 'A' : StringValueCStr(kw_values[0])[0];
|
86
|
+
const char jobvt = kw_values[1] == Qundef ? 'A' : StringValueCStr(kw_values[1])[0];
|
87
|
+
const char order = kw_values[2] == Qundef ? 'R' : StringValueCStr(kw_values[2])[0];
|
88
|
+
|
89
|
+
if (jobu == 'O' && jobvt == 'O') {
|
90
|
+
rb_raise(rb_eArgError, "jobu and jobvt cannot be both 'O'");
|
91
|
+
return Qnil;
|
92
|
+
}
|
93
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
94
|
+
rb_raise(rb_eTypeError, "type of input array is invalid for overwriting");
|
95
|
+
return Qnil;
|
96
|
+
}
|
97
|
+
|
98
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
99
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
100
|
+
}
|
101
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
102
|
+
a_vnary = nary_dup(a_vnary);
|
103
|
+
}
|
104
|
+
|
105
|
+
narray_t* a_nary = NULL;
|
106
|
+
GetNArray(a_vnary, a_nary);
|
107
|
+
const int n_dims = NA_NDIM(a_nary);
|
108
|
+
if (n_dims != 2) {
|
109
|
+
rb_raise(rb_eArgError, "input array must be 2-dimensional");
|
110
|
+
return Qnil;
|
111
|
+
}
|
112
|
+
|
113
|
+
const int matrix_order = order == 'C' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
|
114
|
+
const size_t m = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[0] : NA_SHAPE(a_nary)[1];
|
115
|
+
const size_t n = matrix_order == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[1] : NA_SHAPE(a_nary)[0];
|
116
|
+
|
117
|
+
const size_t min_mn = m < n ? m : n;
|
118
|
+
size_t shape_s[1] = { min_mn };
|
119
|
+
size_t shape_u[2] = { m, m };
|
120
|
+
size_t shape_vt[2] = { n, n };
|
121
|
+
|
122
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
|
123
|
+
ndfunc_arg_out_t aout[4] = { { nary_rtype, 1, shape_s }, { nary_dtype, 2, shape_u }, { nary_dtype, 2, shape_vt }, { numo_cInt32, 0 } };
|
124
|
+
|
125
|
+
switch (jobu) {
|
126
|
+
case 'A':
|
127
|
+
break;
|
128
|
+
case 'S':
|
129
|
+
shape_u[matrix_order == LAPACK_ROW_MAJOR ? 1 : 0] = min_mn;
|
130
|
+
break;
|
131
|
+
case 'O':
|
132
|
+
case 'N':
|
133
|
+
aout[1].dim = 0;
|
134
|
+
break;
|
135
|
+
default:
|
136
|
+
rb_raise(rb_eArgError, "jobu must be 'A', 'S', 'O', or 'N'");
|
137
|
+
return Qnil;
|
138
|
+
}
|
139
|
+
|
140
|
+
switch (jobvt) {
|
141
|
+
case 'A':
|
142
|
+
break;
|
143
|
+
case 'S':
|
144
|
+
shape_vt[matrix_order == LAPACK_ROW_MAJOR ? 0 : 1] = min_mn;
|
145
|
+
break;
|
146
|
+
case 'O':
|
147
|
+
case 'N':
|
148
|
+
aout[2].dim = 0;
|
149
|
+
break;
|
150
|
+
default:
|
151
|
+
rb_raise(rb_eArgError, "jobvt must be 'A', 'S', 'O', or 'N'");
|
152
|
+
return Qnil;
|
153
|
+
}
|
154
|
+
|
155
|
+
ndfunc_t ndf = { iter_gesvd, NO_LOOP | NDF_EXTRACT, 1, 4, ain, aout };
|
156
|
+
gesvd_opt opt = { matrix_order, jobu, jobvt };
|
157
|
+
VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary);
|
158
|
+
|
159
|
+
switch (jobu) {
|
160
|
+
case 'O':
|
161
|
+
rb_ary_store(ret, 1, a_vnary);
|
162
|
+
break;
|
163
|
+
case 'N':
|
164
|
+
rb_ary_store(ret, 1, Qnil);
|
165
|
+
break;
|
166
|
+
}
|
167
|
+
|
168
|
+
switch (jobvt) {
|
169
|
+
case 'O':
|
170
|
+
rb_ary_store(ret, 2, a_vnary);
|
171
|
+
break;
|
172
|
+
case 'N':
|
173
|
+
rb_ary_store(ret, 2, Qnil);
|
174
|
+
break;
|
175
|
+
}
|
176
|
+
|
177
|
+
RB_GC_GUARD(a_vnary);
|
178
|
+
return ret;
|
179
|
+
};
|
180
|
+
};
|
181
|
+
|
182
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,89 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DNrm2 {
|
4
|
+
double call(const int n, const double* x, const int incx) {
|
5
|
+
return cblas_dnrm2(n, x, incx);
|
6
|
+
}
|
7
|
+
};
|
8
|
+
|
9
|
+
struct SNrm2 {
|
10
|
+
float call(const int n, const float* x, const int incx) {
|
11
|
+
return cblas_snrm2(n, x, incx);
|
12
|
+
}
|
13
|
+
};
|
14
|
+
|
15
|
+
struct DZNrm2 {
|
16
|
+
double call(const int n, const void* x, const int incx) {
|
17
|
+
return cblas_dznrm2(n, x, incx);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
struct SCNrm2 {
|
22
|
+
double call(const int n, const void* x, const int incx) {
|
23
|
+
return cblas_scnrm2(n, x, incx);
|
24
|
+
}
|
25
|
+
};
|
26
|
+
|
27
|
+
template <int nary_dtype_id, typename dtype, class BlasFn>
|
28
|
+
class Nrm2 {
|
29
|
+
public:
|
30
|
+
static void define_module_function(VALUE mBlas, const char* modfn_name) {
|
31
|
+
rb_define_module_function(mBlas, modfn_name, RUBY_METHOD_FUNC(tiny_linalg_dot), -1);
|
32
|
+
}
|
33
|
+
|
34
|
+
private:
|
35
|
+
static void iter_nrm2(na_loop_t* const lp) {
|
36
|
+
dtype* x = (dtype*)NDL_PTR(lp, 0);
|
37
|
+
dtype* d = (dtype*)NDL_PTR(lp, 1);
|
38
|
+
const size_t n = NDL_SHAPE(lp, 0)[0];
|
39
|
+
dtype ret = BlasFn().call(n, x, 1);
|
40
|
+
*d = ret;
|
41
|
+
}
|
42
|
+
|
43
|
+
static VALUE tiny_linalg_dot(int argc, VALUE* argv, VALUE self) {
|
44
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
45
|
+
|
46
|
+
VALUE x = Qnil;
|
47
|
+
VALUE kw_args = Qnil;
|
48
|
+
rb_scan_args(argc, argv, "1:", &x, &kw_args);
|
49
|
+
|
50
|
+
ID kw_table[1] = { rb_intern("keepdims") };
|
51
|
+
VALUE kw_values[1] = { Qundef };
|
52
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
53
|
+
const bool keepdims = kw_values[0] != Qundef ? RTEST(kw_values[0]) : false;
|
54
|
+
|
55
|
+
if (CLASS_OF(x) != nary_dtype) {
|
56
|
+
x = rb_funcall(nary_dtype, rb_intern("cast"), 1, x);
|
57
|
+
}
|
58
|
+
if (!RTEST(nary_check_contiguous(x))) {
|
59
|
+
x = nary_dup(x);
|
60
|
+
}
|
61
|
+
|
62
|
+
narray_t* x_nary = NULL;
|
63
|
+
GetNArray(x, x_nary);
|
64
|
+
|
65
|
+
if (NA_NDIM(x_nary) != 1) {
|
66
|
+
rb_raise(rb_eArgError, "x must be 1-dimensional");
|
67
|
+
return Qnil;
|
68
|
+
}
|
69
|
+
if (NA_SIZE(x_nary) == 0) {
|
70
|
+
rb_raise(rb_eArgError, "x must not be empty");
|
71
|
+
return Qnil;
|
72
|
+
}
|
73
|
+
|
74
|
+
ndfunc_arg_in_t ain[1] = { { nary_dtype, 1 } };
|
75
|
+
size_t shape_out[1] = { 1 };
|
76
|
+
ndfunc_arg_out_t aout[1] = { { nary_dtype, 0, shape_out } };
|
77
|
+
ndfunc_t ndf = { iter_nrm2, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
78
|
+
if (keepdims) {
|
79
|
+
ndf.flag |= NDF_KEEP_DIM;
|
80
|
+
}
|
81
|
+
|
82
|
+
VALUE ret = na_ndloop(&ndf, 1, x);
|
83
|
+
|
84
|
+
RB_GC_GUARD(x);
|
85
|
+
return ret;
|
86
|
+
}
|
87
|
+
};
|
88
|
+
|
89
|
+
} // namespace TinyLinalg
|
@@ -0,0 +1,251 @@
|
|
1
|
+
#include "tiny_linalg.hpp"
|
2
|
+
#include "converter.hpp"
|
3
|
+
#include "dot.hpp"
|
4
|
+
#include "dot_sub.hpp"
|
5
|
+
#include "gemm.hpp"
|
6
|
+
#include "gemv.hpp"
|
7
|
+
#include "gesdd.hpp"
|
8
|
+
#include "gesvd.hpp"
|
9
|
+
#include "nrm2.hpp"
|
10
|
+
|
11
|
+
VALUE rb_mTinyLinalg;
|
12
|
+
VALUE rb_mTinyLinalgBlas;
|
13
|
+
VALUE rb_mTinyLinalgLapack;
|
14
|
+
|
15
|
+
char blas_char(VALUE nary_arr) {
|
16
|
+
char type = 'n';
|
17
|
+
const size_t n = RARRAY_LEN(nary_arr);
|
18
|
+
for (size_t i = 0; i < n; i++) {
|
19
|
+
VALUE arg = rb_ary_entry(nary_arr, i);
|
20
|
+
if (RB_TYPE_P(arg, T_ARRAY)) {
|
21
|
+
arg = rb_funcall(numo_cNArray, rb_intern("asarray"), 1, arg);
|
22
|
+
}
|
23
|
+
if (CLASS_OF(arg) == numo_cBit || CLASS_OF(arg) == numo_cInt64 || CLASS_OF(arg) == numo_cInt32 ||
|
24
|
+
CLASS_OF(arg) == numo_cInt16 || CLASS_OF(arg) == numo_cInt8 || CLASS_OF(arg) == numo_cUInt64 ||
|
25
|
+
CLASS_OF(arg) == numo_cUInt32 || CLASS_OF(arg) == numo_cUInt16 || CLASS_OF(arg) == numo_cUInt8) {
|
26
|
+
if (type == 'n') {
|
27
|
+
type = 'd';
|
28
|
+
}
|
29
|
+
} else if (CLASS_OF(arg) == numo_cDFloat) {
|
30
|
+
if (type == 'c' || type == 'z') {
|
31
|
+
type = 'z';
|
32
|
+
} else {
|
33
|
+
type = 'd';
|
34
|
+
}
|
35
|
+
} else if (CLASS_OF(arg) == numo_cSFloat) {
|
36
|
+
if (type == 'n') {
|
37
|
+
type = 's';
|
38
|
+
}
|
39
|
+
} else if (CLASS_OF(arg) == numo_cDComplex) {
|
40
|
+
type = 'z';
|
41
|
+
} else if (CLASS_OF(arg) == numo_cSComplex) {
|
42
|
+
if (type == 'n' || type == 's') {
|
43
|
+
type = 'c';
|
44
|
+
} else if (type == 'd') {
|
45
|
+
type = 'z';
|
46
|
+
}
|
47
|
+
}
|
48
|
+
}
|
49
|
+
return type;
|
50
|
+
}
|
51
|
+
|
52
|
+
static VALUE tiny_linalg_blas_char(int argc, VALUE* argv, VALUE self) {
|
53
|
+
VALUE nary_arr = Qnil;
|
54
|
+
rb_scan_args(argc, argv, "*", &nary_arr);
|
55
|
+
|
56
|
+
const char type = blas_char(nary_arr);
|
57
|
+
if (type == 'n') {
|
58
|
+
rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
|
59
|
+
return Qnil;
|
60
|
+
}
|
61
|
+
|
62
|
+
return rb_str_new(&type, 1);
|
63
|
+
}
|
64
|
+
|
65
|
+
static VALUE tiny_linalg_blas_call(int argc, VALUE* argv, VALUE self) {
|
66
|
+
VALUE fn_name = Qnil;
|
67
|
+
VALUE nary_arr = Qnil;
|
68
|
+
VALUE kw_args = Qnil;
|
69
|
+
rb_scan_args(argc, argv, "1*:", &fn_name, &nary_arr, &kw_args);
|
70
|
+
|
71
|
+
const char type = blas_char(nary_arr);
|
72
|
+
if (type == 'n') {
|
73
|
+
rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
|
74
|
+
return Qnil;
|
75
|
+
}
|
76
|
+
|
77
|
+
std::string type_str{ type };
|
78
|
+
std::string fn_str = type_str + std::string(rb_id2name(rb_to_id(rb_to_symbol(fn_name))));
|
79
|
+
ID fn_id = rb_intern(fn_str.c_str());
|
80
|
+
size_t n = RARRAY_LEN(nary_arr);
|
81
|
+
VALUE ret = Qnil;
|
82
|
+
|
83
|
+
if (NIL_P(kw_args)) {
|
84
|
+
VALUE* args = ALLOCA_N(VALUE, n);
|
85
|
+
for (size_t i = 0; i < n; i++) {
|
86
|
+
args[i] = rb_ary_entry(nary_arr, i);
|
87
|
+
}
|
88
|
+
ret = rb_funcallv(self, fn_id, n, args);
|
89
|
+
} else {
|
90
|
+
VALUE* args = ALLOCA_N(VALUE, n + 1);
|
91
|
+
for (size_t i = 0; i < n; i++) {
|
92
|
+
args[i] = rb_ary_entry(nary_arr, i);
|
93
|
+
}
|
94
|
+
args[n] = kw_args;
|
95
|
+
ret = rb_funcallv_kw(self, fn_id, n + 1, args, RB_PASS_KEYWORDS);
|
96
|
+
}
|
97
|
+
|
98
|
+
return ret;
|
99
|
+
}
|
100
|
+
|
101
|
+
static VALUE tiny_linalg_dot(VALUE self, VALUE a_, VALUE b_) {
|
102
|
+
VALUE a = IsNArray(a_) ? a_ : rb_funcall(numo_cNArray, rb_intern("asarray"), 1, a_);
|
103
|
+
VALUE b = IsNArray(b_) ? b_ : rb_funcall(numo_cNArray, rb_intern("asarray"), 1, b_);
|
104
|
+
|
105
|
+
VALUE arg_arr = rb_ary_new3(2, a, b);
|
106
|
+
const char type = blas_char(arg_arr);
|
107
|
+
if (type == 'n') {
|
108
|
+
rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
|
109
|
+
return Qnil;
|
110
|
+
}
|
111
|
+
|
112
|
+
VALUE ret = Qnil;
|
113
|
+
narray_t* a_nary = NULL;
|
114
|
+
narray_t* b_nary = NULL;
|
115
|
+
GetNArray(a, a_nary);
|
116
|
+
GetNArray(b, b_nary);
|
117
|
+
const int a_ndim = NA_NDIM(a_nary);
|
118
|
+
const int b_ndim = NA_NDIM(b_nary);
|
119
|
+
|
120
|
+
if (a_ndim == 1) {
|
121
|
+
if (b_ndim == 1) {
|
122
|
+
ID fn_id = type == 'c' || type == 'z' ? rb_intern("dotu") : rb_intern("dot");
|
123
|
+
ret = rb_funcall(rb_mTinyLinalgBlas, rb_intern("call"), 3, ID2SYM(fn_id), a, b);
|
124
|
+
} else {
|
125
|
+
VALUE kw_args = rb_hash_new();
|
126
|
+
if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
|
127
|
+
b = rb_funcall(b, rb_intern("transpose"), 0);
|
128
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
|
129
|
+
} else {
|
130
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
|
131
|
+
}
|
132
|
+
char fn_name[] = "xgemv";
|
133
|
+
fn_name[0] = type;
|
134
|
+
VALUE argv[3] = { b, a, kw_args };
|
135
|
+
ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
|
136
|
+
}
|
137
|
+
} else {
|
138
|
+
if (b_ndim == 1) {
|
139
|
+
VALUE kw_args = rb_hash_new();
|
140
|
+
if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
|
141
|
+
a = rb_funcall(a, rb_intern("transpose"), 0);
|
142
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
|
143
|
+
} else {
|
144
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
|
145
|
+
}
|
146
|
+
char fn_name[] = "xgemv";
|
147
|
+
fn_name[0] = type;
|
148
|
+
VALUE argv[3] = { a, b, kw_args };
|
149
|
+
ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
|
150
|
+
} else {
|
151
|
+
VALUE kw_args = rb_hash_new();
|
152
|
+
if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
|
153
|
+
a = rb_funcall(a, rb_intern("transpose"), 0);
|
154
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("T"));
|
155
|
+
} else {
|
156
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("N"));
|
157
|
+
}
|
158
|
+
if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
|
159
|
+
b = rb_funcall(b, rb_intern("transpose"), 0);
|
160
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("T"));
|
161
|
+
} else {
|
162
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("N"));
|
163
|
+
}
|
164
|
+
char fn_name[] = "xgemm";
|
165
|
+
fn_name[0] = type;
|
166
|
+
VALUE argv[3] = { a, b, kw_args };
|
167
|
+
ret = rb_funcallv_kw(rb_mTinyLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
RB_GC_GUARD(a);
|
172
|
+
RB_GC_GUARD(b);
|
173
|
+
|
174
|
+
return ret;
|
175
|
+
}
|
176
|
+
|
177
|
+
extern "C" void Init_tiny_linalg(void) {
|
178
|
+
rb_require("numo/narray");
|
179
|
+
|
180
|
+
/**
|
181
|
+
* Document-module: Numo::TinyLinalg
|
182
|
+
* Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
183
|
+
*/
|
184
|
+
rb_mTinyLinalg = rb_define_module_under(rb_mNumo, "TinyLinalg");
|
185
|
+
/**
|
186
|
+
* Document-module: Numo::TinyLinalg::Blas
|
187
|
+
* Numo::TinyLinalg::Blas is wrapper module of BLAS functions.
|
188
|
+
*/
|
189
|
+
rb_mTinyLinalgBlas = rb_define_module_under(rb_mTinyLinalg, "Blas");
|
190
|
+
/**
|
191
|
+
* Document-module: Numo::TinyLinalg::Lapack
|
192
|
+
* Numo::TinyLinalg::Lapack is wrapper module of LAPACK functions.
|
193
|
+
*/
|
194
|
+
rb_mTinyLinalgLapack = rb_define_module_under(rb_mTinyLinalg, "Lapack");
|
195
|
+
|
196
|
+
/**
|
197
|
+
* Returns BLAS char ([sdcz]) defined by data-type of arguments.
|
198
|
+
*
|
199
|
+
* @overload blas_char(a, ...) -> String
|
200
|
+
* @param [Numo::NArray] a
|
201
|
+
* @return [String]
|
202
|
+
*/
|
203
|
+
rb_define_module_function(rb_mTinyLinalg, "blas_char", RUBY_METHOD_FUNC(tiny_linalg_blas_char), -1);
|
204
|
+
/**
|
205
|
+
* Calculates dot product of two vectors / matrices.
|
206
|
+
*
|
207
|
+
* @overload dot(a, b) -> [Float|Complex|Numo::NArray]
|
208
|
+
* @param [Numo::NArray] a
|
209
|
+
* @param [Numo::NArray] b
|
210
|
+
* @return [Float|Complex|Numo::NArray]
|
211
|
+
*/
|
212
|
+
rb_define_module_function(rb_mTinyLinalg, "dot", RUBY_METHOD_FUNC(tiny_linalg_dot), 2);
|
213
|
+
/**
|
214
|
+
* Calls BLAS function prefixed with BLAS char.
|
215
|
+
*
|
216
|
+
* @overload call(func, *args)
|
217
|
+
* @param func [Symbol] BLAS function name without BLAS char.
|
218
|
+
* @param args arguments of BLAS function.
|
219
|
+
* @example
|
220
|
+
* Numo::TinyLinalg::Blas.call(:gemv, a, b)
|
221
|
+
*/
|
222
|
+
rb_define_singleton_method(rb_mTinyLinalgBlas, "call", RUBY_METHOD_FUNC(tiny_linalg_blas_call), -1);
|
223
|
+
|
224
|
+
TinyLinalg::Dot<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DDot>::define_module_function(rb_mTinyLinalgBlas, "ddot");
|
225
|
+
TinyLinalg::Dot<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SDot>::define_module_function(rb_mTinyLinalgBlas, "sdot");
|
226
|
+
TinyLinalg::DotSub<TinyLinalg::numo_cDComplexId, double, TinyLinalg::ZDotuSub>::define_module_function(rb_mTinyLinalgBlas, "zdotu");
|
227
|
+
TinyLinalg::DotSub<TinyLinalg::numo_cSComplexId, float, TinyLinalg::CDotuSub>::define_module_function(rb_mTinyLinalgBlas, "cdotu");
|
228
|
+
TinyLinalg::Gemm<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGemm, TinyLinalg::DConverter>::define_module_function(rb_mTinyLinalgBlas, "dgemm");
|
229
|
+
TinyLinalg::Gemm<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGemm, TinyLinalg::SConverter>::define_module_function(rb_mTinyLinalgBlas, "sgemm");
|
230
|
+
TinyLinalg::Gemm<TinyLinalg::numo_cDComplexId, dcomplex, TinyLinalg::ZGemm, TinyLinalg::ZConverter>::define_module_function(rb_mTinyLinalgBlas, "zgemm");
|
231
|
+
TinyLinalg::Gemm<TinyLinalg::numo_cSComplexId, scomplex, TinyLinalg::CGemm, TinyLinalg::CConverter>::define_module_function(rb_mTinyLinalgBlas, "cgemm");
|
232
|
+
TinyLinalg::Gemv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGemv, TinyLinalg::DConverter>::define_module_function(rb_mTinyLinalgBlas, "dgemv");
|
233
|
+
TinyLinalg::Gemv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGemv, TinyLinalg::SConverter>::define_module_function(rb_mTinyLinalgBlas, "sgemv");
|
234
|
+
TinyLinalg::Gemv<TinyLinalg::numo_cDComplexId, dcomplex, TinyLinalg::ZGemv, TinyLinalg::ZConverter>::define_module_function(rb_mTinyLinalgBlas, "zgemv");
|
235
|
+
TinyLinalg::Gemv<TinyLinalg::numo_cSComplexId, scomplex, TinyLinalg::CGemv, TinyLinalg::CConverter>::define_module_function(rb_mTinyLinalgBlas, "cgemv");
|
236
|
+
TinyLinalg::Nrm2<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DNrm2>::define_module_function(rb_mTinyLinalgBlas, "dnrm2");
|
237
|
+
TinyLinalg::Nrm2<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SNrm2>::define_module_function(rb_mTinyLinalgBlas, "snrm2");
|
238
|
+
TinyLinalg::Nrm2<TinyLinalg::numo_cDComplexId, double, TinyLinalg::DZNrm2>::define_module_function(rb_mTinyLinalgBlas, "dznrm2");
|
239
|
+
TinyLinalg::Nrm2<TinyLinalg::numo_cSComplexId, float, TinyLinalg::SCNrm2>::define_module_function(rb_mTinyLinalgBlas, "scnrm2");
|
240
|
+
TinyLinalg::GESVD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESVD>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
|
241
|
+
TinyLinalg::GESVD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESVD>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
|
242
|
+
TinyLinalg::GESVD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESVD>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
|
243
|
+
TinyLinalg::GESVD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESVD>::define_module_function(rb_mTinyLinalgLapack, "cgesvd");
|
244
|
+
TinyLinalg::GESDD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESDD>::define_module_function(rb_mTinyLinalgLapack, "dgesdd");
|
245
|
+
TinyLinalg::GESDD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESDD>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
|
246
|
+
TinyLinalg::GESDD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESDD>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
|
247
|
+
TinyLinalg::GESDD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESDD>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
|
248
|
+
|
249
|
+
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
|
250
|
+
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
|
251
|
+
}
|
@@ -0,0 +1,38 @@
|
|
1
|
+
#ifndef NUMO_TINY_LINALG_H
|
2
|
+
#define NUMO_TINY_LINALG_H 1
|
3
|
+
|
4
|
+
#if defined(TINYLINALG_USE_ACCELERATE)
|
5
|
+
#include <Accelerate/Accelerate.h>
|
6
|
+
#else
|
7
|
+
#include <cblas.h>
|
8
|
+
#include <lapacke.h>
|
9
|
+
#endif
|
10
|
+
|
11
|
+
#include <string>
|
12
|
+
|
13
|
+
#include <cstring>
|
14
|
+
|
15
|
+
#include <ruby.h>
|
16
|
+
|
17
|
+
#include <numo/narray.h>
|
18
|
+
#include <numo/template.h>
|
19
|
+
|
20
|
+
namespace TinyLinalg {
|
21
|
+
|
22
|
+
const VALUE NaryTypes[4] = {
|
23
|
+
numo_cDFloat,
|
24
|
+
numo_cSFloat,
|
25
|
+
numo_cDComplex,
|
26
|
+
numo_cSComplex
|
27
|
+
};
|
28
|
+
|
29
|
+
enum NaryType {
|
30
|
+
numo_cDFloatId,
|
31
|
+
numo_cSFloatId,
|
32
|
+
numo_cDComplexId,
|
33
|
+
numo_cSComplexId
|
34
|
+
};
|
35
|
+
|
36
|
+
} // namespace TinyLinalg
|
37
|
+
|
38
|
+
#endif /* NUMO_TINY_LINALG_H */
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Ruby/Numo (NUmerical MOdules)
|
4
|
+
module Numo
|
5
|
+
# Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
6
|
+
module TinyLinalg
|
7
|
+
# The version of Numo::TinyLinalg you install.
|
8
|
+
VERSION = '0.0.1'
|
9
|
+
end
|
10
|
+
end
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/narray'
|
4
|
+
require_relative 'tiny_linalg/version'
|
5
|
+
require_relative 'tiny_linalg/tiny_linalg'
|
6
|
+
|
7
|
+
# Ruby/Numo (NUmerical MOdules)
|
8
|
+
module Numo
|
9
|
+
# Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
10
|
+
module TinyLinalg
|
11
|
+
module_function
|
12
|
+
|
13
|
+
# Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
|
14
|
+
#
|
15
|
+
# @param a [Numo::NArray] Matrix to be decomposed.
|
16
|
+
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
17
|
+
# @param job [String] Job option ('A', 'S', or 'N').
|
18
|
+
# @return [Array<Numo::NArray>] Singular values and singular vectors ([s, u, vt]).
|
19
|
+
def svd(a, driver: 'svd', job: 'A')
|
20
|
+
raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
|
21
|
+
|
22
|
+
case driver.to_s
|
23
|
+
when 'sdd'
|
24
|
+
s, u, vt, info = case a
|
25
|
+
when Numo::DFloat
|
26
|
+
Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
|
27
|
+
when Numo::SFloat
|
28
|
+
Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
|
29
|
+
when Numo::DComplex
|
30
|
+
Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
|
31
|
+
when Numo::SComplex
|
32
|
+
Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
|
33
|
+
else
|
34
|
+
raise ArgumentError, "invalid array type: #{a.class}"
|
35
|
+
end
|
36
|
+
when 'svd'
|
37
|
+
s, u, vt, info = case a
|
38
|
+
when Numo::DFloat
|
39
|
+
Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
|
40
|
+
when Numo::SFloat
|
41
|
+
Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
|
42
|
+
when Numo::DComplex
|
43
|
+
Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
|
44
|
+
when Numo::SComplex
|
45
|
+
Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
|
46
|
+
else
|
47
|
+
raise ArgumentError, "invalid array type: #{a.class}"
|
48
|
+
end
|
49
|
+
else
|
50
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
51
|
+
end
|
52
|
+
|
53
|
+
raise "the #{info.abs}-th argument had illegal value" if info.negative?
|
54
|
+
raise 'input array has a NAN entry' if info == -4
|
55
|
+
raise 'svd did not converge' if info.positive?
|
56
|
+
|
57
|
+
[s, u, vt]
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative 'lib/numo/tiny_linalg/version'
|
4
|
+
|
5
|
+
Gem::Specification.new do |spec|
|
6
|
+
spec.name = 'numo-tiny_linalg'
|
7
|
+
spec.version = Numo::TinyLinalg::VERSION
|
8
|
+
spec.authors = ['yoshoku']
|
9
|
+
spec.email = ['yoshoku@outlook.com']
|
10
|
+
|
11
|
+
spec.summary = <<~MSG
|
12
|
+
Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
13
|
+
MSG
|
14
|
+
spec.description = <<~MSG
|
15
|
+
Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
16
|
+
MSG
|
17
|
+
spec.homepage = 'https://github.com/yoshoku/numo-tiny_linalg'
|
18
|
+
spec.license = 'BSD-3-Clause'
|
19
|
+
|
20
|
+
spec.metadata['homepage_uri'] = spec.homepage
|
21
|
+
spec.metadata['source_code_uri'] = spec.homepage
|
22
|
+
spec.metadata['changelog_uri'] = "#{spec.homepage}/blob/main/CHANGELOG.md"
|
23
|
+
|
24
|
+
# Specify which files should be added to the gem when it is released.
|
25
|
+
# The `git ls-files -z` loads the files in the RubyGem that have been added into git.
|
26
|
+
spec.files = Dir.chdir(__dir__) do
|
27
|
+
`git ls-files -z`.split("\x0").reject do |f|
|
28
|
+
(f == __FILE__) || f.match(%r{\A(?:(?:bin|test|spec|features)/|\.(?:git|circleci)|appveyor)})
|
29
|
+
end
|
30
|
+
end
|
31
|
+
spec.bindir = 'exe'
|
32
|
+
spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
|
33
|
+
spec.require_paths = ['lib']
|
34
|
+
spec.extensions = ['ext/numo/tiny_linalg/extconf.rb']
|
35
|
+
|
36
|
+
# Uncomment to register a new dependency of your gem
|
37
|
+
spec.add_dependency 'numo-narray', '>= 0.9.1'
|
38
|
+
|
39
|
+
# For more information and examples about making a new gem, check out our
|
40
|
+
# guide at: https://bundler.io/guides/creating_gem.html
|
41
|
+
spec.metadata['rubygems_mfa_required'] = 'true'
|
42
|
+
end
|
data/package.json
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
{
|
2
|
+
"name": "numo-tiny_linalg",
|
3
|
+
"repository": "git@github.com:yoshoku/numo-tiny_linalg.git",
|
4
|
+
"author": "Atsushi Tatsuma <yoshoku@outlook.com>",
|
5
|
+
"license": "BSD-3-Clause",
|
6
|
+
"private": true,
|
7
|
+
"scripts": {
|
8
|
+
"prepare": "husky install"
|
9
|
+
},
|
10
|
+
"devDependencies": {
|
11
|
+
"@commitlint/cli": "^17.6.1",
|
12
|
+
"@commitlint/config-conventional": "^17.6.1",
|
13
|
+
"husky": "^8.0.3"
|
14
|
+
}
|
15
|
+
}
|