numo-tiny_linalg 0.0.4 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -18,7 +18,7 @@ struct SSyGvx {
18
18
  }
19
19
  };
20
20
 
21
- template <int nary_dtype_id, typename DType, typename FncType>
21
+ template <int nary_dtype_id, typename dtype, class LapackFn>
22
22
  class SyGvx {
23
23
  public:
24
24
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -32,18 +32,18 @@ private:
32
32
  char jobz;
33
33
  char range;
34
34
  char uplo;
35
- DType vl;
36
- DType vu;
35
+ dtype vl;
36
+ dtype vu;
37
37
  lapack_int il;
38
38
  lapack_int iu;
39
39
  };
40
40
 
41
41
  static void iter_sygvx(na_loop_t* const lp) {
42
- DType* a = (DType*)NDL_PTR(lp, 0);
43
- DType* b = (DType*)NDL_PTR(lp, 1);
42
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
43
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
44
44
  int* m = (int*)NDL_PTR(lp, 2);
45
- DType* w = (DType*)NDL_PTR(lp, 3);
46
- DType* z = (DType*)NDL_PTR(lp, 4);
45
+ dtype* w = (dtype*)NDL_PTR(lp, 3);
46
+ dtype* z = (dtype*)NDL_PTR(lp, 4);
47
47
  int* ifail = (int*)NDL_PTR(lp, 5);
48
48
  int* info = (int*)NDL_PTR(lp, 6);
49
49
  sygvx_opt* opt = (sygvx_opt*)(lp->opt_ptr);
@@ -51,8 +51,8 @@ private:
51
51
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
52
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
53
  const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
- const DType abstol = 0.0;
55
- const lapack_int i = FncType().call(
54
+ const dtype abstol = 0.0;
55
+ const lapack_int i = LapackFn().call(
56
56
  opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
57
  opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
58
  *info = static_cast<int>(i);
@@ -69,15 +69,15 @@ private:
69
69
  rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
70
70
  VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
71
71
  rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
72
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
73
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
74
- const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
75
- const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
76
- const DType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
77
- const DType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
72
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
73
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
74
+ const char range = kw_values[2] != Qundef ? Util().get_range(kw_values[2]) : 'A';
75
+ const char uplo = kw_values[3] != Qundef ? Util().get_uplo(kw_values[3]) : 'U';
76
+ const dtype vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
77
+ const dtype vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
78
78
  const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
79
79
  const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
80
- const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
80
+ const int matrix_layout = kw_values[8] != Qundef ? Util().get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
81
81
 
82
82
  if (CLASS_OF(a_vnary) != nary_dtype) {
83
83
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -131,62 +131,6 @@ private:
131
131
 
132
132
  return ret;
133
133
  }
134
-
135
- static lapack_int get_itype(VALUE val) {
136
- const lapack_int itype = NUM2INT(val);
137
-
138
- if (itype != 1 && itype != 2 && itype != 3) {
139
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
140
- }
141
-
142
- return itype;
143
- }
144
-
145
- static char get_jobz(VALUE val) {
146
- const char jobz = NUM2CHR(val);
147
-
148
- if (jobz != 'N' && jobz != 'V') {
149
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
150
- }
151
-
152
- return jobz;
153
- }
154
-
155
- static char get_range(VALUE val) {
156
- const char range = NUM2CHR(val);
157
-
158
- if (range != 'A' && range != 'V' && range != 'I') {
159
- rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
160
- }
161
-
162
- return range;
163
- }
164
-
165
- static char get_uplo(VALUE val) {
166
- const char uplo = NUM2CHR(val);
167
-
168
- if (uplo != 'U' && uplo != 'L') {
169
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
170
- }
171
-
172
- return uplo;
173
- }
174
-
175
- static int get_matrix_layout(VALUE val) {
176
- const char option = NUM2CHR(val);
177
-
178
- switch (option) {
179
- case 'r':
180
- case 'R':
181
- break;
182
- case 'c':
183
- case 'C':
184
- rb_warn("Numo::TinyLinalg::Lapack.sygvx does not support column major.");
185
- break;
186
- }
187
-
188
- return LAPACK_ROW_MAJOR;
189
- }
190
134
  };
191
135
 
192
136
  } // namespace TinyLinalg
@@ -14,7 +14,7 @@ struct CUngQr {
14
14
  }
15
15
  };
16
16
 
