nmatrix-gemv 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +29 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +14 -0
  5. data/Gemfile +7 -0
  6. data/README.md +29 -0
  7. data/Rakefile +225 -0
  8. data/ext/nmatrix_gemv/binary_format.txt +53 -0
  9. data/ext/nmatrix_gemv/data/complex.h +399 -0
  10. data/ext/nmatrix_gemv/data/data.cpp +298 -0
  11. data/ext/nmatrix_gemv/data/data.h +771 -0
  12. data/ext/nmatrix_gemv/data/meta.h +70 -0
  13. data/ext/nmatrix_gemv/data/rational.h +436 -0
  14. data/ext/nmatrix_gemv/data/ruby_object.h +471 -0
  15. data/ext/nmatrix_gemv/extconf.rb +254 -0
  16. data/ext/nmatrix_gemv/math.cpp +1639 -0
  17. data/ext/nmatrix_gemv/math/asum.h +143 -0
  18. data/ext/nmatrix_gemv/math/geev.h +82 -0
  19. data/ext/nmatrix_gemv/math/gemm.h +271 -0
  20. data/ext/nmatrix_gemv/math/gemv.h +212 -0
  21. data/ext/nmatrix_gemv/math/ger.h +96 -0
  22. data/ext/nmatrix_gemv/math/gesdd.h +80 -0
  23. data/ext/nmatrix_gemv/math/gesvd.h +78 -0
  24. data/ext/nmatrix_gemv/math/getf2.h +86 -0
  25. data/ext/nmatrix_gemv/math/getrf.h +240 -0
  26. data/ext/nmatrix_gemv/math/getri.h +108 -0
  27. data/ext/nmatrix_gemv/math/getrs.h +129 -0
  28. data/ext/nmatrix_gemv/math/idamax.h +86 -0
  29. data/ext/nmatrix_gemv/math/inc.h +47 -0
  30. data/ext/nmatrix_gemv/math/laswp.h +165 -0
  31. data/ext/nmatrix_gemv/math/long_dtype.h +52 -0
  32. data/ext/nmatrix_gemv/math/math.h +1069 -0
  33. data/ext/nmatrix_gemv/math/nrm2.h +181 -0
  34. data/ext/nmatrix_gemv/math/potrs.h +129 -0
  35. data/ext/nmatrix_gemv/math/rot.h +141 -0
  36. data/ext/nmatrix_gemv/math/rotg.h +115 -0
  37. data/ext/nmatrix_gemv/math/scal.h +73 -0
  38. data/ext/nmatrix_gemv/math/swap.h +73 -0
  39. data/ext/nmatrix_gemv/math/trsm.h +387 -0
  40. data/ext/nmatrix_gemv/nm_memory.h +60 -0
  41. data/ext/nmatrix_gemv/nmatrix_gemv.cpp +90 -0
  42. data/ext/nmatrix_gemv/nmatrix_gemv.h +374 -0
  43. data/ext/nmatrix_gemv/ruby_constants.cpp +153 -0
  44. data/ext/nmatrix_gemv/ruby_constants.h +107 -0
  45. data/ext/nmatrix_gemv/ruby_nmatrix.c +84 -0
  46. data/ext/nmatrix_gemv/ttable_helper.rb +122 -0
  47. data/ext/nmatrix_gemv/types.h +54 -0
  48. data/ext/nmatrix_gemv/util/util.h +78 -0
  49. data/lib/nmatrix-gemv.rb +43 -0
  50. data/lib/nmatrix_gemv/blas.rb +85 -0
  51. data/lib/nmatrix_gemv/nmatrix_gemv.rb +35 -0
  52. data/lib/nmatrix_gemv/rspec.rb +75 -0
  53. data/nmatrix-gemv.gemspec +31 -0
  54. data/spec/blas_spec.rb +154 -0
  55. data/spec/spec_helper.rb +128 -0
  56. metadata +186 -0
