nmatrix-atlas 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. checksums.yaml +7 -0
  2. data/ext/nmatrix/data/complex.h +364 -0
  3. data/ext/nmatrix/data/data.h +638 -0
  4. data/ext/nmatrix/data/meta.h +64 -0
  5. data/ext/nmatrix/data/ruby_object.h +389 -0
  6. data/ext/nmatrix/math/asum.h +120 -0
  7. data/ext/nmatrix/math/cblas_enums.h +36 -0
  8. data/ext/nmatrix/math/cblas_templates_core.h +507 -0
  9. data/ext/nmatrix/math/gemm.h +241 -0
  10. data/ext/nmatrix/math/gemv.h +178 -0
  11. data/ext/nmatrix/math/getrf.h +255 -0
  12. data/ext/nmatrix/math/getrs.h +121 -0
  13. data/ext/nmatrix/math/imax.h +79 -0
  14. data/ext/nmatrix/math/laswp.h +165 -0
  15. data/ext/nmatrix/math/long_dtype.h +49 -0
  16. data/ext/nmatrix/math/math.h +744 -0
  17. data/ext/nmatrix/math/nrm2.h +160 -0
  18. data/ext/nmatrix/math/rot.h +117 -0
  19. data/ext/nmatrix/math/rotg.h +106 -0
  20. data/ext/nmatrix/math/scal.h +71 -0
  21. data/ext/nmatrix/math/trsm.h +332 -0
  22. data/ext/nmatrix/math/util.h +148 -0
  23. data/ext/nmatrix/nm_memory.h +60 -0
  24. data/ext/nmatrix/nmatrix.h +408 -0
  25. data/ext/nmatrix/ruby_constants.h +106 -0
  26. data/ext/nmatrix/storage/common.h +176 -0
  27. data/ext/nmatrix/storage/dense/dense.h +128 -0
  28. data/ext/nmatrix/storage/list/list.h +137 -0
  29. data/ext/nmatrix/storage/storage.h +98 -0
  30. data/ext/nmatrix/storage/yale/class.h +1139 -0
  31. data/ext/nmatrix/storage/yale/iterators/base.h +142 -0
  32. data/ext/nmatrix/storage/yale/iterators/iterator.h +130 -0
  33. data/ext/nmatrix/storage/yale/iterators/row.h +449 -0
  34. data/ext/nmatrix/storage/yale/iterators/row_stored.h +139 -0
  35. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +168 -0
  36. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +123 -0
  37. data/ext/nmatrix/storage/yale/math/transpose.h +110 -0
  38. data/ext/nmatrix/storage/yale/yale.h +202 -0
  39. data/ext/nmatrix/types.h +54 -0
  40. data/ext/nmatrix/util/io.h +115 -0
  41. data/ext/nmatrix/util/sl_list.h +143 -0
  42. data/ext/nmatrix/util/util.h +78 -0
  43. data/ext/nmatrix_atlas/extconf.rb +250 -0
  44. data/ext/nmatrix_atlas/math_atlas.cpp +1206 -0
  45. data/ext/nmatrix_atlas/math_atlas/cblas_templates_atlas.h +72 -0
  46. data/ext/nmatrix_atlas/math_atlas/clapack_templates.h +332 -0
  47. data/ext/nmatrix_atlas/math_atlas/geev.h +82 -0
  48. data/ext/nmatrix_atlas/math_atlas/gesdd.h +83 -0
  49. data/ext/nmatrix_atlas/math_atlas/gesvd.h +81 -0
  50. data/ext/nmatrix_atlas/math_atlas/inc.h +47 -0
  51. data/ext/nmatrix_atlas/nmatrix_atlas.cpp +44 -0
  52. data/lib/nmatrix/atlas.rb +213 -0
  53. data/lib/nmatrix/lapack_ext_common.rb +69 -0
  54. data/spec/00_nmatrix_spec.rb +730 -0
  55. data/spec/01_enum_spec.rb +190 -0
  56. data/spec/02_slice_spec.rb +389 -0
  57. data/spec/03_nmatrix_monkeys_spec.rb +78 -0
  58. data/spec/2x2_dense_double.mat +0 -0
  59. data/spec/4x4_sparse.mat +0 -0
  60. data/spec/4x5_dense.mat +0 -0
  61. data/spec/blas_spec.rb +193 -0
  62. data/spec/elementwise_spec.rb +303 -0
  63. data/spec/homogeneous_spec.rb +99 -0
  64. data/spec/io/fortran_format_spec.rb +88 -0
  65. data/spec/io/harwell_boeing_spec.rb +98 -0
  66. data/spec/io/test.rua +9 -0
  67. data/spec/io_spec.rb +149 -0
  68. data/spec/lapack_core_spec.rb +482 -0
  69. data/spec/leakcheck.rb +16 -0
  70. data/spec/math_spec.rb +730 -0
  71. data/spec/nmatrix_yale_resize_test_associations.yaml +2802 -0
  72. data/spec/nmatrix_yale_spec.rb +286 -0
  73. data/spec/plugins/atlas/atlas_spec.rb +242 -0
  74. data/spec/rspec_monkeys.rb +56 -0
  75. data/spec/rspec_spec.rb +34 -0
  76. data/spec/shortcuts_spec.rb +310 -0
  77. data/spec/slice_set_spec.rb +157 -0
  78. data/spec/spec_helper.rb +140 -0
  79. data/spec/stat_spec.rb +203 -0
  80. data/spec/test.pcd +20 -0
  81. data/spec/utm5940.mtx +83844 -0
  82. metadata +159 -0