17
- template <int nary_dtype_id, typename DType, typename FncType>
17
+ template <int nary_dtype_id, typename dtype, class LapackFn>
18
18
  class UngQr {
19
19
  public:
20
20
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -27,15 +27,15 @@ private:
27
27
  };
28
28
 
29
29
  static void iter_ungqr(na_loop_t* const lp) {
30
- DType* a = (DType*)NDL_PTR(lp, 0);
31
- DType* tau = (DType*)NDL_PTR(lp, 1);
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* tau = (dtype*)NDL_PTR(lp, 1);
32
32
  int* info = (int*)NDL_PTR(lp, 2);
33
33
  ungqr_opt* opt = (ungqr_opt*)(lp->opt_ptr);
34
34
  const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
35
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
36
  const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
37
  const lapack_int lda = n;
38
- const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
38
+ const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
39
  *info = static_cast<int>(i);
40
40
  }
41
41
 
@@ -49,7 +49,7 @@ private:
49
49
  ID kw_table[1] = { rb_intern("order") };
50
50
  VALUE kw_values[1] = { Qundef };
51
51
  rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
- const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
52
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
53
 
54
54
  if (CLASS_OF(a_vnary) != nary_dtype) {
55
55
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -90,26 +90,6 @@ private:
90
90
 
91
91
  return ret;
92
92
  }
93
-
94
- static int get_matrix_layout(VALUE val) {
95
- const char* option_str = StringValueCStr(val);
96
-
97
- if (std::strlen(option_str) > 0) {
98
- switch (option_str[0]) {
99
- case 'r':
100
- case 'R':
101
- break;
102
- case 'c':
103
- case 'C':
104
- rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
- break;
106
- }
107
- }
108
-
109
- RB_GC_GUARD(val);
110
-
111
- return LAPACK_ROW_MAJOR;
112
- }
113
93
  };
114
94
 
115
95
  } // namespace TinyLinalg
@@ -1,10 +1,43 @@
1
+ /**
2
+ * Copyright (c) 2023 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
1
31
  #include "tiny_linalg.hpp"
32
+
33
+ #include "converter.hpp"
34
+ #include "util.hpp"
35
+
2
36
  #include "blas/dot.hpp"
3
37
  #include "blas/dot_sub.hpp"
4
38
  #include "blas/gemm.hpp"
5
39
  #include "blas/gemv.hpp"
6
40
  #include "blas/nrm2.hpp"
7
- #include "converter.hpp"
8
41
  #include "lapack/geqrf.hpp"
9
42
  #include "lapack/gesdd.hpp"
10
43
  #include "lapack/gesv.hpp"
@@ -197,11 +230,13 @@ extern "C" void Init_tiny_linalg(void) {
197
230
  /**
198
231
  * Document-module: Numo::TinyLinalg::Blas
199
232
  * Numo::TinyLinalg::Blas is wrapper module of BLAS functions.
233
+ * @!visibility private
200
234
  */
201
235
  rb_mTinyLinalgBlas = rb_define_module_under(rb_mTinyLinalg, "Blas");
202
236
  /**
203
237
  * Document-module: Numo::TinyLinalg::Lapack
204
238
  * Numo::TinyLinalg::Lapack is wrapper module of LAPACK functions.
239
+ * @!visibility private
205
240
  */
206
241
  rb_mTinyLinalgLapack = rb_define_module_under(rb_mTinyLinalg, "Lapack");
207
242
 
