numo-linalg-alt 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +5 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/LICENSE.txt +27 -0
- data/README.md +106 -0
- data/ext/numo/linalg/blas/dot.c +72 -0
- data/ext/numo/linalg/blas/dot.h +13 -0
- data/ext/numo/linalg/blas/dot_sub.c +71 -0
- data/ext/numo/linalg/blas/dot_sub.h +13 -0
- data/ext/numo/linalg/blas/gemm.c +184 -0
- data/ext/numo/linalg/blas/gemm.h +16 -0
- data/ext/numo/linalg/blas/gemv.c +161 -0
- data/ext/numo/linalg/blas/gemv.h +16 -0
- data/ext/numo/linalg/blas/nrm2.c +67 -0
- data/ext/numo/linalg/blas/nrm2.h +13 -0
- data/ext/numo/linalg/converter.c +67 -0
- data/ext/numo/linalg/converter.h +23 -0
- data/ext/numo/linalg/extconf.rb +99 -0
- data/ext/numo/linalg/lapack/geev.c +152 -0
- data/ext/numo/linalg/lapack/geev.h +15 -0
- data/ext/numo/linalg/lapack/gelsd.c +92 -0
- data/ext/numo/linalg/lapack/gelsd.h +15 -0
- data/ext/numo/linalg/lapack/geqrf.c +72 -0
- data/ext/numo/linalg/lapack/geqrf.h +15 -0
- data/ext/numo/linalg/lapack/gesdd.c +108 -0
- data/ext/numo/linalg/lapack/gesdd.h +15 -0
- data/ext/numo/linalg/lapack/gesv.c +99 -0
- data/ext/numo/linalg/lapack/gesv.h +15 -0
- data/ext/numo/linalg/lapack/gesvd.c +152 -0
- data/ext/numo/linalg/lapack/gesvd.h +15 -0
- data/ext/numo/linalg/lapack/getrf.c +71 -0
- data/ext/numo/linalg/lapack/getrf.h +15 -0
- data/ext/numo/linalg/lapack/getri.c +82 -0
- data/ext/numo/linalg/lapack/getri.h +15 -0
- data/ext/numo/linalg/lapack/getrs.c +110 -0
- data/ext/numo/linalg/lapack/getrs.h +15 -0
- data/ext/numo/linalg/lapack/heev.c +71 -0
- data/ext/numo/linalg/lapack/heev.h +15 -0
- data/ext/numo/linalg/lapack/heevd.c +71 -0
- data/ext/numo/linalg/lapack/heevd.h +15 -0
- data/ext/numo/linalg/lapack/heevr.c +111 -0
- data/ext/numo/linalg/lapack/heevr.h +15 -0
- data/ext/numo/linalg/lapack/hegv.c +94 -0
- data/ext/numo/linalg/lapack/hegv.h +15 -0
- data/ext/numo/linalg/lapack/hegvd.c +94 -0
- data/ext/numo/linalg/lapack/hegvd.h +15 -0
- data/ext/numo/linalg/lapack/hegvx.c +133 -0
- data/ext/numo/linalg/lapack/hegvx.h +15 -0
- data/ext/numo/linalg/lapack/hetrf.c +68 -0
- data/ext/numo/linalg/lapack/hetrf.h +15 -0
- data/ext/numo/linalg/lapack/lange.c +66 -0
- data/ext/numo/linalg/lapack/lange.h +15 -0
- data/ext/numo/linalg/lapack/orgqr.c +79 -0
- data/ext/numo/linalg/lapack/orgqr.h +15 -0
- data/ext/numo/linalg/lapack/potrf.c +70 -0
- data/ext/numo/linalg/lapack/potrf.h +15 -0
- data/ext/numo/linalg/lapack/potri.c +70 -0
- data/ext/numo/linalg/lapack/potri.h +15 -0
- data/ext/numo/linalg/lapack/potrs.c +94 -0
- data/ext/numo/linalg/lapack/potrs.h +15 -0
- data/ext/numo/linalg/lapack/syev.c +71 -0
- data/ext/numo/linalg/lapack/syev.h +15 -0
- data/ext/numo/linalg/lapack/syevd.c +71 -0
- data/ext/numo/linalg/lapack/syevd.h +15 -0
- data/ext/numo/linalg/lapack/syevr.c +111 -0
- data/ext/numo/linalg/lapack/syevr.h +15 -0
- data/ext/numo/linalg/lapack/sygv.c +93 -0
- data/ext/numo/linalg/lapack/sygv.h +15 -0
- data/ext/numo/linalg/lapack/sygvd.c +93 -0
- data/ext/numo/linalg/lapack/sygvd.h +15 -0
- data/ext/numo/linalg/lapack/sygvx.c +133 -0
- data/ext/numo/linalg/lapack/sygvx.h +15 -0
- data/ext/numo/linalg/lapack/sytrf.c +72 -0
- data/ext/numo/linalg/lapack/sytrf.h +15 -0
- data/ext/numo/linalg/lapack/trtrs.c +99 -0
- data/ext/numo/linalg/lapack/trtrs.h +15 -0
- data/ext/numo/linalg/lapack/ungqr.c +79 -0
- data/ext/numo/linalg/lapack/ungqr.h +15 -0
- data/ext/numo/linalg/linalg.c +290 -0
- data/ext/numo/linalg/linalg.h +85 -0
- data/ext/numo/linalg/util.c +95 -0
- data/ext/numo/linalg/util.h +17 -0
- data/lib/numo/linalg/version.rb +10 -0
- data/lib/numo/linalg.rb +1309 -0
- data/vendor/tmp/.gitkeep +0 -0
- metadata +146 -0
@@ -0,0 +1,161 @@
|
|
1
|
+
#include "gemv.h"
|
2
|
+
|
3
|
+
#define DEF_LINALG_OPTIONS(tDType) \
|
4
|
+
struct _gemv_options_##tDType { \
|
5
|
+
tDType alpha; \
|
6
|
+
tDType beta; \
|
7
|
+
enum CBLAS_ORDER order; \
|
8
|
+
enum CBLAS_TRANSPOSE trans; \
|
9
|
+
blasint m; \
|
10
|
+
blasint n; \
|
11
|
+
};
|
12
|
+
|
13
|
+
#define DEF_LINALG_ITER_FUNC(tDType, fBlasFunc) \
|
14
|
+
static void _iter_##fBlasFunc(na_loop_t* const lp) { \
|
15
|
+
const tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
16
|
+
const tDType* x = (tDType*)NDL_PTR(lp, 1); \
|
17
|
+
tDType* y = (tDType*)NDL_PTR(lp, 2); \
|
18
|
+
const struct _gemv_options_##tDType* opt = (struct _gemv_options_##tDType*)(lp->opt_ptr); \
|
19
|
+
const blasint lda = opt->n; \
|
20
|
+
cblas_##fBlasFunc(opt->order, opt->trans, opt->m, opt->n, \
|
21
|
+
opt->alpha, a, lda, x, 1, opt->beta, y, 1); \
|
22
|
+
}
|
23
|
+
|
24
|
+
#define DEF_LINALG_ITER_FUNC_COMPLEX(tDType, fBlasFunc) \
|
25
|
+
static void _iter_##fBlasFunc(na_loop_t* const lp) { \
|
26
|
+
const tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
27
|
+
const tDType* x = (tDType*)NDL_PTR(lp, 1); \
|
28
|
+
tDType* y = (tDType*)NDL_PTR(lp, 2); \
|
29
|
+
const struct _gemv_options_##tDType* opt = (struct _gemv_options_##tDType*)(lp->opt_ptr); \
|
30
|
+
const blasint lda = opt->n; \
|
31
|
+
cblas_##fBlasFunc(opt->order, opt->trans, opt->m, opt->n, \
|
32
|
+
&opt->alpha, a, lda, x, 1, &opt->beta, y, 1); \
|
33
|
+
}
|
34
|
+
|
35
|
+
#define DEF_LINALG_FUNC(tDType, tNAryClass, fBlasFunc) \
|
36
|
+
static VALUE _linalg_blas_##fBlasFunc(int argc, VALUE* argv, VALUE self) { \
|
37
|
+
VALUE a = Qnil; \
|
38
|
+
VALUE x = Qnil; \
|
39
|
+
VALUE y = Qnil; \
|
40
|
+
VALUE kw_args = Qnil; \
|
41
|
+
rb_scan_args(argc, argv, "21:", &a, &x, &y, &kw_args); \
|
42
|
+
\
|
43
|
+
ID kw_table[4] = { rb_intern("alpha"), rb_intern("beta"), \
|
44
|
+
rb_intern("order"), rb_intern("trans") }; \
|
45
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef }; \
|
46
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values); \
|
47
|
+
\
|
48
|
+
if (CLASS_OF(a) != tNAryClass) { \
|
49
|
+
a = rb_funcall(tNAryClass, rb_intern("cast"), 1, a); \
|
50
|
+
} \
|
51
|
+
if (!RTEST(nary_check_contiguous(a))) { \
|
52
|
+
a = nary_dup(a); \
|
53
|
+
} \
|
54
|
+
if (CLASS_OF(x) != tNAryClass) { \
|
55
|
+
x = rb_funcall(tNAryClass, rb_intern("cast"), 1, x); \
|
56
|
+
} \
|
57
|
+
if (!RTEST(nary_check_contiguous(x))) { \
|
58
|
+
x = nary_dup(x); \
|
59
|
+
} \
|
60
|
+
if (!NIL_P(y)) { \
|
61
|
+
if (CLASS_OF(y) != tNAryClass) { \
|
62
|
+
y = rb_funcall(tNAryClass, rb_intern("cast"), 1, y); \
|
63
|
+
} \
|
64
|
+
if (!RTEST(nary_check_contiguous(y))) { \
|
65
|
+
y = nary_dup(y); \
|
66
|
+
} \
|
67
|
+
} \
|
68
|
+
\
|
69
|
+
tDType alpha = kw_values[0] != Qundef ? conv_##tDType(kw_values[0]) : one_##tDType(); \
|
70
|
+
tDType beta = kw_values[1] != Qundef ? conv_##tDType(kw_values[1]) : zero_##tDType(); \
|
71
|
+
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor; \
|
72
|
+
enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans; \
|
73
|
+
\
|
74
|
+
narray_t* a_nary = NULL; \
|
75
|
+
GetNArray(a, a_nary); \
|
76
|
+
narray_t* x_nary = NULL; \
|
77
|
+
GetNArray(x, x_nary); \
|
78
|
+
\
|
79
|
+
if (NA_NDIM(a_nary) != 2) { \
|
80
|
+
rb_raise(rb_eArgError, "a must be 2-dimensional"); \
|
81
|
+
return Qnil; \
|
82
|
+
} \
|
83
|
+
if (NA_NDIM(x_nary) != 1) { \
|
84
|
+
rb_raise(rb_eArgError, "x must be 1-dimensional"); \
|
85
|
+
return Qnil; \
|
86
|
+
} \
|
87
|
+
if (NA_SIZE(a_nary) == 0) { \
|
88
|
+
rb_raise(rb_eArgError, "a must not be empty"); \
|
89
|
+
return Qnil; \
|
90
|
+
} \
|
91
|
+
if (NA_SIZE(x_nary) == 0) { \
|
92
|
+
rb_raise(rb_eArgError, "x must not be empty"); \
|
93
|
+
return Qnil; \
|
94
|
+
} \
|
95
|
+
\
|
96
|
+
const blasint ma = (blasint)NA_SHAPE(a_nary)[0]; \
|
97
|
+
const blasint na = (blasint)NA_SHAPE(a_nary)[1]; \
|
98
|
+
const blasint mx = (blasint)NA_SHAPE(x_nary)[0]; \
|
99
|
+
const blasint m = trans == CblasNoTrans ? ma : na; \
|
100
|
+
const blasint n = trans == CblasNoTrans ? na : ma; \
|
101
|
+
\
|
102
|
+
if (n != mx) { \
|
103
|
+
rb_raise(nary_eShapeError, "shape1[1](=%d) != shape2[0](=%d)", n, mx); \
|
104
|
+
return Qnil; \
|
105
|
+
} \
|
106
|
+
\
|
107
|
+
struct _gemv_options_##tDType opt = { alpha, beta, order, trans, ma, na }; \
|
108
|
+
size_t shape_out[1] = { (size_t)(m) }; \
|
109
|
+
ndfunc_arg_out_t aout[1] = { { tNAryClass, 1, shape_out } }; \
|
110
|
+
VALUE ret = Qnil; \
|
111
|
+
\
|
112
|
+
if (!NIL_P(y)) { \
|
113
|
+
narray_t* y_nary = NULL; \
|
114
|
+
GetNArray(y, y_nary); \
|
115
|
+
blasint my = (blasint)NA_SHAPE(y_nary)[0]; \
|
116
|
+
if (m > my) { \
|
117
|
+
rb_raise(nary_eShapeError, "shape3[0](=%d) >= shape1[0]=%d", my, m); \
|
118
|
+
return Qnil; \
|
119
|
+
} \
|
120
|
+
ndfunc_arg_in_t ain[3] = { { tNAryClass, 2 }, { tNAryClass, 1 }, { OVERWRITE, 1 } }; \
|
121
|
+
ndfunc_t ndf = { _iter_##fBlasFunc, NO_LOOP, 3, 0, ain, aout }; \
|
122
|
+
na_ndloop3(&ndf, &opt, 3, a, x, y); \
|
123
|
+
ret = y; \
|
124
|
+
} else { \
|
125
|
+
y = INT2NUM(0); \
|
126
|
+
ndfunc_arg_in_t ain[3] = { { tNAryClass, 2 }, { tNAryClass, 1 }, { sym_init, 0 } }; \
|
127
|
+
ndfunc_t ndf = { _iter_##fBlasFunc, NO_LOOP, 3, 1, ain, aout }; \
|
128
|
+
ret = na_ndloop3(&ndf, &opt, 3, a, x, y); \
|
129
|
+
} \
|
130
|
+
\
|
131
|
+
RB_GC_GUARD(a); \
|
132
|
+
RB_GC_GUARD(x); \
|
133
|
+
RB_GC_GUARD(y); \
|
134
|
+
\
|
135
|
+
return ret; \
|
136
|
+
}
|
137
|
+
|
138
|
+
DEF_LINALG_OPTIONS(double)
|
139
|
+
DEF_LINALG_OPTIONS(float)
|
140
|
+
DEF_LINALG_OPTIONS(dcomplex)
|
141
|
+
DEF_LINALG_OPTIONS(scomplex)
|
142
|
+
DEF_LINALG_ITER_FUNC(double, dgemv)
|
143
|
+
DEF_LINALG_ITER_FUNC(float, sgemv)
|
144
|
+
DEF_LINALG_ITER_FUNC_COMPLEX(dcomplex, zgemv)
|
145
|
+
DEF_LINALG_ITER_FUNC_COMPLEX(scomplex, cgemv)
|
146
|
+
DEF_LINALG_FUNC(double, numo_cDFloat, dgemv)
|
147
|
+
DEF_LINALG_FUNC(float, numo_cSFloat, sgemv)
|
148
|
+
DEF_LINALG_FUNC(dcomplex, numo_cDComplex, zgemv)
|
149
|
+
DEF_LINALG_FUNC(scomplex, numo_cSComplex, cgemv)
|
150
|
+
|
151
|
+
#undef DEF_LINALG_OPTIONS
|
152
|
+
#undef DEF_LINALG_ITER_FUNC
|
153
|
+
#undef DEF_LINALG_ITER_FUNC_COMPLEX
|
154
|
+
#undef DEF_LINALG_FUNC
|
155
|
+
|
156
|
+
void define_linalg_blas_gemv(VALUE mBlas) {
|
157
|
+
rb_define_module_function(mBlas, "dgemv", RUBY_METHOD_FUNC(_linalg_blas_dgemv), -1);
|
158
|
+
rb_define_module_function(mBlas, "sgemv", RUBY_METHOD_FUNC(_linalg_blas_sgemv), -1);
|
159
|
+
rb_define_module_function(mBlas, "zgemv", RUBY_METHOD_FUNC(_linalg_blas_zgemv), -1);
|
160
|
+
rb_define_module_function(mBlas, "cgemv", RUBY_METHOD_FUNC(_linalg_blas_cgemv), -1);
|
161
|
+
}
|
@@ -0,0 +1,16 @@
|
|
1
|
+
#ifndef NUMO_LINALG_ALT_BLAS_GEMV_H
|
2
|
+
#define NUMO_LINALG_ALT_BLAS_GEMV_H 1
|
3
|
+
|
4
|
+
#include <ruby.h>
|
5
|
+
|
6
|
+
#include <cblas.h>
|
7
|
+
|
8
|
+
#include <numo/narray.h>
|
9
|
+
#include <numo/template.h>
|
10
|
+
|
11
|
+
#include "../converter.h"
|
12
|
+
#include "../util.h"
|
13
|
+
|
14
|
+
void define_linalg_blas_gemv(VALUE mBlas);
|
15
|
+
|
16
|
+
#endif /* NUMO_LINALG_ALT_BLAS_GEMV_H */
|
@@ -0,0 +1,67 @@
|
|
1
|
+
#include "nrm2.h"
|
2
|
+
|
3
|
+
#define DEF_LINALG_FUNC(tDType, tRtDType, tNAryClass, tRtNAryClass, fBlasFunc) \
|
4
|
+
static void _iter_##fBlasFunc(na_loop_t* const lp) { \
|
5
|
+
tDType* x = (tDType*)NDL_PTR(lp, 0); \
|
6
|
+
tRtDType* d = (tRtDType*)NDL_PTR(lp, 1); \
|
7
|
+
const blasint n = (blasint)NDL_SHAPE(lp, 0)[0]; \
|
8
|
+
tRtDType ret = cblas_##fBlasFunc(n, x, 1); \
|
9
|
+
*d = ret; \
|
10
|
+
} \
|
11
|
+
\
|
12
|
+
static VALUE _linalg_blas_##fBlasFunc(int argc, VALUE* argv, VALUE self) { \
|
13
|
+
VALUE x = Qnil; \
|
14
|
+
VALUE kw_args = Qnil; \
|
15
|
+
rb_scan_args(argc, argv, "1:", &x, &kw_args); \
|
16
|
+
\
|
17
|
+
ID kw_table[1] = { rb_intern("keepdims") }; \
|
18
|
+
VALUE kw_values[1] = { Qundef }; \
|
19
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values); \
|
20
|
+
const bool keepdims = kw_values[0] != Qundef ? RTEST(kw_values[0]) : false; \
|
21
|
+
\
|
22
|
+
if (CLASS_OF(x) != tNAryClass) { \
|
23
|
+
x = rb_funcall(tNAryClass, rb_intern("cast"), 1, x); \
|
24
|
+
} \
|
25
|
+
if (!RTEST(nary_check_contiguous(x))) { \
|
26
|
+
x = nary_dup(x); \
|
27
|
+
} \
|
28
|
+
\
|
29
|
+
narray_t* x_nary = NULL; \
|
30
|
+
GetNArray(x, x_nary); \
|
31
|
+
\
|
32
|
+
if (NA_NDIM(x_nary) != 1) { \
|
33
|
+
rb_raise(rb_eArgError, "x must be 1-dimensional"); \
|
34
|
+
return Qnil; \
|
35
|
+
} \
|
36
|
+
if (NA_SIZE(x_nary) == 0) { \
|
37
|
+
rb_raise(rb_eArgError, "x must not be empty"); \
|
38
|
+
return Qnil; \
|
39
|
+
} \
|
40
|
+
\
|
41
|
+
ndfunc_arg_in_t ain[1] = { { tNAryClass, 1 } }; \
|
42
|
+
size_t shape_out[1] = { 1 }; \
|
43
|
+
ndfunc_arg_out_t aout[1] = { { tRtNAryClass, 0, shape_out } }; \
|
44
|
+
ndfunc_t ndf = { _iter_##fBlasFunc, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout }; \
|
45
|
+
if (keepdims) { \
|
46
|
+
ndf.flag |= NDF_KEEP_DIM; \
|
47
|
+
} \
|
48
|
+
\
|
49
|
+
VALUE ret = na_ndloop(&ndf, 1, x); \
|
50
|
+
\
|
51
|
+
RB_GC_GUARD(x); \
|
52
|
+
return ret; \
|
53
|
+
}
|
54
|
+
|
55
|
+
DEF_LINALG_FUNC(double, double, numo_cDFloat, numo_cDFloat, dnrm2)
|
56
|
+
DEF_LINALG_FUNC(float, float, numo_cSFloat, numo_cSFloat, snrm2)
|
57
|
+
DEF_LINALG_FUNC(dcomplex, double, numo_cDComplex, numo_cDFloat, dznrm2)
|
58
|
+
DEF_LINALG_FUNC(scomplex, float, numo_cSComplex, numo_cSFloat, scnrm2)
|
59
|
+
|
60
|
+
#undef DEF_LINALG_FUNC
|
61
|
+
|
62
|
+
void define_linalg_blas_nrm2(VALUE mBlas) {
|
63
|
+
rb_define_module_function(mBlas, "dnrm2", RUBY_METHOD_FUNC(_linalg_blas_dnrm2), -1);
|
64
|
+
rb_define_module_function(mBlas, "snrm2", RUBY_METHOD_FUNC(_linalg_blas_snrm2), -1);
|
65
|
+
rb_define_module_function(mBlas, "dznrm2", RUBY_METHOD_FUNC(_linalg_blas_dznrm2), -1);
|
66
|
+
rb_define_module_function(mBlas, "scnrm2", RUBY_METHOD_FUNC(_linalg_blas_scnrm2), -1);
|
67
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#ifndef NUMO_LINALG_ALT_BLAS_NRM2_H
|
2
|
+
#define NUMO_LINALG_ALT_BLAS_NRM2_H 1
|
3
|
+
|
4
|
+
#include <ruby.h>
|
5
|
+
|
6
|
+
#include <cblas.h>
|
7
|
+
|
8
|
+
#include <numo/narray.h>
|
9
|
+
#include <numo/template.h>
|
10
|
+
|
11
|
+
void define_linalg_blas_nrm2(VALUE mBlas);
|
12
|
+
|
13
|
+
#endif /* NUMO_LINALG_ALT_BLAS_NRM2_H */
|
@@ -0,0 +1,67 @@
|
|
1
|
+
#include "converter.h"
|
2
|
+
|
3
|
+
double conv_double(VALUE val) {
|
4
|
+
return NUM2DBL(val);
|
5
|
+
}
|
6
|
+
|
7
|
+
double one_double(void) {
|
8
|
+
return 1.0;
|
9
|
+
}
|
10
|
+
|
11
|
+
double zero_double(void) {
|
12
|
+
return 0.0;
|
13
|
+
}
|
14
|
+
|
15
|
+
float conv_float(VALUE val) {
|
16
|
+
return (float)NUM2DBL(val);
|
17
|
+
}
|
18
|
+
|
19
|
+
float one_float(void) {
|
20
|
+
return 1.0f;
|
21
|
+
}
|
22
|
+
|
23
|
+
float zero_float(void) {
|
24
|
+
return 0.0f;
|
25
|
+
}
|
26
|
+
|
27
|
+
dcomplex conv_dcomplex(VALUE val) {
|
28
|
+
dcomplex z;
|
29
|
+
REAL(z) = NUM2DBL(rb_funcall(val, rb_intern("real"), 0));
|
30
|
+
IMAG(z) = NUM2DBL(rb_funcall(val, rb_intern("imag"), 0));
|
31
|
+
return z;
|
32
|
+
}
|
33
|
+
|
34
|
+
dcomplex one_dcomplex(void) {
|
35
|
+
dcomplex z;
|
36
|
+
REAL(z) = 1.0;
|
37
|
+
IMAG(z) = 0.0;
|
38
|
+
return z;
|
39
|
+
}
|
40
|
+
|
41
|
+
dcomplex zero_dcomplex(void) {
|
42
|
+
dcomplex z;
|
43
|
+
REAL(z) = 0.0;
|
44
|
+
IMAG(z) = 0.0;
|
45
|
+
return z;
|
46
|
+
}
|
47
|
+
|
48
|
+
scomplex conv_scomplex(VALUE val) {
|
49
|
+
scomplex z;
|
50
|
+
REAL(z) = (float)NUM2DBL(rb_funcall(val, rb_intern("real"), 0));
|
51
|
+
IMAG(z) = (float)NUM2DBL(rb_funcall(val, rb_intern("imag"), 0));
|
52
|
+
return z;
|
53
|
+
}
|
54
|
+
|
55
|
+
scomplex one_scomplex(void) {
|
56
|
+
scomplex z;
|
57
|
+
REAL(z) = 1.0f;
|
58
|
+
IMAG(z) = 0.0f;
|
59
|
+
return z;
|
60
|
+
}
|
61
|
+
|
62
|
+
scomplex zero_scomplex(void) {
|
63
|
+
scomplex z;
|
64
|
+
REAL(z) = 0.0f;
|
65
|
+
IMAG(z) = 0.0f;
|
66
|
+
return z;
|
67
|
+
}
|
@@ -0,0 +1,23 @@
|
|
1
|
+
#ifndef NUMO_LINALG_ALT_CONVERTER_H
|
2
|
+
#define NUMO_LINALG_ALT_CONVERTER_H 1
|
3
|
+
|
4
|
+
#include <ruby.h>
|
5
|
+
|
6
|
+
#include <cblas.h>
|
7
|
+
|
8
|
+
#include <numo/narray.h>
|
9
|
+
|
10
|
+
double conv_double(VALUE val);
|
11
|
+
double one_double(void);
|
12
|
+
double zero_double(void);
|
13
|
+
float conv_float(VALUE val);
|
14
|
+
float one_float(void);
|
15
|
+
float zero_float(void);
|
16
|
+
dcomplex conv_dcomplex(VALUE val);
|
17
|
+
dcomplex one_dcomplex(void);
|
18
|
+
dcomplex zero_dcomplex(void);
|
19
|
+
scomplex conv_scomplex(VALUE val);
|
20
|
+
scomplex one_scomplex(void);
|
21
|
+
scomplex zero_scomplex(void);
|
22
|
+
|
23
|
+
#endif /* NUMO_LINALG_ALT_CONVERTER_H */
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'mkmf'
|
4
|
+
require 'numo/narray'
|
5
|
+
require 'open-uri'
|
6
|
+
require 'etc'
|
7
|
+
require 'fileutils'
|
8
|
+
require 'open3'
|
9
|
+
require 'digest/md5'
|
10
|
+
require 'rubygems/package'
|
11
|
+
|
12
|
+
$LOAD_PATH.each do |lp|
|
13
|
+
if File.exist?(File.join(lp, 'numo/numo/narray.h'))
|
14
|
+
$INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
|
15
|
+
break
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
abort 'numo/narray.h is not found' unless have_header('numo/narray.h')
|
20
|
+
|
21
|
+
on_windows = RUBY_PLATFORM.match?(/mswin|cygwin|mingw/)
|
22
|
+
|
23
|
+
if on_windows
|
24
|
+
$LOAD_PATH.each do |lp|
|
25
|
+
if File.exist?(File.join(lp, 'numo/libnarray.a'))
|
26
|
+
$LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
|
27
|
+
break
|
28
|
+
end
|
29
|
+
end
|
30
|
+
abort 'libnarray.a is not found' unless have_library('narray', 'nary_new')
|
31
|
+
end
|
32
|
+
|
33
|
+
build_openblas = false
|
34
|
+
unless find_library('openblas', 'LAPACKE_dsyevr')
|
35
|
+
build_openblas = true unless have_library('openblas')
|
36
|
+
build_openblas = true unless have_library('lapacke')
|
37
|
+
end
|
38
|
+
build_openblas = true unless have_header('cblas.h')
|
39
|
+
build_openblas = true unless have_header('lapacke.h')
|
40
|
+
build_openblas = true unless have_header('openblas_config.h')
|
41
|
+
|
42
|
+
if build_openblas
|
43
|
+
warn 'BLAS and LAPACKE APIs are not found. Downloading and Building OpenBLAS...'
|
44
|
+
|
45
|
+
VENDOR_DIR = File.expand_path("#{__dir__}/../../../vendor")
|
46
|
+
LINALG_DIR = File.expand_path("#{__dir__}/../../../lib/numo/linalg")
|
47
|
+
OPENBLAS_VER = '0.3.30'
|
48
|
+
OPENBLAS_KEY = '8db3d57f4d4485c6ae3f21ea465660e7'
|
49
|
+
OPENBLAS_URI = "https://github.com/OpenMathLib/OpenBLAS/archive/v#{OPENBLAS_VER}.tar.gz"
|
50
|
+
OPENBLAS_TGZ = "#{VENDOR_DIR}/tmp/openblas.tgz"
|
51
|
+
|
52
|
+
unless File.exist?("#{VENDOR_DIR}/installed_#{OPENBLAS_VER}")
|
53
|
+
URI.parse(OPENBLAS_URI).open { |f| File.binwrite(OPENBLAS_TGZ, f.read) }
|
54
|
+
abort('MD5 digest of downloaded OpenBLAS does not match.') if Digest::MD5.file(OPENBLAS_TGZ).to_s != OPENBLAS_KEY
|
55
|
+
|
56
|
+
Gem::Package::TarReader.new(Zlib::GzipReader.open(OPENBLAS_TGZ)) do |tar|
|
57
|
+
tar.each do |entry|
|
58
|
+
next unless entry.file?
|
59
|
+
|
60
|
+
filename = "#{VENDOR_DIR}/tmp/#{entry.full_name}"
|
61
|
+
next if filename == File.dirname(filename)
|
62
|
+
|
63
|
+
FileUtils.mkdir_p("#{VENDOR_DIR}/tmp/#{File.dirname(entry.full_name)}")
|
64
|
+
File.binwrite(filename, entry.read)
|
65
|
+
File.chmod(entry.header.mode, filename)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
Dir.chdir("#{VENDOR_DIR}/tmp/OpenBLAS-#{OPENBLAS_VER}") do
|
70
|
+
mkstdout, _mkstderr, mkstatus = Open3.capture3("make -j#{Etc.nprocessors} shared")
|
71
|
+
File.open("#{VENDOR_DIR}/tmp/openblas.log", 'w') { |f| f.puts(mkstdout) }
|
72
|
+
abort("Failed to build OpenBLAS. Check the openblas.log file for more details: #{VENDOR_DIR}/tmp/openblas.log") unless mkstatus.success?
|
73
|
+
|
74
|
+
insstdout, _insstderr, insstatus = Open3.capture3("make install PREFIX=#{VENDOR_DIR}")
|
75
|
+
File.open("#{VENDOR_DIR}/tmp/openblas.log", 'a') { |f| f.puts(insstdout) }
|
76
|
+
abort("Failed to install OpenBLAS. Check the openblas.log file for more details: #{VENDOR_DIR}/tmp/openblas.log") unless insstatus.success?
|
77
|
+
|
78
|
+
FileUtils.cp("#{VENDOR_DIR}/lib/libopenblas.a", LINALG_DIR) if on_windows
|
79
|
+
|
80
|
+
FileUtils.touch("#{VENDOR_DIR}/installed_#{OPENBLAS_VER}")
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
84
|
+
abort('libopenblas is not found.') unless find_library('openblas', nil, "#{VENDOR_DIR}/lib")
|
85
|
+
abort('openblas_config.h is not found.') unless find_header('openblas_config.h', nil, "#{VENDOR_DIR}/include")
|
86
|
+
abort('cblas.h is not found.') unless find_header('cblas.h', nil, "#{VENDOR_DIR}/include")
|
87
|
+
abort('lapacke.h is not found.') unless find_header('lapacke.h', nil, "#{VENDOR_DIR}/include")
|
88
|
+
end
|
89
|
+
|
90
|
+
if RUBY_PLATFORM.include?('darwin') && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION) &&
|
91
|
+
try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
|
92
|
+
$LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
|
93
|
+
end
|
94
|
+
|
95
|
+
$srcs = Dir.glob("#{$srcdir}/**/*.c").map { |path| File.basename(path) }
|
96
|
+
$VPATH << '$(srcdir)/blas'
|
97
|
+
$VPATH << '$(srcdir)/lapack'
|
98
|
+
|
99
|
+
create_makefile('numo/linalg/linalg')
|
@@ -0,0 +1,152 @@
|
|
1
|
+
#include "geev.h"
|
2
|
+
|
3
|
+
struct _geev_option {
|
4
|
+
int matrix_layout;
|
5
|
+
char jobvl;
|
6
|
+
char jobvr;
|
7
|
+
};
|
8
|
+
|
9
|
+
char _get_jobvl(VALUE val) {
|
10
|
+
const char jobvl = NUM2CHR(val);
|
11
|
+
if (jobvl != 'N' && jobvl != 'V') {
|
12
|
+
rb_raise(rb_eArgError, "jobvl must be 'N' or 'V'");
|
13
|
+
}
|
14
|
+
return jobvl;
|
15
|
+
}
|
16
|
+
|
17
|
+
char _get_jobvr(VALUE val) {
|
18
|
+
const char jobvr = NUM2CHR(val);
|
19
|
+
if (jobvr != 'N' && jobvr != 'V') {
|
20
|
+
rb_raise(rb_eArgError, "jobvr must be 'N' or 'V'");
|
21
|
+
}
|
22
|
+
return jobvr;
|
23
|
+
}
|
24
|
+
|
25
|
+
#define DEF_LINALG_FUNC(tDType, tNAryClass, fLapackFunc) \
|
26
|
+
static void _iter_##fLapackFunc(na_loop_t* const lp) { \
|
27
|
+
tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
28
|
+
tDType* wr = (tDType*)NDL_PTR(lp, 1); \
|
29
|
+
tDType* wi = (tDType*)NDL_PTR(lp, 2); \
|
30
|
+
tDType* vl = (tDType*)NDL_PTR(lp, 3); \
|
31
|
+
tDType* vr = (tDType*)NDL_PTR(lp, 4); \
|
32
|
+
int* info = (int*)NDL_PTR(lp, 5); \
|
33
|
+
struct _geev_option* opt = (struct _geev_option*)(lp->opt_ptr); \
|
34
|
+
const lapack_int n = (lapack_int)(opt->matrix_layout == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[0] : NDL_SHAPE(lp, 0)[1]); \
|
35
|
+
const lapack_int lda = n; \
|
36
|
+
const lapack_int ldvl = (opt->jobvl == 'N') ? 1 : n; \
|
37
|
+
const lapack_int ldvr = (opt->jobvr == 'N') ? 1 : n; \
|
38
|
+
lapack_int i = LAPACKE_##fLapackFunc(opt->matrix_layout, opt->jobvl, opt->jobvr, n, a, lda, wr, wi, vl, ldvl, vr, ldvr); \
|
39
|
+
*info = (int)i; \
|
40
|
+
} \
|
41
|
+
\
|
42
|
+
static VALUE _linalg_lapack_##fLapackFunc(int argc, VALUE* argv, VALUE self) { \
|
43
|
+
VALUE a_vnary = Qnil; \
|
44
|
+
VALUE kw_args = Qnil; \
|
45
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args); \
|
46
|
+
ID kw_table[3] = { rb_intern("order"), rb_intern("jobvl"), rb_intern("jobvr") }; \
|
47
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef }; \
|
48
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values); \
|
49
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; \
|
50
|
+
const char jobvl = kw_values[1] != Qundef ? _get_jobvl(kw_values[1]) : 'V'; \
|
51
|
+
const char jobvr = kw_values[2] != Qundef ? _get_jobvr(kw_values[2]) : 'V'; \
|
52
|
+
\
|
53
|
+
if (CLASS_OF(a_vnary) != tNAryClass) { \
|
54
|
+
a_vnary = rb_funcall(tNAryClass, rb_intern("cast"), 1, a_vnary); \
|
55
|
+
} \
|
56
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) { \
|
57
|
+
a_vnary = nary_dup(a_vnary); \
|
58
|
+
} \
|
59
|
+
\
|
60
|
+
narray_t* a_nary = NULL; \
|
61
|
+
GetNArray(a_vnary, a_nary); \
|
62
|
+
const int n_dims = NA_NDIM(a_nary); \
|
63
|
+
if (n_dims != 2) { \
|
64
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional"); \
|
65
|
+
return Qnil; \
|
66
|
+
} \
|
67
|
+
\
|
68
|
+
size_t n = matrix_layout == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[0] : NA_SHAPE(a_nary)[1]; \
|
69
|
+
size_t shape_wr[1] = { n }; \
|
70
|
+
size_t shape_wi[1] = { n }; \
|
71
|
+
size_t shape_vl[2] = { n, (jobvl == 'N') ? 1 : n }; \
|
72
|
+
size_t shape_vr[2] = { n, (jobvr == 'N') ? 1 : n }; \
|
73
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } }; \
|
74
|
+
ndfunc_arg_out_t aout[5] = { { tNAryClass, 1, shape_wr }, { tNAryClass, 1, shape_wi }, { tNAryClass, 2, shape_vl }, { tNAryClass, 2, shape_vr }, { numo_cInt32, 0 } }; \
|
75
|
+
ndfunc_t ndf = { _iter_##fLapackFunc, NO_LOOP | NDF_EXTRACT, 1, 5, ain, aout }; \
|
76
|
+
struct _geev_option opt = { matrix_layout, jobvl, jobvr }; \
|
77
|
+
VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary); \
|
78
|
+
\
|
79
|
+
RB_GC_GUARD(a_vnary); \
|
80
|
+
return ret; \
|
81
|
+
}
|
82
|
+
|
83
|
+
#define DEF_LINALG_FUNC_COMPLEX(tDType, tNAryClass, fLapackFunc) \
|
84
|
+
static void _iter_##fLapackFunc(na_loop_t* const lp) { \
|
85
|
+
tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
86
|
+
tDType* w = (tDType*)NDL_PTR(lp, 1); \
|
87
|
+
tDType* vl = (tDType*)NDL_PTR(lp, 2); \
|
88
|
+
tDType* vr = (tDType*)NDL_PTR(lp, 3); \
|
89
|
+
int* info = (int*)NDL_PTR(lp, 4); \
|
90
|
+
struct _geev_option* opt = (struct _geev_option*)(lp->opt_ptr); \
|
91
|
+
const lapack_int n = (lapack_int)(opt->matrix_layout == LAPACK_ROW_MAJOR ? NDL_SHAPE(lp, 0)[0] : NDL_SHAPE(lp, 0)[1]); \
|
92
|
+
const lapack_int lda = n; \
|
93
|
+
const lapack_int ldvl = (opt->jobvl == 'N') ? 1 : n; \
|
94
|
+
const lapack_int ldvr = (opt->jobvr == 'N') ? 1 : n; \
|
95
|
+
lapack_int i = LAPACKE_##fLapackFunc(opt->matrix_layout, opt->jobvl, opt->jobvr, n, a, lda, w, vl, ldvl, vr, ldvr); \
|
96
|
+
*info = (int)i; \
|
97
|
+
} \
|
98
|
+
\
|
99
|
+
static VALUE _linalg_lapack_##fLapackFunc(int argc, VALUE* argv, VALUE self) { \
|
100
|
+
VALUE a_vnary = Qnil; \
|
101
|
+
VALUE kw_args = Qnil; \
|
102
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args); \
|
103
|
+
ID kw_table[3] = { rb_intern("order"), rb_intern("jobvl"), rb_intern("jobvr") }; \
|
104
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef }; \
|
105
|
+
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values); \
|
106
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; \
|
107
|
+
const char jobvl = kw_values[1] != Qundef ? _get_jobvl(kw_values[1]) : 'V'; \
|
108
|
+
const char jobvr = kw_values[2] != Qundef ? _get_jobvr(kw_values[2]) : 'V'; \
|
109
|
+
\
|
110
|
+
if (CLASS_OF(a_vnary) != tNAryClass) { \
|
111
|
+
a_vnary = rb_funcall(tNAryClass, rb_intern("cast"), 1, a_vnary); \
|
112
|
+
} \
|
113
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) { \
|
114
|
+
a_vnary = nary_dup(a_vnary); \
|
115
|
+
} \
|
116
|
+
\
|
117
|
+
narray_t* a_nary = NULL; \
|
118
|
+
GetNArray(a_vnary, a_nary); \
|
119
|
+
const int n_dims = NA_NDIM(a_nary); \
|
120
|
+
if (n_dims != 2) { \
|
121
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional"); \
|
122
|
+
return Qnil; \
|
123
|
+
} \
|
124
|
+
\
|
125
|
+
size_t n = matrix_layout == LAPACK_ROW_MAJOR ? NA_SHAPE(a_nary)[0] : NA_SHAPE(a_nary)[1]; \
|
126
|
+
size_t shape_w[1] = { n }; \
|
127
|
+
size_t shape_vl[2] = { n, (jobvl == 'N') ? 1 : n }; \
|
128
|
+
size_t shape_vr[2] = { n, (jobvr == 'N') ? 1 : n }; \
|
129
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } }; \
|
130
|
+
ndfunc_arg_out_t aout[4] = { { tNAryClass, 1, shape_w }, { tNAryClass, 2, shape_vl }, { tNAryClass, 2, shape_vr }, { numo_cInt32, 0 } }; \
|
131
|
+
ndfunc_t ndf = { _iter_##fLapackFunc, NO_LOOP | NDF_EXTRACT, 1, 4, ain, aout }; \
|
132
|
+
struct _geev_option opt = { matrix_layout, jobvl, jobvr }; \
|
133
|
+
VALUE ret = na_ndloop3(&ndf, &opt, 1, a_vnary); \
|
134
|
+
\
|
135
|
+
RB_GC_GUARD(a_vnary); \
|
136
|
+
return ret; \
|
137
|
+
}
|
138
|
+
|
139
|
+
DEF_LINALG_FUNC(double, numo_cDFloat, dgeev)
|
140
|
+
DEF_LINALG_FUNC(float, numo_cSFloat, sgeev)
|
141
|
+
DEF_LINALG_FUNC_COMPLEX(lapack_complex_double, numo_cDComplex, zgeev)
|
142
|
+
DEF_LINALG_FUNC_COMPLEX(lapack_complex_float, numo_cSComplex, cgeev)
|
143
|
+
|
144
|
+
#undef DEF_LINALG_FUNC
|
145
|
+
#undef DEF_LINALG_FUNC_COMPLEX
|
146
|
+
|
147
|
+
void define_linalg_lapack_geev(VALUE mLapack) {
|
148
|
+
rb_define_module_function(mLapack, "dgeev", RUBY_METHOD_FUNC(_linalg_lapack_dgeev), -1);
|
149
|
+
rb_define_module_function(mLapack, "sgeev", RUBY_METHOD_FUNC(_linalg_lapack_sgeev), -1);
|
150
|
+
rb_define_module_function(mLapack, "zgeev", RUBY_METHOD_FUNC(_linalg_lapack_zgeev), -1);
|
151
|
+
rb_define_module_function(mLapack, "cgeev", RUBY_METHOD_FUNC(_linalg_lapack_cgeev), -1);
|
152
|
+
}
|
@@ -0,0 +1,15 @@
|
|
1
|
+
#ifndef NUMO_LINALG_ALT_LAPACK_GEEV_H
|
2
|
+
#define NUMO_LINALG_ALT_LAPACK_GEEV_H 1
|
3
|
+
|
4
|
+
#include <lapacke.h>
|
5
|
+
|
6
|
+
#include <ruby.h>
|
7
|
+
|
8
|
+
#include <numo/narray.h>
|
9
|
+
#include <numo/template.h>
|
10
|
+
|
11
|
+
#include "../util.h"
|
12
|
+
|
13
|
+
void define_linalg_lapack_geev(VALUE mLapack);
|
14
|
+
|
15
|
+
#endif /* NUMO_LINALG_ALT_LAPACK_GEEV_H */
|