nmatrix-gemv 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +29 -0
- data/.rspec +2 -0
- data/.travis.yml +14 -0
- data/Gemfile +7 -0
- data/README.md +29 -0
- data/Rakefile +225 -0
- data/ext/nmatrix_gemv/binary_format.txt +53 -0
- data/ext/nmatrix_gemv/data/complex.h +399 -0
- data/ext/nmatrix_gemv/data/data.cpp +298 -0
- data/ext/nmatrix_gemv/data/data.h +771 -0
- data/ext/nmatrix_gemv/data/meta.h +70 -0
- data/ext/nmatrix_gemv/data/rational.h +436 -0
- data/ext/nmatrix_gemv/data/ruby_object.h +471 -0
- data/ext/nmatrix_gemv/extconf.rb +254 -0
- data/ext/nmatrix_gemv/math.cpp +1639 -0
- data/ext/nmatrix_gemv/math/asum.h +143 -0
- data/ext/nmatrix_gemv/math/geev.h +82 -0
- data/ext/nmatrix_gemv/math/gemm.h +271 -0
- data/ext/nmatrix_gemv/math/gemv.h +212 -0
- data/ext/nmatrix_gemv/math/ger.h +96 -0
- data/ext/nmatrix_gemv/math/gesdd.h +80 -0
- data/ext/nmatrix_gemv/math/gesvd.h +78 -0
- data/ext/nmatrix_gemv/math/getf2.h +86 -0
- data/ext/nmatrix_gemv/math/getrf.h +240 -0
- data/ext/nmatrix_gemv/math/getri.h +108 -0
- data/ext/nmatrix_gemv/math/getrs.h +129 -0
- data/ext/nmatrix_gemv/math/idamax.h +86 -0
- data/ext/nmatrix_gemv/math/inc.h +47 -0
- data/ext/nmatrix_gemv/math/laswp.h +165 -0
- data/ext/nmatrix_gemv/math/long_dtype.h +52 -0
- data/ext/nmatrix_gemv/math/math.h +1069 -0
- data/ext/nmatrix_gemv/math/nrm2.h +181 -0
- data/ext/nmatrix_gemv/math/potrs.h +129 -0
- data/ext/nmatrix_gemv/math/rot.h +141 -0
- data/ext/nmatrix_gemv/math/rotg.h +115 -0
- data/ext/nmatrix_gemv/math/scal.h +73 -0
- data/ext/nmatrix_gemv/math/swap.h +73 -0
- data/ext/nmatrix_gemv/math/trsm.h +387 -0
- data/ext/nmatrix_gemv/nm_memory.h +60 -0
- data/ext/nmatrix_gemv/nmatrix_gemv.cpp +90 -0
- data/ext/nmatrix_gemv/nmatrix_gemv.h +374 -0
- data/ext/nmatrix_gemv/ruby_constants.cpp +153 -0
- data/ext/nmatrix_gemv/ruby_constants.h +107 -0
- data/ext/nmatrix_gemv/ruby_nmatrix.c +84 -0
- data/ext/nmatrix_gemv/ttable_helper.rb +122 -0
- data/ext/nmatrix_gemv/types.h +54 -0
- data/ext/nmatrix_gemv/util/util.h +78 -0
- data/lib/nmatrix-gemv.rb +43 -0
- data/lib/nmatrix_gemv/blas.rb +85 -0
- data/lib/nmatrix_gemv/nmatrix_gemv.rb +35 -0
- data/lib/nmatrix_gemv/rspec.rb +75 -0
- data/nmatrix-gemv.gemspec +31 -0
- data/spec/blas_spec.rb +154 -0
- data/spec/spec_helper.rb +128 -0
- metadata +186 -0
@@ -0,0 +1,143 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == asum.h
|
25
|
+
//
|
26
|
+
// CBLAS asum function
|
27
|
+
//
|
28
|
+
|
29
|
+
/*
|
30
|
+
* Automatically Tuned Linear Algebra Software v3.8.4
|
31
|
+
* (C) Copyright 1999 R. Clint Whaley
|
32
|
+
*
|
33
|
+
* Redistribution and use in source and binary forms, with or without
|
34
|
+
* modification, are permitted provided that the following conditions
|
35
|
+
* are met:
|
36
|
+
* 1. Redistributions of source code must retain the above copyright
|
37
|
+
* notice, this list of conditions and the following disclaimer.
|
38
|
+
* 2. Redistributions in binary form must reproduce the above copyright
|
39
|
+
* notice, this list of conditions, and the following disclaimer in the
|
40
|
+
* documentation and/or other materials provided with the distribution.
|
41
|
+
* 3. The name of the ATLAS group or the names of its contributers may
|
42
|
+
* not be used to endorse or promote products derived from this
|
43
|
+
* software without specific written permission.
|
44
|
+
*
|
45
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
46
|
+
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
47
|
+
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
48
|
+
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
|
49
|
+
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
50
|
+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
51
|
+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
52
|
+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
53
|
+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
54
|
+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
55
|
+
* POSSIBILITY OF SUCH DAMAGE.
|
56
|
+
*
|
57
|
+
*/
|
58
|
+
|
59
|
+
#ifndef ASUM_H
|
60
|
+
# define ASUM_H
|
61
|
+
|
62
|
+
|
63
|
+
namespace nm { namespace math {
|
64
|
+
|
65
|
+
/*
|
66
|
+
* Level 1 BLAS routine which sums the absolute values of a vector's contents. If the vector consists of complex values,
|
67
|
+
* the routine sums the absolute values of the real and imaginary components as well.
|
68
|
+
*
|
69
|
+
* So, based on input types, these are the valid return types:
|
70
|
+
* int -> int
|
71
|
+
* float -> float or double
|
72
|
+
* double -> double
|
73
|
+
* complex64 -> float or double
|
74
|
+
* complex128 -> double
|
75
|
+
* rational -> rational
|
76
|
+
*/
|
77
|
+
template <typename ReturnDType, typename DType>
|
78
|
+
inline ReturnDType asum(const int N, const DType* X, const int incX) {
|
79
|
+
ReturnDType sum = 0;
|
80
|
+
if ((N > 0) && (incX > 0)) {
|
81
|
+
for (int i = 0; i < N; ++i) {
|
82
|
+
sum += std::abs(X[i*incX]);
|
83
|
+
}
|
84
|
+
}
|
85
|
+
return sum;
|
86
|
+
}
|
87
|
+
|
88
|
+
|
89
|
+
#if defined HAVE_CBLAS_H || defined HAVE_ATLAS_CBLAS_H
|
90
|
+
template <>
|
91
|
+
inline float asum(const int N, const float* X, const int incX) {
|
92
|
+
return cblas_sasum(N, X, incX);
|
93
|
+
}
|
94
|
+
|
95
|
+
template <>
|
96
|
+
inline double asum(const int N, const double* X, const int incX) {
|
97
|
+
return cblas_dasum(N, X, incX);
|
98
|
+
}
|
99
|
+
|
100
|
+
template <>
|
101
|
+
inline float asum(const int N, const Complex64* X, const int incX) {
|
102
|
+
return cblas_scasum(N, X, incX);
|
103
|
+
}
|
104
|
+
|
105
|
+
template <>
|
106
|
+
inline double asum(const int N, const Complex128* X, const int incX) {
|
107
|
+
return cblas_dzasum(N, X, incX);
|
108
|
+
}
|
109
|
+
#else
|
110
|
+
template <>
|
111
|
+
inline float asum(const int N, const Complex64* X, const int incX) {
|
112
|
+
float sum = 0;
|
113
|
+
if ((N > 0) && (incX > 0)) {
|
114
|
+
for (int i = 0; i < N; ++i) {
|
115
|
+
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
|
116
|
+
}
|
117
|
+
}
|
118
|
+
return sum;
|
119
|
+
}
|
120
|
+
|
121
|
+
template <>
|
122
|
+
inline double asum(const int N, const Complex128* X, const int incX) {
|
123
|
+
double sum = 0;
|
124
|
+
if ((N > 0) && (incX > 0)) {
|
125
|
+
for (int i = 0; i < N; ++i) {
|
126
|
+
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
|
127
|
+
}
|
128
|
+
}
|
129
|
+
return sum;
|
130
|
+
}
|
131
|
+
#endif
|
132
|
+
|
133
|
+
|
134
|
+
template <typename ReturnDType, typename DType>
|
135
|
+
inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
|
136
|
+
*reinterpret_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
|
137
|
+
}
|
138
|
+
|
139
|
+
|
140
|
+
|
141
|
+
}} // end of namespace nm::math
|
142
|
+
|
143
|
+
#endif // NRM2_H
|
@@ -0,0 +1,82 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == geev.h
|
25
|
+
//
|
26
|
+
// Header file for interface with LAPACK's xGEEV functions.
|
27
|
+
//
|
28
|
+
|
29
|
+
#ifndef GEEV_H
|
30
|
+
# define GEEV_H
|
31
|
+
|
32
|
+
extern "C" {
|
33
|
+
void sgeev_(char* jobvl, char* jobvr, int* n, float* a, int* lda, float* wr, float* wi, float* vl, int* ldvl, float* vr, int* ldvr, float* work, int* lwork, int* info);
|
34
|
+
void dgeev_(char* jobvl, char* jobvr, int* n, double* a, int* lda, double* wr, double* wi, double* vl, int* ldvl, double* vr, int* ldvr, double* work, int* lwork, int* info);
|
35
|
+
void cgeev_(char* jobvl, char* jobvr, int* n, nm::Complex64* a, int* lda, nm::Complex64* w, nm::Complex64* vl, int* ldvl, nm::Complex64* vr, int* ldvr, nm::Complex64* work, int* lwork, float* rwork, int* info);
|
36
|
+
void zgeev_(char* jobvl, char* jobvr, int* n, nm::Complex128* a, int* lda, nm::Complex128* w, nm::Complex128* vl, int* ldvl, nm::Complex128* vr, int* ldvr, nm::Complex128* work, int* lwork, double* rwork, int* info);
|
37
|
+
}
|
38
|
+
|
39
|
+
namespace nm { namespace math {
|
40
|
+
|
41
|
+
template <typename DType, typename CType> // wr
|
42
|
+
inline int geev(char jobvl, char jobvr, int n, DType* a, int lda, DType* w, DType* wi, DType* vl, int ldvl, DType* vr, int ldvr, DType* work, int lwork, CType* rwork) {
|
43
|
+
rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
|
44
|
+
return -1;
|
45
|
+
}
|
46
|
+
|
47
|
+
template <>
|
48
|
+
inline int geev(char jobvl, char jobvr, int n, float* a, int lda, float* w, float* wi, float* vl, int ldvl, float* vr, int ldvr, float* work, int lwork, float* rwork) {
|
49
|
+
int info;
|
50
|
+
sgeev_(&jobvl, &jobvr, &n, a, &lda, w, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info);
|
51
|
+
return info;
|
52
|
+
}
|
53
|
+
|
54
|
+
template <>
|
55
|
+
inline int geev(char jobvl, char jobvr, int n, double* a, int lda, double* w, double* wi, double* vl, int ldvl, double* vr, int ldvr, double* work, int lwork, double* rwork) {
|
56
|
+
int info;
|
57
|
+
dgeev_(&jobvl, &jobvr, &n, a, &lda, w, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info);
|
58
|
+
return info;
|
59
|
+
}
|
60
|
+
|
61
|
+
template <>
|
62
|
+
inline int geev(char jobvl, char jobvr, int n, Complex64* a, int lda, Complex64* w, Complex64* wi, Complex64* vl, int ldvl, Complex64* vr, int ldvr, Complex64* work, int lwork, float* rwork) {
|
63
|
+
int info;
|
64
|
+
cgeev_(&jobvl, &jobvr, &n, a, &lda, w, vl, &ldvl, vr, &ldvr, work, &lwork, rwork, &info);
|
65
|
+
return info;
|
66
|
+
}
|
67
|
+
|
68
|
+
template <>
|
69
|
+
inline int geev(char jobvl, char jobvr, int n, Complex128* a, int lda, Complex128* w, Complex128* wi, Complex128* vl, int ldvl, Complex128* vr, int ldvr, Complex128* work, int lwork, double* rwork) {
|
70
|
+
int info;
|
71
|
+
zgeev_(&jobvl, &jobvr, &n, a, &lda, w, vl, &ldvl, vr, &ldvr, work, &lwork, rwork, &info);
|
72
|
+
return info;
|
73
|
+
}
|
74
|
+
|
75
|
+
template <typename DType, typename CType>
|
76
|
+
inline int lapack_geev(char jobvl, char jobvr, int n, void* a, int lda, void* w, void* wi, void* vl, int ldvl, void* vr, int ldvr, void* work, int lwork, void* rwork) {
|
77
|
+
return geev<DType,CType>(jobvl, jobvr, n, reinterpret_cast<DType*>(a), lda, reinterpret_cast<DType*>(w), reinterpret_cast<DType*>(wi), reinterpret_cast<DType*>(vl), ldvl, reinterpret_cast<DType*>(vr), ldvr, reinterpret_cast<DType*>(work), lwork, reinterpret_cast<CType*>(rwork));
|
78
|
+
}
|
79
|
+
|
80
|
+
}} // end nm::math
|
81
|
+
|
82
|
+
#endif // GEEV_H
|
@@ -0,0 +1,271 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == gemm.h
|
25
|
+
//
|
26
|
+
// Header file for interface with ATLAS's CBLAS gemm functions and
|
27
|
+
// native templated version of LAPACK's gemm function.
|
28
|
+
//
|
29
|
+
|
30
|
+
#ifndef GEMM_H
|
31
|
+
# define GEMM_H
|
32
|
+
|
33
|
+
extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
|
34
|
+
#if defined HAVE_CBLAS_H
|
35
|
+
#include <cblas.h>
|
36
|
+
#elif defined HAVE_ATLAS_CBLAS_H
|
37
|
+
#include <atlas/cblas.h>
|
38
|
+
#endif
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
namespace nm { namespace math {
|
43
|
+
/*
|
44
|
+
* GEneral Matrix Multiplication: based on dgemm.f from Netlib.
|
45
|
+
*
|
46
|
+
* This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
|
47
|
+
*
|
48
|
+
* Template parameters: LT -- long version of type T. Type T is the matrix dtype.
|
49
|
+
*
|
50
|
+
* This version throws no errors. Use gemm<DType> instead for error checking.
|
51
|
+
*/
|
52
|
+
template <typename DType>
|
53
|
+
inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
54
|
+
const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
|
55
|
+
{
|
56
|
+
|
57
|
+
typename LongDType<DType>::type temp;
|
58
|
+
|
59
|
+
// Quick return if possible
|
60
|
+
if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;
|
61
|
+
|
62
|
+
// For alpha = 0
|
63
|
+
if (*alpha == 0) {
|
64
|
+
if (*beta == 0) {
|
65
|
+
for (int j = 0; j < N; ++j)
|
66
|
+
for (int i = 0; i < M; ++i) {
|
67
|
+
C[i+j*ldc] = 0;
|
68
|
+
}
|
69
|
+
} else {
|
70
|
+
for (int j = 0; j < N; ++j)
|
71
|
+
for (int i = 0; i < M; ++i) {
|
72
|
+
C[i+j*ldc] *= *beta;
|
73
|
+
}
|
74
|
+
}
|
75
|
+
return;
|
76
|
+
}
|
77
|
+
|
78
|
+
// Start the operations
|
79
|
+
if (TransB == CblasNoTrans) {
|
80
|
+
if (TransA == CblasNoTrans) {
|
81
|
+
// C = alpha*A*B+beta*C
|
82
|
+
for (int j = 0; j < N; ++j) {
|
83
|
+
if (*beta == 0) {
|
84
|
+
for (int i = 0; i < M; ++i) {
|
85
|
+
C[i+j*ldc] = 0;
|
86
|
+
}
|
87
|
+
} else if (*beta != 1) {
|
88
|
+
for (int i = 0; i < M; ++i) {
|
89
|
+
C[i+j*ldc] *= *beta;
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
for (int l = 0; l < K; ++l) {
|
94
|
+
if (B[l+j*ldb] != 0) {
|
95
|
+
temp = *alpha * B[l+j*ldb];
|
96
|
+
for (int i = 0; i < M; ++i) {
|
97
|
+
C[i+j*ldc] += A[i+l*lda] * temp;
|
98
|
+
}
|
99
|
+
}
|
100
|
+
}
|
101
|
+
}
|
102
|
+
|
103
|
+
} else {
|
104
|
+
|
105
|
+
// C = alpha*A**DType*B + beta*C
|
106
|
+
for (int j = 0; j < N; ++j) {
|
107
|
+
for (int i = 0; i < M; ++i) {
|
108
|
+
temp = 0;
|
109
|
+
for (int l = 0; l < K; ++l) {
|
110
|
+
temp += A[l+i*lda] * B[l+j*ldb];
|
111
|
+
}
|
112
|
+
|
113
|
+
if (*beta == 0) {
|
114
|
+
C[i+j*ldc] = *alpha*temp;
|
115
|
+
} else {
|
116
|
+
C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
|
117
|
+
}
|
118
|
+
}
|
119
|
+
}
|
120
|
+
|
121
|
+
}
|
122
|
+
|
123
|
+
} else if (TransA == CblasNoTrans) {
|
124
|
+
|
125
|
+
// C = alpha*A*B**T + beta*C
|
126
|
+
for (int j = 0; j < N; ++j) {
|
127
|
+
if (*beta == 0) {
|
128
|
+
for (int i = 0; i < M; ++i) {
|
129
|
+
C[i+j*ldc] = 0;
|
130
|
+
}
|
131
|
+
} else if (*beta != 1) {
|
132
|
+
for (int i = 0; i < M; ++i) {
|
133
|
+
C[i+j*ldc] *= *beta;
|
134
|
+
}
|
135
|
+
}
|
136
|
+
|
137
|
+
for (int l = 0; l < K; ++l) {
|
138
|
+
if (B[j+l*ldb] != 0) {
|
139
|
+
temp = *alpha * B[j+l*ldb];
|
140
|
+
for (int i = 0; i < M; ++i) {
|
141
|
+
C[i+j*ldc] += A[i+l*lda] * temp;
|
142
|
+
}
|
143
|
+
}
|
144
|
+
}
|
145
|
+
|
146
|
+
}
|
147
|
+
|
148
|
+
} else {
|
149
|
+
|
150
|
+
// C = alpha*A**DType*B**T + beta*C
|
151
|
+
for (int j = 0; j < N; ++j) {
|
152
|
+
for (int i = 0; i < M; ++i) {
|
153
|
+
temp = 0;
|
154
|
+
for (int l = 0; l < K; ++l) {
|
155
|
+
temp += A[l+i*lda] * B[j+l*ldb];
|
156
|
+
}
|
157
|
+
|
158
|
+
if (*beta == 0) {
|
159
|
+
C[i+j*ldc] = *alpha*temp;
|
160
|
+
} else {
|
161
|
+
C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
|
162
|
+
}
|
163
|
+
}
|
164
|
+
}
|
165
|
+
|
166
|
+
}
|
167
|
+
|
168
|
+
return;
|
169
|
+
}
|
170
|
+
|
171
|
+
|
172
|
+
|
173
|
+
template <typename DType>
|
174
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
175
|
+
const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
|
176
|
+
{
|
177
|
+
if (Order == CblasRowMajor) {
|
178
|
+
if (TransA == CblasNoTrans) {
|
179
|
+
if (lda < std::max(K,1)) {
|
180
|
+
rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
|
181
|
+
}
|
182
|
+
} else {
|
183
|
+
if (lda < std::max(M,1)) { // && TransA == CblasTrans
|
184
|
+
rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
|
185
|
+
}
|
186
|
+
}
|
187
|
+
|
188
|
+
if (TransB == CblasNoTrans) {
|
189
|
+
if (ldb < std::max(N,1)) {
|
190
|
+
rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
|
191
|
+
}
|
192
|
+
} else {
|
193
|
+
if (ldb < std::max(K,1)) {
|
194
|
+
rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
|
195
|
+
}
|
196
|
+
}
|
197
|
+
|
198
|
+
if (ldc < std::max(N,1)) {
|
199
|
+
rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
|
200
|
+
}
|
201
|
+
} else { // CblasColMajor
|
202
|
+
if (TransA == CblasNoTrans) {
|
203
|
+
if (lda < std::max(M,1)) {
|
204
|
+
rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
|
205
|
+
}
|
206
|
+
} else {
|
207
|
+
if (lda < std::max(K,1)) { // && TransA == CblasTrans
|
208
|
+
rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
|
209
|
+
}
|
210
|
+
}
|
211
|
+
|
212
|
+
if (TransB == CblasNoTrans) {
|
213
|
+
if (ldb < std::max(K,1)) {
|
214
|
+
rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
|
215
|
+
}
|
216
|
+
} else {
|
217
|
+
if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
|
218
|
+
rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
|
219
|
+
}
|
220
|
+
}
|
221
|
+
|
222
|
+
if (ldc < std::max(M,1)) {
|
223
|
+
rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
|
224
|
+
}
|
225
|
+
}
|
226
|
+
|
227
|
+
/*
|
228
|
+
* Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
|
229
|
+
* we copy C and then subtract to preserve asymmetry.
|
230
|
+
*/
|
231
|
+
|
232
|
+
if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
|
233
|
+
rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
|
234
|
+
/*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
|
235
|
+
syreflect(CblasUpper, N, C, ldc);
|
236
|
+
*/
|
237
|
+
}
|
238
|
+
|
239
|
+
if (Order == CblasRowMajor) gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
|
240
|
+
else gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
241
|
+
|
242
|
+
}
|
243
|
+
|
244
|
+
|
245
|
+
template <>
|
246
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
247
|
+
const float* alpha, const float* A, const int lda, const float* B, const int ldb, const float* beta, float* C, const int ldc) {
|
248
|
+
cblas_sgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
|
249
|
+
}
|
250
|
+
|
251
|
+
template <>
|
252
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
253
|
+
const double* alpha, const double* A, const int lda, const double* B, const int ldb, const double* beta, double* C, const int ldc) {
|
254
|
+
cblas_dgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
|
255
|
+
}
|
256
|
+
|
257
|
+
template <>
|
258
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
259
|
+
const Complex64* alpha, const Complex64* A, const int lda, const Complex64* B, const int ldb, const Complex64* beta, Complex64* C, const int ldc) {
|
260
|
+
cblas_cgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
261
|
+
}
|
262
|
+
|
263
|
+
template <>
|
264
|
+
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
265
|
+
const Complex128* alpha, const Complex128* A, const int lda, const Complex128* B, const int ldb, const Complex128* beta, Complex128* C, const int ldc) {
|
266
|
+
cblas_zgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
267
|
+
}
|
268
|
+
|
269
|
+
}} // end of namespace nm::math
|
270
|
+
|
271
|
+
#endif // GEMM_H
|