nmatrix 0.1.0 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/ext/nmatrix/data/complex.h +20 -55
  3. data/ext/nmatrix/data/data.cpp +11 -44
  4. data/ext/nmatrix/data/data.h +174 -311
  5. data/ext/nmatrix/data/meta.h +1 -7
  6. data/ext/nmatrix/data/ruby_object.h +3 -85
  7. data/ext/nmatrix/extconf.rb +2 -73
  8. data/ext/nmatrix/math.cpp +170 -813
  9. data/ext/nmatrix/math/asum.h +2 -25
  10. data/ext/nmatrix/math/{inc.h → cblas_enums.h} +11 -22
  11. data/ext/nmatrix/math/cblas_templates_core.h +507 -0
  12. data/ext/nmatrix/math/gemm.h +2 -32
  13. data/ext/nmatrix/math/gemv.h +1 -35
  14. data/ext/nmatrix/math/getrf.h +21 -6
  15. data/ext/nmatrix/math/getrs.h +0 -8
  16. data/ext/nmatrix/math/imax.h +0 -22
  17. data/ext/nmatrix/math/long_dtype.h +0 -3
  18. data/ext/nmatrix/math/math.h +11 -337
  19. data/ext/nmatrix/math/nrm2.h +2 -23
  20. data/ext/nmatrix/math/rot.h +1 -25
  21. data/ext/nmatrix/math/rotg.h +4 -13
  22. data/ext/nmatrix/math/scal.h +0 -22
  23. data/ext/nmatrix/math/trsm.h +0 -55
  24. data/ext/nmatrix/math/util.h +148 -0
  25. data/ext/nmatrix/nmatrix.cpp +0 -14
  26. data/ext/nmatrix/nmatrix.h +92 -84
  27. data/ext/nmatrix/ruby_constants.cpp +0 -2
  28. data/ext/nmatrix/ruby_constants.h +0 -2
  29. data/ext/nmatrix/ruby_nmatrix.c +86 -45
  30. data/ext/nmatrix/storage/dense/dense.cpp +1 -7
  31. data/ext/nmatrix/storage/storage.h +0 -1
  32. data/ext/nmatrix/ttable_helper.rb +0 -6
  33. data/ext/nmatrix/util/io.cpp +1 -1
  34. data/lib/nmatrix.rb +1 -19
  35. data/lib/nmatrix/blas.rb +33 -11
  36. data/lib/nmatrix/io/market.rb +3 -3
  37. data/lib/nmatrix/lapack_core.rb +181 -0
  38. data/lib/nmatrix/lapack_plugin.rb +44 -0
  39. data/lib/nmatrix/math.rb +382 -131
  40. data/lib/nmatrix/monkeys.rb +2 -3
  41. data/lib/nmatrix/nmatrix.rb +166 -13
  42. data/lib/nmatrix/shortcuts.rb +72 -7
  43. data/lib/nmatrix/version.rb +2 -2
  44. data/spec/00_nmatrix_spec.rb +154 -5
  45. data/spec/02_slice_spec.rb +2 -6
  46. data/spec/03_nmatrix_monkeys_spec.rb +7 -1
  47. data/spec/blas_spec.rb +60 -33
  48. data/spec/homogeneous_spec.rb +10 -10
  49. data/spec/lapack_core_spec.rb +482 -0
  50. data/spec/math_spec.rb +436 -52
  51. data/spec/shortcuts_spec.rb +28 -4
  52. data/spec/spec_helper.rb +14 -2
  53. data/spec/utm5940.mtx +83844 -0
  54. metadata +49 -76
  55. data/.gitignore +0 -27
  56. data/.rspec +0 -2
  57. data/.travis.yml +0 -15
  58. data/CONTRIBUTING.md +0 -82
  59. data/Gemfile +0 -2
  60. data/History.txt +0 -677
  61. data/LICENSE.txt +0 -23
  62. data/Manifest.txt +0 -92
  63. data/README.rdoc +0 -150
  64. data/Rakefile +0 -216
  65. data/ext/nmatrix/data/rational.h +0 -440
  66. data/ext/nmatrix/math/geev.h +0 -82
  67. data/ext/nmatrix/math/ger.h +0 -96
  68. data/ext/nmatrix/math/gesdd.h +0 -80
  69. data/ext/nmatrix/math/gesvd.h +0 -78
  70. data/ext/nmatrix/math/getf2.h +0 -86
  71. data/ext/nmatrix/math/getri.h +0 -108
  72. data/ext/nmatrix/math/potrs.h +0 -129
  73. data/ext/nmatrix/math/swap.h +0 -52
  74. data/lib/nmatrix/lapack.rb +0 -240
  75. data/nmatrix.gemspec +0 -55
  76. data/scripts/mac-brew-gcc.sh +0 -50
  77. data/scripts/mac-mavericks-brew-gcc.sh +0 -22
  78. data/spec/lapack_spec.rb +0 -459
