nmatrix 0.0.6 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -0
  3. data/Gemfile +5 -0
  4. data/History.txt +97 -0
  5. data/Manifest.txt +34 -7
  6. data/README.rdoc +13 -13
  7. data/Rakefile +36 -26
  8. data/ext/nmatrix/data/data.cpp +15 -2
  9. data/ext/nmatrix/data/data.h +4 -0
  10. data/ext/nmatrix/data/ruby_object.h +5 -14
  11. data/ext/nmatrix/extconf.rb +3 -2
  12. data/ext/nmatrix/{util/math.cpp → math.cpp} +296 -6
  13. data/ext/nmatrix/math/asum.h +143 -0
  14. data/ext/nmatrix/math/geev.h +82 -0
  15. data/ext/nmatrix/math/gemm.h +267 -0
  16. data/ext/nmatrix/math/gemv.h +208 -0
  17. data/ext/nmatrix/math/ger.h +96 -0
  18. data/ext/nmatrix/math/gesdd.h +80 -0
  19. data/ext/nmatrix/math/gesvd.h +78 -0
  20. data/ext/nmatrix/math/getf2.h +86 -0
  21. data/ext/nmatrix/math/getrf.h +240 -0
  22. data/ext/nmatrix/math/getri.h +107 -0
  23. data/ext/nmatrix/math/getrs.h +125 -0
  24. data/ext/nmatrix/math/idamax.h +86 -0
  25. data/ext/nmatrix/{util → math}/lapack.h +60 -356
  26. data/ext/nmatrix/math/laswp.h +165 -0
  27. data/ext/nmatrix/math/long_dtype.h +52 -0
  28. data/ext/nmatrix/math/math.h +1154 -0
  29. data/ext/nmatrix/math/nrm2.h +181 -0
  30. data/ext/nmatrix/math/potrs.h +125 -0
  31. data/ext/nmatrix/math/rot.h +141 -0
  32. data/ext/nmatrix/math/rotg.h +115 -0
  33. data/ext/nmatrix/math/scal.h +73 -0
  34. data/ext/nmatrix/math/swap.h +73 -0
  35. data/ext/nmatrix/math/trsm.h +383 -0
  36. data/ext/nmatrix/nmatrix.cpp +176 -152
  37. data/ext/nmatrix/nmatrix.h +1 -2
  38. data/ext/nmatrix/ruby_constants.cpp +9 -4
  39. data/ext/nmatrix/ruby_constants.h +1 -0
  40. data/ext/nmatrix/storage/dense.cpp +57 -41
  41. data/ext/nmatrix/storage/list.cpp +52 -50
  42. data/ext/nmatrix/storage/storage.cpp +59 -43
  43. data/ext/nmatrix/storage/yale.cpp +352 -333
  44. data/ext/nmatrix/storage/yale.h +4 -0
  45. data/lib/nmatrix.rb +2 -2
  46. data/lib/nmatrix/blas.rb +4 -4
  47. data/lib/nmatrix/enumerate.rb +241 -0
  48. data/lib/nmatrix/lapack.rb +54 -1
  49. data/lib/nmatrix/math.rb +462 -0
  50. data/lib/nmatrix/nmatrix.rb +210 -486
  51. data/lib/nmatrix/nvector.rb +0 -62
  52. data/lib/nmatrix/rspec.rb +75 -0
  53. data/lib/nmatrix/shortcuts.rb +136 -108
  54. data/lib/nmatrix/version.rb +1 -1
  55. data/spec/blas_spec.rb +20 -12
  56. data/spec/elementwise_spec.rb +22 -13
  57. data/spec/io_spec.rb +1 -0
  58. data/spec/lapack_spec.rb +197 -0
  59. data/spec/nmatrix_spec.rb +39 -38
  60. data/spec/nvector_spec.rb +3 -9
  61. data/spec/rspec_monkeys.rb +29 -0
  62. data/spec/rspec_spec.rb +34 -0
  63. data/spec/shortcuts_spec.rb +14 -16
  64. data/spec/slice_spec.rb +242 -186
  65. data/spec/spec_helper.rb +19 -0
  66. metadata +33 -5
  67. data/ext/nmatrix/util/math.h +0 -2612