@@ -0,0 +1,143 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == asum.h
25
+ //
26
+ // CBLAS asum function
27
+ //
28
+
29
+ /*
30
+ * Automatically Tuned Linear Algebra Software v3.8.4
31
+ * (C) Copyright 1999 R. Clint Whaley
32
+ *
33
+ * Redistribution and use in source and binary forms, with or without
34
+ * modification, are permitted provided that the following conditions
35
+ * are met:
36
+ * 1. Redistributions of source code must retain the above copyright
37
+ * notice, this list of conditions and the following disclaimer.
38
+ * 2. Redistributions in binary form must reproduce the above copyright
39
+ * notice, this list of conditions, and the following disclaimer in the
40
+ * documentation and/or other materials provided with the distribution.
41
+ * 3. The name of the ATLAS group or the names of its contributers may
42
+ * not be used to endorse or promote products derived from this
43
+ * software without specific written permission.
44
+ *
45
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
46
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
47
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
48
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
49
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
50
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
51
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
52
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
53
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
54
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
55
+ * POSSIBILITY OF SUCH DAMAGE.
56
+ *
57
+ */
58
+
59
+ #ifndef ASUM_H
60
+ # define ASUM_H
61
+
62
+
63
+ namespace nm { namespace math {
64
+
65
+ /*
66
+ * Level 1 BLAS routine which sums the absolute values of a vector's contents. If the vector consists of complex values,
67
+ * the routine sums the absolute values of the real and imaginary components as well.
68
+ *
69
+ * So, based on input types, these are the valid return types:
70
+ * int -> int
71
+ * float -> float or double
72
+ * double -> double
73
+ * complex64 -> float or double
74
+ * complex128 -> double
75
+ * rational -> rational
76
+ */
77
+ template <typename ReturnDType, typename DType>
78
+ inline ReturnDType asum(const int N, const DType* X, const int incX) {
79
+ ReturnDType sum = 0;
80
+ if ((N > 0) && (incX > 0)) {
81
+ for (int i = 0; i < N; ++i) {
82
+ sum += std::abs(X[i*incX]);
83
+ }
84
+ }
85
+ return sum;
86
+ }
87
+
88
+
89
+ #if defined HAVE_CBLAS_H || defined HAVE_ATLAS_CBLAS_H
90
+ template <>
91
+ inline float asum(const int N, const float* X, const int incX) {
92
+ return cblas_sasum(N, X, incX);
93
+ }
94
+
95
+ template <>
96
+ inline double asum(const int N, const double* X, const int incX) {
97
+ return cblas_dasum(N, X, incX);
98
+ }
99
+
100
+ template <>
101
+ inline float asum(const int N, const Complex64* X, const int incX) {
102
+ return cblas_scasum(N, X, incX);
103
+ }
104
+
105
+ template <>
106
+ inline double asum(const int N, const Complex128* X, const int incX) {
107
+ return cblas_dzasum(N, X, incX);
108
+ }
109
+ #else
110
+ template <>
111
+ inline float asum(const int N, const Complex64* X, const int incX) {
112
+ float sum = 0;
113
+ if ((N > 0) && (incX > 0)) {
114
+ for (int i = 0; i < N; ++i) {
115
+ sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
116
+ }
117
+ }
118
+ return sum;
119
+ }
120
+
121
+ template <>
122
+ inline double asum(const int N, const Complex128* X, const int incX) {
123
+ double sum = 0;
124
+ if ((N > 0) && (incX > 0)) {
125
+ for (int i = 0; i < N; ++i) {
126
+ sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
127
+ }
128
+ }
129
+ return sum;
130
+ }
131
+ #endif
132
+
133
+
134
+ template <typename ReturnDType, typename DType>
135
+ inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
136
+ *reinterpret_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
137
+ }
138
+
139
+
140
+
141
+ }} // end of namespace nm::math
142
+
143
+ #endif // NRM2_H
@@ -0,0 +1,82 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == geev.h
25
+ //
26
+ // Header file for interface with LAPACK's xGEEV functions.
27
+ //
28
+
29
+ #ifndef GEEV_H
30
+ # define GEEV_H
31
+
32
+ extern "C" {
33
+ void sgeev_(char* jobvl, char* jobvr, int* n, float* a, int* lda, float* wr, float* wi, float* vl, int* ldvl, float* vr, int* ldvr, float* work, int* lwork, int* info);
34
+ void dgeev_(char* jobvl, char* jobvr, int* n, double* a, int* lda, double* wr, double* wi, double* vl, int* ldvl, double* vr, int* ldvr, double* work, int* lwork, int* info);
35
+ void cgeev_(char* jobvl, char* jobvr, int* n, nm::Complex64* a, int* lda, nm::Complex64* w, nm::Complex64* vl, int* ldvl, nm::Complex64* vr, int* ldvr, nm::Complex64* work, int* lwork, float* rwork, int* info);
36
+ void zgeev_(char* jobvl, char* jobvr, int* n, nm::Complex128* a, int* lda, nm::Complex128* w, nm::Complex128* vl, int* ldvl, nm::Complex128* vr, int* ldvr, nm::Complex128* work, int* lwork, double* rwork, int* info);
37
+ }
38
+
39
+ namespace nm { namespace math {
40
+
41
+ template <typename DType, typename CType> // wr
42
+ inline int geev(char jobvl, char jobvr, int n, DType* a, int lda, DType* w, DType* wi, DType* vl, int ldvl, DType* vr, int ldvr, DType* work, int lwork, CType* rwork) {
43
+ rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
44
+ return -1;
45
+ }
46
+
47
+ template <>
48
+ inline int geev(char jobvl, char jobvr, int n, float* a, int lda, float* w, float* wi, float* vl, int ldvl, float* vr, int ldvr, float* work, int lwork, float* rwork) {
49
+ int info;
50
+ sgeev_(&jobvl, &jobvr, &n, a, &lda, w, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info);
51
+ return info;
52
+ }
53
+
54
+ template <>
55
+ inline int geev(char jobvl, char jobvr, int n, double* a, int lda, double* w, double* wi, double* vl, int ldvl, double* vr, int ldvr, double* work, int lwork, double* rwork) {
56
+ int info;
57
+ dgeev_(&jobvl, &jobvr, &n, a, &lda, w, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info);
58
+ return info;
59
+ }
60
+
61
+ template <>
62
+ inline int geev(char jobvl, char jobvr, int n, Complex64* a, int lda, Complex64* w, Complex64* wi, Complex64* vl, int ldvl, Complex64* vr, int ldvr, Complex64* work, int lwork, float* rwork) {
63
+ int info;
64
+ cgeev_(&jobvl, &jobvr, &n, a, &lda, w, vl, &ldvl, vr, &ldvr, work, &lwork, rwork, &info);
65
+ return info;
66
+ }
67
+
68
+ template <>
69
+ inline int geev(char jobvl, char jobvr, int n, Complex128* a, int lda, Complex128* w, Complex128* wi, Complex128* vl, int ldvl, Complex128* vr, int ldvr, Complex128* work, int lwork, double* rwork) {
70
+ int info;
71
+ zgeev_(&jobvl, &jobvr, &n, a, &lda, w, vl, &ldvl, vr, &ldvr, work, &lwork, rwork, &info);
72
+ return info;
73
+ }
74
+
75
+ template <typename DType, typename CType>
76
+ inline int lapack_geev(char jobvl, char jobvr, int n, void* a, int lda, void* w, void* wi, void* vl, int ldvl, void* vr, int ldvr, void* work, int lwork, void* rwork) {
77
+ return geev<DType,CType>(jobvl, jobvr, n, reinterpret_cast<DType*>(a), lda, reinterpret_cast<DType*>(w), reinterpret_cast<DType*>(wi), reinterpret_cast<DType*>(vl), ldvl, reinterpret_cast<DType*>(vr), ldvr, reinterpret_cast<DType*>(work), lwork, reinterpret_cast<CType*>(rwork));
78
+ }
79
+
80
+ }} // end nm::math
81
+
82
+ #endif // GEEV_H
@@ -0,0 +1,271 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gemm.h
25
+ //
26
+ // Header file for interface with ATLAS's CBLAS gemm functions and
27
+ // native templated version of LAPACK's gemm function.
28
+ //
29
+
30
+ #ifndef GEMM_H
31
+ # define GEMM_H
32
+
33
+ extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
34
+ #if defined HAVE_CBLAS_H
35
+ #include <cblas.h>
36
+ #elif defined HAVE_ATLAS_CBLAS_H
37
+ #include <atlas/cblas.h>
38
+ #endif
39
+ }
40
+
41
+
42
+ namespace nm { namespace math {
43
+ /*
44
+ * GEneral Matrix Multiplication: based on dgemm.f from Netlib.
45
+ *
46
+ * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
47
+ *
48
+ * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
49
+ *
50
+ * This version throws no errors. Use gemm<DType> instead for error checking.
51
+ */
52
+ template <typename DType>
53
+ inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
54
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
55
+ {
56
+
57
+ typename LongDType<DType>::type temp;
58
+
59
+ // Quick return if possible
60
+ if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;
61
+
62
+ // For alpha = 0
63
+ if (*alpha == 0) {
64
+ if (*beta == 0) {
65
+ for (int j = 0; j < N; ++j)
66
+ for (int i = 0; i < M; ++i) {
67
+ C[i+j*ldc] = 0;
68
+ }
69
+ } else {
70
+ for (int j = 0; j < N; ++j)
71
+ for (int i = 0; i < M; ++i) {
72
+ C[i+j*ldc] *= *beta;
73
+ }
74
+ }
75
+ return;
76
+ }
77
+
78
+ // Start the operations
79
+ if (TransB == CblasNoTrans) {
80
+ if (TransA == CblasNoTrans) {
81
+ // C = alpha*A*B+beta*C
82
+ for (int j = 0; j < N; ++j) {
83
+ if (*beta == 0) {
84
+ for (int i = 0; i < M; ++i) {
85
+ C[i+j*ldc] = 0;
86
+ }
87
+ } else if (*beta != 1) {
88
+ for (int i = 0; i < M; ++i) {
89
+ C[i+j*ldc] *= *beta;
90
+ }
91
+ }
92
+
93
+ for (int l = 0; l < K; ++l) {
94
+ if (B[l+j*ldb] != 0) {
95
+ temp = *alpha * B[l+j*ldb];
96
+ for (int i = 0; i < M; ++i) {
97
+ C[i+j*ldc] += A[i+l*lda] * temp;
98
+ }
99
+ }
100
+ }
101
+ }
102
+
103
+ } else {
104
+
105
+ // C = alpha*A**DType*B + beta*C
106
+ for (int j = 0; j < N; ++j) {
107
+ for (int i = 0; i < M; ++i) {
108
+ temp = 0;
109
+ for (int l = 0; l < K; ++l) {
110
+ temp += A[l+i*lda] * B[l+j*ldb];
111
+ }
112
+
113
+ if (*beta == 0) {
114
+ C[i+j*ldc] = *alpha*temp;
115
+ } else {
116
+ C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
117
+ }
118
+ }
119
+ }
120
+
121
+ }
122
+
123
+ } else if (TransA == CblasNoTrans) {
124
+
125
+ // C = alpha*A*B**T + beta*C
126
+ for (int j = 0; j < N; ++j) {
127
+ if (*beta == 0) {
128
+ for (int i = 0; i < M; ++i) {
129
+ C[i+j*ldc] = 0;
130
+ }
131
+ } else if (*beta != 1) {
132
+ for (int i = 0; i < M; ++i) {
133
+ C[i+j*ldc] *= *beta;
134
+ }
135
+ }
136
+
137
+ for (int l = 0; l < K; ++l) {
138
+ if (B[j+l*ldb] != 0) {
139
+ temp = *alpha * B[j+l*ldb];
140
+ for (int i = 0; i < M; ++i) {
141
+ C[i+j*ldc] += A[i+l*lda] * temp;
142
+ }
143
+ }
144
+ }
145
+
146
+ }
147
+
148
+ } else {
149
+
150
+ // C = alpha*A**DType*B**T + beta*C
151
+ for (int j = 0; j < N; ++j) {
152
+ for (int i = 0; i < M; ++i) {
153
+ temp = 0;
154
+ for (int l = 0; l < K; ++l) {
155
+ temp += A[l+i*lda] * B[j+l*ldb];
156
+ }
157
+
158
+ if (*beta == 0) {
159
+ C[i+j*ldc] = *alpha*temp;
160
+ } else {
161
+ C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
162
+ }
163
+ }
164
+ }
165
+
166
+ }
167
+
168
+ return;
169
+ }
170
+
171
+
172
+
173
+ template <typename DType>
174
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
175
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
176
+ {
177
+ if (Order == CblasRowMajor) {
178
+ if (TransA == CblasNoTrans) {
179
+ if (lda < std::max(K,1)) {
180
+ rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
181
+ }
182
+ } else {
183
+ if (lda < std::max(M,1)) { // && TransA == CblasTrans
184
+ rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
185
+ }
186
+ }
187
+
188
+ if (TransB == CblasNoTrans) {
189
+ if (ldb < std::max(N,1)) {
190
+ rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
191
+ }
192
+ } else {
193
+ if (ldb < std::max(K,1)) {
194
+ rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
195
+ }
196
+ }
197
+
198
+ if (ldc < std::max(N,1)) {
199
+ rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
200
+ }
201
+ } else { // CblasColMajor
202
+ if (TransA == CblasNoTrans) {
203
+ if (lda < std::max(M,1)) {
204
+ rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
205
+ }
206
+ } else {
207
+ if (lda < std::max(K,1)) { // && TransA == CblasTrans
208
+ rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
209
+ }
210
+ }
211
+
212
+ if (TransB == CblasNoTrans) {
213
+ if (ldb < std::max(K,1)) {
214
+ rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
215
+ }
216
+ } else {
217
+ if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
218
+ rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
219
+ }
220
+ }
221
+
222
+ if (ldc < std::max(M,1)) {
223
+ rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
224
+ }
225
+ }
226
+
227
+ /*
228
+ * Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
229
+ * we copy C and then subtract to preserve asymmetry.
230
+ */
231
+
232
+ if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
233
+ rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
234
+ /*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
235
+ syreflect(CblasUpper, N, C, ldc);
236
+ */
237
+ }
238
+
239
+ if (Order == CblasRowMajor) gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
240
+ else gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
241
+
242
+ }
243
+
244
+
245
+ template <>
246
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
247
+ const float* alpha, const float* A, const int lda, const float* B, const int ldb, const float* beta, float* C, const int ldc) {
248
+ cblas_sgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
249
+ }
250
+
251
+ template <>
252
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
253
+ const double* alpha, const double* A, const int lda, const double* B, const int ldb, const double* beta, double* C, const int ldc) {
254
+ cblas_dgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
255
+ }
256
+
257
+ template <>
258
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
259
+ const Complex64* alpha, const Complex64* A, const int lda, const Complex64* B, const int ldb, const Complex64* beta, Complex64* C, const int ldc) {
260
+ cblas_cgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
261
+ }
262
+
263
+ template <>
264
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
265
+ const Complex128* alpha, const Complex128* A, const int lda, const Complex128* B, const int ldb, const Complex128* beta, Complex128* C, const int ldc) {
266
+ cblas_zgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
267
+ }
268
+
269
+ }} // end of namespace nm::math
270
+
271
+ #endif // GEMM_H