@@ -0,0 +1,241 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gemm.h
25
+ //
26
+ // Header file for interface with ATLAS's CBLAS gemm functions and
27
+ // native templated version of LAPACK's gemm function.
28
+ //
29
+
30
+ #ifndef GEMM_H
31
+ # define GEMM_H
32
+
33
+ #include "cblas_enums.h"
34
+ #include "math/long_dtype.h"
35
+
36
+ namespace nm { namespace math {
37
+ /*
38
+ * GEneral Matrix Multiplication: based on dgemm.f from Netlib.
39
+ *
40
+ * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
41
+ *
42
+ * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
43
+ *
44
+ * This version throws no errors. Use gemm<DType> instead for error checking.
45
+ */
46
+ template <typename DType>
47
+ inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
48
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
49
+ {
50
+
51
+ typename LongDType<DType>::type temp;
52
+
53
+ // Quick return if possible
54
+ if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;
55
+
56
+ // For alpha = 0
57
+ if (*alpha == 0) {
58
+ if (*beta == 0) {
59
+ for (int j = 0; j < N; ++j)
60
+ for (int i = 0; i < M; ++i) {
61
+ C[i+j*ldc] = 0;
62
+ }
63
+ } else {
64
+ for (int j = 0; j < N; ++j)
65
+ for (int i = 0; i < M; ++i) {
66
+ C[i+j*ldc] *= *beta;
67
+ }
68
+ }
69
+ return;
70
+ }
71
+
72
+ // Start the operations
73
+ if (TransB == CblasNoTrans) {
74
+ if (TransA == CblasNoTrans) {
75
+ // C = alpha*A*B+beta*C
76
+ for (int j = 0; j < N; ++j) {
77
+ if (*beta == 0) {
78
+ for (int i = 0; i < M; ++i) {
79
+ C[i+j*ldc] = 0;
80
+ }
81
+ } else if (*beta != 1) {
82
+ for (int i = 0; i < M; ++i) {
83
+ C[i+j*ldc] *= *beta;
84
+ }
85
+ }
86
+
87
+ for (int l = 0; l < K; ++l) {
88
+ if (B[l+j*ldb] != 0) {
89
+ temp = *alpha * B[l+j*ldb];
90
+ for (int i = 0; i < M; ++i) {
91
+ C[i+j*ldc] += A[i+l*lda] * temp;
92
+ }
93
+ }
94
+ }
95
+ }
96
+
97
+ } else {
98
+
99
+ // C = alpha*A**DType*B + beta*C
100
+ for (int j = 0; j < N; ++j) {
101
+ for (int i = 0; i < M; ++i) {
102
+ temp = 0;
103
+ for (int l = 0; l < K; ++l) {
104
+ temp += A[l+i*lda] * B[l+j*ldb];
105
+ }
106
+
107
+ if (*beta == 0) {
108
+ C[i+j*ldc] = *alpha*temp;
109
+ } else {
110
+ C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
111
+ }
112
+ }
113
+ }
114
+
115
+ }
116
+
117
+ } else if (TransA == CblasNoTrans) {
118
+
119
+ // C = alpha*A*B**T + beta*C
120
+ for (int j = 0; j < N; ++j) {
121
+ if (*beta == 0) {
122
+ for (int i = 0; i < M; ++i) {
123
+ C[i+j*ldc] = 0;
124
+ }
125
+ } else if (*beta != 1) {
126
+ for (int i = 0; i < M; ++i) {
127
+ C[i+j*ldc] *= *beta;
128
+ }
129
+ }
130
+
131
+ for (int l = 0; l < K; ++l) {
132
+ if (B[j+l*ldb] != 0) {
133
+ temp = *alpha * B[j+l*ldb];
134
+ for (int i = 0; i < M; ++i) {
135
+ C[i+j*ldc] += A[i+l*lda] * temp;
136
+ }
137
+ }
138
+ }
139
+
140
+ }
141
+
142
+ } else {
143
+
144
+ // C = alpha*A**DType*B**T + beta*C
145
+ for (int j = 0; j < N; ++j) {
146
+ for (int i = 0; i < M; ++i) {
147
+ temp = 0;
148
+ for (int l = 0; l < K; ++l) {
149
+ temp += A[l+i*lda] * B[j+l*ldb];
150
+ }
151
+
152
+ if (*beta == 0) {
153
+ C[i+j*ldc] = *alpha*temp;
154
+ } else {
155
+ C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
156
+ }
157
+ }
158
+ }
159
+
160
+ }
161
+
162
+ return;
163
+ }
164
+
165
+
166
+
167
+ template <typename DType>
168
+ inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
169
+ const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
170
+ {
171
+ if (Order == CblasRowMajor) {
172
+ if (TransA == CblasNoTrans) {
173
+ if (lda < std::max(K,1)) {
174
+ rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
175
+ }
176
+ } else {
177
+ if (lda < std::max(M,1)) { // && TransA == CblasTrans
178
+ rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
179
+ }
180
+ }
181
+
182
+ if (TransB == CblasNoTrans) {
183
+ if (ldb < std::max(N,1)) {
184
+ rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
185
+ }
186
+ } else {
187
+ if (ldb < std::max(K,1)) {
188
+ rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
189
+ }
190
+ }
191
+
192
+ if (ldc < std::max(N,1)) {
193
+ rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
194
+ }
195
+ } else { // CblasColMajor
196
+ if (TransA == CblasNoTrans) {
197
+ if (lda < std::max(M,1)) {
198
+ rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
199
+ }
200
+ } else {
201
+ if (lda < std::max(K,1)) { // && TransA == CblasTrans
202
+ rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
203
+ }
204
+ }
205
+
206
+ if (TransB == CblasNoTrans) {
207
+ if (ldb < std::max(K,1)) {
208
+ rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
209
+ }
210
+ } else {
211
+ if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
212
+ rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
213
+ }
214
+ }
215
+
216
+ if (ldc < std::max(M,1)) {
217
+ rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
218
+ }
219
+ }
220
+
221
+ /*
222
+ * Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
223
+ * we copy C and then subtract to preserve asymmetry.
224
+ */
225
+
226
+ if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
227
+ rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
228
+ /*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
229
+ syreflect(CblasUpper, N, C, ldc);
230
+ */
231
+ }
232
+
233
+ if (Order == CblasRowMajor) gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
234
+ else gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
235
+
236
+ }
237
+
238
+
239
+ }} // end of namespace nm::math
240
+
241
+ #endif // GEMM_H
@@ -0,0 +1,178 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gemv.h
25
+ //
26
+ // Header file for interface with ATLAS's CBLAS gemv functions and
27
+ // native templated version of LAPACK's gemv function.
28
+ //
29
+
30
+ #ifndef GEMV_H
31
+ # define GEMV_H
32
+
33
+ #include "math/long_dtype.h"
34
+
35
+ namespace nm { namespace math {
36
+
37
+ /*
38
+ * GEneral Matrix-Vector multiplication: based on dgemv.f from Netlib.
39
+ *
40
+ * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
41
+ *
42
+ * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
43
+ */
44
+ template <typename DType>
45
+ inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const DType* alpha, const DType* A, const int lda,
46
+ const DType* X, const int incX, const DType* beta, DType* Y, const int incY) {
47
+ int lenX, lenY, i, j;
48
+ int kx, ky, iy, jx, jy, ix;
49
+
50
+ typename LongDType<DType>::type temp;
51
+
52
+ // Test the input parameters
53
+ if (Trans < 111 || Trans > 113) {
54
+ rb_raise(rb_eArgError, "GEMV: TransA must be CblasNoTrans, CblasTrans, or CblasConjTrans");
55
+ return false;
56
+ } else if (lda < std::max(1, N)) {
57
+ fprintf(stderr, "GEMV: N = %d; got lda=%d", N, lda);
58
+ rb_raise(rb_eArgError, "GEMV: Expected lda >= max(1, N)");
59
+ return false;
60
+ } else if (incX == 0) {
61
+ rb_raise(rb_eArgError, "GEMV: Expected incX != 0\n");
62
+ return false;
63
+ } else if (incY == 0) {
64
+ rb_raise(rb_eArgError, "GEMV: Expected incY != 0\n");
65
+ return false;
66
+ }
67
+
68
+ // Quick return if possible
69
+ if (!M or !N or (*alpha == 0 and *beta == 1)) return true;
70
+
71
+ if (Trans == CblasNoTrans) {
72
+ lenX = N;
73
+ lenY = M;
74
+ } else {
75
+ lenX = M;
76
+ lenY = N;
77
+ }
78
+
79
+ if (incX > 0) kx = 0;
80
+ else kx = (lenX - 1) * -incX;
81
+
82
+ if (incY > 0) ky = 0;
83
+ else ky = (lenY - 1) * -incY;
84
+
85
+ // Start the operations. In this version, the elements of A are accessed sequentially with one pass through A.
86
+ if (*beta != 1) {
87
+ if (incY == 1) {
88
+ if (*beta == 0) {
89
+ for (i = 0; i < lenY; ++i) {
90
+ Y[i] = 0;
91
+ }
92
+ } else {
93
+ for (i = 0; i < lenY; ++i) {
94
+ Y[i] *= *beta;
95
+ }
96
+ }
97
+ } else {
98
+ iy = ky;
99
+ if (*beta == 0) {
100
+ for (i = 0; i < lenY; ++i) {
101
+ Y[iy] = 0;
102
+ iy += incY;
103
+ }
104
+ } else {
105
+ for (i = 0; i < lenY; ++i) {
106
+ Y[iy] *= *beta;
107
+ iy += incY;
108
+ }
109
+ }
110
+ }
111
+ }
112
+
113
+ if (*alpha == 0) return false;
114
+
115
+ if (Trans == CblasNoTrans) {
116
+
117
+ // Form y := alpha*A*x + y.
118
+ jx = kx;
119
+ if (incY == 1) {
120
+ for (j = 0; j < N; ++j) {
121
+ if (X[jx] != 0) {
122
+ temp = *alpha * X[jx];
123
+ for (i = 0; i < M; ++i) {
124
+ Y[i] += A[j+i*lda] * temp;
125
+ }
126
+ }
127
+ jx += incX;
128
+ }
129
+ } else {
130
+ for (j = 0; j < N; ++j) {
131
+ if (X[jx] != 0) {
132
+ temp = *alpha * X[jx];
133
+ iy = ky;
134
+ for (i = 0; i < M; ++i) {
135
+ Y[iy] += A[j+i*lda] * temp;
136
+ iy += incY;
137
+ }
138
+ }
139
+ jx += incX;
140
+ }
141
+ }
142
+
143
+ } else { // TODO: Check that indices are correct! They're switched for C.
144
+
145
+ // Form y := alpha*A**DType*x + y.
146
+ jy = ky;
147
+
148
+ if (incX == 1) {
149
+ for (j = 0; j < N; ++j) {
150
+ temp = 0;
151
+ for (i = 0; i < M; ++i) {
152
+ temp += A[j+i*lda]*X[j];
153
+ }
154
+ Y[jy] += *alpha * temp;
155
+ jy += incY;
156
+ }
157
+ } else {
158
+ for (j = 0; j < N; ++j) {
159
+ temp = 0;
160
+ ix = kx;
161
+ for (i = 0; i < M; ++i) {
162
+ temp += A[j+i*lda] * X[ix];
163
+ ix += incX;
164
+ }
165
+
166
+ Y[jy] += *alpha * temp;
167
+ jy += incY;
168
+ }
169
+ }
170
+ }
171
+
172
+ return true;
173
+ } // end of GEMV
174
+
175
+
176
+ }} // end of namespace nm::math
177
+
178
+ #endif // GEMM_H
@@ -0,0 +1,255 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == getrf.h
25
+ //
26
+ // getrf function in native C++.
27
+ //
28
+
29
+ /*
30
+ * Automatically Tuned Linear Algebra Software v3.8.4
31
+ * (C) Copyright 1999 R. Clint Whaley
32
+ *
33
+ * Redistribution and use in source and binary forms, with or without
34
+ * modification, are permitted provided that the following conditions
35
+ * are met:
36
+ * 1. Redistributions of source code must retain the above copyright
37
+ * notice, this list of conditions and the following disclaimer.
38
+ * 2. Redistributions in binary form must reproduce the above copyright
39
+ * notice, this list of conditions, and the following disclaimer in the
40
+ * documentation and/or other materials provided with the distribution.
41
+ * 3. The name of the ATLAS group or the names of its contributers may
42
+ * not be used to endorse or promote products derived from this
43
+ * software without specific written permission.
44
+ *
45
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
46
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
47
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
48
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
49
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
50
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
51
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
52
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
53
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
54
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
55
+ * POSSIBILITY OF SUCH DAMAGE.
56
+ *
57
+ */
58
+
59
+ #ifndef GETRF_H
60
+ #define GETRF_H
61
+
62
+ #include "math/laswp.h"
63
+ #include "math/math.h"
64
+ #include "math/trsm.h"
65
+ #include "math/gemm.h"
66
+ #include "math/imax.h"
67
+ #include "math/scal.h"
68
+
69
+ namespace nm { namespace math {
70
+
71
+ /* Numeric inverse -- usually just 1 / f, but a little more complicated for complex. */
72
+ template <typename DType>
73
+ inline DType numeric_inverse(const DType& n) {
74
+ return n.inverse();
75
+ }
76
+ template <> inline float numeric_inverse(const float& n) { return 1 / n; }
77
+ template <> inline double numeric_inverse(const double& n) { return 1 / n; }
78
+
79
+ /*
80
+ * Templated version of row-order and column-order getrf, derived from ATL_getrfR.c (from ATLAS 3.8.0).
81
+ *
82
+ * 1. Row-major factorization of form
83
+ * A = L * U * P
84
+ * where P is a column-permutation matrix, L is lower triangular (lower
85
+ * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
86
+ * trapazoidal if M < N). This is the recursive Level 3 BLAS version.
87
+ *
88
+ * 2. Column-major factorization of form
89
+ * A = P * L * U
90
+ * where P is a row-permutation matrix, L is lower triangular with unit diagonal
91
+ * elements (lower trapazoidal if M > N), and U is upper triangular (upper
92
+ * trapazoidal if M < N). This is the recursive Level 3 BLAS version.
93
+ *
94
+ * Template argument determines whether 1 or 2 is utilized.
95
+ */
96
+ template <bool RowMajor, typename DType>
97
+ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int* ipiv) {
98
+ const int MN = std::min(M, N);
99
+ int ierr = 0;
100
+
101
+ // Symbols used by ATLAS in the several versions of this function:
102
+ // Row Col Us
103
+ // Nup Nleft N_ul
104
+ // Ndown Nright N_dr
105
+ // We're going to use N_ul, N_dr
106
+
107
+ DType neg_one = -1, one = 1;
108
+
109
+ if (MN > 1) {
110
+ int N_ul = MN >> 1;
111
+
112
+ // FIXME: Figure out how ATLAS #defines NB
113
+ #ifdef NB
114
+ if (N_ul > NB) N_ul = ATL_MulByNB(ATL_DivByNB(N_ul));
115
+ #endif
116
+
117
+ int N_dr;
118
+ if (RowMajor) {
119
+ N_dr = M - N_ul;
120
+ } else {
121
+ N_dr = N - N_ul;
122
+ }
123
+
124
+ int i = RowMajor ? getrf_nothrow<true,DType>(N_ul, N, A, lda, ipiv) : getrf_nothrow<false,DType>(M, N_ul, A, lda, ipiv);
125
+
126
+ if (i) if (!ierr) ierr = i;
127
+
128
+ DType *Ar, *Ac, *An;
129
+ if (RowMajor) {
130
+ Ar = &(A[N_ul * lda]),
131
+ Ac = &(A[N_ul]);
132
+ An = &(Ar[N_ul]);
133
+
134
+ nm::math::laswp<DType>(N_dr, Ar, lda, 0, N_ul, ipiv, 1);
135
+
136
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, N_dr, N_ul, one, A, lda, Ar, lda);
137
+ nm::math::gemm<DType>(CblasRowMajor, CblasNoTrans, CblasNoTrans, N_dr, N-N_ul, N_ul, &neg_one, Ar, lda, Ac, lda, &one, An, lda);
138
+
139
+ i = getrf_nothrow<true,DType>(N_dr, N-N_ul, An, lda, ipiv+N_ul);
140
+ } else {
141
+ Ar = NULL;
142
+ Ac = &(A[N_ul * lda]);
143
+ An = &(Ac[N_ul]);
144
+
145
+ nm::math::laswp<DType>(N_dr, Ac, lda, 0, N_ul, ipiv, 1);
146
+
147
+ nm::math::trsm<DType>(CblasColMajor, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N_ul, N_dr, one, A, lda, Ac, lda);
148
+ nm::math::gemm<DType>(CblasColMajor, CblasNoTrans, CblasNoTrans, M-N_ul, N_dr, N_ul, &neg_one, &(A[N_ul]), lda, Ac, lda, &one, An, lda);
149
+
150
+ i = getrf_nothrow<false,DType>(M-N_ul, N_dr, An, lda, ipiv+N_ul);
151
+ }
152
+
153
+ if (i) if (!ierr) ierr = N_ul + i;
154
+
155
+ for (i = N_ul; i != MN; i++) {
156
+ ipiv[i] += N_ul;
157
+ }
158
+
159
+ nm::math::laswp<DType>(N_ul, A, lda, N_ul, MN, ipiv, 1); /* apply pivots */
160
+
161
+ } else if (MN == 1) { // there's another case for the colmajor version, but it doesn't seem to be necessary.
162
+
163
+ int i;
164
+ if (RowMajor) {
165
+ i = *ipiv = nm::math::imax<DType>(N, A, 1); // cblas_iamax(N, A, 1);
166
+ } else {
167
+ i = *ipiv = nm::math::imax<DType>(M, A, 1);
168
+ }
169
+
170
+ DType tmp = A[i];
171
+ if (tmp != 0) {
172
+
173
+ nm::math::scal<DType>((RowMajor ? N : M), nm::math::numeric_inverse(tmp), A, 1);
174
+ A[i] = *A;
175
+ *A = tmp;
176
+
177
+ } else ierr = 1;
178
+
179
+ }
180
+ return(ierr);
181
+ }
182
+
183
+
184
+ /*
185
+ * From ATLAS 3.8.0:
186
+ *
187
+ * Computes one of two LU factorizations based on the setting of the Order
188
+ * parameter, as follows:
189
+ * ----------------------------------------------------------------------------
190
+ * Order == CblasColMajor
191
+ * Column-major factorization of form
192
+ * A = P * L * U
193
+ * where P is a row-permutation matrix, L is lower triangular with unit
194
+ * diagonal elements (lower trapazoidal if M > N), and U is upper triangular
195
+ * (upper trapazoidal if M < N).
196
+ *
197
+ * ----------------------------------------------------------------------------
198
+ * Order == CblasRowMajor
199
+ * Row-major factorization of form
200
+ * A = P * L * U
201
+ * where P is a column-permutation matrix, L is lower triangular (lower
202
+ * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
203
+ * trapazoidal if M < N).
204
+ *
205
+ * ============================================================================
206
+ * Let IERR be the return value of the function:
207
+ * If IERR == 0, successful exit.
208
+ * If (IERR < 0) the -IERR argument had an illegal value
209
+ * If (IERR > 0 && Order == CblasColMajor)
210
+ * U(i-1,i-1) is exactly zero. The factorization has been completed,
211
+ * but the factor U is exactly singular, and division by zero will
212
+ * occur if it is used to solve a system of equations.
213
+ * If (IERR > 0 && Order == CblasRowMajor)
214
+ * L(i-1,i-1) is exactly zero. The factorization has been completed,
215
+ * but the factor L is exactly singular, and division by zero will
216
+ * occur if it is used to solve a system of equations.
217
+ */
218
+ template <typename DType>
219
+ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, DType* A, int lda, int* ipiv) {
220
+ if (Order == CblasRowMajor) {
221
+ if (lda < std::max(1,N)) {
222
+ rb_raise(rb_eArgError, "GETRF: lda must be >= MAX(N,1): lda=%d N=%d", lda, N);
223
+ return -6;
224
+ }
225
+
226
+ return getrf_nothrow<true,DType>(M, N, A, lda, ipiv);
227
+ } else {
228
+ if (lda < std::max(1,M)) {
229
+ rb_raise(rb_eArgError, "GETRF: lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
230
+ return -6;
231
+ }
232
+
233
+ return getrf_nothrow<false,DType>(M, N, A, lda, ipiv);
234
+ //rb_raise(rb_eNotImpError, "column major getrf not implemented");
235
+ }
236
+ }
237
+
238
+
239
+
240
+ /*
241
+ * Function signature conversion for calling LAPACK's getrf functions as directly as possible.
242
+ *
243
+ * For documentation: http://www.netlib.org/lapack/double/dgetrf.f
244
+ *
245
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
246
+ */
247
+ template <typename DType>
248
+ inline int clapack_getrf(const enum CBLAS_ORDER order, const int m, const int n, void* a, const int lda, int* ipiv) {
249
+ return getrf<DType>(order, m, n, reinterpret_cast<DType*>(a), lda, ipiv);
250
+ }
251
+
252
+
253
+ } } // end nm::math
254
+
255
+ #endif