@@ -0,0 +1,82 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2013, Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == geev.h
25
+ //
26
+ // Header file for interface with LAPACK's xGEEV functions.
27
+ //
28
+
29
+ #ifndef GEEV_H
30
+ # define GEEV_H
31
+
32
+ extern "C" {
33
+ void sgeev_(char* jobvl, char* jobvr, int* n, float* a, int* lda, float* wr, float* wi, float* vl, int* ldvl, float* vr, int* ldvr, float* work, int* lwork, int* info);
34
+ void dgeev_(char* jobvl, char* jobvr, int* n, double* a, int* lda, double* wr, double* wi, double* vl, int* ldvl, double* vr, int* ldvr, double* work, int* lwork, int* info);
35
+ void cgeev_(char* jobvl, char* jobvr, int* n, nm::Complex64* a, int* lda, nm::Complex64* w, nm::Complex64* vl, int* ldvl, nm::Complex64* vr, int* ldvr, nm::Complex64* work, int* lwork, float* rwork, int* info);
36
+ void zgeev_(char* jobvl, char* jobvr, int* n, nm::Complex128* a, int* lda, nm::Complex128* w, nm::Complex128* vl, int* ldvl, nm::Complex128* vr, int* ldvr, nm::Complex128* work, int* lwork, double* rwork, int* info);
37
+ }
38
+
39
+ namespace nm { namespace math {
40
+
41
+ template <typename DType, typename CType> // wr
42
+ inline int geev(char jobvl, char jobvr, int n, DType* a, int lda, DType* w, DType* wi, DType* vl, int ldvl, DType* vr, int ldvr, DType* work, int lwork, CType* rwork) {
43
+ rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
44
+ return -1;
45
+ }
46
+
47
+ template <>
48
+ inline int geev(char jobvl, char jobvr, int n, float* a, int lda, float* w, float* wi, float* vl, int ldvl, float* vr, int ldvr, float* work, int lwork, float* rwork) {
49
+ int info;
50
+ sgeev_(&jobvl, &jobvr, &n, a, &lda, w, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info);
51
+ return info;
52
+ }
53
+
54
+ template <>
55
+ inline int geev(char jobvl, char jobvr, int n, double* a, int lda, double* w, double* wi, double* vl, int ldvl, double* vr, int ldvr, double* work, int lwork, double* rwork) {
56
+ int info;
57
+ dgeev_(&jobvl, &jobvr, &n, a, &lda, w, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info);
58
+ return info;
59
+ }
60
+
61
+ template <>
62
+ inline int geev(char jobvl, char jobvr, int n, Complex64* a, int lda, Complex64* w, Complex64* wi, Complex64* vl, int ldvl, Complex64* vr, int ldvr, Complex64* work, int lwork, float* rwork) {
63
+ int info;
64
+ cgeev_(&jobvl, &jobvr, &n, a, &lda, w, vl, &ldvl, vr, &ldvr, work, &lwork, rwork, &info);
65
+ return info;
66
+ }
67
+
68
+ template <>
69
+ inline int geev(char jobvl, char jobvr, int n, Complex128* a, int lda, Complex128* w, Complex128* wi, Complex128* vl, int ldvl, Complex128* vr, int ldvr, Complex128* work, int lwork, double* rwork) {
70
+ int info;
71
+ zgeev_(&jobvl, &jobvr, &n, a, &lda, w, vl, &ldvl, vr, &ldvr, work, &lwork, rwork, &info);
72
+ return info;
73
+ }
74
+
75
+ template <typename DType, typename CType>
76
+ inline int lapack_geev(char jobvl, char jobvr, int n, void* a, int lda, void* w, void* wi, void* vl, int ldvl, void* vr, int ldvr, void* work, int lwork, void* rwork) {
77
+ return geev<DType,CType>(jobvl, jobvr, n, reinterpret_cast<DType*>(a), lda, reinterpret_cast<DType*>(w), reinterpret_cast<DType*>(wi), reinterpret_cast<DType*>(vl), ldvl, reinterpret_cast<DType*>(vr), ldvr, reinterpret_cast<DType*>(work), lwork, reinterpret_cast<CType*>(rwork));
78
+ }
79
+
80
+ }} // end nm::math
81
+
82
+ #endif // GEEV_H
@@ -0,0 +1,267 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2013, Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gemm.h
25
+ //
26
+ // Header file for interface with ATLAS's CBLAS gemm functions and
27
+ // native templated version of LAPACK's gemm function.
28
+ //
29
+
30
+ #ifndef GEMM_H
31
+ # define GEMM_H
32
+
33
+ extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
34
+ #include <cblas.h>
35
+ }
36
+
37
+
38
+ namespace nm { namespace math {
39
+ /*
40
+ * GEneral Matrix Multiplication: based on dgemm.f from Netlib.
41
+ *
42
+ * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
43
+ *
44
+ * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
45
+ *
46
+ * This version throws no errors. Use gemm<DType> instead for error checking.
47
+ */
48
+ template <typename DType>
49
+ inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
50
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
51
+ {
52
+
53
+ typename LongDType<DType>::type temp;
54
+
55
+ // Quick return if possible
56
+ if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;
57
+
58
+ // For alpha = 0
59
+ if (*alpha == 0) {
60
+ if (*beta == 0) {
61
+ for (int j = 0; j < N; ++j)
62
+ for (int i = 0; i < M; ++i) {
63
+ C[i+j*ldc] = 0;
64
+ }
65
+ } else {
66
+ for (int j = 0; j < N; ++j)
67
+ for (int i = 0; i < M; ++i) {
68
+ C[i+j*ldc] *= *beta;
69
+ }
70
+ }
71
+ return;
72
+ }
73
+
74
+ // Start the operations
75
+ if (TransB == CblasNoTrans) {
76
+ if (TransA == CblasNoTrans) {
77
+ // C = alpha*A*B+beta*C
78
+ for (int j = 0; j < N; ++j) {
79
+ if (*beta == 0) {
80
+ for (int i = 0; i < M; ++i) {
81
+ C[i+j*ldc] = 0;
82
+ }
83
+ } else if (*beta != 1) {
84
+ for (int i = 0; i < M; ++i) {
85
+ C[i+j*ldc] *= *beta;
86
+ }
87
+ }
88
+
89
+ for (int l = 0; l < K; ++l) {
90
+ if (B[l+j*ldb] != 0) {
91
+ temp = *alpha * B[l+j*ldb];
92
+ for (int i = 0; i < M; ++i) {
93
+ C[i+j*ldc] += A[i+l*lda] * temp;
94
+ }
95
+ }
96
+ }
97
+ }
98
+
99
+ } else {
100
+
101
+ // C = alpha*A**DType*B + beta*C
102
+ for (int j = 0; j < N; ++j) {
103
+ for (int i = 0; i < M; ++i) {
104
+ temp = 0;
105
+ for (int l = 0; l < K; ++l) {
106
+ temp += A[l+i*lda] * B[l+j*ldb];
107
+ }
108
+
109
+ if (*beta == 0) {
110
+ C[i+j*ldc] = *alpha*temp;
111
+ } else {
112
+ C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
113
+ }
114
+ }
115
+ }
116
+
117
+ }
118
+
119
+ } else if (TransA == CblasNoTrans) {
120
+
121
+ // C = alpha*A*B**T + beta*C
122
+ for (int j = 0; j < N; ++j) {
123
+ if (*beta == 0) {
124
+ for (int i = 0; i < M; ++i) {
125
+ C[i+j*ldc] = 0;
126
+ }
127
+ } else if (*beta != 1) {
128
+ for (int i = 0; i < M; ++i) {
129
+ C[i+j*ldc] *= *beta;
130
+ }
131
+ }
132
+
133
+ for (int l = 0; l < K; ++l) {
134
+ if (B[j+l*ldb] != 0) {
135
+ temp = *alpha * B[j+l*ldb];
136
+ for (int i = 0; i < M; ++i) {
137
+ C[i+j*ldc] += A[i+l*lda] * temp;
138
+ }
139
+ }
140
+ }
141
+
142
+ }
143
+
144
+ } else {
145
+
146
+ // C = alpha*A**DType*B**T + beta*C
147
+ for (int j = 0; j < N; ++j) {
148
+ for (int i = 0; i < M; ++i) {
149
+ temp = 0;
150
+ for (int l = 0; l < K; ++l) {
151
+ temp += A[l+i*lda] * B[j+l*ldb];
152
+ }
153
+
154
+ if (*beta == 0) {
155
+ C[i+j*ldc] = *alpha*temp;
156
+ } else {
157
+ C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
158
+ }
159
+ }
160
+ }
161
+
162
+ }
163
+
164
+ return;
165
+ }
166
+
167
+
168
+
169
+ template <typename DType>
170
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
171
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
172
+ {
173
+ if (Order == CblasRowMajor) {
174
+ if (TransA == CblasNoTrans) {
175
+ if (lda < std::max(K,1)) {
176
+ rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
177
+ }
178
+ } else {
179
+ if (lda < std::max(M,1)) { // && TransA == CblasTrans
180
+ rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
181
+ }
182
+ }
183
+
184
+ if (TransB == CblasNoTrans) {
185
+ if (ldb < std::max(N,1)) {
186
+ rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
187
+ }
188
+ } else {
189
+ if (ldb < std::max(K,1)) {
190
+ rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
191
+ }
192
+ }
193
+
194
+ if (ldc < std::max(N,1)) {
195
+ rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
196
+ }
197
+ } else { // CblasColMajor
198
+ if (TransA == CblasNoTrans) {
199
+ if (lda < std::max(M,1)) {
200
+ rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
201
+ }
202
+ } else {
203
+ if (lda < std::max(K,1)) { // && TransA == CblasTrans
204
+ rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
205
+ }
206
+ }
207
+
208
+ if (TransB == CblasNoTrans) {
209
+ if (ldb < std::max(K,1)) {
210
+ rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
211
+ }
212
+ } else {
213
+ if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
214
+ rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
215
+ }
216
+ }
217
+
218
+ if (ldc < std::max(M,1)) {
219
+ rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
220
+ }
221
+ }
222
+
223
+ /*
224
+ * Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
225
+ * we copy C and then subtract to preserve asymmetry.
226
+ */
227
+
228
+ if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
229
+ rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
230
+ /*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
231
+ syreflect(CblasUpper, N, C, ldc);
232
+ */
233
+ }
234
+
235
+ if (Order == CblasRowMajor) gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
236
+ else gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
237
+
238
+ }
239
+
240
+
241
+ template <>
242
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
243
+ const float* alpha, const float* A, const int lda, const float* B, const int ldb, const float* beta, float* C, const int ldc) {
244
+ cblas_sgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
245
+ }
246
+
247
+ template <>
248
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
249
+ const double* alpha, const double* A, const int lda, const double* B, const int ldb, const double* beta, double* C, const int ldc) {
250
+ cblas_dgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
251
+ }
252
+
253
+ template <>
254
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
255
+ const Complex64* alpha, const Complex64* A, const int lda, const Complex64* B, const int ldb, const Complex64* beta, Complex64* C, const int ldc) {
256
+ cblas_cgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
257
+ }
258
+
259
+ template <>
260
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
261
+ const Complex128* alpha, const Complex128* A, const int lda, const Complex128* B, const int ldb, const Complex128* beta, Complex128* C, const int ldc) {
262
+ cblas_zgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
263
+ }
264
+
265
+ }} // end of namespace nm::math
266
+
267
+ #endif // GEMM_H
@@ -0,0 +1,208 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2013, Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gemv.h
25
+ //
26
+ // Header file for interface with ATLAS's CBLAS gemv functions and
27
+ // native templated version of LAPACK's gemv function.
28
+ //
29
+
30
+ #ifndef GEMV_H
31
+ # define GEMV_H
32
+
33
+ extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
34
+ #include <cblas.h>
35
+ }
36
+
37
+
38
+ namespace nm { namespace math {
39
+
40
+ /*
41
+ * GEneral Matrix-Vector multiplication: based on dgemv.f from Netlib.
42
+ *
43
+ * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
44
+ *
45
+ * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
46
+ */
47
+ template <typename DType>
48
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const DType* alpha, const DType* A, const int lda,
49
+ const DType* X, const int incX, const DType* beta, DType* Y, const int incY) {
50
+ int lenX, lenY, i, j;
51
+ int kx, ky, iy, jx, jy, ix;
52
+
53
+ typename LongDType<DType>::type temp;
54
+
55
+ // Test the input parameters
56
+ if (Trans < 111 || Trans > 113) {
57
+ rb_raise(rb_eArgError, "GEMV: TransA must be CblasNoTrans, CblasTrans, or CblasConjTrans");
58
+ return false;
59
+ } else if (lda < std::max(1, N)) {
60
+ fprintf(stderr, "GEMV: N = %d; got lda=%d", N, lda);
61
+ rb_raise(rb_eArgError, "GEMV: Expected lda >= max(1, N)");
62
+ return false;
63
+ } else if (incX == 0) {
64
+ rb_raise(rb_eArgError, "GEMV: Expected incX != 0\n");
65
+ return false;
66
+ } else if (incY == 0) {
67
+ rb_raise(rb_eArgError, "GEMV: Expected incY != 0\n");
68
+ return false;
69
+ }
70
+
71
+ // Quick return if possible
72
+ if (!M or !N or (*alpha == 0 and *beta == 1)) return true;
73
+
74
+ if (Trans == CblasNoTrans) {
75
+ lenX = N;
76
+ lenY = M;
77
+ } else {
78
+ lenX = M;
79
+ lenY = N;
80
+ }
81
+
82
+ if (incX > 0) kx = 0;
83
+ else kx = (lenX - 1) * -incX;
84
+
85
+ if (incY > 0) ky = 0;
86
+ else ky = (lenY - 1) * -incY;
87
+
88
+ // Start the operations. In this version, the elements of A are accessed sequentially with one pass through A.
89
+ if (*beta != 1) {
90
+ if (incY == 1) {
91
+ if (*beta == 0) {
92
+ for (i = 0; i < lenY; ++i) {
93
+ Y[i] = 0;
94
+ }
95
+ } else {
96
+ for (i = 0; i < lenY; ++i) {
97
+ Y[i] *= *beta;
98
+ }
99
+ }
100
+ } else {
101
+ iy = ky;
102
+ if (*beta == 0) {
103
+ for (i = 0; i < lenY; ++i) {
104
+ Y[iy] = 0;
105
+ iy += incY;
106
+ }
107
+ } else {
108
+ for (i = 0; i < lenY; ++i) {
109
+ Y[iy] *= *beta;
110
+ iy += incY;
111
+ }
112
+ }
113
+ }
114
+ }
115
+
116
+ if (*alpha == 0) return false;
117
+
118
+ if (Trans == CblasNoTrans) {
119
+
120
+ // Form y := alpha*A*x + y.
121
+ jx = kx;
122
+ if (incY == 1) {
123
+ for (j = 0; j < N; ++j) {
124
+ if (X[jx] != 0) {
125
+ temp = *alpha * X[jx];
126
+ for (i = 0; i < M; ++i) {
127
+ Y[i] += A[j+i*lda] * temp;
128
+ }
129
+ }
130
+ jx += incX;
131
+ }
132
+ } else {
133
+ for (j = 0; j < N; ++j) {
134
+ if (X[jx] != 0) {
135
+ temp = *alpha * X[jx];
136
+ iy = ky;
137
+ for (i = 0; i < M; ++i) {
138
+ Y[iy] += A[j+i*lda] * temp;
139
+ iy += incY;
140
+ }
141
+ }
142
+ jx += incX;
143
+ }
144
+ }
145
+
146
+ } else { // TODO: Check that indices are correct! They're switched for C.
147
+
148
+ // Form y := alpha*A**DType*x + y.
149
+ jy = ky;
150
+
151
+ if (incX == 1) {
152
+ for (j = 0; j < N; ++j) {
153
+ temp = 0;
154
+ for (i = 0; i < M; ++i) {
155
+ temp += A[j+i*lda]*X[j];
156
+ }
157
+ Y[jy] += *alpha * temp;
158
+ jy += incY;
159
+ }
160
+ } else {
161
+ for (j = 0; j < N; ++j) {
162
+ temp = 0;
163
+ ix = kx;
164
+ for (i = 0; i < M; ++i) {
165
+ temp += A[j+i*lda] * X[ix];
166
+ ix += incX;
167
+ }
168
+
169
+ Y[jy] += *alpha * temp;
170
+ jy += incY;
171
+ }
172
+ }
173
+ }
174
+
175
+ return true;
176
+ } // end of GEMV
177
+
178
+ template <>
179
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const float* alpha, const float* A, const int lda,
180
+ const float* X, const int incX, const float* beta, float* Y, const int incY) {
181
+ cblas_sgemv(CblasRowMajor, Trans, M, N, *alpha, A, lda, X, incX, *beta, Y, incY);
182
+ return true;
183
+ }
184
+
185
+ template <>
186
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const double* alpha, const double* A, const int lda,
187
+ const double* X, const int incX, const double* beta, double* Y, const int incY) {
188
+ cblas_dgemv(CblasRowMajor, Trans, M, N, *alpha, A, lda, X, incX, *beta, Y, incY);
189
+ return true;
190
+ }
191
+
192
+ template <>
193
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const Complex64* alpha, const Complex64* A, const int lda,
194
+ const Complex64* X, const int incX, const Complex64* beta, Complex64* Y, const int incY) {
195
+ cblas_cgemv(CblasRowMajor, Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
196
+ return true;
197
+ }
198
+
199
+ template <>
200
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const Complex128* alpha, const Complex128* A, const int lda,
201
+ const Complex128* X, const int incX, const Complex128* beta, Complex128* Y, const int incY) {
202
+ cblas_zgemv(CblasRowMajor, Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
203
+ return true;
204
+ }
205
+
206
+ }} // end of namespace nm::math
207
+
208
+ #endif // GEMM_H