@@ -249,26 +284,26 @@ extern "C" void Init_tiny_linalg(void) {
249
284
  TinyLinalg::Nrm2<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SNrm2>::define_module_function(rb_mTinyLinalgBlas, "snrm2");
250
285
  TinyLinalg::Nrm2<TinyLinalg::numo_cDComplexId, double, TinyLinalg::DZNrm2>::define_module_function(rb_mTinyLinalgBlas, "dznrm2");
251
286
  TinyLinalg::Nrm2<TinyLinalg::numo_cSComplexId, float, TinyLinalg::SCNrm2>::define_module_function(rb_mTinyLinalgBlas, "scnrm2");
252
- TinyLinalg::GESV<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGESV>::define_module_function(rb_mTinyLinalgLapack, "dgesv");
253
- TinyLinalg::GESV<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGESV>::define_module_function(rb_mTinyLinalgLapack, "sgesv");
254
- TinyLinalg::GESV<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGESV>::define_module_function(rb_mTinyLinalgLapack, "zgesv");
255
- TinyLinalg::GESV<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGESV>::define_module_function(rb_mTinyLinalgLapack, "cgesv");
256
- TinyLinalg::GESVD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESVD>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
257
- TinyLinalg::GESVD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESVD>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
258
- TinyLinalg::GESVD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESVD>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
259
- TinyLinalg::GESVD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESVD>::define_module_function(rb_mTinyLinalgLapack, "cgesvd");
260
- TinyLinalg::GESDD<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGESDD>::define_module_function(rb_mTinyLinalgLapack, "dgesdd");
261
- TinyLinalg::GESDD<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGESDD>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
262
- TinyLinalg::GESDD<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGESDD>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
263
- TinyLinalg::GESDD<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGESDD>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
264
- TinyLinalg::GETRF<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGETRF>::define_module_function(rb_mTinyLinalgLapack, "dgetrf");
265
- TinyLinalg::GETRF<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRF>::define_module_function(rb_mTinyLinalgLapack, "sgetrf");
266
- TinyLinalg::GETRF<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRF>::define_module_function(rb_mTinyLinalgLapack, "zgetrf");
267
- TinyLinalg::GETRF<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRF>::define_module_function(rb_mTinyLinalgLapack, "cgetrf");
268
- TinyLinalg::GETRI<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGETRI>::define_module_function(rb_mTinyLinalgLapack, "dgetri");
269
- TinyLinalg::GETRI<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGETRI>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
270
- TinyLinalg::GETRI<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGETRI>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
271
- TinyLinalg::GETRI<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGETRI>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
287
+ TinyLinalg::GeSv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeSv>::define_module_function(rb_mTinyLinalgLapack, "dgesv");
288
+ TinyLinalg::GeSv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeSv>::define_module_function(rb_mTinyLinalgLapack, "sgesv");
289
+ TinyLinalg::GeSv<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeSv>::define_module_function(rb_mTinyLinalgLapack, "zgesv");
290
+ TinyLinalg::GeSv<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeSv>::define_module_function(rb_mTinyLinalgLapack, "cgesv");
291
+ TinyLinalg::GeSvd<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGeSvd>::define_module_function(rb_mTinyLinalgLapack, "dgesvd");
292
+ TinyLinalg::GeSvd<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGeSvd>::define_module_function(rb_mTinyLinalgLapack, "sgesvd");
293
+ TinyLinalg::GeSvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGeSvd>::define_module_function(rb_mTinyLinalgLapack, "zgesvd");
294
+ TinyLinalg::GeSvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGeSvd>::define_module_function(rb_mTinyLinalgLapack, "cgesvd");
295
+ TinyLinalg::GeSdd<TinyLinalg::numo_cDFloatId, TinyLinalg::numo_cDFloatId, double, double, TinyLinalg::DGeSdd>::define_module_function(rb_mTinyLinalgLapack, "dgesdd");
296
+ TinyLinalg::GeSdd<TinyLinalg::numo_cSFloatId, TinyLinalg::numo_cSFloatId, float, float, TinyLinalg::SGeSdd>::define_module_function(rb_mTinyLinalgLapack, "sgesdd");
297
+ TinyLinalg::GeSdd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZGeSdd>::define_module_function(rb_mTinyLinalgLapack, "zgesdd");
298
+ TinyLinalg::GeSdd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CGeSdd>::define_module_function(rb_mTinyLinalgLapack, "cgesdd");
299
+ TinyLinalg::GeTrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeTrf>::define_module_function(rb_mTinyLinalgLapack, "dgetrf");
300
+ TinyLinalg::GeTrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTrf>::define_module_function(rb_mTinyLinalgLapack, "sgetrf");
301
+ TinyLinalg::GeTrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTrf>::define_module_function(rb_mTinyLinalgLapack, "zgetrf");
302
+ TinyLinalg::GeTrf<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTrf>::define_module_function(rb_mTinyLinalgLapack, "cgetrf");
303
+ TinyLinalg::GeTri<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeTri>::define_module_function(rb_mTinyLinalgLapack, "dgetri");
304
+ TinyLinalg::GeTri<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeTri>::define_module_function(rb_mTinyLinalgLapack, "sgetri");
305
+ TinyLinalg::GeTri<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeTri>::define_module_function(rb_mTinyLinalgLapack, "zgetri");
306
+ TinyLinalg::GeTri<TinyLinalg::numo_cSComplexId, lapack_complex_float, TinyLinalg::CGeTri>::define_module_function(rb_mTinyLinalgLapack, "cgetri");
272
307
  TinyLinalg::GeQrf<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DGeQrf>::define_module_function(rb_mTinyLinalgLapack, "dgeqrf");
273
308
  TinyLinalg::GeQrf<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SGeQrf>::define_module_function(rb_mTinyLinalgLapack, "sgeqrf");
274
309
  TinyLinalg::GeQrf<TinyLinalg::numo_cDComplexId, lapack_complex_double, TinyLinalg::ZGeQrf>::define_module_function(rb_mTinyLinalgLapack, "zgeqrf");
@@ -1,17 +1,41 @@
1
+ /**
2
+ * Copyright (c) 2023 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
1
31
  #ifndef NUMO_TINY_LINALG_H
2
32
  #define NUMO_TINY_LINALG_H 1
3
33
 
4
- #if defined(TINYLINALG_USE_ACCELERATE)
5
- #include <Accelerate/Accelerate.h>
6
- #else
7
34
  #include <cblas.h>
8
35
  #include <lapacke.h>
9
- #endif
10
36
 
11
37
  #include <string>
12
38
 
13
- #include <cstring>
14
-
15
39
  #include <ruby.h>
16
40
 
17
41
  #include <numo/narray.h>
@@ -0,0 +1,100 @@
1
+ namespace TinyLinalg {
2
+
3
+ class Util {
4
+ public:
5
+ static lapack_int get_itype(VALUE val) {
6
+ const lapack_int itype = NUM2INT(val);
7
+
8
+ if (itype != 1 && itype != 2 && itype != 3) {
9
+ rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
10
+ }
11
+
12
+ return itype;
13
+ }
14
+
15
+ static char get_jobz(VALUE val) {
16
+ const char jobz = NUM2CHR(val);
17
+
18
+ if (jobz != 'N' && jobz != 'V') {
19
+ rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
20
+ }
21
+
22
+ return jobz;
23
+ }
24
+
25
+ static char get_range(VALUE val) {
26
+ const char range = NUM2CHR(val);
27
+
28
+ if (range != 'A' && range != 'V' && range != 'I') {
29
+ rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
30
+ }
31
+
32
+ return range;
33
+ }
34
+
35
+ static char get_uplo(VALUE val) {
36
+ const char uplo = NUM2CHR(val);
37
+
38
+ if (uplo != 'U' && uplo != 'L') {
39
+ rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
40
+ }
41
+
42
+ return uplo;
43
+ }
44
+
45
+ static int get_matrix_layout(VALUE val) {
46
+ const char option = NUM2CHR(val);
47
+
48
+ switch (option) {
49
+ case 'r':
50
+ case 'R':
51
+ break;
52
+ case 'c':
53
+ case 'C':
54
+ rb_warn("Numo::TinyLinalg does not support column major.");
55
+ break;
56
+ }
57
+
58
+ return LAPACK_ROW_MAJOR;
59
+ }
60
+
61
+ static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
62
+ const char option = NUM2CHR(val);
63
+ enum CBLAS_TRANSPOSE res = CblasNoTrans;
64
+
65
+ switch (option) {
66
+ case 'n':
67
+ case 'N':
68
+ res = CblasNoTrans;
69
+ break;
70
+ case 't':
71
+ case 'T':
72
+ res = CblasTrans;
73
+ break;
74
+ case 'c':
75
+ case 'C':
76
+ res = CblasConjTrans;
77
+ break;
78
+ }
79
+
80
+ return res;
81
+ }
82
+
83
+ static enum CBLAS_ORDER get_cblas_order(VALUE val) {
84
+ const char option = NUM2CHR(val);
85
+
86
+ switch (option) {
87
+ case 'r':
88
+ case 'R':
89
+ break;
90
+ case 'c':
91
+ case 'C':
92
+ rb_warn("Numo::TinyLinalg does not support column major.");
93
+ break;
94
+ }
95
+
96
+ return CblasRowMajor;
97
+ }
98
+ };
99
+
100
+ } // namespace TinyLinalg
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::TinyLinalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
6
6
  module TinyLinalg
7
7
  # The version of Numo::TinyLinalg you install.
8
- VERSION = '0.0.4'
8
+ VERSION = '0.1.1'
9
9
  end
10
10
  end