nmatrix-fftw 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. checksums.yaml +7 -0
  2. data/ext/nmatrix/data/complex.h +388 -0
  3. data/ext/nmatrix/data/data.h +652 -0
  4. data/ext/nmatrix/data/meta.h +64 -0
  5. data/ext/nmatrix/data/ruby_object.h +389 -0
  6. data/ext/nmatrix/math/asum.h +120 -0
  7. data/ext/nmatrix/math/cblas_enums.h +36 -0
  8. data/ext/nmatrix/math/cblas_templates_core.h +507 -0
  9. data/ext/nmatrix/math/gemm.h +241 -0
  10. data/ext/nmatrix/math/gemv.h +178 -0
  11. data/ext/nmatrix/math/getrf.h +255 -0
  12. data/ext/nmatrix/math/getrs.h +121 -0
  13. data/ext/nmatrix/math/imax.h +79 -0
  14. data/ext/nmatrix/math/laswp.h +165 -0
  15. data/ext/nmatrix/math/long_dtype.h +49 -0
  16. data/ext/nmatrix/math/math.h +745 -0
  17. data/ext/nmatrix/math/nrm2.h +160 -0
  18. data/ext/nmatrix/math/rot.h +117 -0
  19. data/ext/nmatrix/math/rotg.h +106 -0
  20. data/ext/nmatrix/math/scal.h +71 -0
  21. data/ext/nmatrix/math/trsm.h +332 -0
  22. data/ext/nmatrix/math/util.h +148 -0
  23. data/ext/nmatrix/nm_memory.h +60 -0
  24. data/ext/nmatrix/nmatrix.h +438 -0
  25. data/ext/nmatrix/ruby_constants.h +106 -0
  26. data/ext/nmatrix/storage/common.h +177 -0
  27. data/ext/nmatrix/storage/dense/dense.h +129 -0
  28. data/ext/nmatrix/storage/list/list.h +138 -0
  29. data/ext/nmatrix/storage/storage.h +99 -0
  30. data/ext/nmatrix/storage/yale/class.h +1139 -0
  31. data/ext/nmatrix/storage/yale/iterators/base.h +143 -0
  32. data/ext/nmatrix/storage/yale/iterators/iterator.h +131 -0
  33. data/ext/nmatrix/storage/yale/iterators/row.h +450 -0
  34. data/ext/nmatrix/storage/yale/iterators/row_stored.h +140 -0
  35. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +169 -0
  36. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +124 -0
  37. data/ext/nmatrix/storage/yale/math/transpose.h +110 -0
  38. data/ext/nmatrix/storage/yale/yale.h +203 -0
  39. data/ext/nmatrix/types.h +55 -0
  40. data/ext/nmatrix/util/io.h +115 -0
  41. data/ext/nmatrix/util/sl_list.h +144 -0
  42. data/ext/nmatrix/util/util.h +78 -0
  43. data/ext/nmatrix_fftw/extconf.rb +122 -0
  44. data/ext/nmatrix_fftw/nmatrix_fftw.cpp +274 -0
  45. data/lib/nmatrix/fftw.rb +343 -0
  46. data/spec/00_nmatrix_spec.rb +736 -0
  47. data/spec/01_enum_spec.rb +190 -0
  48. data/spec/02_slice_spec.rb +389 -0
  49. data/spec/03_nmatrix_monkeys_spec.rb +78 -0
  50. data/spec/2x2_dense_double.mat +0 -0
  51. data/spec/4x4_sparse.mat +0 -0
  52. data/spec/4x5_dense.mat +0 -0
  53. data/spec/blas_spec.rb +193 -0
  54. data/spec/elementwise_spec.rb +303 -0
  55. data/spec/homogeneous_spec.rb +99 -0
  56. data/spec/io/fortran_format_spec.rb +88 -0
  57. data/spec/io/harwell_boeing_spec.rb +98 -0
  58. data/spec/io/test.rua +9 -0
  59. data/spec/io_spec.rb +149 -0
  60. data/spec/lapack_core_spec.rb +482 -0
  61. data/spec/leakcheck.rb +16 -0
  62. data/spec/math_spec.rb +807 -0
  63. data/spec/nmatrix_yale_resize_test_associations.yaml +2802 -0
  64. data/spec/nmatrix_yale_spec.rb +286 -0
  65. data/spec/plugins/fftw/fftw_spec.rb +348 -0
  66. data/spec/rspec_monkeys.rb +56 -0
  67. data/spec/rspec_spec.rb +34 -0
  68. data/spec/shortcuts_spec.rb +310 -0
  69. data/spec/slice_set_spec.rb +157 -0
  70. data/spec/spec_helper.rb +149 -0
  71. data/spec/stat_spec.rb +203 -0
  72. data/spec/test.pcd +20 -0
  73. data/spec/utm5940.mtx +83844 -0
  74. metadata +151 -0
