numo-tiny_linalg 0.0.4 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -1
- data/README.md +3 -3
- data/ext/numo/tiny_linalg/blas/gemm.hpp +3 -49
- data/ext/numo/tiny_linalg/blas/gemv.hpp +2 -48
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/gesdd.hpp +11 -11
- data/ext/numo/tiny_linalg/lapack/gesv.hpp +10 -30
- data/ext/numo/tiny_linalg/lapack/gesvd.hpp +12 -12
- data/ext/numo/tiny_linalg/lapack/getrf.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/getri.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +16 -72
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +9 -55
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +16 -72
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +5 -25
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +56 -21
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +30 -6
- data/ext/numo/tiny_linalg/util.hpp +100 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +175 -52
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 81e69278b32b0f565e3f5b5e1eae1738a57556773e9306bb3b0c8450c9fb75ea
|
4
|
+
data.tar.gz: de6daa349cd0d16b14f9c133ad451fac8bacbcd42cca682be68509bd5d8fc416
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 72894605f074e09f47a37834ba1c2b3a3e504e149742df3290fd4c4e8042b1ac0c70d11bd50e4ab83e19b07181d8a656ab57afcaa47de53794fc28a298e1e768
|
7
|
+
data.tar.gz: c896b561fcbeb0ff19196548a3a0f2f6a06098d706fbfe95b785ffd850b139e87ea5d62c1144cde4c456a7ef5bb2b79422b5c2e0ab909df76f5ea9fa4a4bd446
|
data/CHANGELOG.md
CHANGED
@@ -1,6 +1,12 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
-
## [[0.
|
3
|
+
## [[0.1.1](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.1.0...v0.1.1)] - 2023-08-07
|
4
|
+
- Fix method of getting start and end of eigenvalue range from vals_range arguement of TinyLinalg.eigh.
|
5
|
+
|
6
|
+
## [[0.1.0](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.4...v0.1.0)] - 2023-08-06
|
7
|
+
- Refactor codes and update documentations.
|
8
|
+
|
9
|
+
## [[0.0.4](https://github.com/yoshoku/numo-tiny_linalg/compare/v0.0.3...v0.0.4)] - 2023-08-06
|
4
10
|
- Add dsygv, ssygv, zhegv, and chegv module functions to TinyLinalg::Lapack.
|
5
11
|
- Add dsygvd, ssygvd, zhegvd, and chegvd module functions to TinyLinalg::Lapack.
|
6
12
|
- Add dsygvx, ssygvx, zhegvx, and chegvx module functions to TinyLinalg::Lapack.
|
data/README.md
CHANGED
@@ -5,11 +5,11 @@
|
|
5
5
|
[](https://github.com/yoshoku/numo-tiny_linalg/blob/main/LICENSE.txt)
|
6
6
|
[](https://yoshoku.github.io/numo-tiny_linalg/doc/)
|
7
7
|
|
8
|
-
Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
9
|
-
|
10
|
-
This gem is still **under development** and may undergo many changes in the future.
|
8
|
+
Numo::TinyLinalg is a subset library from [Numo::Linalg](https://github.com/ruby-numo/numo-linalg) consisting only of methods used in Machine Learning algorithms.
|
9
|
+
The functions Numo::TinyLinalg supports are dot, det, eigh, inv, pinv, qr, solve, and svd.
|
11
10
|
|
12
11
|
## Installation
|
12
|
+
Unlike Numo::Linalg, Numo::TinyLinalg only supports OpenBLAS as a backend library for BLAS and LAPACK.
|
13
13
|
|
14
14
|
Install the OpenBlas.
|
15
15
|
|
@@ -102,9 +102,9 @@ private:
|
|
102
102
|
|
103
103
|
dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
|
104
104
|
dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
|
105
|
-
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
|
106
|
-
enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
107
|
-
enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? get_cblas_trans(kw_values[4]) : CblasNoTrans;
|
105
|
+
enum CBLAS_ORDER order = kw_values[2] != Qundef ? Util().get_cblas_order(kw_values[2]) : CblasRowMajor;
|
106
|
+
enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? Util().get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
107
|
+
enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? Util().get_cblas_trans(kw_values[4]) : CblasNoTrans;
|
108
108
|
|
109
109
|
narray_t* a_nary = NULL;
|
110
110
|
GetNArray(a, a_nary);
|
@@ -172,52 +172,6 @@ private:
|
|
172
172
|
|
173
173
|
return ret;
|
174
174
|
};
|
175
|
-
|
176
|
-
static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
|
177
|
-
const char* option_str = StringValueCStr(val);
|
178
|
-
enum CBLAS_TRANSPOSE res = CblasNoTrans;
|
179
|
-
|
180
|
-
if (std::strlen(option_str) > 0) {
|
181
|
-
switch (option_str[0]) {
|
182
|
-
case 'n':
|
183
|
-
case 'N':
|
184
|
-
res = CblasNoTrans;
|
185
|
-
break;
|
186
|
-
case 't':
|
187
|
-
case 'T':
|
188
|
-
res = CblasTrans;
|
189
|
-
break;
|
190
|
-
case 'c':
|
191
|
-
case 'C':
|
192
|
-
res = CblasConjTrans;
|
193
|
-
break;
|
194
|
-
}
|
195
|
-
}
|
196
|
-
|
197
|
-
RB_GC_GUARD(val);
|
198
|
-
|
199
|
-
return res;
|
200
|
-
}
|
201
|
-
|
202
|
-
static enum CBLAS_ORDER get_cblas_order(VALUE val) {
|
203
|
-
const char* option_str = StringValueCStr(val);
|
204
|
-
|
205
|
-
if (std::strlen(option_str) > 0) {
|
206
|
-
switch (option_str[0]) {
|
207
|
-
case 'r':
|
208
|
-
case 'R':
|
209
|
-
break;
|
210
|
-
case 'c':
|
211
|
-
case 'C':
|
212
|
-
rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
|
213
|
-
break;
|
214
|
-
}
|
215
|
-
}
|
216
|
-
|
217
|
-
RB_GC_GUARD(val);
|
218
|
-
|
219
|
-
return CblasRowMajor;
|
220
|
-
}
|
221
175
|
};
|
222
176
|
|
223
177
|
} // namespace TinyLinalg
|
@@ -94,8 +94,8 @@ private:
|
|
94
94
|
|
95
95
|
dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
|
96
96
|
dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
|
97
|
-
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
|
98
|
-
enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
97
|
+
enum CBLAS_ORDER order = kw_values[2] != Qundef ? Util().get_cblas_order(kw_values[2]) : CblasRowMajor;
|
98
|
+
enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? Util().get_cblas_trans(kw_values[3]) : CblasNoTrans;
|
99
99
|
|
100
100
|
narray_t* a_nary = NULL;
|
101
101
|
GetNArray(a, a_nary);
|
@@ -160,52 +160,6 @@ private:
|
|
160
160
|
|
161
161
|
return ret;
|
162
162
|
}
|
163
|
-
|
164
|
-
static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
|
165
|
-
const char* option_str = StringValueCStr(val);
|
166
|
-
enum CBLAS_TRANSPOSE res = CblasNoTrans;
|
167
|
-
|
168
|
-
if (std::strlen(option_str) > 0) {
|
169
|
-
switch (option_str[0]) {
|
170
|
-
case 'n':
|
171
|
-
case 'N':
|
172
|
-
res = CblasNoTrans;
|
173
|
-
break;
|
174
|
-
case 't':
|
175
|
-
case 'T':
|
176
|
-
res = CblasTrans;
|
177
|
-
break;
|
178
|
-
case 'c':
|
179
|
-
case 'C':
|
180
|
-
res = CblasConjTrans;
|
181
|
-
break;
|
182
|
-
}
|
183
|
-
}
|
184
|
-
|
185
|
-
RB_GC_GUARD(val);
|
186
|
-
|
187
|
-
return res;
|
188
|
-
}
|
189
|
-
|
190
|
-
static enum CBLAS_ORDER get_cblas_order(VALUE val) {
|
191
|
-
const char* option_str = StringValueCStr(val);
|
192
|
-
|
193
|
-
if (std::strlen(option_str) > 0) {
|
194
|
-
switch (option_str[0]) {
|
195
|
-
case 'r':
|
196
|
-
case 'R':
|
197
|
-
break;
|
198
|
-
case 'c':
|
199
|
-
case 'C':
|
200
|
-
rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
|
201
|
-
break;
|
202
|
-
}
|
203
|
-
}
|
204
|
-
|
205
|
-
RB_GC_GUARD(val);
|
206
|
-
|
207
|
-
return CblasRowMajor;
|
208
|
-
}
|
209
163
|
};
|
210
164
|
|
211
165
|
} // namespace TinyLinalg
|
@@ -28,7 +28,7 @@ struct CGeQrf {
|
|
28
28
|
}
|
29
29
|
};
|
30
30
|
|
31
|
-
template <int nary_dtype_id, typename
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
32
32
|
class GeQrf {
|
33
33
|
public:
|
34
34
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
@@ -41,14 +41,14 @@ private:
|
|
41
41
|
};
|
42
42
|
|
43
43
|
static void iter_geqrf(na_loop_t* const lp) {
|
44
|
-
|
45
|
-
|
44
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
45
|
+
dtype* tau = (dtype*)NDL_PTR(lp, 1);
|
46
46
|
int* info = (int*)NDL_PTR(lp, 2);
|
47
47
|
geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
|
48
48
|
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
49
49
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
50
50
|
const lapack_int lda = n;
|
51
|
-
const lapack_int i =
|
51
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, a, lda, tau);
|
52
52
|
*info = static_cast<int>(i);
|
53
53
|
}
|
54
54
|
|
@@ -61,7 +61,7 @@ private:
|
|
61
61
|
ID kw_table[1] = { rb_intern("order") };
|
62
62
|
VALUE kw_values[1] = { Qundef };
|
63
63
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
64
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
64
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
65
65
|
|
66
66
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
67
67
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -93,26 +93,6 @@ private:
|
|
93
93
|
|
94
94
|
return ret;
|
95
95
|
}
|
96
|
-
|
97
|
-
static int get_matrix_layout(VALUE val) {
|
98
|
-
const char* option_str = StringValueCStr(val);
|
99
|
-
|
100
|
-
if (std::strlen(option_str) > 0) {
|
101
|
-
switch (option_str[0]) {
|
102
|
-
case 'r':
|
103
|
-
case 'R':
|
104
|
-
break;
|
105
|
-
case 'c':
|
106
|
-
case 'C':
|
107
|
-
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
108
|
-
break;
|
109
|
-
}
|
110
|
-
}
|
111
|
-
|
112
|
-
RB_GC_GUARD(val);
|
113
|
-
|
114
|
-
return LAPACK_ROW_MAJOR;
|
115
|
-
}
|
116
96
|
};
|
117
97
|
|
118
98
|
} // namespace TinyLinalg
|
@@ -1,35 +1,35 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeSdd {
|
4
4
|
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
5
5
|
double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt) {
|
6
6
|
return LAPACKE_dgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
7
7
|
};
|
8
8
|
};
|
9
9
|
|
10
|
-
struct
|
10
|
+
struct SGeSdd {
|
11
11
|
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
12
12
|
float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt) {
|
13
13
|
return LAPACKE_sgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
14
14
|
};
|
15
15
|
};
|
16
16
|
|
17
|
-
struct
|
17
|
+
struct ZGeSdd {
|
18
18
|
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
19
19
|
lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt) {
|
20
20
|
return LAPACKE_zgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
21
21
|
};
|
22
22
|
};
|
23
23
|
|
24
|
-
struct
|
24
|
+
struct CGeSdd {
|
25
25
|
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
|
26
26
|
lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt) {
|
27
27
|
return LAPACKE_cgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
28
28
|
};
|
29
29
|
};
|
30
30
|
|
31
|
-
template <int nary_dtype_id, int nary_rtype_id, typename
|
32
|
-
class
|
31
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
32
|
+
class GeSdd {
|
33
33
|
public:
|
34
34
|
static void define_module_function(VALUE mLapack, const char* mf_name) {
|
35
35
|
rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesdd), -1);
|
@@ -42,10 +42,10 @@ private:
|
|
42
42
|
};
|
43
43
|
|
44
44
|
static void iter_gesdd(na_loop_t* const lp) {
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
45
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
46
|
+
rtype* s = (rtype*)NDL_PTR(lp, 1);
|
47
|
+
dtype* u = (dtype*)NDL_PTR(lp, 2);
|
48
|
+
dtype* vt = (dtype*)NDL_PTR(lp, 3);
|
49
49
|
int* info = (int*)NDL_PTR(lp, 4);
|
50
50
|
gesdd_opt* opt = (gesdd_opt*)(lp->opt_ptr);
|
51
51
|
|
@@ -56,7 +56,7 @@ private:
|
|
56
56
|
const lapack_int ldu = opt->jobz == 'S' ? min_mn : m;
|
57
57
|
const lapack_int ldvt = opt->jobz == 'S' ? min_mn : n;
|
58
58
|
|
59
|
-
lapack_int i =
|
59
|
+
lapack_int i = LapackFn().call(opt->matrix_order, opt->jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
|
60
60
|
*info = static_cast<int>(i);
|
61
61
|
};
|
62
62
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeSv {
|
4
4
|
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
5
5
|
double* a, lapack_int lda, lapack_int* ipiv,
|
6
6
|
double* b, lapack_int ldb) {
|
@@ -8,7 +8,7 @@ struct DGESV {
|
|
8
8
|
}
|
9
9
|
};
|
10
10
|
|
11
|
-
struct
|
11
|
+
struct SGeSv {
|
12
12
|
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
13
13
|
float* a, lapack_int lda, lapack_int* ipiv,
|
14
14
|
float* b, lapack_int ldb) {
|
@@ -16,7 +16,7 @@ struct SGESV {
|
|
16
16
|
}
|
17
17
|
};
|
18
18
|
|
19
|
-
struct
|
19
|
+
struct ZGeSv {
|
20
20
|
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
21
21
|
lapack_complex_double* a, lapack_int lda, lapack_int* ipiv,
|
22
22
|
lapack_complex_double* b, lapack_int ldb) {
|
@@ -24,7 +24,7 @@ struct ZGESV {
|
|
24
24
|
}
|
25
25
|
};
|
26
26
|
|
27
|
-
struct
|
27
|
+
struct CGeSv {
|
28
28
|
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
|
29
29
|
lapack_complex_float* a, lapack_int lda, lapack_int* ipiv,
|
30
30
|
lapack_complex_float* b, lapack_int ldb) {
|
@@ -32,8 +32,8 @@ struct CGESV {
|
|
32
32
|
}
|
33
33
|
};
|
34
34
|
|
35
|
-
template <int nary_dtype_id, typename
|
36
|
-
class
|
35
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
36
|
+
class GeSv {
|
37
37
|
public:
|
38
38
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
39
39
|
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_gesv), -1);
|
@@ -45,8 +45,8 @@ private:
|
|
45
45
|
};
|
46
46
|
|
47
47
|
static void iter_gesv(na_loop_t* const lp) {
|
48
|
-
|
49
|
-
|
48
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
49
|
+
dtype* b = (dtype*)NDL_PTR(lp, 1);
|
50
50
|
int* ipiv = (int*)NDL_PTR(lp, 2);
|
51
51
|
int* info = (int*)NDL_PTR(lp, 3);
|
52
52
|
gesv_opt* opt = (gesv_opt*)(lp->opt_ptr);
|
@@ -54,7 +54,7 @@ private:
|
|
54
54
|
const lapack_int nhrs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
|
55
55
|
const lapack_int lda = n;
|
56
56
|
const lapack_int ldb = nhrs;
|
57
|
-
const lapack_int i =
|
57
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, n, nhrs, a, lda, ipiv, b, ldb);
|
58
58
|
*info = static_cast<int>(i);
|
59
59
|
}
|
60
60
|
|
@@ -71,7 +71,7 @@ private:
|
|
71
71
|
|
72
72
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
73
73
|
|
74
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
74
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
75
75
|
|
76
76
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
77
77
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -123,26 +123,6 @@ private:
|
|
123
123
|
|
124
124
|
return ret;
|
125
125
|
}
|
126
|
-
|
127
|
-
static int get_matrix_layout(VALUE val) {
|
128
|
-
const char* option_str = StringValueCStr(val);
|
129
|
-
|
130
|
-
if (std::strlen(option_str) > 0) {
|
131
|
-
switch (option_str[0]) {
|
132
|
-
case 'r':
|
133
|
-
case 'R':
|
134
|
-
break;
|
135
|
-
case 'c':
|
136
|
-
case 'C':
|
137
|
-
rb_warn("Numo::TinyLinalg::Lapack.gesv does not support column major.");
|
138
|
-
break;
|
139
|
-
}
|
140
|
-
}
|
141
|
-
|
142
|
-
RB_GC_GUARD(val);
|
143
|
-
|
144
|
-
return LAPACK_ROW_MAJOR;
|
145
|
-
}
|
146
126
|
};
|
147
127
|
|
148
128
|
} // namespace TinyLinalg
|
@@ -1,6 +1,6 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeSvd {
|
4
4
|
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
5
5
|
double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt,
|
6
6
|
double* superb) {
|
@@ -8,7 +8,7 @@ struct DGESVD {
|
|
8
8
|
};
|
9
9
|
};
|
10
10
|
|
11
|
-
struct
|
11
|
+
struct SGeSvd {
|
12
12
|
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
13
13
|
float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt,
|
14
14
|
float* superb) {
|
@@ -16,7 +16,7 @@ struct SGESVD {
|
|
16
16
|
};
|
17
17
|
};
|
18
18
|
|
19
|
-
struct
|
19
|
+
struct ZGeSvd {
|
20
20
|
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
21
21
|
lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt,
|
22
22
|
double* superb) {
|
@@ -24,7 +24,7 @@ struct ZGESVD {
|
|
24
24
|
};
|
25
25
|
};
|
26
26
|
|
27
|
-
struct
|
27
|
+
struct CGeSvd {
|
28
28
|
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
|
29
29
|
lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt,
|
30
30
|
float* superb) {
|
@@ -32,8 +32,8 @@ struct CGESVD {
|
|
32
32
|
};
|
33
33
|
};
|
34
34
|
|
35
|
-
template <int nary_dtype_id, int nary_rtype_id, typename
|
36
|
-
class
|
35
|
+
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
|
36
|
+
class GeSvd {
|
37
37
|
public:
|
38
38
|
static void define_module_function(VALUE mLapack, const char* mf_name) {
|
39
39
|
rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesvd), -1);
|
@@ -47,10 +47,10 @@ private:
|
|
47
47
|
};
|
48
48
|
|
49
49
|
static void iter_gesvd(na_loop_t* const lp) {
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
50
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
51
|
+
rtype* s = (rtype*)NDL_PTR(lp, 1);
|
52
|
+
dtype* u = (dtype*)NDL_PTR(lp, 2);
|
53
|
+
dtype* vt = (dtype*)NDL_PTR(lp, 3);
|
54
54
|
int* info = (int*)NDL_PTR(lp, 4);
|
55
55
|
gesvd_opt* opt = (gesvd_opt*)(lp->opt_ptr);
|
56
56
|
|
@@ -61,9 +61,9 @@ private:
|
|
61
61
|
const lapack_int ldu = opt->jobu == 'A' ? m : min_mn;
|
62
62
|
const lapack_int ldvt = n;
|
63
63
|
|
64
|
-
|
64
|
+
rtype* superb = (rtype*)ruby_xmalloc(min_mn * sizeof(rtype));
|
65
65
|
|
66
|
-
lapack_int i =
|
66
|
+
lapack_int i = LapackFn().call(opt->matrix_order, opt->jobu, opt->jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
|
67
67
|
*info = static_cast<int>(i);
|
68
68
|
|
69
69
|
ruby_xfree(superb);
|
@@ -1,35 +1,35 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeTrf {
|
4
4
|
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
5
5
|
double* a, lapack_int lda, lapack_int* ipiv) {
|
6
6
|
return LAPACKE_dgetrf(matrix_layout, m, n, a, lda, ipiv);
|
7
7
|
}
|
8
8
|
};
|
9
9
|
|
10
|
-
struct
|
10
|
+
struct SGeTrf {
|
11
11
|
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
12
12
|
float* a, lapack_int lda, lapack_int* ipiv) {
|
13
13
|
return LAPACKE_sgetrf(matrix_layout, m, n, a, lda, ipiv);
|
14
14
|
}
|
15
15
|
};
|
16
16
|
|
17
|
-
struct
|
17
|
+
struct ZGeTrf {
|
18
18
|
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
19
19
|
lapack_complex_double* a, lapack_int lda, lapack_int* ipiv) {
|
20
20
|
return LAPACKE_zgetrf(matrix_layout, m, n, a, lda, ipiv);
|
21
21
|
}
|
22
22
|
};
|
23
23
|
|
24
|
-
struct
|
24
|
+
struct CGeTrf {
|
25
25
|
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
|
26
26
|
lapack_complex_float* a, lapack_int lda, lapack_int* ipiv) {
|
27
27
|
return LAPACKE_cgetrf(matrix_layout, m, n, a, lda, ipiv);
|
28
28
|
}
|
29
29
|
};
|
30
30
|
|
31
|
-
template <int nary_dtype_id, typename
|
32
|
-
class
|
31
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
32
|
+
class GeTrf {
|
33
33
|
public:
|
34
34
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
35
35
|
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getrf), -1);
|
@@ -41,14 +41,14 @@ private:
|
|
41
41
|
};
|
42
42
|
|
43
43
|
static void iter_getrf(na_loop_t* const lp) {
|
44
|
-
|
44
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
45
45
|
int* ipiv = (int*)NDL_PTR(lp, 1);
|
46
46
|
int* info = (int*)NDL_PTR(lp, 2);
|
47
47
|
getrf_opt* opt = (getrf_opt*)(lp->opt_ptr);
|
48
48
|
const lapack_int m = NDL_SHAPE(lp, 0)[0];
|
49
49
|
const lapack_int n = NDL_SHAPE(lp, 0)[1];
|
50
50
|
const lapack_int lda = n;
|
51
|
-
const lapack_int i =
|
51
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, a, lda, ipiv);
|
52
52
|
*info = static_cast<int>(i);
|
53
53
|
}
|
54
54
|
|
@@ -61,7 +61,7 @@ private:
|
|
61
61
|
ID kw_table[1] = { rb_intern("order") };
|
62
62
|
VALUE kw_values[1] = { Qundef };
|
63
63
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
64
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
64
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
65
65
|
|
66
66
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
67
67
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -93,26 +93,6 @@ private:
|
|
93
93
|
|
94
94
|
return ret;
|
95
95
|
}
|
96
|
-
|
97
|
-
static int get_matrix_layout(VALUE val) {
|
98
|
-
const char* option_str = StringValueCStr(val);
|
99
|
-
|
100
|
-
if (std::strlen(option_str) > 0) {
|
101
|
-
switch (option_str[0]) {
|
102
|
-
case 'r':
|
103
|
-
case 'R':
|
104
|
-
break;
|
105
|
-
case 'c':
|
106
|
-
case 'C':
|
107
|
-
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
|
108
|
-
break;
|
109
|
-
}
|
110
|
-
}
|
111
|
-
|
112
|
-
RB_GC_GUARD(val);
|
113
|
-
|
114
|
-
return LAPACK_ROW_MAJOR;
|
115
|
-
}
|
116
96
|
};
|
117
97
|
|
118
98
|
} // namespace TinyLinalg
|
@@ -1,31 +1,31 @@
|
|
1
1
|
namespace TinyLinalg {
|
2
2
|
|
3
|
-
struct
|
3
|
+
struct DGeTri {
|
4
4
|
lapack_int call(int matrix_layout, lapack_int n, double* a, lapack_int lda, const lapack_int* ipiv) {
|
5
5
|
return LAPACKE_dgetri(matrix_layout, n, a, lda, ipiv);
|
6
6
|
}
|
7
7
|
};
|
8
8
|
|
9
|
-
struct
|
9
|
+
struct SGeTri {
|
10
10
|
lapack_int call(int matrix_layout, lapack_int n, float* a, lapack_int lda, const lapack_int* ipiv) {
|
11
11
|
return LAPACKE_sgetri(matrix_layout, n, a, lda, ipiv);
|
12
12
|
}
|
13
13
|
};
|
14
14
|
|
15
|
-
struct
|
15
|
+
struct ZGeTri {
|
16
16
|
lapack_int call(int matrix_layout, lapack_int n, lapack_complex_double* a, lapack_int lda, const lapack_int* ipiv) {
|
17
17
|
return LAPACKE_zgetri(matrix_layout, n, a, lda, ipiv);
|
18
18
|
}
|
19
19
|
};
|
20
20
|
|
21
|
-
struct
|
21
|
+
struct CGeTri {
|
22
22
|
lapack_int call(int matrix_layout, lapack_int n, lapack_complex_float* a, lapack_int lda, const lapack_int* ipiv) {
|
23
23
|
return LAPACKE_cgetri(matrix_layout, n, a, lda, ipiv);
|
24
24
|
}
|
25
25
|
};
|
26
26
|
|
27
|
-
template <int nary_dtype_id, typename
|
28
|
-
class
|
27
|
+
template <int nary_dtype_id, typename dtype, class LapackFn>
|
28
|
+
class GeTri {
|
29
29
|
public:
|
30
30
|
static void define_module_function(VALUE mLapack, const char* fnc_name) {
|
31
31
|
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getri), -1);
|
@@ -37,13 +37,13 @@ private:
|
|
37
37
|
};
|
38
38
|
|
39
39
|
static void iter_getri(na_loop_t* const lp) {
|
40
|
-
|
40
|
+
dtype* a = (dtype*)NDL_PTR(lp, 0);
|
41
41
|
lapack_int* ipiv = (lapack_int*)NDL_PTR(lp, 1);
|
42
42
|
int* info = (int*)NDL_PTR(lp, 2);
|
43
43
|
getri_opt* opt = (getri_opt*)(lp->opt_ptr);
|
44
44
|
const lapack_int n = NDL_SHAPE(lp, 0)[0];
|
45
45
|
const lapack_int lda = n;
|
46
|
-
const lapack_int i =
|
46
|
+
const lapack_int i = LapackFn().call(opt->matrix_layout, n, a, lda, ipiv);
|
47
47
|
*info = static_cast<int>(i);
|
48
48
|
}
|
49
49
|
|
@@ -57,7 +57,7 @@ private:
|
|
57
57
|
ID kw_table[1] = { rb_intern("order") };
|
58
58
|
VALUE kw_values[1] = { Qundef };
|
59
59
|
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
60
|
-
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
60
|
+
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
|
61
61
|
|
62
62
|
if (CLASS_OF(a_vnary) != nary_dtype) {
|
63
63
|
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
|
@@ -102,26 +102,6 @@ private:
|
|
102
102
|
|
103
103
|
return ret;
|
104
104
|
}
|
105
|
-
|
106
|
-
static int get_matrix_layout(VALUE val) {
|
107
|
-
const char* option_str = StringValueCStr(val);
|
108
|
-
|
109
|
-
if (std::strlen(option_str) > 0) {
|
110
|
-
switch (option_str[0]) {
|
111
|
-
case 'r':
|
112
|
-
case 'R':
|
113
|
-
break;
|
114
|
-
case 'c':
|
115
|
-
case 'C':
|
116
|
-
rb_warn("Numo::TinyLinalg::Lapack.getri does not support column major.");
|
117
|
-
break;
|
118
|
-
}
|
119
|
-
}
|
120
|
-
|
121
|
-
RB_GC_GUARD(val);
|
122
|
-
|
123
|
-
return LAPACK_ROW_MAJOR;
|
124
|
-
}
|
125
105
|
};
|
126
106
|
|
127
107
|
} // namespace TinyLinalg
|