nmatrix-gemv 0.0.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +29 -0
- data/.rspec +2 -0
- data/.travis.yml +14 -0
- data/Gemfile +7 -0
- data/README.md +29 -0
- data/Rakefile +225 -0
- data/ext/nmatrix_gemv/binary_format.txt +53 -0
- data/ext/nmatrix_gemv/data/complex.h +399 -0
- data/ext/nmatrix_gemv/data/data.cpp +298 -0
- data/ext/nmatrix_gemv/data/data.h +771 -0
- data/ext/nmatrix_gemv/data/meta.h +70 -0
- data/ext/nmatrix_gemv/data/rational.h +436 -0
- data/ext/nmatrix_gemv/data/ruby_object.h +471 -0
- data/ext/nmatrix_gemv/extconf.rb +254 -0
- data/ext/nmatrix_gemv/math.cpp +1639 -0
- data/ext/nmatrix_gemv/math/asum.h +143 -0
- data/ext/nmatrix_gemv/math/geev.h +82 -0
- data/ext/nmatrix_gemv/math/gemm.h +271 -0
- data/ext/nmatrix_gemv/math/gemv.h +212 -0
- data/ext/nmatrix_gemv/math/ger.h +96 -0
- data/ext/nmatrix_gemv/math/gesdd.h +80 -0
- data/ext/nmatrix_gemv/math/gesvd.h +78 -0
- data/ext/nmatrix_gemv/math/getf2.h +86 -0
- data/ext/nmatrix_gemv/math/getrf.h +240 -0
- data/ext/nmatrix_gemv/math/getri.h +108 -0
- data/ext/nmatrix_gemv/math/getrs.h +129 -0
- data/ext/nmatrix_gemv/math/idamax.h +86 -0
- data/ext/nmatrix_gemv/math/inc.h +47 -0
- data/ext/nmatrix_gemv/math/laswp.h +165 -0
- data/ext/nmatrix_gemv/math/long_dtype.h +52 -0
- data/ext/nmatrix_gemv/math/math.h +1069 -0
- data/ext/nmatrix_gemv/math/nrm2.h +181 -0
- data/ext/nmatrix_gemv/math/potrs.h +129 -0
- data/ext/nmatrix_gemv/math/rot.h +141 -0
- data/ext/nmatrix_gemv/math/rotg.h +115 -0
- data/ext/nmatrix_gemv/math/scal.h +73 -0
- data/ext/nmatrix_gemv/math/swap.h +73 -0
- data/ext/nmatrix_gemv/math/trsm.h +387 -0
- data/ext/nmatrix_gemv/nm_memory.h +60 -0
- data/ext/nmatrix_gemv/nmatrix_gemv.cpp +90 -0
- data/ext/nmatrix_gemv/nmatrix_gemv.h +374 -0
- data/ext/nmatrix_gemv/ruby_constants.cpp +153 -0
- data/ext/nmatrix_gemv/ruby_constants.h +107 -0
- data/ext/nmatrix_gemv/ruby_nmatrix.c +84 -0
- data/ext/nmatrix_gemv/ttable_helper.rb +122 -0
- data/ext/nmatrix_gemv/types.h +54 -0
- data/ext/nmatrix_gemv/util/util.h +78 -0
- data/lib/nmatrix-gemv.rb +43 -0
- data/lib/nmatrix_gemv/blas.rb +85 -0
- data/lib/nmatrix_gemv/nmatrix_gemv.rb +35 -0
- data/lib/nmatrix_gemv/rspec.rb +75 -0
- data/nmatrix-gemv.gemspec +31 -0
- data/spec/blas_spec.rb +154 -0
- data/spec/spec_helper.rb +128 -0
- metadata +186 -0
@@ -0,0 +1,143 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == asum.h
|
25
|
+
//
|
26
|
+
// CBLAS asum function
|
27
|
+
//
|
28
|
+
|
29
|
+
/*
|
30
|
+
* Automatically Tuned Linear Algebra Software v3.8.4
|
31
|
+
* (C) Copyright 1999 R. Clint Whaley
|
32
|
+
*
|
33
|
+
* Redistribution and use in source and binary forms, with or without
|
34
|
+
* modification, are permitted provided that the following conditions
|
35
|
+
* are met:
|
36
|
+
* 1. Redistributions of source code must retain the above copyright
|
37
|
+
* notice, this list of conditions and the following disclaimer.
|
38
|
+
* 2. Redistributions in binary form must reproduce the above copyright
|
39
|
+
* notice, this list of conditions, and the following disclaimer in the
|
40
|
+
* documentation and/or other materials provided with the distribution.
|
41
|
+
* 3. The name of the ATLAS group or the names of its contributers may
|
42
|
+
* not be used to endorse or promote products derived from this
|
43
|
+
* software without specific written permission.
|
44
|
+
*
|
45
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
46
|
+
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
47
|
+
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
48
|
+
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
|
49
|
+
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
50
|
+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
51
|
+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
52
|
+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
53
|
+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
54
|
+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
55
|
+
* POSSIBILITY OF SUCH DAMAGE.
|
56
|
+
*
|
57
|
+
*/
|
58
|
+
|
59
|
+
#ifndef ASUM_H
|
60
|
+
# define ASUM_H
|
61
|
+
|
62
|
+
|
63
|
+
namespace nm { namespace math {
|
64
|
+
|
65
|
+
/*
|
66
|
+
* Level 1 BLAS routine which sums the absolute values of a vector's contents. If the vector consists of complex values,
|
67
|
+
* the routine sums the absolute values of the real and imaginary components as well.
|
68
|
+
*
|
69
|
+
* So, based on input types, these are the valid return types:
|
70
|
+
* int -> int
|
71
|
+
* float -> float or double
|
72
|
+
* double -> double
|
73
|
+
* complex64 -> float or double
|
74
|
+
* complex128 -> double
|
75
|
+
* rational -> rational
|
76
|
+
*/
|
77
|
+
template <typename ReturnDType, typename DType>
|
78
|
+
inline ReturnDType asum(const int N, const DType* X, const int incX) {
|
79
|
+
ReturnDType sum = 0;
|
80
|
+
if ((N > 0) && (incX > 0)) {
|
81
|
+
for (int i = 0; i < N; ++i) {
|
82
|
+
sum += std::abs(X[i*incX]);
|
83
|
+
}
|
84
|
+
}
|
85
|
+
return sum;
|
86
|
+
}
|
87
|
+
|
88
|
+
|
89
|
+
#if defined HAVE_CBLAS_H || defined HAVE_ATLAS_CBLAS_H
|
90
|
+
template <>
|
91
|
+
inline float asum(const int N, const float* X, const int incX) {
|
92
|
+
return cblas_sasum(N, X, incX);
|
93
|
+
}
|
94
|
+
|
95
|
+
template <>
|
96
|
+
inline double asum(const int N, const double* X, const int incX) {
|
97
|
+
return cblas_dasum(N, X, incX);
|
98
|
+
}
|
99
|
+
|
100
|
+
template <>
|
101
|
+
inline float asum(const int N, const Complex64* X, const int incX) {
|
102
|
+
return cblas_scasum(N, X, incX);
|
103
|
+
}
|
104
|
+
|
105
|
+
template <>
|
106
|
+
inline double asum(const int N, const Complex128* X, const int incX) {
|
107
|
+
return cblas_dzasum(N, X, incX);
|
108
|
+
}
|
109
|
+
#else
|
110
|
+
template <>
|
111
|
+
inline float asum(const int N, const Complex64* X, const int incX) {
|
112
|
+
float sum = 0;
|
113
|
+
if ((N > 0) && (incX > 0)) {
|
114
|
+
for (int i = 0; i < N; ++i) {
|
115
|
+
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
|
116
|
+
}
|
117
|
+
}
|
118
|
+
return sum;
|
119
|
+
}
|
120
|
+
|
121
|
+
template <>
|
122
|
+
inline double asum(const int N, const Complex128* X, const int incX) {
|
123
|
+
double sum = 0;
|
124
|
+
if ((N > 0) && (incX > 0)) {
|
125
|
+
for (int i = 0; i < N; ++i) {
|
126
|
+
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
|
127
|
+
}
|
128
|
+
}
|
129
|
+
return sum;
|
130
|
+
}
|
131
|
+
#endif
|
132
|
+
|
133
|
+
|
134
|
+
template <typename ReturnDType, typename DType>
|
135
|
+
inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
|
136
|
+
*reinterpret_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
|
137
|
+
}
|
138
|
+
|
139
|
+
|
140
|
+
|
141
|
+
}} // end of namespace nm::math
|
142
|
+
|
143
|
+
#endif // NRM2_H
|
@@ -0,0 +1,82 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == geev.h
|
25
|
+
//
|
26
|
+
// Header file for interface with LAPACK's xGEEV functions.
|
27
|
+
//
|
28
|
+
|
29
|
+
#ifndef GEEV_H
|
30
|
+
# define GEEV_H
|
31
|
+
|
32
|
+
extern "C" {
|
33
|
+
void sgeev_(char* jobvl, char* jobvr, int* n, float* a, int* lda, float* wr, float* wi, float* vl, int* ldvl, float* vr, int* ldvr, float* work, int* lwork, int* info);
|
34
|
+
void dgeev_(char* jobvl, char* jobvr, int* n, double* a, int* lda, double* wr, double* wi, double* vl, int* ldvl, double* vr, int* ldvr, double* work, int* lwork, int* info);
|
35
|
+
void cgeev_(char* jobvl, char* jobvr, int* n, nm::Complex64* a, int* lda, nm::Complex64* w, nm::Complex64* vl, int* ldvl, nm::Complex64* vr, int* ldvr, nm::Complex64* work, int* lwork, float* rwork, int* info);
|
36
|
+
void zgeev_(char* jobvl, char* jobvr, int* n, nm::Complex128* a, int* lda, nm::Complex128* w, nm::Complex128* vl, int* ldvl, nm::Complex128* vr, int* ldvr, nm::Complex128* work, int* lwork, double* rwork, int* info);
|
37
|
+
}
|
38
|
+
|
39
|
+
namespace nm { namespace math {
|
40
|
+
|
41
|
+
template <typename DType, typename CType> // wr
|
42
|
+
inline int geev(char jobvl, char jobvr, int n, DType* a, int lda, DType* w, DType* wi, DType* vl, int ldvl, DType* vr, int ldvr, DType* work, int lwork, CType* rwork) {
|
43
|
+
rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
|
44
|
+
return -1;
|
45
|
+
}
|
46
|
+
|
47
|
+
template <>
|
48
|
+
inline int geev(char jobvl, char jobvr, int n, float* a, int lda, float* w, float* wi, float* vl, int ldvl, float* vr, int ldvr, float* work, int lwork, float* rwork) {
|
49
|
+
int info;
|
50
|
+
sgeev_(&jobvl, &jobvr, &n, a, &lda, w, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info);
|
51
|
+
return info;
|
52
|
+
}
|
53
|
+
|
54
|
+
template <>
|
55
|
+
inline int geev(char jobvl, char jobvr, int n, double* a, int lda, double* w, double* wi, double* vl, int ldvl, double* vr, int ldvr, double* work, int lwork, double* rwork) {
|
56
|
+
int info;
|
57
|
+
dgeev_(&jobvl, &jobvr, &n, a, &lda, w, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info);
|
58
|
+
return info;
|
59
|
+
}
|
60
|
+
|
61
|
+
template <>
|
62
|
+
inline int geev(char jobvl, char jobvr, int n, Complex64* a, int lda, Complex64* w, Complex64* wi, Complex64* vl, int ldvl, Complex64* vr, int ldvr, Complex64* work, int lwork, float* rwork) {
|
63
|
+
int info;
|
64
|
+
cgeev_(&jobvl, &jobvr, &n, a, &lda, w, vl, &ldvl, vr, &ldvr, work, &lwork, rwork, &info);
|
65
|
+
return info;
|
66
|
+
}
|
67
|
+
|
68
|
+
template <>
|
69
|
+
inline int geev(char jobvl, char jobvr, int n, Complex128* a, int lda, Complex128* w, Complex128* wi, Complex128* vl, int ldvl, Complex128* vr, int ldvr, Complex128* work, int lwork, double* rwork) {
|
70
|
+
int info;
|
71
|
+
zgeev_(&jobvl, &jobvr, &n, a, &lda, w, vl, &ldvl, vr, &ldvr, work, &lwork, rwork, &info);
|
72
|
+
return info;
|
73
|
+
}
|
74
|
+
|
75
|
+
template <typename DType, typename CType>
|
76
|
+
inline int lapack_geev(char jobvl, char jobvr, int n, void* a, int lda, void* w, void* wi, void* vl, int ldvl, void* vr, int ldvr, void* work, int lwork, void* rwork) {
|
77
|
+
return geev<DType,CType>(jobvl, jobvr, n, reinterpret_cast<DType*>(a), lda, reinterpret_cast<DType*>(w), reinterpret_cast<DType*>(wi), reinterpret_cast<DType*>(vl), ldvl, reinterpret_cast<DType*>(vr), ldvr, reinterpret_cast<DType*>(work), lwork, reinterpret_cast<CType*>(rwork));
|
78
|
+
}
|
79
|
+
|
80
|
+
}} // end nm::math
|
81
|
+
|
82
|
+
#endif // GEEV_H
|
@@ -0,0 +1,271 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == gemm.h
|
25
|
+
//
|
26
|
+
// Header file for interface with ATLAS's CBLAS gemm functions and
|
27
|
+
// native templated version of LAPACK's gemm function.
|
28
|
+
//
|
29
|
+
|
30
|
+
#ifndef GEMM_H
|
31
|
+
# define GEMM_H
|
32
|
+
|
33
|
+
extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
|
34
|
+
#if defined HAVE_CBLAS_H
|
35
|
+
#include <cblas.h>
|
36
|
+
#elif defined HAVE_ATLAS_CBLAS_H
|
37
|
+
#include <atlas/cblas.h>
|
38
|
+
#endif
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
namespace nm { namespace math {
|
43
|
+
/*
|
44
|
+
* GEneral Matrix Multiplication: based on dgemm.f from Netlib.
|
45
|
+
*
|
46
|
+
* This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
|
47
|
+
*
|
48
|
+
* Template parameters: LT -- long version of type T. Type T is the matrix dtype.
|
49
|
+
*
|
50
|
+
* This version throws no errors. Use gemm<DType> instead for error checking.
|
51
|
+
*/
|
52
|
+
template <typename DType>
|
53
|
+
inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
54
|
+
const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
|
55
|
+
{
|
56
|
+
|
57
|
+
typename LongDType<DType>::type temp;
|
58
|
+
|
59
|
+
// Quick return if possible
|
60
|
+
if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;
|
61
|
+
|
62
|
+
// For alpha = 0
|
63
|
+
if (*alpha == 0) {
|
64
|
+
if (*beta == 0) {
|
65
|
+
for (int j = 0; j < N; ++j)
|
66
|
+
for (int i = 0; i < M; ++i) {
|
67
|
+
C[i+j*ldc] = 0;
|
68
|
+
}
|
69
|
+
} else {
|
70
|
+
for (int j = 0; j < N; ++j)
|
71
|
+
for (int i = 0; i < M; ++i) {
|
72
|
+
C[i+j*ldc] *= *beta;
|
73
|
+
}
|
74
|
+
}
|
75
|
+
return;
|
76
|
+
}
|
77
|
+
|
78
|
+
// Start the operations
|
79
|
+
if (TransB == CblasNoTrans) {
|
80
|
+
if (TransA == CblasNoTrans) {
|
81
|
+
// C = alpha*A*B+beta*C
|
82
|
+
for (int j = 0; j < N; ++j) {
|
83
|
+
if (*beta == 0) {
|
84
|
+
for (int i = 0; i < M; ++i) {
|
85
|
+
C[i+j*ldc] = 0;
|
86
|
+
}
|
87
|
+
} else if (*beta != 1) {
|
88
|
+
for (int i = 0; i < M; ++i) {
|
89
|
+
C[i+j*ldc] *= *beta;
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
for (int l = 0; l < K; ++l) {
|
94
|
+
if (B[l+j*ldb] != 0) {
|
95
|
+
temp = *alpha * B[l+j*ldb];
|
96
|
+
for (int i = 0; i < M; ++i) {
|
97
|
+
C[i+j*ldc] += A[i+l*lda] * temp;
|
98
|
+
}
|
99
|
+
}
|
100
|
+
}
|
101
|
+
}
|
102
|
+
|
103
|
+
} else {
|
104
|
+
|
105
|
+
// C = alpha*A**DType*B + beta*C
|
106
|
+
for (int j = 0; j < N; ++j) {
|
107
|
+
for (int i = 0; i < M; ++i) {
|
108
|
+
temp = 0;
|
109
|
+
for (int l = 0; l < K; ++l) {
|
110
|
+
temp += A[l+i*lda] * B[l+j*ldb];
|
111
|
+
}
|
112
|
+
|
113
|
+
if (*beta == 0) {
|
114
|
+
C[i+j*ldc] = *alpha*temp;
|
115
|
+
} else {
|
116
|
+
C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
|
117
|
+
}
|
118
|
+
}
|
119
|
+
}
|
120
|
+
|
121
|
+
}
|
122
|
+
|
123
|
+
} else if (TransA == CblasNoTrans) {
|
124
|
+
|
125
|
+
// C = alpha*A*B**T + beta*C
|
126
|
+
for (int j = 0; j < N; ++j) {
|
127
|
+
if (*beta == 0) {
|
128
|
+
for (int i = 0; i < M; ++i) {
|
129
|
+
C[i+j*ldc] = 0;
|
130
|
+
}
|
131
|
+
} else if (*beta != 1) {
|
132
|
+
for (int i = 0; i < M; ++i) {
|
133
|
+
C[i+j*ldc] *= *beta;
|
134
|
+
}
|
135
|
+
}
|
136
|
+
|
137
|
+
for (int l = 0; l < K; ++l) {
|
138
|
+
if (B[j+l*ldb] != 0) {
|
139
|
+
temp = *alpha * B[j+l*ldb];
|
140
|
+
for (int i = 0; i < M; ++i) {
|
141
|
+
C[i+j*ldc] += A[i+l*lda] * temp;
|
142
|
+
}
|
143
|
+
}
|
144
|
+
}
|
145
|
+
|
146
|
+
}
|
147
|
+
|
148
|
+
} else {
|
149
|
+
|
150
|
+
// C = alpha*A**DType*B**T + beta*C
|
151
|
+
for (int j = 0; j < N; ++j) {
|
152
|
+
for (int i = 0; i < M; ++i) {
|
153
|
+
temp = 0;
|
154
|
+
for (int l = 0; l < K; ++l) {
|
155
|
+
temp += A[l+i*lda] * B[j+l*ldb];
|
156
|
+
}
|
157
|
+
|
158
|
+
if (*beta == 0) {
|
159
|
+
C[i+j*ldc] = *alpha*temp;
|
160
|
+
} else {
|
161
|
+
C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
|
162
|
+
}
|
163
|
+
}
|
164
|
+
}
|
165
|
+
|
166
|
+
}
|
167
|
+
|
168
|
+
return;
|
169
|
+
}
|
170
|
+
|
171
|
+
|
172
|
+
|
173
|
+
template <typename DType>
|
174
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
175
|
+
const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
|
176
|
+
{
|
177
|
+
if (Order == CblasRowMajor) {
|
178
|
+
if (TransA == CblasNoTrans) {
|
179
|
+
if (lda < std::max(K,1)) {
|
180
|
+
rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
|
181
|
+
}
|
182
|
+
} else {
|
183
|
+
if (lda < std::max(M,1)) { // && TransA == CblasTrans
|
184
|
+
rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
|
185
|
+
}
|
186
|
+
}
|
187
|
+
|
188
|
+
if (TransB == CblasNoTrans) {
|
189
|
+
if (ldb < std::max(N,1)) {
|
190
|
+
rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
|
191
|
+
}
|
192
|
+
} else {
|
193
|
+
if (ldb < std::max(K,1)) {
|
194
|
+
rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
|
195
|
+
}
|
196
|
+
}
|
197
|
+
|
198
|
+
if (ldc < std::max(N,1)) {
|
199
|
+
rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
|
200
|
+
}
|
201
|
+
} else { // CblasColMajor
|
202
|
+
if (TransA == CblasNoTrans) {
|
203
|
+
if (lda < std::max(M,1)) {
|
204
|
+
rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
|
205
|
+
}
|
206
|
+
} else {
|
207
|
+
if (lda < std::max(K,1)) { // && TransA == CblasTrans
|
208
|
+
rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
|
209
|
+
}
|
210
|
+
}
|
211
|
+
|
212
|
+
if (TransB == CblasNoTrans) {
|
213
|
+
if (ldb < std::max(K,1)) {
|
214
|
+
rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
|
215
|
+
}
|
216
|
+
} else {
|
217
|
+
if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
|
218
|
+
rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
|
219
|
+
}
|
220
|
+
}
|
221
|
+
|
222
|
+
if (ldc < std::max(M,1)) {
|
223
|
+
rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
|
224
|
+
}
|
225
|
+
}
|
226
|
+
|
227
|
+
/*
|
228
|
+
* Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
|
229
|
+
* we copy C and then subtract to preserve asymmetry.
|
230
|
+
*/
|
231
|
+
|
232
|
+
if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
|
233
|
+
rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
|
234
|
+
/*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
|
235
|
+
syreflect(CblasUpper, N, C, ldc);
|
236
|
+
*/
|
237
|
+
}
|
238
|
+
|
239
|
+
if (Order == CblasRowMajor) gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
|
240
|
+
else gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
241
|
+
|
242
|
+
}
|
243
|
+
|
244
|
+
|
245
|
+
template <>
|
246
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
247
|
+
const float* alpha, const float* A, const int lda, const float* B, const int ldb, const float* beta, float* C, const int ldc) {
|
248
|
+
cblas_sgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
|
249
|
+
}
|
250
|
+
|
251
|
+
template <>
|
252
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
253
|
+
const double* alpha, const double* A, const int lda, const double* B, const int ldb, const double* beta, double* C, const int ldc) {
|
254
|
+
cblas_dgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
|
255
|
+
}
|
256
|
+
|
257
|
+
template <>
|
258
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
259
|
+
const Complex64* alpha, const Complex64* A, const int lda, const Complex64* B, const int ldb, const Complex64* beta, Complex64* C, const int ldc) {
|
260
|
+
cblas_cgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
261
|
+
}
|
262
|
+
|
263
|
+
template <>
|
264
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
265
|
+
const Complex128* alpha, const Complex128* A, const int lda, const Complex128* B, const int ldb, const Complex128* beta, Complex128* C, const int ldc) {
|
266
|
+
cblas_zgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
267
|
+
}
|
268
|
+
|
269
|
+
}} // end of namespace nm::math
|
270
|
+
|
271
|
+
#endif // GEMM_H
|