numo-tiny_linalg 0.0.4 → 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -1
- data/README.md +3 -3
- data/ext/numo/tiny_linalg/blas/gemm.hpp +3 -49
- data/ext/numo/tiny_linalg/blas/gemv.hpp +2 -48
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/gesdd.hpp +11 -11
- data/ext/numo/tiny_linalg/lapack/gesv.hpp +10 -30
- data/ext/numo/tiny_linalg/lapack/gesvd.hpp +12 -12
- data/ext/numo/tiny_linalg/lapack/getrf.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/getri.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +16 -72
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +16 -72
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +5 -25
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +56 -21
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +30 -6
- data/ext/numo/tiny_linalg/util.hpp +100 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +159 -35
- metadata +2 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 7a298eaec7ee7338e4856ac50b22a06bf774073456a01f5ffb2a405cab632f7a
|
4
|
+
data.tar.gz: a44ae42f723ee7c9a1af6c2b1a01e6257f1fc417a5cd2a22cb6be51ac0589c1c
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: fce33aa331257bc37d7a972e09a0cf1d9328046459c9de6b9feaf8341d0d4545d202b90d6bb0e5e6ee097b4c9c58a4bc76fc0224fd87dc10818c8f566259ab7e
|
7
|
+
data.tar.gz: 6da516469d0fdfb47884fea2d9fc2483b5ea066063ce1681afc8e9319d1cfcc6863f59e96a4ce1d7cb22ceb159a8b2f8f6e41478958ea245a5eaaf209753ea2c
|
data/CHANGELOG.md
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
-
## [[0.0
|
3
|
+
## [[0.1.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.4...v0.1.0)] - 2023-08-06
|
4
|
+
- Refactor codes and update documentations.
|
5
|
+
|
6
|
+
## [[0.0.4](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.3...v0.0.4)] - 2023-08-06
|
4
7
|
- Add dsygv, ssygv, zhegv, and chegv module functions to TinyLinalg::Lapack.
|
5
8
|
- Add dsygvd, ssygvd, zhegvd, and chegvd module functions to TinyLinalg::Lapack.
|
6
9
|
- Add dsygvx, ssygvx, zhegvx, and chegvx module functions to TinyLinalg::Lapack.
|
data/README.md
CHANGED
@@ -5,11 +5,11 @@
|
|
5
5
|
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/numo-tiny_linalg/blob/main/LICENSE.txt)
|
6
6
|
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/numo-tiny_linalg/doc/)
|
7
7
|
|
8
|
-
Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
9
|
-
|
10
|
-
This gem is still **under development** and may undergo many changes in the future.
|
8
|
+
Numo::TinyLinalg is a subset library from [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) consisting only of methods used in Machine Learning algorithms.
|
9
|
+
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, and svd.
|
11
10
|
|
12
11
|
## Installation
|
12
|
+
Unlike Numo::Linalg, Numo::TinyLinalg only supports OpenBLAS as a backend library for BLAS and LAPACK.
|
13
13
|
|
14
14
|
Install the OpenBlas.
|
15
15
|
|
@@ -102,9 +102,9 @@ private:
|
|
102
102
|
|
103
103
|
dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
|
104
104
|
dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
|
105
|
-
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
|
106
|
-
enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
107
|
-
enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? get_cblas_trans(kw_values[4]) : CblasNoTrans;
|
105
|
+
enum CBLAS_ORDER order = kw_values[2] != Qundef ? Util().get_cblas_order(kw_values[2]) : CblasRowMajor;
|
106
|
+
enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? Util().get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
107
|
+
enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? Util().get_cblas_trans(kw_values[4]) : CblasNoTrans;
|
108
108
|
|
109
109
|
narray_t* a_nary = NULL;
|
110
110
|
GetNArray(a, a_nary);
|
@@ -172,52 +172,6 @@ private:
|
|
172
172
|
|
173
173
|
return ret;
|
174
174
|
};
|
175
|
-
|
176
|
-
static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
|
177
|
-
const char* option_str = StringValueCStr(val);
|
178
|
-
enum CBLAS_TRANSPOSE res = CblasNoTrans;
|
179
|
-
|
180
|
-
if (std::strlen(option_str) > 0) {
|
181
|
-
switch (option_str[0]) {
|
182
|
-
case 'n':
|
183
|
-
case 'N':
|
184
|
-
res = CblasNoTrans;
|
185
|
-
break;
|
186
|
-
case 't':
|
187
|
-
case 'T':
|
188
|
-
res = CblasTrans;
|
189
|
-
break;
|
190
|
-
case 'c':
|
191
|
-
case 'C':
|
192
|
-
res = CblasConjTrans;
|
193
|
-
break;
|
194
|
-
}
|
195
|
-
}
|
196
|
-
|
197
|
-
RB_GC_GUARD(val);
|
198
|
-
|
199
|
-
return res;
|
200
|
-
}
|
201
|
-
|
202
|
-
static enum CBLAS_ORDER get_cblas_order(VALUE val) {
|
203
|
-
const char* option_str = StringValueCStr(val);
|
204
|
-
|
205
|
-
if (std::strlen(option_str) > 0) {
|
206
|
-
switch (option_str[0]) {
|
207
|
-
case 'r':
|
208
|
-
case 'R':
|
209
|
-
break;
|
210
|
-
case 'c':
|
211
|
-
case 'C':
|
212
|
-
rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
|
213
|
-
break;
|
214
|
-
}
|
215
|
-
}
|
216
|
-
|
217
|
-
RB_GC_GUARD(val);
|
218
|
-
|
219
|
-
return CblasRowMajor;
|
220
|
-
}
|
221
175
|
};
|
222
176
|
|
223
177
|
} // namespace TinyLinalg
|
@@ -94,8 +94,8 @@ private:
|
|
94
94
|
|
95
95
|
dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
|
96
96
|
dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
|
97
|
-
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
|
98
|
-
enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
97
|
+
enum CBLAS_ORDER order = kw_values[2] != Qundef ? Util().get_cblas_order(kw_values[2]) : CblasRowMajor;
|
98
|
+
enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? Util().get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
99
99
|
|
100
100
|
narray_t* a_nary = NULL;
|
101
101
|
GetNArray(a, a_nary);
|
@@ -160,52 +160,6 @@ private:
|
|
160
160
|
|
161
161
|
return ret;
|
162
162
|
}
|
163
|
-
|
164
|
-
static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
|
165
|
-
const char* option_str = StringValueCStr(val);
|
166
|
-
enum CBLAS_TRANSPOSE res = CblasNoTrans;
|
167
|
-
|
168
|
-
if (std::strlen(option_str) > 0) {
|
169
|
-
switch (option_str[0]) {
|
170
|
-
case 'n':
|
171
|
-
case 'N':
|
172
|
-
res = CblasNoTrans;
|
173
|
-
break;
|
174
|
-
case 't':
|
175
|
-
case 'T':
|
176
|
-
res = CblasTrans;
|
177
|
-
break;
|
178
|
-
case 'c':
|
179
|
-
case 'C':
|
180
|
-
res = CblasConjTrans;
|
181
|
-
break;
|
182
|
-
}
|
183
|
-
}
|
184
|
-
|
185
|
-
RB_GC_GUARD(val);
|
186
|
-
|
187
|
-
return res;
|
188
|
-
}
|
189
|
-
|
190
|
-
static enum CBLAS_ORDER get_cblas_order(VALUE val) {
|
191
|
-
const char* option_str = StringValueCStr(val);
|
192
|
-
|
193
|
-
if (std::strlen(option_str) > 0) {
|
194
|
-
switch (option_str[0]) {
|
195
|
-
case 'r':
|
196
|
-
case 'R':
|
197
|
-
break;
|
198
|
-
case 'c':
|
199
|
-
case 'C':
|
200
|
-
rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
|
201
|
-
break;
|
202
|
-
}
|
203
|
-
}
|
204
|
-
|
205
|
-
RB_GC_GUARD(val);
|
206
|
-
|
207
|
-
return CblasRowMajor;
|
208
|
-
}
|
209
163
|
};
|
210
164
|
|
211
165
|
} // namespace TinyLinalg
|
@@ -28,7 +28,7 @@ struct CGeQrf {
|
|
28
28
|
}
|
29
29
|
};
|
30
30
|
|
31
|
-
template <int nary_dtype_id, typename
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
32
32
|
class GeQrf {
|
33
33
|
public:
|
34
34
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -41,14 +41,14 @@ private:
|
|
41
41
|
};
|
42
42
|
|
43
43
|
static void iter_geqrf(na_loop_t* const lp) {
|
44
|
-
|
45
|
-
|
44
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
45
|
+
dtype* tau = (dtype*)NDL_PTR(lp, 1);
|
46
46
|
int* info = (int*)NDL_PTR(lp, 2);
|
47
47
|
geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
|
48
48
|
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
49
49
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
50
50
|
const lapack_int lda = n;
|
51
|
-
const lapack_int i =
|
51
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, a, lda, tau);
|
52
52
|
*info = static_cast<int>(i);
|
53
53
|
}
|
54
54
|
|
@@ -61,7 +61,7 @@ private:
|
|
61
61
|
ID kw_table[1] = { rb_intern("order") };
|
62
62
|
VALUE kw_values[1] = { Qundef };
|
63
63
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
64
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
64
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
65
65
|
|
66
66
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
67
67
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -93,26 +93,6 @@ private:
|
|
93
93
|
|
94
94
|
return ret;
|
95
95
|
}
|
96
|
-
|
97
|
-
static int get_matrix_layout(VALUE val) {
|
98
|
-
const char* option_str = StringValueCStr(val);
|
99
|
-
|
100
|
-
if (std::strlen(option_str) > 0) {
|
101
|
-
switch (option_str[0]) {
|
102
|
-
case 'r':
|
103
|
-
case 'R':
|
104
|
-
break;
|
105
|
-
case 'c':
|
106
|
-
case 'C':
|
107
|
-
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
108
|
-
break;
|
109
|
-
}
|
110
|
-
}
|
111
|
-
|
112
|
-
RB_GC_GUARD(val);
|
113
|
-
|
114
|
-
return LAPACK_ROW_MAJOR;
|
115
|
-
}
|
116
96
|
};
|
117
97
|
|
118
98
|
} // namespace TinyLinalg
|
@@ -1,35 +1,35 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeSdd {
|
4
4
|
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
5
5
|
double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt) {
|
6
6
|
return LAPACKE_dgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
7
7
|
};
|
8
8
|
};
|
9
9
|
|
10
|
-
struct
|
10
|
+
struct SGeSdd {
|
11
11
|
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
12
12
|
float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt) {
|
13
13
|
return LAPACKE_sgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
14
14
|
};
|
15
15
|
};
|
16
16
|
|
17
|
-
struct
|
17
|
+
struct ZGeSdd {
|
18
18
|
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
19
19
|
lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt) {
|
20
20
|
return LAPACKE_zgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
21
21
|
};
|
22
22
|
};
|
23
23
|
|
24
|
-
struct
|
24
|
+
struct CGeSdd {
|
25
25
|
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
26
26
|
lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt) {
|
27
27
|
return LAPACKE_cgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
28
28
|
};
|
29
29
|
};
|
30
30
|
|
31
|
-
template <int nary_dtype_id, int nary_rtype_id, typename
|
32
|
-
class
|
31
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
32
|
+
class GeSdd {
|
33
33
|
public:
|
34
34
|
static void define_module_function(VALUE mLapack, const char* mf_name) {
|
35
35
|
rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesdd), -1);
|
@@ -42,10 +42,10 @@ private:
|
|
42
42
|
};
|
43
43
|
|
44
44
|
static void iter_gesdd(na_loop_t* const lp) {
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
45
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
46
|
+
rtype* s = (rtype*)NDL_PTR(lp, 1);
|
47
|
+
dtype* u = (dtype*)NDL_PTR(lp, 2);
|
48
|
+
dtype* vt = (dtype*)NDL_PTR(lp, 3);
|
49
49
|
int* info = (int*)NDL_PTR(lp, 4);
|
50
50
|
gesdd_opt* opt = (gesdd_opt*)(lp->opt_ptr);
|
51
51
|
|
@@ -56,7 +56,7 @@ private:
|
|
56
56
|
const lapack_int ldu = opt->jobz == 'S' ? min_mn : m;
|
57
57
|
const lapack_int ldvt = opt->jobz == 'S' ? min_mn : n;
|
58
58
|
|
59
|
-
lapack_int i =
|
59
|
+
lapack_int i = LapackFn().call(opt->matrix_order, opt->jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
60
60
|
*info = static_cast<int>(i);
|
61
61
|
};
|
62
62
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeSv {
|
4
4
|
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
5
5
|
double* a, lapack_int lda, lapack_int* ipiv,
|
6
6
|
double* b, lapack_int ldb) {
|
@@ -8,7 +8,7 @@ struct DGESV {
|
|
8
8
|
}
|
9
9
|
};
|
10
10
|
|
11
|
-
struct
|
11
|
+
struct SGeSv {
|
12
12
|
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
13
13
|
float* a, lapack_int lda, lapack_int* ipiv,
|
14
14
|
float* b, lapack_int ldb) {
|
@@ -16,7 +16,7 @@ struct SGESV {
|
|
16
16
|
}
|
17
17
|
};
|
18
18
|
|
19
|
-
struct
|
19
|
+
struct ZGeSv {
|
20
20
|
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
21
21
|
lapack_complex_double* a, lapack_int lda, lapack_int* ipiv,
|
22
22
|
lapack_complex_double* b, lapack_int ldb) {
|
@@ -24,7 +24,7 @@ struct ZGESV {
|
|
24
24
|
}
|
25
25
|
};
|
26
26
|
|
27
|
-
struct
|
27
|
+
struct CGeSv {
|
28
28
|
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
29
29
|
lapack_complex_float* a, lapack_int lda, lapack_int* ipiv,
|
30
30
|
lapack_complex_float* b, lapack_int ldb) {
|
@@ -32,8 +32,8 @@ struct CGESV {
|
|
32
32
|
}
|
33
33
|
};
|
34
34
|
|
35
|
-
template <int nary_dtype_id, typename
|
36
|
-
class
|
35
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
36
|
+
class GeSv {
|
37
37
|
public:
|
38
38
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
39
39
|
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_gesv), -1);
|
@@ -45,8 +45,8 @@ private:
|
|
45
45
|
};
|
46
46
|
|
47
47
|
static void iter_gesv(na_loop_t* const lp) {
|
48
|
-
|
49
|
-
|
48
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
49
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
50
50
|
int* ipiv = (int*)NDL_PTR(lp, 2);
|
51
51
|
int* info = (int*)NDL_PTR(lp, 3);
|
52
52
|
gesv_opt* opt = (gesv_opt*)(lp->opt_ptr);
|
@@ -54,7 +54,7 @@ private:
|
|
54
54
|
const lapack_int nhrs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
|
55
55
|
const lapack_int lda = n;
|
56
56
|
const lapack_int ldb = nhrs;
|
57
|
-
const lapack_int i =
|
57
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, n, nhrs, a, lda, ipiv, b, ldb);
|
58
58
|
*info = static_cast<int>(i);
|
59
59
|
}
|
60
60
|
|
@@ -71,7 +71,7 @@ private:
|
|
71
71
|
|
72
72
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
73
73
|
|
74
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
74
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
75
75
|
|
76
76
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
77
77
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -123,26 +123,6 @@ private:
|
|
123
123
|
|
124
124
|
return ret;
|
125
125
|
}
|
126
|
-
|
127
|
-
static int get_matrix_layout(VALUE val) {
|
128
|
-
const char* option_str = StringValueCStr(val);
|
129
|
-
|
130
|
-
if (std::strlen(option_str) > 0) {
|
131
|
-
switch (option_str[0]) {
|
132
|
-
case 'r':
|
133
|
-
case 'R':
|
134
|
-
break;
|
135
|
-
case 'c':
|
136
|
-
case 'C':
|
137
|
-
rb_warn("Numo::TinyLinalg::Lapack.gesv does not support column major.");
|
138
|
-
break;
|
139
|
-
}
|
140
|
-
}
|
141
|
-
|
142
|
-
RB_GC_GUARD(val);
|
143
|
-
|
144
|
-
return LAPACK_ROW_MAJOR;
|
145
|
-
}
|
146
126
|
};
|
147
127
|
|
148
128
|
} // namespace TinyLinalg
|
@@ -1,6 +1,6 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeSvd {
|
4
4
|
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
5
5
|
double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt,
|
6
6
|
double* superb) {
|
@@ -8,7 +8,7 @@ struct DGESVD {
|
|
8
8
|
};
|
9
9
|
};
|
10
10
|
|
11
|
-
struct
|
11
|
+
struct SGeSvd {
|
12
12
|
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
13
13
|
float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt,
|
14
14
|
float* superb) {
|
@@ -16,7 +16,7 @@ struct SGESVD {
|
|
16
16
|
};
|
17
17
|
};
|
18
18
|
|
19
|
-
struct
|
19
|
+
struct ZGeSvd {
|
20
20
|
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
21
21
|
lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt,
|
22
22
|
double* superb) {
|
@@ -24,7 +24,7 @@ struct ZGESVD {
|
|
24
24
|
};
|
25
25
|
};
|
26
26
|
|
27
|
-
struct
|
27
|
+
struct CGeSvd {
|
28
28
|
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
29
29
|
lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt,
|
30
30
|
float* superb) {
|
@@ -32,8 +32,8 @@ struct CGESVD {
|
|
32
32
|
};
|
33
33
|
};
|
34
34
|
|
35
|
-
template <int nary_dtype_id, int nary_rtype_id, typename
|
36
|
-
class
|
35
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
36
|
+
class GeSvd {
|
37
37
|
public:
|
38
38
|
static void define_module_function(VALUE mLapack, const char* mf_name) {
|
39
39
|
rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesvd), -1);
|
@@ -47,10 +47,10 @@ private:
|
|
47
47
|
};
|
48
48
|
|
49
49
|
static void iter_gesvd(na_loop_t* const lp) {
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
50
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
51
|
+
rtype* s = (rtype*)NDL_PTR(lp, 1);
|
52
|
+
dtype* u = (dtype*)NDL_PTR(lp, 2);
|
53
|
+
dtype* vt = (dtype*)NDL_PTR(lp, 3);
|
54
54
|
int* info = (int*)NDL_PTR(lp, 4);
|
55
55
|
gesvd_opt* opt = (gesvd_opt*)(lp->opt_ptr);
|
56
56
|
|
@@ -61,9 +61,9 @@ private:
|
|
61
61
|
const lapack_int ldu = opt->jobu == 'A' ? m : min_mn;
|
62
62
|
const lapack_int ldvt = n;
|
63
63
|
|
64
|
-
|
64
|
+
rtype* superb = (rtype*)ruby_xmalloc(min_mn * sizeof(rtype));
|
65
65
|
|
66
|
-
lapack_int i =
|
66
|
+
lapack_int i = LapackFn().call(opt->matrix_order, opt->jobu, opt->jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
|
67
67
|
*info = static_cast<int>(i);
|
68
68
|
|
69
69
|
ruby_xfree(superb);
|
@@ -1,35 +1,35 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeTrf {
|
4
4
|
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
5
5
|
double* a, lapack_int lda, lapack_int* ipiv) {
|
6
6
|
return LAPACKE_dgetrf(matrix_layout, m, n, a, lda, ipiv);
|
7
7
|
}
|
8
8
|
};
|
9
9
|
|
10
|
-
struct
|
10
|
+
struct SGeTrf {
|
11
11
|
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
12
12
|
float* a, lapack_int lda, lapack_int* ipiv) {
|
13
13
|
return LAPACKE_sgetrf(matrix_layout, m, n, a, lda, ipiv);
|
14
14
|
}
|
15
15
|
};
|
16
16
|
|
17
|
-
struct
|
17
|
+
struct ZGeTrf {
|
18
18
|
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
19
19
|
lapack_complex_double* a, lapack_int lda, lapack_int* ipiv) {
|
20
20
|
return LAPACKE_zgetrf(matrix_layout, m, n, a, lda, ipiv);
|
21
21
|
}
|
22
22
|
};
|
23
23
|
|
24
|
-
struct
|
24
|
+
struct CGeTrf {
|
25
25
|
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
26
26
|
lapack_complex_float* a, lapack_int lda, lapack_int* ipiv) {
|
27
27
|
return LAPACKE_cgetrf(matrix_layout, m, n, a, lda, ipiv);
|
28
28
|
}
|
29
29
|
};
|
30
30
|
|
31
|
-
template <int nary_dtype_id, typename
|
32
|
-
class
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
32
|
+
class GeTrf {
|
33
33
|
public:
|
34
34
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
35
|
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getrf), -1);
|
@@ -41,14 +41,14 @@ private:
|
|
41
41
|
};
|
42
42
|
|
43
43
|
static void iter_getrf(na_loop_t* const lp) {
|
44
|
-
|
44
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
45
45
|
int* ipiv = (int*)NDL_PTR(lp, 1);
|
46
46
|
int* info = (int*)NDL_PTR(lp, 2);
|
47
47
|
getrf_opt* opt = (getrf_opt*)(lp->opt_ptr);
|
48
48
|
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
49
49
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
50
50
|
const lapack_int lda = n;
|
51
|
-
const lapack_int i =
|
51
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, a, lda, ipiv);
|
52
52
|
*info = static_cast<int>(i);
|
53
53
|
}
|
54
54
|
|
@@ -61,7 +61,7 @@ private:
|
|
61
61
|
ID kw_table[1] = { rb_intern("order") };
|
62
62
|
VALUE kw_values[1] = { Qundef };
|
63
63
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
64
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
64
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
65
65
|
|
66
66
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
67
67
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -93,26 +93,6 @@ private:
|
|
93
93
|
|
94
94
|
return ret;
|
95
95
|
}
|
96
|
-
|
97
|
-
static int get_matrix_layout(VALUE val) {
|
98
|
-
const char* option_str = StringValueCStr(val);
|
99
|
-
|
100
|
-
if (std::strlen(option_str) > 0) {
|
101
|
-
switch (option_str[0]) {
|
102
|
-
case 'r':
|
103
|
-
case 'R':
|
104
|
-
break;
|
105
|
-
case 'c':
|
106
|
-
case 'C':
|
107
|
-
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
108
|
-
break;
|
109
|
-
}
|
110
|
-
}
|
111
|
-
|
112
|
-
RB_GC_GUARD(val);
|
113
|
-
|
114
|
-
return LAPACK_ROW_MAJOR;
|
115
|
-
}
|
116
96
|
};
|
117
97
|
|
118
98
|
} // namespace TinyLinalg
|
@@ -1,31 +1,31 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeTri {
|
4
4
|
lapack_int call(int matrix_layout, lapack_int n, double* a, lapack_int lda, const lapack_int* ipiv) {
|
5
5
|
return LAPACKE_dgetri(matrix_layout, n, a, lda, ipiv);
|
6
6
|
}
|
7
7
|
};
|
8
8
|
|
9
|
-
struct
|
9
|
+
struct SGeTri {
|
10
10
|
lapack_int call(int matrix_layout, lapack_int n, float* a, lapack_int lda, const lapack_int* ipiv) {
|
11
11
|
return LAPACKE_sgetri(matrix_layout, n, a, lda, ipiv);
|
12
12
|
}
|
13
13
|
};
|
14
14
|
|
15
|
-
struct
|
15
|
+
struct ZGeTri {
|
16
16
|
lapack_int call(int matrix_layout, lapack_int n, lapack_complex_double* a, lapack_int lda, const lapack_int* ipiv) {
|
17
17
|
return LAPACKE_zgetri(matrix_layout, n, a, lda, ipiv);
|
18
18
|
}
|
19
19
|
};
|
20
20
|
|
21
|
-
struct
|
21
|
+
struct CGeTri {
|
22
22
|
lapack_int call(int matrix_layout, lapack_int n, lapack_complex_float* a, lapack_int lda, const lapack_int* ipiv) {
|
23
23
|
return LAPACKE_cgetri(matrix_layout, n, a, lda, ipiv);
|
24
24
|
}
|
25
25
|
};
|
26
26
|
|
27
|
-
template <int nary_dtype_id, typename
|
28
|
-
class
|
27
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
28
|
+
class GeTri {
|
29
29
|
public:
|
30
30
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
31
|
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getri), -1);
|
@@ -37,13 +37,13 @@ private:
|
|
37
37
|
};
|
38
38
|
|
39
39
|
static void iter_getri(na_loop_t* const lp) {
|
40
|
-
|
40
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
41
41
|
lapack_int* ipiv = (lapack_int*)NDL_PTR(lp, 1);
|
42
42
|
int* info = (int*)NDL_PTR(lp, 2);
|
43
43
|
getri_opt* opt = (getri_opt*)(lp->opt_ptr);
|
44
44
|
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
45
45
|
const lapack_int lda = n;
|
46
|
-
const lapack_int i =
|
46
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, n, a, lda, ipiv);
|
47
47
|
*info = static_cast<int>(i);
|
48
48
|
}
|
49
49
|
|
@@ -57,7 +57,7 @@ private:
|
|
57
57
|
ID kw_table[1] = { rb_intern("order") };
|
58
58
|
VALUE kw_values[1] = { Qundef };
|
59
59
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
60
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
60
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
61
61
|
|
62
62
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
63
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -102,26 +102,6 @@ private:
|
|
102
102
|
|
103
103
|
return ret;
|
104
104
|
}
|
105
|
-
|
106
|
-
static int get_matrix_layout(VALUE val) {
|
107
|
-
const char* option_str = StringValueCStr(val);
|
108
|
-
|
109
|
-
if (std::strlen(option_str) > 0) {
|
110
|
-
switch (option_str[0]) {
|
111
|
-
case 'r':
|
112
|
-
case 'R':
|
113
|
-
break;
|
114
|
-
case 'c':
|
115
|
-
case 'C':
|
116
|
-
rb_warn("Numo::TinyLinalg::Lapack.getri does not support column major.");
|
117
|
-
break;
|
118
|
-
}
|
119
|
-
}
|
120
|
-
|
121
|
-
RB_GC_GUARD(val);
|
122
|
-
|
123
|
-
return LAPACK_ROW_MAJOR;
|
124
|
-
}
|
125
105
|
};
|
126
106
|
|
127
107
|
} // namespace TinyLinalg
|