numo-linalg-alt 0.3.0 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/numo/linalg/blas/dot.c +59 -59
- data/ext/numo/linalg/blas/dot_sub.c +58 -58
- data/ext/numo/linalg/blas/gemm.c +157 -148
- data/ext/numo/linalg/blas/gemv.c +131 -127
- data/ext/numo/linalg/blas/nrm2.c +50 -50
- data/ext/numo/linalg/lapack/gees.c +239 -220
- data/ext/numo/linalg/lapack/geev.c +127 -110
- data/ext/numo/linalg/lapack/gelsd.c +81 -70
- data/ext/numo/linalg/lapack/geqrf.c +52 -51
- data/ext/numo/linalg/lapack/gerqf.c +70 -0
- data/ext/numo/linalg/lapack/gerqf.h +15 -0
- data/ext/numo/linalg/lapack/gesdd.c +96 -86
- data/ext/numo/linalg/lapack/gesv.c +80 -78
- data/ext/numo/linalg/lapack/gesvd.c +140 -129
- data/ext/numo/linalg/lapack/getrf.c +51 -50
- data/ext/numo/linalg/lapack/getri.c +64 -63
- data/ext/numo/linalg/lapack/getrs.c +92 -88
- data/ext/numo/linalg/lapack/gges.c +214 -0
- data/ext/numo/linalg/lapack/gges.h +15 -0
- data/ext/numo/linalg/lapack/heev.c +54 -52
- data/ext/numo/linalg/lapack/heevd.c +54 -52
- data/ext/numo/linalg/lapack/heevr.c +109 -98
- data/ext/numo/linalg/lapack/hegv.c +77 -74
- data/ext/numo/linalg/lapack/hegvd.c +77 -74
- data/ext/numo/linalg/lapack/hegvx.c +132 -120
- data/ext/numo/linalg/lapack/hetrf.c +54 -50
- data/ext/numo/linalg/lapack/lange.c +45 -44
- data/ext/numo/linalg/lapack/orgqr.c +63 -62
- data/ext/numo/linalg/lapack/orgrq.c +78 -0
- data/ext/numo/linalg/lapack/orgrq.h +15 -0
- data/ext/numo/linalg/lapack/potrf.c +49 -48
- data/ext/numo/linalg/lapack/potri.c +49 -48
- data/ext/numo/linalg/lapack/potrs.c +74 -72
- data/ext/numo/linalg/lapack/syev.c +54 -52
- data/ext/numo/linalg/lapack/syevd.c +54 -52
- data/ext/numo/linalg/lapack/syevr.c +107 -98
- data/ext/numo/linalg/lapack/sygv.c +77 -73
- data/ext/numo/linalg/lapack/sygvd.c +77 -73
- data/ext/numo/linalg/lapack/sygvx.c +132 -120
- data/ext/numo/linalg/lapack/sytrf.c +54 -50
- data/ext/numo/linalg/lapack/trtrs.c +79 -75
- data/ext/numo/linalg/lapack/ungqr.c +63 -62
- data/ext/numo/linalg/lapack/ungrq.c +78 -0
- data/ext/numo/linalg/lapack/ungrq.h +15 -0
- data/ext/numo/linalg/linalg.c +20 -10
- data/ext/numo/linalg/linalg.h +4 -0
- data/ext/numo/linalg/util.c +8 -0
- data/ext/numo/linalg/util.h +1 -0
- data/lib/numo/linalg/version.rb +1 -1
- data/lib/numo/linalg.rb +139 -3
- metadata +10 -2
data/ext/numo/linalg/blas/gemm.c
CHANGED
@@ -1,161 +1,170 @@
|
|
1
1
|
#include "gemm.h"
|
2
2
|
|
3
|
-
#define DEF_LINALG_OPTIONS(tDType)
|
4
|
-
struct _gemm_options_##tDType {
|
5
|
-
tDType alpha;
|
6
|
-
tDType beta;
|
7
|
-
enum CBLAS_ORDER order;
|
8
|
-
enum CBLAS_TRANSPOSE transa;
|
9
|
-
enum CBLAS_TRANSPOSE transb;
|
10
|
-
blasint m;
|
11
|
-
blasint n;
|
12
|
-
blasint k;
|
3
|
+
#define DEF_LINALG_OPTIONS(tDType) \
|
4
|
+
struct _gemm_options_##tDType { \
|
5
|
+
tDType alpha; \
|
6
|
+
tDType beta; \
|
7
|
+
enum CBLAS_ORDER order; \
|
8
|
+
enum CBLAS_TRANSPOSE transa; \
|
9
|
+
enum CBLAS_TRANSPOSE transb; \
|
10
|
+
blasint m; \
|
11
|
+
blasint n; \
|
12
|
+
blasint k; \
|
13
13
|
};
|
14
14
|
|
15
|
-
#define DEF_LINALG_ITER_FUNC(tDType, fBlasFunc)
|
16
|
-
static void _iter_##fBlasFunc(na_loop_t* const lp) {
|
17
|
-
const tDType* a = (tDType*)NDL_PTR(lp, 0);
|
18
|
-
const tDType* b = (tDType*)NDL_PTR(lp, 1);
|
19
|
-
tDType* c = (tDType*)NDL_PTR(lp, 2);
|
20
|
-
const struct _gemm_options_##tDType* opt = (struct _gemm_options_##tDType*)(lp->opt_ptr);
|
21
|
-
const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m;
|
22
|
-
const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k;
|
23
|
-
const blasint ldc = opt->n;
|
24
|
-
cblas_##fBlasFunc(
|
25
|
-
|
15
|
+
#define DEF_LINALG_ITER_FUNC(tDType, fBlasFunc) \
|
16
|
+
static void _iter_##fBlasFunc(na_loop_t* const lp) { \
|
17
|
+
const tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
18
|
+
const tDType* b = (tDType*)NDL_PTR(lp, 1); \
|
19
|
+
tDType* c = (tDType*)NDL_PTR(lp, 2); \
|
20
|
+
const struct _gemm_options_##tDType* opt = (struct _gemm_options_##tDType*)(lp->opt_ptr); \
|
21
|
+
const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m; \
|
22
|
+
const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k; \
|
23
|
+
const blasint ldc = opt->n; \
|
24
|
+
cblas_##fBlasFunc( \
|
25
|
+
opt->order, opt->transa, opt->transb, opt->m, opt->n, opt->k, opt->alpha, a, lda, b, \
|
26
|
+
ldb, opt->beta, c, ldc \
|
27
|
+
); \
|
26
28
|
}
|
27
29
|
|
28
|
-
#define DEF_LINALG_ITER_FUNC_COMPLEX(tDType, fBlasFunc)
|
29
|
-
static void _iter_##fBlasFunc(na_loop_t* const lp) {
|
30
|
-
const tDType* a = (tDType*)NDL_PTR(lp, 0);
|
31
|
-
const tDType* b = (tDType*)NDL_PTR(lp, 1);
|
32
|
-
tDType* c = (tDType*)NDL_PTR(lp, 2);
|
33
|
-
const struct _gemm_options_##tDType* opt = (struct _gemm_options_##tDType*)(lp->opt_ptr);
|
34
|
-
const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m;
|
35
|
-
const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k;
|
36
|
-
const blasint ldc = opt->n;
|
37
|
-
cblas_##fBlasFunc(
|
38
|
-
|
30
|
+
#define DEF_LINALG_ITER_FUNC_COMPLEX(tDType, fBlasFunc) \
|
31
|
+
static void _iter_##fBlasFunc(na_loop_t* const lp) { \
|
32
|
+
const tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
33
|
+
const tDType* b = (tDType*)NDL_PTR(lp, 1); \
|
34
|
+
tDType* c = (tDType*)NDL_PTR(lp, 2); \
|
35
|
+
const struct _gemm_options_##tDType* opt = (struct _gemm_options_##tDType*)(lp->opt_ptr); \
|
36
|
+
const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m; \
|
37
|
+
const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k; \
|
38
|
+
const blasint ldc = opt->n; \
|
39
|
+
cblas_##fBlasFunc( \
|
40
|
+
opt->order, opt->transa, opt->transb, opt->m, opt->n, opt->k, &opt->alpha, a, lda, b, \
|
41
|
+
ldb, &opt->beta, c, ldc \
|
42
|
+
); \
|
39
43
|
}
|
40
44
|
|
41
|
-
#define DEF_LINALG_ITER_FUNC(tDType, fBlasFunc)
|
42
|
-
static void _iter_##fBlasFunc(na_loop_t* const lp) {
|
43
|
-
const tDType* a = (tDType*)NDL_PTR(lp, 0);
|
44
|
-
const tDType* b = (tDType*)NDL_PTR(lp, 1);
|
45
|
-
tDType* c = (tDType*)NDL_PTR(lp, 2);
|
46
|
-
const struct _gemm_options_##tDType* opt = (struct _gemm_options_##tDType*)(lp->opt_ptr);
|
47
|
-
const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m;
|
48
|
-
const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k;
|
49
|
-
const blasint ldc = opt->n;
|
50
|
-
cblas_##fBlasFunc(
|
51
|
-
|
45
|
+
#define DEF_LINALG_ITER_FUNC(tDType, fBlasFunc) \
|
46
|
+
static void _iter_##fBlasFunc(na_loop_t* const lp) { \
|
47
|
+
const tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
48
|
+
const tDType* b = (tDType*)NDL_PTR(lp, 1); \
|
49
|
+
tDType* c = (tDType*)NDL_PTR(lp, 2); \
|
50
|
+
const struct _gemm_options_##tDType* opt = (struct _gemm_options_##tDType*)(lp->opt_ptr); \
|
51
|
+
const blasint lda = opt->transa == CblasNoTrans ? opt->k : opt->m; \
|
52
|
+
const blasint ldb = opt->transb == CblasNoTrans ? opt->n : opt->k; \
|
53
|
+
const blasint ldc = opt->n; \
|
54
|
+
cblas_##fBlasFunc( \
|
55
|
+
opt->order, opt->transa, opt->transb, opt->m, opt->n, opt->k, opt->alpha, a, lda, b, \
|
56
|
+
ldb, opt->beta, c, ldc \
|
57
|
+
); \
|
52
58
|
}
|
53
59
|
|
54
|
-
#define DEF_LINALG_FUNC(tDType, tNAryClass, fBlasFunc)
|
55
|
-
static VALUE _linalg_blas_##fBlasFunc(int argc, VALUE* argv, VALUE self) {
|
56
|
-
VALUE a = Qnil;
|
57
|
-
VALUE b = Qnil;
|
58
|
-
VALUE c = Qnil;
|
59
|
-
VALUE kw_args = Qnil;
|
60
|
-
rb_scan_args(argc, argv, "21:", &a, &b, &c, &kw_args);
|
61
|
-
|
62
|
-
ID kw_table[5] = { rb_intern("alpha"), rb_intern("beta"), rb_intern("order"),
|
63
|
-
rb_intern("transa"), rb_intern("transb") };
|
64
|
-
VALUE kw_values[5] = { Qundef, Qundef, Qundef, Qundef, Qundef };
|
65
|
-
rb_get_kwargs(kw_args, kw_table, 0, 5, kw_values);
|
66
|
-
|
67
|
-
if (CLASS_OF(a) != tNAryClass) {
|
68
|
-
a = rb_funcall(tNAryClass, rb_intern("cast"), 1, a);
|
69
|
-
}
|
70
|
-
if (!RTEST(nary_check_contiguous(a))) {
|
71
|
-
a = nary_dup(a);
|
72
|
-
}
|
73
|
-
if (CLASS_OF(b) != tNAryClass) {
|
74
|
-
b = rb_funcall(tNAryClass, rb_intern("cast"), 1, b);
|
75
|
-
}
|
76
|
-
if (!RTEST(nary_check_contiguous(b))) {
|
77
|
-
b = nary_dup(b);
|
78
|
-
}
|
79
|
-
if (!NIL_P(c)) {
|
80
|
-
if (CLASS_OF(c) != tNAryClass) {
|
81
|
-
c = rb_funcall(tNAryClass, rb_intern("cast"), 1, c);
|
82
|
-
}
|
83
|
-
if (!RTEST(nary_check_contiguous(c))) {
|
84
|
-
c = nary_dup(c);
|
85
|
-
}
|
86
|
-
}
|
87
|
-
|
88
|
-
tDType alpha = kw_values[0] != Qundef ? conv_##tDType(kw_values[0]) : one_##tDType();
|
89
|
-
tDType beta = kw_values[1] != Qundef ? conv_##tDType(kw_values[1]) : zero_##tDType();
|
90
|
-
enum CBLAS_ORDER order =
|
91
|
-
|
92
|
-
enum CBLAS_TRANSPOSE
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
const blasint
|
120
|
-
const blasint
|
121
|
-
const blasint
|
122
|
-
const blasint
|
123
|
-
const blasint
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
60
|
+
#define DEF_LINALG_FUNC(tDType, tNAryClass, fBlasFunc) \
|
61
|
+
static VALUE _linalg_blas_##fBlasFunc(int argc, VALUE* argv, VALUE self) { \
|
62
|
+
VALUE a = Qnil; \
|
63
|
+
VALUE b = Qnil; \
|
64
|
+
VALUE c = Qnil; \
|
65
|
+
VALUE kw_args = Qnil; \
|
66
|
+
rb_scan_args(argc, argv, "21:", &a, &b, &c, &kw_args); \
|
67
|
+
\
|
68
|
+
ID kw_table[5] = { rb_intern("alpha"), rb_intern("beta"), rb_intern("order"), \
|
69
|
+
rb_intern("transa"), rb_intern("transb") }; \
|
70
|
+
VALUE kw_values[5] = { Qundef, Qundef, Qundef, Qundef, Qundef }; \
|
71
|
+
rb_get_kwargs(kw_args, kw_table, 0, 5, kw_values); \
|
72
|
+
\
|
73
|
+
if (CLASS_OF(a) != tNAryClass) { \
|
74
|
+
a = rb_funcall(tNAryClass, rb_intern("cast"), 1, a); \
|
75
|
+
} \
|
76
|
+
if (!RTEST(nary_check_contiguous(a))) { \
|
77
|
+
a = nary_dup(a); \
|
78
|
+
} \
|
79
|
+
if (CLASS_OF(b) != tNAryClass) { \
|
80
|
+
b = rb_funcall(tNAryClass, rb_intern("cast"), 1, b); \
|
81
|
+
} \
|
82
|
+
if (!RTEST(nary_check_contiguous(b))) { \
|
83
|
+
b = nary_dup(b); \
|
84
|
+
} \
|
85
|
+
if (!NIL_P(c)) { \
|
86
|
+
if (CLASS_OF(c) != tNAryClass) { \
|
87
|
+
c = rb_funcall(tNAryClass, rb_intern("cast"), 1, c); \
|
88
|
+
} \
|
89
|
+
if (!RTEST(nary_check_contiguous(c))) { \
|
90
|
+
c = nary_dup(c); \
|
91
|
+
} \
|
92
|
+
} \
|
93
|
+
\
|
94
|
+
tDType alpha = kw_values[0] != Qundef ? conv_##tDType(kw_values[0]) : one_##tDType(); \
|
95
|
+
tDType beta = kw_values[1] != Qundef ? conv_##tDType(kw_values[1]) : zero_##tDType(); \
|
96
|
+
enum CBLAS_ORDER order = \
|
97
|
+
kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor; \
|
98
|
+
enum CBLAS_TRANSPOSE transa = \
|
99
|
+
kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans; \
|
100
|
+
enum CBLAS_TRANSPOSE transb = \
|
101
|
+
kw_values[4] != Qundef ? get_cblas_trans(kw_values[4]) : CblasNoTrans; \
|
102
|
+
\
|
103
|
+
narray_t* a_nary = NULL; \
|
104
|
+
GetNArray(a, a_nary); \
|
105
|
+
narray_t* b_nary = NULL; \
|
106
|
+
GetNArray(b, b_nary); \
|
107
|
+
\
|
108
|
+
if (NA_NDIM(a_nary) != 2) { \
|
109
|
+
rb_raise(rb_eArgError, "a must be 2-dimensional"); \
|
110
|
+
return Qnil; \
|
111
|
+
} \
|
112
|
+
if (NA_NDIM(b_nary) != 2) { \
|
113
|
+
rb_raise(rb_eArgError, "b must be 2-dimensional"); \
|
114
|
+
return Qnil; \
|
115
|
+
} \
|
116
|
+
if (NA_SIZE(a_nary) == 0) { \
|
117
|
+
rb_raise(rb_eArgError, "a must not be empty"); \
|
118
|
+
return Qnil; \
|
119
|
+
} \
|
120
|
+
if (NA_SIZE(b_nary) == 0) { \
|
121
|
+
rb_raise(rb_eArgError, "b must not be empty"); \
|
122
|
+
return Qnil; \
|
123
|
+
} \
|
124
|
+
\
|
125
|
+
const blasint ma = (blasint)NA_SHAPE(a_nary)[0]; \
|
126
|
+
const blasint ka = (blasint)NA_SHAPE(a_nary)[1]; \
|
127
|
+
const blasint kb = (blasint)NA_SHAPE(b_nary)[0]; \
|
128
|
+
const blasint nb = (blasint)NA_SHAPE(b_nary)[1]; \
|
129
|
+
const blasint m = transa == CblasNoTrans ? ma : ka; \
|
130
|
+
const blasint n = transb == CblasNoTrans ? nb : kb; \
|
131
|
+
const blasint k = transa == CblasNoTrans ? ka : ma; \
|
132
|
+
const blasint l = transb == CblasNoTrans ? kb : nb; \
|
133
|
+
\
|
134
|
+
if (k != l) { \
|
135
|
+
rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", k, l); \
|
136
|
+
return Qnil; \
|
137
|
+
} \
|
138
|
+
\
|
139
|
+
struct _gemm_options_##tDType opt = { alpha, beta, order, transa, transb, m, n, k }; \
|
140
|
+
size_t shape_out[2] = { (size_t)m, (size_t)n }; \
|
141
|
+
ndfunc_arg_out_t aout[1] = { { tNAryClass, 2, shape_out } }; \
|
142
|
+
VALUE ret = Qnil; \
|
143
|
+
\
|
144
|
+
if (!NIL_P(c)) { \
|
145
|
+
narray_t* c_nary = NULL; \
|
146
|
+
GetNArray(c, c_nary); \
|
147
|
+
blasint nc = (blasint)NA_SHAPE(c_nary)[0]; \
|
148
|
+
if (m > nc) { \
|
149
|
+
rb_raise(nary_eShapeError, "shape3[0](=%d) >= shape1[0]=%d", nc, m); \
|
150
|
+
return Qnil; \
|
151
|
+
} \
|
152
|
+
ndfunc_arg_in_t ain[3] = { { tNAryClass, 2 }, { tNAryClass, 2 }, { OVERWRITE, 2 } }; \
|
153
|
+
ndfunc_t ndf = { _iter_##fBlasFunc, NO_LOOP, 3, 0, ain, aout }; \
|
154
|
+
na_ndloop3(&ndf, &opt, 3, a, b, c); \
|
155
|
+
ret = c; \
|
156
|
+
} else { \
|
157
|
+
c = INT2NUM(0); \
|
158
|
+
ndfunc_arg_in_t ain[3] = { { tNAryClass, 2 }, { tNAryClass, 2 }, { sym_init, 0 } }; \
|
159
|
+
ndfunc_t ndf = { _iter_##fBlasFunc, NO_LOOP, 3, 1, ain, aout }; \
|
160
|
+
ret = na_ndloop3(&ndf, &opt, 3, a, b, c); \
|
161
|
+
} \
|
162
|
+
\
|
163
|
+
RB_GC_GUARD(a); \
|
164
|
+
RB_GC_GUARD(b); \
|
165
|
+
RB_GC_GUARD(c); \
|
166
|
+
\
|
167
|
+
return ret; \
|
159
168
|
}
|
160
169
|
|
161
170
|
DEF_LINALG_OPTIONS(double)
|