numo-tiny_linalg 0.0.3 → 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +3 -3
- data/ext/numo/tiny_linalg/blas/gemm.hpp +3 -49
- data/ext/numo/tiny_linalg/blas/gemv.hpp +2 -48
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/gesdd.hpp +11 -11
- data/ext/numo/tiny_linalg/lapack/gesv.hpp +10 -30
- data/ext/numo/tiny_linalg/lapack/gesvd.hpp +12 -12
- data/ext/numo/tiny_linalg/lapack/getrf.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/getri.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +121 -0
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +121 -0
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +137 -0
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +112 -0
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +112 -0
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +136 -0
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +5 -25
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +74 -21
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +30 -6
- data/ext/numo/tiny_linalg/util.hpp +100 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +203 -35
- metadata +9 -2
@@ -0,0 +1,136 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
struct DSyGvx {
|
4
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
5
|
+
lapack_int n, double* a, lapack_int lda, double* b, lapack_int ldb,
|
6
|
+
double vl, double vu, lapack_int il, lapack_int iu,
|
7
|
+
double abstol, lapack_int* m, double* w, double* z, lapack_int ldz, lapack_int* ifail) {
|
8
|
+
return LAPACKE_dsygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
9
|
+
}
|
10
|
+
};
|
11
|
+
|
12
|
+
struct SSyGvx {
|
13
|
+
lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
|
14
|
+
lapack_int n, float* a, lapack_int lda, float* b, lapack_int ldb,
|
15
|
+
float vl, float vu, lapack_int il, lapack_int iu,
|
16
|
+
float abstol, lapack_int* m, float* w, float* z, lapack_int ldz, lapack_int* ifail) {
|
17
|
+
return LAPACKE_ssygvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
21
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
22
|
+
class SyGvx {
|
23
|
+
public:
|
24
|
+
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
25
|
+
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvx), -1);
|
26
|
+
}
|
27
|
+
|
28
|
+
private:
|
29
|
+
struct sygvx_opt {
|
30
|
+
int matrix_layout;
|
31
|
+
lapack_int itype;
|
32
|
+
char jobz;
|
33
|
+
char range;
|
34
|
+
char uplo;
|
35
|
+
dtype vl;
|
36
|
+
dtype vu;
|
37
|
+
lapack_int il;
|
38
|
+
lapack_int iu;
|
39
|
+
};
|
40
|
+
|
41
|
+
static void iter_sygvx(na_loop_t* const lp) {
|
42
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
43
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
44
|
+
int* m = (int*)NDL_PTR(lp, 2);
|
45
|
+
dtype* w = (dtype*)NDL_PTR(lp, 3);
|
46
|
+
dtype* z = (dtype*)NDL_PTR(lp, 4);
|
47
|
+
int* ifail = (int*)NDL_PTR(lp, 5);
|
48
|
+
int* info = (int*)NDL_PTR(lp, 6);
|
49
|
+
sygvx_opt* opt = (sygvx_opt*)(lp->opt_ptr);
|
50
|
+
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
51
|
+
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
|
52
|
+
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
|
53
|
+
const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
|
54
|
+
const dtype abstol = 0.0;
|
55
|
+
const lapack_int i = LapackFn().call(
|
56
|
+
opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
|
57
|
+
opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
|
58
|
+
*info = static_cast<int>(i);
|
59
|
+
}
|
60
|
+
|
61
|
+
static VALUE tiny_linalg_sygvx(int argc, VALUE* argv, VALUE self) {
|
62
|
+
VALUE nary_dtype = NaryTypes[nary_dtype_id];
|
63
|
+
|
64
|
+
VALUE a_vnary = Qnil;
|
65
|
+
VALUE b_vnary = Qnil;
|
66
|
+
VALUE kw_args = Qnil;
|
67
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
|
68
|
+
ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
|
69
|
+
rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
|
70
|
+
VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
|
71
|
+
rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
|
72
|
+
const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
|
73
|
+
const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
|
74
|
+
const char range = kw_values[2] != Qundef ? Util().get_range(kw_values[2]) : 'A';
|
75
|
+
const char uplo = kw_values[3] != Qundef ? Util().get_uplo(kw_values[3]) : 'U';
|
76
|
+
const dtype vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
|
77
|
+
const dtype vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
|
78
|
+
const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
|
79
|
+
const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
|
80
|
+
const int matrix_layout = kw_values[8] != Qundef ? Util().get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
|
81
|
+
|
82
|
+
if (CLASS_OF(a_vnary) != nary_dtype) {
|
83
|
+
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
84
|
+
}
|
85
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) {
|
86
|
+
a_vnary = nary_dup(a_vnary);
|
87
|
+
}
|
88
|
+
if (CLASS_OF(b_vnary) != nary_dtype) {
|
89
|
+
b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
|
90
|
+
}
|
91
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) {
|
92
|
+
b_vnary = nary_dup(b_vnary);
|
93
|
+
}
|
94
|
+
|
95
|
+
narray_t* a_nary = nullptr;
|
96
|
+
GetNArray(a_vnary, a_nary);
|
97
|
+
if (NA_NDIM(a_nary) != 2) {
|
98
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
|
99
|
+
return Qnil;
|
100
|
+
}
|
101
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
|
102
|
+
rb_raise(rb_eArgError, "input array a must be square");
|
103
|
+
return Qnil;
|
104
|
+
}
|
105
|
+
narray_t* b_nary = nullptr;
|
106
|
+
GetNArray(a_vnary, b_nary);
|
107
|
+
if (NA_NDIM(b_nary) != 2) {
|
108
|
+
rb_raise(rb_eArgError, "input array b must be 2-dimensional");
|
109
|
+
return Qnil;
|
110
|
+
}
|
111
|
+
if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
|
112
|
+
rb_raise(rb_eArgError, "input array b must be square");
|
113
|
+
return Qnil;
|
114
|
+
}
|
115
|
+
|
116
|
+
const size_t n = NA_SHAPE(a_nary)[1];
|
117
|
+
size_t m = range != 'I' ? n : iu - il + 1;
|
118
|
+
size_t w_shape[1] = { m };
|
119
|
+
size_t z_shape[2] = { n, m };
|
120
|
+
size_t ifail_shape[1] = { n };
|
121
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
|
122
|
+
ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_dtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
|
123
|
+
ndfunc_t ndf = { iter_sygvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
|
124
|
+
sygvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
|
125
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
|
126
|
+
VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
|
127
|
+
rb_ary_entry(res, 3), rb_ary_entry(res, 4));
|
128
|
+
|
129
|
+
RB_GC_GUARD(a_vnary);
|
130
|
+
RB_GC_GUARD(b_vnary);
|
131
|
+
|
132
|
+
return ret;
|
133
|
+
}
|
134
|
+
};
|
135
|
+
|
136
|
+
} // namespace TinyLinalg
|
@@ -14,7 +14,7 @@ struct CUngQr {
|
|
14
14
|
}
|
15
15
|
};
|
16
16
|
|
17
|
-
template <int nary_dtype_id, typename
|
17
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
18
18
|
class UngQr {
|
19
19
|
public:
|
20
20
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -27,15 +27,15 @@ private:
|
|
27
27
|
};
|
28
28
|
|
29
29
|
static void iter_ungqr(na_loop_t* const lp) {
|
30
|
-
|
31
|
-
|
30
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
31
|
+
dtype* tau = (dtype*)NDL_PTR(lp, 1);
|
32
32
|
int* info = (int*)NDL_PTR(lp, 2);
|
33
33
|
ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
|
34
34
|
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
35
35
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
36
36
|
const lapack_int k = NDL_SHAPE(lp, 1)[0];
|
37
37
|
const lapack_int lda = n;
|
38
|
-
const lapack_int i =
|
38
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, k, a, lda, tau);
|
39
39
|
*info = static_cast<int>(i);
|
40
40
|
}
|
41
41
|
|
@@ -49,7 +49,7 @@ private:
|
|
49
49
|
ID kw_table[1] = { rb_intern("order") };
|
50
50
|
VALUE kw_values[1] = { Qundef };
|
51
51
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
52
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
52
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
53
53
|
|
54
54
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
55
55
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -90,26 +90,6 @@ private:
|
|
90
90
|
|
91
91
|
return ret;
|
92
92
|
}
|
93
|
-
|
94
|
-
static int get_matrix_layout(VALUE val) {
|
95
|
-
const char* option_str = StringValueCStr(val);
|
96
|
-
|
97
|
-
if (std::strlen(option_str) > 0) {
|
98
|
-
switch (option_str[0]) {
|
99
|
-
case 'r':
|
100
|
-
case 'R':
|
101
|
-
break;
|
102
|
-
case 'c':
|
103
|
-
case 'C':
|
104
|
-
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
105
|
-
break;
|
106
|
-
}
|
107
|
-
}
|
108
|
-
|
109
|
-
RB_GC_GUARD(val);
|
110
|
-
|
111
|
-
return LAPACK_ROW_MAJOR;
|
112
|
-
}
|
113
93
|
};
|
114
94
|
|
115
95
|
} // namespace TinyLinalg
|
@@ -1,17 +1,56 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2023 Atsushi Tatsuma
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* Redistribution and use in source and binary forms, with or without
|
6
|
+
* modification, are permitted provided that the following conditions are met:
|
7
|
+
*
|
8
|
+
* * Redistributions of source code must retain the above copyright notice, this
|
9
|
+
* list of conditions and the following disclaimer.
|
10
|
+
*
|
11
|
+
* * Redistributions in binary form must reproduce the above copyright notice,
|
12
|
+
* this list of conditions and the following disclaimer in the documentation
|
13
|
+
* and/or other materials provided with the distribution.
|
14
|
+
*
|
15
|
+
* * Neither the name of the copyright holder nor the names of its
|
16
|
+
* contributors may be used to endorse or promote products derived from
|
17
|
+
* this software without specific prior written permission.
|
18
|
+
*
|
19
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29
|
+
*/
|
30
|
+
|
1
31
|
#include "tiny_linalg.hpp"
|
32
|
+
|
33
|
+
#include "converter.hpp"
|
34
|
+
#include "util.hpp"
|
35
|
+
|
2
36
|
#include "blas/dot.hpp"
|
3
37
|
#include "blas/dot_sub.hpp"
|
4
38
|
#include "blas/gemm.hpp"
|
5
39
|
#include "blas/gemv.hpp"
|
6
40
|
#include "blas/nrm2.hpp"
|
7
|
-
#include "converter.hpp"
|
8
41
|
#include "lapack/geqrf.hpp"
|
9
42
|
#include "lapack/gesdd.hpp"
|
10
43
|
#include "lapack/gesv.hpp"
|
11
44
|
#include "lapack/gesvd.hpp"
|
12
45
|
#include "lapack/getrf.hpp"
|
13
46
|
#include "lapack/getri.hpp"
|
47
|
+
#include "lapack/hegv.hpp"
|
48
|
+
#include "lapack/hegvd.hpp"
|
49
|
+
#include "lapack/hegvx.hpp"
|
14
50
|
#include "lapack/orgqr.hpp"
|
51
|
+
#include "lapack/sygv.hpp"
|
52
|
+
#include "lapack/sygvd.hpp"
|
53
|
+
#include "lapack/sygvx.hpp"
|
15
54
|
#include "lapack/ungqr.hpp"
|
16
55
|
|
17
56
|
VALUE rb_mTinyLinalg;
|
@@ -191,11 +230,13 @@ extern "C" void Init_tiny_linalg(void) {
|
|
191
230
|
/**
|
192
231
|
* Document-module: Numo::TinyLinalg::Blas
|
193
232
|
* Numo::TinyLinalg::Blas is wrapper module of BLAS functions.
|
233
|
+
* @!visibility private
|
194
234
|
*/
|
195
235
|
rb_mTinyLinalgBlas = rb_define_module_under(rb_mTinyLinalg, "Blas");
|
196
236
|
/**
|
197
237
|
* Document-module: Numo::TinyLinalg::Lapack
|
198
238
|
* Numo::TinyLinalg::Lapack is wrapper module of LAPACK functions.
|
239
|
+
* @!visibility private
|
199
240
|
*/
|
200
241
|
rb_mTinyLinalgLapack = rb_define_module_under(rb_mTinyLinalg, "Lapack");
|
201
242
|
|
@@ -243,26 +284,26 @@ extern "C" void Init_tiny_linalg(void) {
|
|
243
284
|
TinyLinalg::Nrm2<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SNrm2>::define_module_function(rb_mTinyLinalgBlas, "snrm2");
|
244
285
|
TinyLinalg::Nrm2<TinyLinalg::numo_cDComplexId, double, TinyLinalg::DZNrm2>::define_module_function(rb_mTinyLinalgBlas, "dznrm2");
|
245
286
|
TinyLinalg::Nrm2<TinyLinalg::numo_cSComplexId, float, TinyLinalg::SCNrm2>::define_module_function(rb_mTinyLinalgBlas, "scnrm2");
|
246
|
-
TinyLinalg::
|
247
|
-
TinyLinalg::
|
248
|
-
TinyLinalg::
|
249
|
-
TinyLinalg::
|
250
|
-
TinyLinalg::
|
251
|
-
TinyLinalg::
|
252
|
-
TinyLinalg::
|
253
|
-
TinyLinalg::
|
254
|
-
TinyLinalg::
|
255
|
-
TinyLinalg::
|
256
|
-
TinyLinalg::
|
257
|
-
TinyLinalg::
|
258
|
-
TinyLinalg::
|
259
|
-
TinyLinalg::
|
260
|
-
TinyLinalg::
|
261
|
-
TinyLinalg::
|
262
|
-
TinyLinalg::
|
263
|
-
TinyLinalg::
|
264
|
-
TinyLinalg::
|
265
|
-
TinyLinalg::
|
287
|
+
TinyLinalg::GeSv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeSv>::define_module_function(rb_mTinyLinalgLapack, "dgesv");
|
288
|
+
TinyLinalg::GeSv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeSv>::define_module_function(rb_mTinyLinalgLapack, "sgesv");
|
289
|
+
TinyLinalg::GeSv<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeSv>::define_module_function(rb_mTinyLinalgLapack, "zgesv");
|
290
|
+
TinyLinalg::GeSv<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeSv>::define_module_function(rb_mTinyLinalgLapack, "cgesv");
|
291
|
+
TinyLinalg::GeSvd<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGeSvd>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
|
292
|
+
TinyLinalg::GeSvd<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGeSvd>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
|
293
|
+
TinyLinalg::GeSvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGeSvd>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
|
294
|
+
TinyLinalg::GeSvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGeSvd>::define_module_function(rb_mTinyLinalgLapack, "cgesvd");
|
295
|
+
TinyLinalg::GeSdd<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGeSdd>::define_module_function(rb_mTinyLinalgLapack, "dgesdd");
|
296
|
+
TinyLinalg::GeSdd<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGeSdd>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
|
297
|
+
TinyLinalg::GeSdd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGeSdd>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
|
298
|
+
TinyLinalg::GeSdd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGeSdd>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
|
299
|
+
TinyLinalg::GeTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeTrf>::define_module_function(rb_mTinyLinalgLapack, "dgetrf");
|
300
|
+
TinyLinalg::GeTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTrf>::define_module_function(rb_mTinyLinalgLapack, "sgetrf");
|
301
|
+
TinyLinalg::GeTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTrf>::define_module_function(rb_mTinyLinalgLapack, "zgetrf");
|
302
|
+
TinyLinalg::GeTrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTrf>::define_module_function(rb_mTinyLinalgLapack, "cgetrf");
|
303
|
+
TinyLinalg::GeTri<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeTri>::define_module_function(rb_mTinyLinalgLapack, "dgetri");
|
304
|
+
TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
|
305
|
+
TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
|
306
|
+
TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
|
266
307
|
TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
|
267
308
|
TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
|
268
309
|
TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
|
@@ -271,6 +312,18 @@ extern "C" void Init_tiny_linalg(void) {
|
|
271
312
|
TinyLinalg::OrgQr<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SOrgQr>::define_module_function(rb_mTinyLinalgLapack, "sorgqr");
|
272
313
|
TinyLinalg::UngQr<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZUngQr>::define_module_function(rb_mTinyLinalgLapack, "zungqr");
|
273
314
|
TinyLinalg::UngQr<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CUngQr>::define_module_function(rb_mTinyLinalgLapack, "cungqr");
|
315
|
+
TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
|
316
|
+
TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
|
317
|
+
TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
|
318
|
+
TinyLinalg::HeGv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGv>::define_module_function(rb_mTinyLinalgLapack, "chegv");
|
319
|
+
TinyLinalg::SyGvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvd>::define_module_function(rb_mTinyLinalgLapack, "dsygvd");
|
320
|
+
TinyLinalg::SyGvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvd>::define_module_function(rb_mTinyLinalgLapack, "ssygvd");
|
321
|
+
TinyLinalg::HeGvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvd>::define_module_function(rb_mTinyLinalgLapack, "zhegvd");
|
322
|
+
TinyLinalg::HeGvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvd>::define_module_function(rb_mTinyLinalgLapack, "chegvd");
|
323
|
+
TinyLinalg::SyGvx<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGvx>::define_module_function(rb_mTinyLinalgLapack, "dsygvx");
|
324
|
+
TinyLinalg::SyGvx<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGvx>::define_module_function(rb_mTinyLinalgLapack, "ssygvx");
|
325
|
+
TinyLinalg::HeGvx<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGvx>::define_module_function(rb_mTinyLinalgLapack, "zhegvx");
|
326
|
+
TinyLinalg::HeGvx<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeGvx>::define_module_function(rb_mTinyLinalgLapack, "chegvx");
|
274
327
|
|
275
328
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "znrm2", "dznrm2");
|
276
329
|
rb_define_alias(rb_singleton_class(rb_mTinyLinalgBlas), "cnrm2", "scnrm2");
|
@@ -1,17 +1,41 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2023 Atsushi Tatsuma
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* Redistribution and use in source and binary forms, with or without
|
6
|
+
* modification, are permitted provided that the following conditions are met:
|
7
|
+
*
|
8
|
+
* * Redistributions of source code must retain the above copyright notice, this
|
9
|
+
* list of conditions and the following disclaimer.
|
10
|
+
*
|
11
|
+
* * Redistributions in binary form must reproduce the above copyright notice,
|
12
|
+
* this list of conditions and the following disclaimer in the documentation
|
13
|
+
* and/or other materials provided with the distribution.
|
14
|
+
*
|
15
|
+
* * Neither the name of the copyright holder nor the names of its
|
16
|
+
* contributors may be used to endorse or promote products derived from
|
17
|
+
* this software without specific prior written permission.
|
18
|
+
*
|
19
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29
|
+
*/
|
30
|
+
|
1
31
|
#ifndef NUMO_TINY_LINALG_H
|
2
32
|
#define NUMO_TINY_LINALG_H 1
|
3
33
|
|
4
|
-
#if defined(TINYLINALG_USE_ACCELERATE)
|
5
|
-
#include <Accelerate/Accelerate.h>
|
6
|
-
#else
|
7
34
|
#include <cblas.h>
|
8
35
|
#include <lapacke.h>
|
9
|
-
#endif
|
10
36
|
|
11
37
|
#include <string>
|
12
38
|
|
13
|
-
#include <cstring>
|
14
|
-
|
15
39
|
#include <ruby.h>
|
16
40
|
|
17
41
|
#include <numo/narray.h>
|
@@ -0,0 +1,100 @@
|
|
1
|
+
namespace TinyLinalg {
|
2
|
+
|
3
|
+
class Util {
|
4
|
+
public:
|
5
|
+
static lapack_int get_itype(VALUE val) {
|
6
|
+
const lapack_int itype = NUM2INT(val);
|
7
|
+
|
8
|
+
if (itype != 1 && itype != 2 && itype != 3) {
|
9
|
+
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
|
10
|
+
}
|
11
|
+
|
12
|
+
return itype;
|
13
|
+
}
|
14
|
+
|
15
|
+
static char get_jobz(VALUE val) {
|
16
|
+
const char jobz = NUM2CHR(val);
|
17
|
+
|
18
|
+
if (jobz != 'N' && jobz != 'V') {
|
19
|
+
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
|
20
|
+
}
|
21
|
+
|
22
|
+
return jobz;
|
23
|
+
}
|
24
|
+
|
25
|
+
static char get_range(VALUE val) {
|
26
|
+
const char range = NUM2CHR(val);
|
27
|
+
|
28
|
+
if (range != 'A' && range != 'V' && range != 'I') {
|
29
|
+
rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
|
30
|
+
}
|
31
|
+
|
32
|
+
return range;
|
33
|
+
}
|
34
|
+
|
35
|
+
static char get_uplo(VALUE val) {
|
36
|
+
const char uplo = NUM2CHR(val);
|
37
|
+
|
38
|
+
if (uplo != 'U' && uplo != 'L') {
|
39
|
+
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
|
40
|
+
}
|
41
|
+
|
42
|
+
return uplo;
|
43
|
+
}
|
44
|
+
|
45
|
+
static int get_matrix_layout(VALUE val) {
|
46
|
+
const char option = NUM2CHR(val);
|
47
|
+
|
48
|
+
switch (option) {
|
49
|
+
case 'r':
|
50
|
+
case 'R':
|
51
|
+
break;
|
52
|
+
case 'c':
|
53
|
+
case 'C':
|
54
|
+
rb_warn("Numo::TinyLinalg does not support column major.");
|
55
|
+
break;
|
56
|
+
}
|
57
|
+
|
58
|
+
return LAPACK_ROW_MAJOR;
|
59
|
+
}
|
60
|
+
|
61
|
+
static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
|
62
|
+
const char option = NUM2CHR(val);
|
63
|
+
enum CBLAS_TRANSPOSE res = CblasNoTrans;
|
64
|
+
|
65
|
+
switch (option) {
|
66
|
+
case 'n':
|
67
|
+
case 'N':
|
68
|
+
res = CblasNoTrans;
|
69
|
+
break;
|
70
|
+
case 't':
|
71
|
+
case 'T':
|
72
|
+
res = CblasTrans;
|
73
|
+
break;
|
74
|
+
case 'c':
|
75
|
+
case 'C':
|
76
|
+
res = CblasConjTrans;
|
77
|
+
break;
|
78
|
+
}
|
79
|
+
|
80
|
+
return res;
|
81
|
+
}
|
82
|
+
|
83
|
+
static enum CBLAS_ORDER get_cblas_order(VALUE val) {
|
84
|
+
const char option = NUM2CHR(val);
|
85
|
+
|
86
|
+
switch (option) {
|
87
|
+
case 'r':
|
88
|
+
case 'R':
|
89
|
+
break;
|
90
|
+
case 'c':
|
91
|
+
case 'C':
|
92
|
+
rb_warn("Numo::TinyLinalg does not support column major.");
|
93
|
+
break;
|
94
|
+
}
|
95
|
+
|
96
|
+
return CblasRowMajor;
|
97
|
+
}
|
98
|
+
};
|
99
|
+
|
100
|
+
} // namespace TinyLinalg
|