@@ -57,7 +57,7 @@
57
57
  */
58
58
 
59
59
  #ifndef ASUM_H
60
- # define ASUM_H
60
+ #define ASUM_H
61
61
 
62
62
 
63
63
  namespace nm { namespace math {
@@ -72,7 +72,6 @@ namespace nm { namespace math {
72
72
  * double -> double
73
73
  * complex64 -> float or double
74
74
  * complex128 -> double
75
- * rational -> rational
76
75
  */
77
76
  template <typename ReturnDType, typename DType>
78
77
  inline ReturnDType asum(const int N, const DType* X, const int incX) {
@@ -86,27 +85,6 @@ inline ReturnDType asum(const int N, const DType* X, const int incX) {
86
85
  }
87
86
 
88
87
 
89
- #if defined HAVE_CBLAS_H || defined HAVE_ATLAS_CBLAS_H
90
- template <>
91
- inline float asum(const int N, const float* X, const int incX) {
92
- return cblas_sasum(N, X, incX);
93
- }
94
-
95
- template <>
96
- inline double asum(const int N, const double* X, const int incX) {
97
- return cblas_dasum(N, X, incX);
98
- }
99
-
100
- template <>
101
- inline float asum(const int N, const Complex64* X, const int incX) {
102
- return cblas_scasum(N, X, incX);
103
- }
104
-
105
- template <>
106
- inline double asum(const int N, const Complex128* X, const int incX) {
107
- return cblas_dzasum(N, X, incX);
108
- }
109
- #else
110
88
  template <>
111
89
  inline float asum(const int N, const Complex64* X, const int incX) {
112
90
  float sum = 0;
@@ -128,7 +106,6 @@ inline double asum(const int N, const Complex128* X, const int incX) {
128
106
  }
129
107
  return sum;
130
108
  }
131
- #endif
132
109
 
133
110
 
134
111
  template <typename ReturnDType, typename DType>
@@ -140,4 +117,4 @@ inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
140
117
 
141
118
  }} // end of namespace nm::math
142
119
 
143
- #endif // NRM2_H
120
+ #endif // ASUM_H
@@ -9,8 +9,8 @@
9
9
  //
10
10
  // == Copyright Information
11
11
  //
12
- // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
- // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
12
+ // SciRuby is Copyright (c) 2010 - 2015, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2015, John Woods and the Ruby Science Foundation
14
14
  //
15
15
  // Please see LICENSE.txt for additional copyright notices.
16
16
  //
@@ -21,27 +21,16 @@
21
21
  //
22
22
  // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
23
  //
24
- // == inc.h
24
+ // == cblas_enums.h
25
25
  //
26
- // Includes needed for LAPACK, CLAPACK, and CBLAS functions.
26
+ // CBLAS definitions for when CBLAS is not available.
27
27
  //
28
28
 
29
- #ifndef INC_H
30
- # define INC_H
31
-
32
-
33
- extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
34
- #if defined HAVE_CBLAS_H
35
- #include <cblas.h>
36
- #elif defined HAVE_ATLAS_CBLAS_H
37
- #include <atlas/cblas.h>
29
+ #ifndef CBLAS_ENUM_DEFINED_H
30
+ #define CBLAS_ENUM_DEFINED_H
31
+ enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
32
+ enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
33
+ enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
34
+ enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
35
+ enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
38
36
  #endif
39
-
40
- #if defined HAVE_CLAPACK_H
41
- #include <clapack.h>
42
- #elif defined HAVE_ATLAS_CLAPACK_H
43
- #include <atlas/clapack.h>
44
- #endif
45
- }
46
-
47
- #endif // INC_H
@@ -0,0 +1,507 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == cblas_templates_core.h
25
+ //
26
+ // This header files is not used by the main nmatrix gem but has to be stored
27
+ // in this directory so that it can be shared between nmatrix-atlas and
28
+ // nmatrix-lapack.
29
+ //
30
+
31
+ //This is not a normal header file so we don't use an include guard.
32
+ //See ext/nmatrix_atlas/math_atlas/cblas_templates_atlas.h for how
33
+ //to use.
34
+
35
+ //Below are the BLAS functions for which we have internal implementations.
36
+ //The internal implementations are defined in the ext/nmatrix/math directory
37
+ //and are the non-specialized
38
+ //forms of the template functions nm::math::whatever().
39
+ //They are are called below for non-BLAS
40
+ //types in the non-specialized form of the template nm::math::something_else::whatever().
41
+ //The specialized forms call the appropriate cblas functions.
42
+
43
+ //For all functions besides herk, we also define the cblas_whatever() template
44
+ //functions below, which just cast
45
+ //their arguments to the appropriate types.
46
+
47
+ //rotg
48
+ template <typename DType>
49
+ inline void rotg(DType* a, DType* b, DType* c, DType* s) {
50
+ nm::math::rotg(a, b, c, s);
51
+ }
52
+
53
+ template <>
54
+ inline void rotg(float* a, float* b, float* c, float* s) {
55
+ cblas_srotg(a, b, c, s);
56
+ }
57
+
58
+ template <>
59
+ inline void rotg(double* a, double* b, double* c, double* s) {
60
+ cblas_drotg(a, b, c, s);
61
+ }
62
+
63
+ //Complex versions of rot and rotg are available in the ATLAS (and Intel)
64
+ //version of CBLAS, but not part
65
+ //of the reference implementation or OpenBLAS, so we omit them here
66
+ //and fall back to the generic internal implementation.
67
+ //Another options would be to directly call the fortran functions, e.g. ZROTG,
68
+ //which for some reason are a part of the standard.
69
+ //We can still define complex specializations of these functions in an ATLAS-specific
70
+ //header.
71
+
72
+ template <typename DType>
73
+ inline void cblas_rotg(void* a, void* b, void* c, void* s) {
74
+ rotg<DType>(static_cast<DType*>(a), static_cast<DType*>(b), static_cast<DType*>(c), static_cast<DType*>(s));
75
+ }
76
+
77
+ //rot
78
+ template <typename DType, typename CSDType>
79
+ inline void rot(const int N, DType* X, const int incX, DType* Y, const int incY, const CSDType c, const CSDType s) {
80
+ nm::math::rot<DType,CSDType>(N, X, incX, Y, incY, c, s);
81
+ }
82
+
83
+ template <>
84
+ inline void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) {
85
+ cblas_srot(N, X, incX, Y, incY, (float)c, (float)s);
86
+ }
87
+
88
+ template <>
89
+ inline void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) {
90
+ cblas_drot(N, X, incX, Y, incY, c, s);
91
+ }
92
+
93
+ template <typename DType, typename CSDType>
94
+ inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int incY, const void* c, const void* s) {
95
+ rot<DType,CSDType>(N, static_cast<DType*>(X), incX, static_cast<DType*>(Y), incY,
96
+ *static_cast<const CSDType*>(c), *static_cast<const CSDType*>(s));
97
+ }
98
+
99
+ /*
100
+ * Level 1 BLAS routine which sums the absolute values of a vector's contents. If the vector consists of complex values,
101
+ * the routine sums the absolute values of the real and imaginary components as well.
102
+ *
103
+ * So, based on input types, these are the valid return types:
104
+ * int -> int
105
+ * float -> float or double
106
+ * double -> double
107
+ * complex64 -> float or double
108
+ * complex128 -> double
109
+ */
110
+ template <typename ReturnDType, typename DType>
111
+ inline ReturnDType asum(const int N, const DType* X, const int incX) {
112
+ return nm::math::asum<ReturnDType,DType>(N,X,incX);
113
+ }
114
+
115
+
116
+ template <>
117
+ inline float asum(const int N, const float* X, const int incX) {
118
+ return cblas_sasum(N, X, incX);
119
+ }
120
+
121
+ template <>
122
+ inline double asum(const int N, const double* X, const int incX) {
123
+ return cblas_dasum(N, X, incX);
124
+ }
125
+
126
+ template <>
127
+ inline float asum(const int N, const Complex64* X, const int incX) {
128
+ return cblas_scasum(N, X, incX);
129
+ }
130
+
131
+ template <>
132
+ inline double asum(const int N, const Complex128* X, const int incX) {
133
+ return cblas_dzasum(N, X, incX);
134
+ }
135
+
136
+
137
+ template <typename ReturnDType, typename DType>
138
+ inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
139
+ *static_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, static_cast<const DType*>(X), incX );
140
+ }
141
+
142
+ /*
143
+ * Level 1 BLAS routine which returns the 2-norm of an n-vector x.
144
+ #
145
+ * Based on input types, these are the valid return types:
146
+ * int -> int
147
+ * float -> float or double
148
+ * double -> double
149
+ * complex64 -> float or double
150
+ * complex128 -> double
151
+ */
152
+ template <typename ReturnDType, typename DType>
153
+ inline ReturnDType nrm2(const int N, const DType* X, const int incX) {
154
+ return nm::math::nrm2<ReturnDType,DType>(N, X, incX);
155
+ }
156
+
157
+
158
+ template <>
159
+ inline float nrm2(const int N, const float* X, const int incX) {
160
+ return cblas_snrm2(N, X, incX);
161
+ }
162
+
163
+ template <>
164
+ inline double nrm2(const int N, const double* X, const int incX) {
165
+ return cblas_dnrm2(N, X, incX);
166
+ }
167
+
168
+ template <>
169
+ inline float nrm2(const int N, const Complex64* X, const int incX) {
170
+ return cblas_scnrm2(N, X, incX);
171
+ }
172
+
173
+ template <>
174
+ inline double nrm2(const int N, const Complex128* X, const int incX) {
175
+ return cblas_dznrm2(N, X, incX);
176
+ }
177
+
178
+ template <typename ReturnDType, typename DType>
179
+ inline void cblas_nrm2(const int N, const void* X, const int incX, void* result) {
180
+ *static_cast<ReturnDType*>( result ) = nrm2<ReturnDType, DType>( N, static_cast<const DType*>(X), incX );
181
+ }
182
+
183
+ //imax
184
+ template<typename DType>
185
+ inline int imax(const int n, const DType *x, const int incx) {
186
+ return nm::math::imax(n, x, incx);
187
+ }
188
+
189
+ template<>
190
+ inline int imax(const int n, const float* x, const int incx) {
191
+ return cblas_isamax(n, x, incx);
192
+ }
193
+
194
+ template<>
195
+ inline int imax(const int n, const double* x, const int incx) {
196
+ return cblas_idamax(n, x, incx);
197
+ }
198
+
199
+ template<>
200
+ inline int imax(const int n, const Complex64* x, const int incx) {
201
+ return cblas_icamax(n, x, incx);
202
+ }
203
+
204
+ template <>
205
+ inline int imax(const int n, const Complex128* x, const int incx) {
206
+ return cblas_izamax(n, x, incx);
207
+ }
208
+
209
+ template<typename DType>
210
+ inline int cblas_imax(const int n, const void* x, const int incx) {
211
+ return imax<DType>(n, static_cast<const DType*>(x), incx);
212
+ }
213
+
214
+ //scal
215
+ template <typename DType>
216
+ inline void scal(const int n, const DType scalar, DType* x, const int incx) {
217
+ nm::math::scal(n, scalar, x, incx);
218
+ }
219
+
220
+ template <>
221
+ inline void scal(const int n, const float scalar, float* x, const int incx) {
222
+ cblas_sscal(n, scalar, x, incx);
223
+ }
224
+
225
+ template <>
226
+ inline void scal(const int n, const double scalar, double* x, const int incx) {
227
+ cblas_dscal(n, scalar, x, incx);
228
+ }
229
+
230
+ template <>
231
+ inline void scal(const int n, const Complex64 scalar, Complex64* x, const int incx) {
232
+ cblas_cscal(n, &scalar, x, incx);
233
+ }
234
+
235
+ template <>
236
+ inline void scal(const int n, const Complex128 scalar, Complex128* x, const int incx) {
237
+ cblas_zscal(n, &scalar, x, incx);
238
+ }
239
+
240
+ template <typename DType>
241
+ inline void cblas_scal(const int n, const void* scalar, void* x, const int incx) {
242
+ scal<DType>(n, *static_cast<const DType*>(scalar), static_cast<DType*>(x), incx);
243
+ }
244
+
245
+ //gemv
246
+ template <typename DType>
247
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const DType* alpha, const DType* A, const int lda,
248
+ const DType* X, const int incX, const DType* beta, DType* Y, const int incY) {
249
+ return nm::math::gemv(Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
250
+ }
251
+
252
+ template <>
253
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const float* alpha, const float* A, const int lda,
254
+ const float* X, const int incX, const float* beta, float* Y, const int incY) {
255
+ cblas_sgemv(CblasRowMajor, Trans, M, N, *alpha, A, lda, X, incX, *beta, Y, incY);
256
+ return true;
257
+ }
258
+
259
+ template <>
260
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const double* alpha, const double* A, const int lda,
261
+ const double* X, const int incX, const double* beta, double* Y, const int incY) {
262
+ cblas_dgemv(CblasRowMajor, Trans, M, N, *alpha, A, lda, X, incX, *beta, Y, incY);
263
+ return true;
264
+ }
265
+
266
+ template <>
267
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const Complex64* alpha, const Complex64* A, const int lda,
268
+ const Complex64* X, const int incX, const Complex64* beta, Complex64* Y, const int incY) {
269
+ cblas_cgemv(CblasRowMajor, Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
270
+ return true;
271
+ }
272
+
273
+ template <>
274
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const Complex128* alpha, const Complex128* A, const int lda,
275
+ const Complex128* X, const int incX, const Complex128* beta, Complex128* Y, const int incY) {
276
+ cblas_zgemv(CblasRowMajor, Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
277
+ return true;
278
+ }
279
+
280
+ template <typename DType>
281
+ inline static bool cblas_gemv(const enum CBLAS_TRANSPOSE trans,
282
+ const int m, const int n,
283
+ const void* alpha,
284
+ const void* a, const int lda,
285
+ const void* x, const int incx,
286
+ const void* beta,
287
+ void* y, const int incy)
288
+ {
289
+ return gemv<DType>(trans,
290
+ m, n, static_cast<const DType*>(alpha),
291
+ static_cast<const DType*>(a), lda,
292
+ static_cast<const DType*>(x), incx, static_cast<const DType*>(beta),
293
+ static_cast<DType*>(y), incy);
294
+ }
295
+
296
+ //gemm
297
+ template <typename DType>
298
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
299
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
300
+ {
301
+ nm::math::gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
302
+ }
303
+
304
+ template <>
305
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
306
+ const float* alpha, const float* A, const int lda, const float* B, const int ldb, const float* beta, float* C, const int ldc) {
307
+ cblas_sgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
308
+ }
309
+
310
+ template <>
311
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
312
+ const double* alpha, const double* A, const int lda, const double* B, const int ldb, const double* beta, double* C, const int ldc) {
313
+ cblas_dgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
314
+ }
315
+
316
+ template <>
317
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
318
+ const Complex64* alpha, const Complex64* A, const int lda, const Complex64* B, const int ldb, const Complex64* beta, Complex64* C, const int ldc) {
319
+ cblas_cgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
320
+ }
321
+
322
+ template <>
323
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
324
+ const Complex128* alpha, const Complex128* A, const int lda, const Complex128* B, const int ldb, const Complex128* beta, Complex128* C, const int ldc) {
325
+ cblas_zgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
326
+ }
327
+
328
+ template <typename DType>
329
+ inline static void cblas_gemm(const enum CBLAS_ORDER order,
330
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
331
+ int m, int n, int k,
332
+ void* alpha,
333
+ void* a, int lda,
334
+ void* b, int ldb,
335
+ void* beta,
336
+ void* c, int ldc)
337
+ {
338
+ gemm<DType>(order, trans_a, trans_b, m, n, k, static_cast<DType*>(alpha),
339
+ static_cast<DType*>(a), lda,
340
+ static_cast<DType*>(b), ldb, static_cast<DType*>(beta),
341
+ static_cast<DType*>(c), ldc);
342
+ }
343
+
344
+ //trsm
345
+ template <typename DType, typename = typename std::enable_if<!std::is_integral<DType>::value>::type>
346
+ inline void trsm(const enum CBLAS_ORDER order,
347
+ const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
348
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
349
+ const int m, const int n, const DType alpha, const DType* a,
350
+ const int lda, DType* b, const int ldb)
351
+ {
352
+ nm::math::trsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
353
+ }
354
+
355
+ template <>
356
+ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
357
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
358
+ const int m, const int n, const float alpha, const float* a,
359
+ const int lda, float* b, const int ldb)
360
+ {
361
+ cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
362
+ }
363
+
364
+ template <>
365
+ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
366
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
367
+ const int m, const int n, const double alpha, const double* a,
368
+ const int lda, double* b, const int ldb)
369
+ {
370
+ cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
371
+ }
372
+
373
+
374
+ template <>
375
+ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
376
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
377
+ const int m, const int n, const Complex64 alpha, const Complex64* a,
378
+ const int lda, Complex64* b, const int ldb)
379
+ {
380
+ cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, &alpha, a, lda, b, ldb);
381
+ }
382
+
383
+ template <>
384
+ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
385
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
386
+ const int m, const int n, const Complex128 alpha, const Complex128* a,
387
+ const int lda, Complex128* b, const int ldb)
388
+ {
389
+ cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, &alpha, a, lda, b, ldb);
390
+ }
391
+
392
+ template <typename DType>
393
+ inline static void cblas_trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
394
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
395
+ const int m, const int n, const void* alpha, const void* a,
396
+ const int lda, void* b, const int ldb)
397
+ {
398
+ trsm<DType>(order, side, uplo, trans_a, diag, m, n, *static_cast<const DType*>(alpha),
399
+ static_cast<const DType*>(a), lda, static_cast<DType*>(b), ldb);
400
+ }
401
+
402
+ //Below are BLAS functions that we don't have an internal implementation for.
403
+ //In this case the non-specialized form just raises an error.
404
+
405
+ //syrk
406
+ template <typename DType>
407
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
408
+ const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
409
+ rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes");
410
+ }
411
+
412
+ template <>
413
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
414
+ const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) {
415
+ cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
416
+ }
417
+
418
+ template <>
419
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
420
+ const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) {
421
+ cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
422
+ }
423
+
424
+ template <>
425
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
426
+ const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
427
+ cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
428
+ }
429
+
430
+ template <>
431
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
432
+ const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
433
+ cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
434
+ }
435
+
436
+ template <typename DType>
437
+ inline static void cblas_syrk(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const enum CBLAS_TRANSPOSE trans,
438
+ const int n, const int k, const void* alpha,
439
+ const void* A, const int lda, const void* beta, void* C, const int ldc)
440
+ {
441
+ syrk<DType>(order, uplo, trans, n, k, static_cast<const DType*>(alpha),
442
+ static_cast<const DType*>(A), lda, static_cast<const DType*>(beta), static_cast<DType*>(C), ldc);
443
+ }
444
+
445
+ //herk
446
+ template <typename DType>
447
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
448
+ const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
449
+ rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes");
450
+ }
451
+
452
+ template <>
453
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
454
+ const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
455
+ cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
456
+ }
457
+
458
+ template <>
459
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
460
+ const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
461
+ cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
462
+ }
463
+
464
+ //trmm
465
+ template <typename DType>
466
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
467
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha,
468
+ const DType* A, const int lda, DType* B, const int ldb) {
469
+ rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes");
470
+ }
471
+
472
+ template <>
473
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
474
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha,
475
+ const float* A, const int lda, float* B, const int ldb) {
476
+ cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
477
+ }
478
+
479
+ template <>
480
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
481
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha,
482
+ const double* A, const int lda, double* B, const int ldb) {
483
+ cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
484
+ }
485
+
486
+ template <>
487
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
488
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha,
489
+ const Complex64* A, const int lda, Complex64* B, const int ldb) {
490
+ cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
491
+ }
492
+
493
+ template <>
494
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
495
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha,
496
+ const Complex128* A, const int lda, Complex128* B, const int ldb) {
497
+ cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
498
+ }
499
+
500
+ template <typename DType>
501
+ inline static void cblas_trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
502
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const void* alpha,
503
+ const void* A, const int lda, void* B, const int ldb)
504
+ {
505
+ trmm<DType>(order, side, uplo, ta, diag, m, n, static_cast<const DType*>(alpha),
506
+ static_cast<const DType*>(A), lda, static_cast<DType*>(B), ldb);
507
+ }