@@ -0,0 +1,241 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gemm.h
25
+ //
26
+ // Header file for interface with ATLAS's CBLAS gemm functions and
27
+ // native templated version of LAPACK's gemm function.
28
+ //
29
+
30
+ #ifndef GEMM_H
31
+ # define GEMM_H
32
+
33
+ #include "cblas_enums.h"
34
+ #include "math/long_dtype.h"
35
+
36
+ namespace nm { namespace math {
37
+ /*
38
+ * GEneral Matrix Multiplication: based on dgemm.f from Netlib.
39
+ *
40
+ * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
41
+ *
42
+ * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
43
+ *
44
+ * This version throws no errors. Use gemm<DType> instead for error checking.
45
+ */
46
+ template <typename DType>
47
+ inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
48
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
49
+ {
50
+
51
+ typename LongDType<DType>::type temp;
52
+
53
+ // Quick return if possible
54
+ if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;
55
+
56
+ // For alpha = 0
57
+ if (*alpha == 0) {
58
+ if (*beta == 0) {
59
+ for (int j = 0; j < N; ++j)
60
+ for (int i = 0; i < M; ++i) {
61
+ C[i+j*ldc] = 0;
62
+ }
63
+ } else {
64
+ for (int j = 0; j < N; ++j)
65
+ for (int i = 0; i < M; ++i) {
66
+ C[i+j*ldc] *= *beta;
67
+ }
68
+ }
69
+ return;
70
+ }
71
+
72
+ // Start the operations
73
+ if (TransB == CblasNoTrans) {
74
+ if (TransA == CblasNoTrans) {
75
+ // C = alpha*A*B+beta*C
76
+ for (int j = 0; j < N; ++j) {
77
+ if (*beta == 0) {
78
+ for (int i = 0; i < M; ++i) {
79
+ C[i+j*ldc] = 0;
80
+ }
81
+ } else if (*beta != 1) {
82
+ for (int i = 0; i < M; ++i) {
83
+ C[i+j*ldc] *= *beta;
84
+ }
85
+ }
86
+
87
+ for (int l = 0; l < K; ++l) {
88
+ if (B[l+j*ldb] != 0) {
89
+ temp = *alpha * B[l+j*ldb];
90
+ for (int i = 0; i < M; ++i) {
91
+ C[i+j*ldc] += A[i+l*lda] * temp;
92
+ }
93
+ }
94
+ }
95
+ }
96
+
97
+ } else {
98
+
99
+ // C = alpha*A**DType*B + beta*C
100
+ for (int j = 0; j < N; ++j) {
101
+ for (int i = 0; i < M; ++i) {
102
+ temp = 0;
103
+ for (int l = 0; l < K; ++l) {
104
+ temp += A[l+i*lda] * B[l+j*ldb];
105
+ }
106
+
107
+ if (*beta == 0) {
108
+ C[i+j*ldc] = *alpha*temp;
109
+ } else {
110
+ C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
111
+ }
112
+ }
113
+ }
114
+
115
+ }
116
+
117
+ } else if (TransA == CblasNoTrans) {
118
+
119
+ // C = alpha*A*B**T + beta*C
120
+ for (int j = 0; j < N; ++j) {
121
+ if (*beta == 0) {
122
+ for (int i = 0; i < M; ++i) {
123
+ C[i+j*ldc] = 0;
124
+ }
125
+ } else if (*beta != 1) {
126
+ for (int i = 0; i < M; ++i) {
127
+ C[i+j*ldc] *= *beta;
128
+ }
129
+ }
130
+
131
+ for (int l = 0; l < K; ++l) {
132
+ if (B[j+l*ldb] != 0) {
133
+ temp = *alpha * B[j+l*ldb];
134
+ for (int i = 0; i < M; ++i) {
135
+ C[i+j*ldc] += A[i+l*lda] * temp;
136
+ }
137
+ }
138
+ }
139
+
140
+ }
141
+
142
+ } else {
143
+
144
+ // C = alpha*A**DType*B**T + beta*C
145
+ for (int j = 0; j < N; ++j) {
146
+ for (int i = 0; i < M; ++i) {
147
+ temp = 0;
148
+ for (int l = 0; l < K; ++l) {
149
+ temp += A[l+i*lda] * B[j+l*ldb];
150
+ }
151
+
152
+ if (*beta == 0) {
153
+ C[i+j*ldc] = *alpha*temp;
154
+ } else {
155
+ C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
156
+ }
157
+ }
158
+ }
159
+
160
+ }
161
+
162
+ return;
163
+ }
164
+
165
+
166
+
167
+ template <typename DType>
168
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
169
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
170
+ {
171
+ if (Order == CblasRowMajor) {
172
+ if (TransA == CblasNoTrans) {
173
+ if (lda < std::max(K,1)) {
174
+ rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
175
+ }
176
+ } else {
177
+ if (lda < std::max(M,1)) { // && TransA == CblasTrans
178
+ rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
179
+ }
180
+ }
181
+
182
+ if (TransB == CblasNoTrans) {
183
+ if (ldb < std::max(N,1)) {
184
+ rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
185
+ }
186
+ } else {
187
+ if (ldb < std::max(K,1)) {
188
+ rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
189
+ }
190
+ }
191
+
192
+ if (ldc < std::max(N,1)) {
193
+ rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
194
+ }
195
+ } else { // CblasColMajor
196
+ if (TransA == CblasNoTrans) {
197
+ if (lda < std::max(M,1)) {
198
+ rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
199
+ }
200
+ } else {
201
+ if (lda < std::max(K,1)) { // && TransA == CblasTrans
202
+ rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
203
+ }
204
+ }
205
+
206
+ if (TransB == CblasNoTrans) {
207
+ if (ldb < std::max(K,1)) {
208
+ rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
209
+ }
210
+ } else {
211
+ if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
212
+ rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
213
+ }
214
+ }
215
+
216
+ if (ldc < std::max(M,1)) {
217
+ rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
218
+ }
219
+ }
220
+
221
+ /*
222
+ * Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
223
+ * we copy C and then subtract to preserve asymmetry.
224
+ */
225
+
226
+ if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
227
+ rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
228
+ /*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
229
+ syreflect(CblasUpper, N, C, ldc);
230
+ */
231
+ }
232
+
233
+ if (Order == CblasRowMajor) gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
234
+ else gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
235
+
236
+ }
237
+
238
+
239
+ }} // end of namespace nm::math
240
+
241
+ #endif // GEMM_H
@@ -0,0 +1,178 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gemv.h
25
+ //
26
+ // Header file for interface with ATLAS's CBLAS gemv functions and
27
+ // native templated version of LAPACK's gemv function.
28
+ //
29
+
30
+ #ifndef GEMV_H
31
+ # define GEMV_H
32
+
33
+ #include "math/long_dtype.h"
34
+
35
+ namespace nm { namespace math {
36
+
37
+ /*
38
+ * GEneral Matrix-Vector multiplication: based on dgemv.f from Netlib.
39
+ *
40
+ * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
41
+ *
42
+ * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
43
+ */
44
+ template <typename DType>
45
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const DType* alpha, const DType* A, const int lda,
46
+ const DType* X, const int incX, const DType* beta, DType* Y, const int incY) {
47
+ int lenX, lenY, i, j;
48
+ int kx, ky, iy, jx, jy, ix;
49
+
50
+ typename LongDType<DType>::type temp;
51
+
52
+ // Test the input parameters
53
+ if (Trans < 111 || Trans > 113) {
54
+ rb_raise(rb_eArgError, "GEMV: TransA must be CblasNoTrans, CblasTrans, or CblasConjTrans");
55
+ return false;
56
+ } else if (lda < std::max(1, N)) {
57
+ fprintf(stderr, "GEMV: N = %d; got lda=%d", N, lda);
58
+ rb_raise(rb_eArgError, "GEMV: Expected lda >= max(1, N)");
59
+ return false;
60
+ } else if (incX == 0) {
61
+ rb_raise(rb_eArgError, "GEMV: Expected incX != 0\n");
62
+ return false;
63
+ } else if (incY == 0) {
64
+ rb_raise(rb_eArgError, "GEMV: Expected incY != 0\n");
65
+ return false;
66
+ }
67
+
68
+ // Quick return if possible
69
+ if (!M or !N or (*alpha == 0 and *beta == 1)) return true;
70
+
71
+ if (Trans == CblasNoTrans) {
72
+ lenX = N;
73
+ lenY = M;
74
+ } else {
75
+ lenX = M;
76
+ lenY = N;
77
+ }
78
+
79
+ if (incX > 0) kx = 0;
80
+ else kx = (lenX - 1) * -incX;
81
+
82
+ if (incY > 0) ky = 0;
83
+ else ky = (lenY - 1) * -incY;
84
+
85
+ // Start the operations. In this version, the elements of A are accessed sequentially with one pass through A.
86
+ if (*beta != 1) {
87
+ if (incY == 1) {
88
+ if (*beta == 0) {
89
+ for (i = 0; i < lenY; ++i) {
90
+ Y[i] = 0;
91
+ }
92
+ } else {
93
+ for (i = 0; i < lenY; ++i) {
94
+ Y[i] *= *beta;
95
+ }
96
+ }
97
+ } else {
98
+ iy = ky;
99
+ if (*beta == 0) {
100
+ for (i = 0; i < lenY; ++i) {
101
+ Y[iy] = 0;
102
+ iy += incY;
103
+ }
104
+ } else {
105
+ for (i = 0; i < lenY; ++i) {
106
+ Y[iy] *= *beta;
107
+ iy += incY;
108
+ }
109
+ }
110
+ }
111
+ }
112
+
113
+ if (*alpha == 0) return false;
114
+
115
+ if (Trans == CblasNoTrans) {
116
+
117
+ // Form y := alpha*A*x + y.
118
+ jx = kx;
119
+ if (incY == 1) {
120
+ for (j = 0; j < N; ++j) {
121
+ if (X[jx] != 0) {
122
+ temp = *alpha * X[jx];
123
+ for (i = 0; i < M; ++i) {
124
+ Y[i] += A[j+i*lda] * temp;
125
+ }
126
+ }
127
+ jx += incX;
128
+ }
129
+ } else {
130
+ for (j = 0; j < N; ++j) {
131
+ if (X[jx] != 0) {
132
+ temp = *alpha * X[jx];
133
+ iy = ky;
134
+ for (i = 0; i < M; ++i) {
135
+ Y[iy] += A[j+i*lda] * temp;
136
+ iy += incY;
137
+ }
138
+ }
139
+ jx += incX;
140
+ }
141
+ }
142
+
143
+ } else { // TODO: Check that indices are correct! They're switched for C.
144
+
145
+ // Form y := alpha*A**DType*x + y.
146
+ jy = ky;
147
+
148
+ if (incX == 1) {
149
+ for (j = 0; j < N; ++j) {
150
+ temp = 0;
151
+ for (i = 0; i < M; ++i) {
152
+ temp += A[j+i*lda]*X[j];
153
+ }
154
+ Y[jy] += *alpha * temp;
155
+ jy += incY;
156
+ }
157
+ } else {
158
+ for (j = 0; j < N; ++j) {
159
+ temp = 0;
160
+ ix = kx;
161
+ for (i = 0; i < M; ++i) {
162
+ temp += A[j+i*lda] * X[ix];
163
+ ix += incX;
164
+ }
165
+
166
+ Y[jy] += *alpha * temp;
167
+ jy += incY;
168
+ }
169
+ }
170
+ }
171
+
172
+ return true;
173
+ } // end of GEMV
174
+
175
+
176
+ }} // end of namespace nm::math
177
+
178
+ #endif // GEMM_H
@@ -0,0 +1,255 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == getrf.h
25
+ //
26
+ // getrf function in native C++.
27
+ //
28
+
29
+ /*
30
+ * Automatically Tuned Linear Algebra Software v3.8.4
31
+ * (C) Copyright 1999 R. Clint Whaley
32
+ *
33
+ * Redistribution and use in source and binary forms, with or without
34
+ * modification, are permitted provided that the following conditions
35
+ * are met:
36
+ * 1. Redistributions of source code must retain the above copyright
37
+ * notice, this list of conditions and the following disclaimer.
38
+ * 2. Redistributions in binary form must reproduce the above copyright
39
+ * notice, this list of conditions, and the following disclaimer in the
40
+ * documentation and/or other materials provided with the distribution.
41
+ * 3. The name of the ATLAS group or the names of its contributers may
42
+ * not be used to endorse or promote products derived from this
43
+ * software without specific written permission.
44
+ *
45
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
46
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
47
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
48
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
49
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
50
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
51
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
52
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
53
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
54
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
55
+ * POSSIBILITY OF SUCH DAMAGE.
56
+ *
57
+ */
58
+
59
+ #ifndef GETRF_H
60
+ #define GETRF_H
61
+
62
+ #include "math/laswp.h"
63
+ #include "math/math.h"
64
+ #include "math/trsm.h"
65
+ #include "math/gemm.h"
66
+ #include "math/imax.h"
67
+ #include "math/scal.h"
68
+
69
+ namespace nm { namespace math {
70
+
71
+ /* Numeric inverse -- usually just 1 / f, but a little more complicated for complex. */
72
+ template <typename DType>
73
+ inline DType numeric_inverse(const DType& n) {
74
+ return n.inverse();
75
+ }
76
+ template <> inline float numeric_inverse(const float& n) { return 1 / n; }
77
+ template <> inline double numeric_inverse(const double& n) { return 1 / n; }
78
+
79
+ /*
80
+ * Templated version of row-order and column-order getrf, derived from ATL_getrfR.c (from ATLAS 3.8.0).
81
+ *
82
+ * 1. Row-major factorization of form
83
+ * A = L * U * P
84
+ * where P is a column-permutation matrix, L is lower triangular (lower
85
+ * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
86
+ * trapazoidal if M < N). This is the recursive Level 3 BLAS version.
87
+ *
88
+ * 2. Column-major factorization of form
89
+ * A = P * L * U
90
+ * where P is a row-permutation matrix, L is lower triangular with unit diagonal
91
+ * elements (lower trapazoidal if M > N), and U is upper triangular (upper
92
+ * trapazoidal if M < N). This is the recursive Level 3 BLAS version.
93
+ *
94
+ * Template argument determines whether 1 or 2 is utilized.
95
+ */
96
+ template <bool RowMajor, typename DType>
97
+ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int* ipiv) {
98
+ const int MN = std::min(M, N);
99
+ int ierr = 0;
100
+
101
+ // Symbols used by ATLAS in the several versions of this function:
102
+ // Row Col Us
103
+ // Nup Nleft N_ul
104
+ // Ndown Nright N_dr
105
+ // We're going to use N_ul, N_dr
106
+
107
+ DType neg_one = -1, one = 1;
108
+
109
+ if (MN > 1) {
110
+ int N_ul = MN >> 1;
111
+
112
+ // FIXME: Figure out how ATLAS #defines NB
113
+ #ifdef NB
114
+ if (N_ul > NB) N_ul = ATL_MulByNB(ATL_DivByNB(N_ul));
115
+ #endif
116
+
117
+ int N_dr;
118
+ if (RowMajor) {
119
+ N_dr = M - N_ul;
120
+ } else {
121
+ N_dr = N - N_ul;
122
+ }
123
+
124
+ int i = RowMajor ? getrf_nothrow<true,DType>(N_ul, N, A, lda, ipiv) : getrf_nothrow<false,DType>(M, N_ul, A, lda, ipiv);
125
+
126
+ if (i) if (!ierr) ierr = i;
127
+
128
+ DType *Ar, *Ac, *An;
129
+ if (RowMajor) {
130
+ Ar = &(A[N_ul * lda]),
131
+ Ac = &(A[N_ul]);
132
+ An = &(Ar[N_ul]);
133
+
134
+ nm::math::laswp<DType>(N_dr, Ar, lda, 0, N_ul, ipiv, 1);
135
+
136
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, N_dr, N_ul, one, A, lda, Ar, lda);
137
+ nm::math::gemm<DType>(CblasRowMajor, CblasNoTrans, CblasNoTrans, N_dr, N-N_ul, N_ul, &neg_one, Ar, lda, Ac, lda, &one, An, lda);
138
+
139
+ i = getrf_nothrow<true,DType>(N_dr, N-N_ul, An, lda, ipiv+N_ul);
140
+ } else {
141
+ Ar = NULL;
142
+ Ac = &(A[N_ul * lda]);
143
+ An = &(Ac[N_ul]);
144
+
145
+ nm::math::laswp<DType>(N_dr, Ac, lda, 0, N_ul, ipiv, 1);
146
+
147
+ nm::math::trsm<DType>(CblasColMajor, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N_ul, N_dr, one, A, lda, Ac, lda);
148
+ nm::math::gemm<DType>(CblasColMajor, CblasNoTrans, CblasNoTrans, M-N_ul, N_dr, N_ul, &neg_one, &(A[N_ul]), lda, Ac, lda, &one, An, lda);
149
+
150
+ i = getrf_nothrow<false,DType>(M-N_ul, N_dr, An, lda, ipiv+N_ul);
151
+ }
152
+
153
+ if (i) if (!ierr) ierr = N_ul + i;
154
+
155
+ for (i = N_ul; i != MN; i++) {
156
+ ipiv[i] += N_ul;
157
+ }
158
+
159
+ nm::math::laswp<DType>(N_ul, A, lda, N_ul, MN, ipiv, 1); /* apply pivots */
160
+
161
+ } else if (MN == 1) { // there's another case for the colmajor version, but it doesn't seem to be necessary.
162
+
163
+ int i;
164
+ if (RowMajor) {
165
+ i = *ipiv = nm::math::imax<DType>(N, A, 1); // cblas_iamax(N, A, 1);
166
+ } else {
167
+ i = *ipiv = nm::math::imax<DType>(M, A, 1);
168
+ }
169
+
170
+ DType tmp = A[i];
171
+ if (tmp != 0) {
172
+
173
+ nm::math::scal<DType>((RowMajor ? N : M), nm::math::numeric_inverse(tmp), A, 1);
174
+ A[i] = *A;
175
+ *A = tmp;
176
+
177
+ } else ierr = 1;
178
+
179
+ }
180
+ return(ierr);
181
+ }
182
+
183
+
184
+ /*
185
+ * From ATLAS 3.8.0:
186
+ *
187
+ * Computes one of two LU factorizations based on the setting of the Order
188
+ * parameter, as follows:
189
+ * ----------------------------------------------------------------------------
190
+ * Order == CblasColMajor
191
+ * Column-major factorization of form
192
+ * A = P * L * U
193
+ * where P is a row-permutation matrix, L is lower triangular with unit
194
+ * diagonal elements (lower trapazoidal if M > N), and U is upper triangular
195
+ * (upper trapazoidal if M < N).
196
+ *
197
+ * ----------------------------------------------------------------------------
198
+ * Order == CblasRowMajor
199
+ * Row-major factorization of form
200
+ * A = P * L * U
201
+ * where P is a column-permutation matrix, L is lower triangular (lower
202
+ * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
203
+ * trapazoidal if M < N).
204
+ *
205
+ * ============================================================================
206
+ * Let IERR be the return value of the function:
207
+ * If IERR == 0, successful exit.
208
+ * If (IERR < 0) the -IERR argument had an illegal value
209
+ * If (IERR > 0 && Order == CblasColMajor)
210
+ * U(i-1,i-1) is exactly zero. The factorization has been completed,
211
+ * but the factor U is exactly singular, and division by zero will
212
+ * occur if it is used to solve a system of equations.
213
+ * If (IERR > 0 && Order == CblasRowMajor)
214
+ * L(i-1,i-1) is exactly zero. The factorization has been completed,
215
+ * but the factor L is exactly singular, and division by zero will
216
+ * occur if it is used to solve a system of equations.
217
+ */
218
+ template <typename DType>
219
+ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, DType* A, int lda, int* ipiv) {
220
+ if (Order == CblasRowMajor) {
221
+ if (lda < std::max(1,N)) {
222
+ rb_raise(rb_eArgError, "GETRF: lda must be >= MAX(N,1): lda=%d N=%d", lda, N);
223
+ return -6;
224
+ }
225
+
226
+ return getrf_nothrow<true,DType>(M, N, A, lda, ipiv);
227
+ } else {
228
+ if (lda < std::max(1,M)) {
229
+ rb_raise(rb_eArgError, "GETRF: lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
230
+ return -6;
231
+ }
232
+
233
+ return getrf_nothrow<false,DType>(M, N, A, lda, ipiv);
234
+ //rb_raise(rb_eNotImpError, "column major getrf not implemented");
235
+ }
236
+ }
237
+
238
+
239
+
240
+ /*
241
+ * Function signature conversion for calling LAPACK's getrf functions as directly as possible.
242
+ *
243
+ * For documentation: http://www.netlib.org/lapack/double/dgetrf.f
244
+ *
245
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
246
+ */
247
+ template <typename DType>
248
+ inline int clapack_getrf(const enum CBLAS_ORDER order, const int m, const int n, void* a, const int lda, int* ipiv) {
249
+ return getrf<DType>(order, m, n, reinterpret_cast<DType*>(a), lda, ipiv);
250
+ }
251
+
252
+
253
+ } } // end nm::math
254
+
255
+ #endif