nmatrix 0.0.6 → 0.0.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (67) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -0
  3. data/Gemfile +5 -0
  4. data/History.txt +97 -0
  5. data/Manifest.txt +34 -7
  6. data/README.rdoc +13 -13
  7. data/Rakefile +36 -26
  8. data/ext/nmatrix/data/data.cpp +15 -2
  9. data/ext/nmatrix/data/data.h +4 -0
  10. data/ext/nmatrix/data/ruby_object.h +5 -14
  11. data/ext/nmatrix/extconf.rb +3 -2
  12. data/ext/nmatrix/{util/math.cpp → math.cpp} +296 -6
  13. data/ext/nmatrix/math/asum.h +143 -0
  14. data/ext/nmatrix/math/geev.h +82 -0
  15. data/ext/nmatrix/math/gemm.h +267 -0
  16. data/ext/nmatrix/math/gemv.h +208 -0
  17. data/ext/nmatrix/math/ger.h +96 -0
  18. data/ext/nmatrix/math/gesdd.h +80 -0
  19. data/ext/nmatrix/math/gesvd.h +78 -0
  20. data/ext/nmatrix/math/getf2.h +86 -0
  21. data/ext/nmatrix/math/getrf.h +240 -0
  22. data/ext/nmatrix/math/getri.h +107 -0
  23. data/ext/nmatrix/math/getrs.h +125 -0
  24. data/ext/nmatrix/math/idamax.h +86 -0
  25. data/ext/nmatrix/{util → math}/lapack.h +60 -356
  26. data/ext/nmatrix/math/laswp.h +165 -0
  27. data/ext/nmatrix/math/long_dtype.h +52 -0
  28. data/ext/nmatrix/math/math.h +1154 -0
  29. data/ext/nmatrix/math/nrm2.h +181 -0
  30. data/ext/nmatrix/math/potrs.h +125 -0
  31. data/ext/nmatrix/math/rot.h +141 -0
  32. data/ext/nmatrix/math/rotg.h +115 -0
  33. data/ext/nmatrix/math/scal.h +73 -0
  34. data/ext/nmatrix/math/swap.h +73 -0
  35. data/ext/nmatrix/math/trsm.h +383 -0
  36. data/ext/nmatrix/nmatrix.cpp +176 -152
  37. data/ext/nmatrix/nmatrix.h +1 -2
  38. data/ext/nmatrix/ruby_constants.cpp +9 -4
  39. data/ext/nmatrix/ruby_constants.h +1 -0
  40. data/ext/nmatrix/storage/dense.cpp +57 -41
  41. data/ext/nmatrix/storage/list.cpp +52 -50
  42. data/ext/nmatrix/storage/storage.cpp +59 -43
  43. data/ext/nmatrix/storage/yale.cpp +352 -333
  44. data/ext/nmatrix/storage/yale.h +4 -0
  45. data/lib/nmatrix.rb +2 -2
  46. data/lib/nmatrix/blas.rb +4 -4
  47. data/lib/nmatrix/enumerate.rb +241 -0
  48. data/lib/nmatrix/lapack.rb +54 -1
  49. data/lib/nmatrix/math.rb +462 -0
  50. data/lib/nmatrix/nmatrix.rb +210 -486
  51. data/lib/nmatrix/nvector.rb +0 -62
  52. data/lib/nmatrix/rspec.rb +75 -0
  53. data/lib/nmatrix/shortcuts.rb +136 -108
  54. data/lib/nmatrix/version.rb +1 -1
  55. data/spec/blas_spec.rb +20 -12
  56. data/spec/elementwise_spec.rb +22 -13
  57. data/spec/io_spec.rb +1 -0
  58. data/spec/lapack_spec.rb +197 -0
  59. data/spec/nmatrix_spec.rb +39 -38
  60. data/spec/nvector_spec.rb +3 -9
  61. data/spec/rspec_monkeys.rb +29 -0
  62. data/spec/rspec_spec.rb +34 -0
  63. data/spec/shortcuts_spec.rb +14 -16
  64. data/spec/slice_spec.rb +242 -186
  65. data/spec/spec_helper.rb +19 -0
  66. metadata +33 -5
  67. data/ext/nmatrix/util/math.h +0 -2612
data/spec/spec_helper.rb CHANGED
@@ -24,6 +24,7 @@
24
24
  #
25
25
  # Common data for testing.
26
26
  require "./lib/nmatrix"
27
+ require "./lib/nmatrix/rspec"
27
28
 
28
29
  MATRIX43A_ARRAY = [14.0, 9.0, 3.0, 2.0, 11.0, 15.0, 0.0, 12.0, 17.0, 5.0, 2.0, 3.0]
29
30
  MATRIX32A_ARRAY = [12.0, 25.0, 9.0, 10.0, 8.0, 5.0]
@@ -66,3 +67,21 @@ def create_vector(stype) #:nodoc:
66
67
 
67
68
  m
68
69
  end
70
+
71
+ # Stupid but independent comparison for slice_spec
72
+ def nm_eql(n, m) #:nodoc:
73
+ if n.shape != m.shape
74
+ false
75
+ else # NMatrix
76
+ n.shape[0].times do |i|
77
+ n.shape[1].times do |j|
78
+ if n[i,j] != m[i,j]
79
+ puts "n[#{i},#{j}] != m[#{i},#{j}] (#{n[i,j]} != #{m[i,j]})"
80
+ return false
81
+ end
82
+ end
83
+ end
84
+ end
85
+ true
86
+ end
87
+
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nmatrix
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.6
4
+ version: 0.0.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - John Woods
@@ -10,7 +10,7 @@ authors:
10
10
  autorequire:
11
11
  bindir: bin
12
12
  cert_chain: []
13
- date: 2013-08-09 00:00:00.000000000 Z
13
+ date: 2013-08-22 00:00:00.000000000 Z
14
14
  dependencies:
15
15
  - !ruby/object:Gem::Dependency
16
16
  name: rdoc
@@ -136,6 +136,30 @@ files:
136
136
  - ext/nmatrix/data/rational.h
137
137
  - ext/nmatrix/data/ruby_object.h
138
138
  - ext/nmatrix/extconf.rb
139
+ - ext/nmatrix/math.cpp
140
+ - ext/nmatrix/math/asum.h
141
+ - ext/nmatrix/math/geev.h
142
+ - ext/nmatrix/math/gemm.h
143
+ - ext/nmatrix/math/gemv.h
144
+ - ext/nmatrix/math/ger.h
145
+ - ext/nmatrix/math/gesdd.h
146
+ - ext/nmatrix/math/gesvd.h
147
+ - ext/nmatrix/math/getf2.h
148
+ - ext/nmatrix/math/getrf.h
149
+ - ext/nmatrix/math/getri.h
150
+ - ext/nmatrix/math/getrs.h
151
+ - ext/nmatrix/math/idamax.h
152
+ - ext/nmatrix/math/lapack.h
153
+ - ext/nmatrix/math/laswp.h
154
+ - ext/nmatrix/math/long_dtype.h
155
+ - ext/nmatrix/math/math.h
156
+ - ext/nmatrix/math/nrm2.h
157
+ - ext/nmatrix/math/potrs.h
158
+ - ext/nmatrix/math/rot.h
159
+ - ext/nmatrix/math/rotg.h
160
+ - ext/nmatrix/math/scal.h
161
+ - ext/nmatrix/math/swap.h
162
+ - ext/nmatrix/math/trsm.h
139
163
  - ext/nmatrix/nmatrix.cpp
140
164
  - ext/nmatrix/nmatrix.h
141
165
  - ext/nmatrix/ruby_constants.cpp
@@ -154,21 +178,21 @@ files:
154
178
  - ext/nmatrix/types.h
155
179
  - ext/nmatrix/util/io.cpp
156
180
  - ext/nmatrix/util/io.h
157
- - ext/nmatrix/util/lapack.h
158
- - ext/nmatrix/util/math.cpp
159
- - ext/nmatrix/util/math.h
160
181
  - ext/nmatrix/util/sl_list.cpp
161
182
  - ext/nmatrix/util/sl_list.h
162
183
  - ext/nmatrix/util/util.h
163
184
  - lib/nmatrix.rb
164
185
  - lib/nmatrix/blas.rb
186
+ - lib/nmatrix/enumerate.rb
165
187
  - lib/nmatrix/io/market.rb
166
188
  - lib/nmatrix/io/mat5_reader.rb
167
189
  - lib/nmatrix/io/mat_reader.rb
168
190
  - lib/nmatrix/lapack.rb
191
+ - lib/nmatrix/math.rb
169
192
  - lib/nmatrix/monkeys.rb
170
193
  - lib/nmatrix/nmatrix.rb
171
194
  - lib/nmatrix/nvector.rb
195
+ - lib/nmatrix/rspec.rb
172
196
  - lib/nmatrix/shortcuts.rb
173
197
  - lib/nmatrix/version.rb
174
198
  - lib/nmatrix/yale_functions.rb
@@ -188,6 +212,8 @@ files:
188
212
  - spec/nmatrix_yale_resize_test_associations.yaml
189
213
  - spec/nmatrix_yale_spec.rb
190
214
  - spec/nvector_spec.rb
215
+ - spec/rspec_monkeys.rb
216
+ - spec/rspec_spec.rb
191
217
  - spec/shortcuts_spec.rb
192
218
  - spec/slice_spec.rb
193
219
  - spec/spec_helper.rb
@@ -258,6 +284,8 @@ test_files:
258
284
  - spec/nmatrix_yale_resize_test_associations.yaml
259
285
  - spec/nmatrix_yale_spec.rb
260
286
  - spec/nvector_spec.rb
287
+ - spec/rspec_monkeys.rb
288
+ - spec/rspec_spec.rb
261
289
  - spec/shortcuts_spec.rb
262
290
  - spec/slice_spec.rb
263
291
  - spec/spec_helper.rb
@@ -1,2612 +0,0 @@
1
- /////////////////////////////////////////////////////////////////////
2
- // = NMatrix
3
- //
4
- // A linear algebra library for scientific computation in Ruby.
5
- // NMatrix is part of SciRuby.
6
- //
7
- // NMatrix was originally inspired by and derived from NArray, by
8
- // Masahiro Tanaka: http://narray.rubyforge.org
9
- //
10
- // == Copyright Information
11
- //
12
- // SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
13
- // NMatrix is Copyright (c) 2013, Ruby Science Foundation
14
- //
15
- // Please see LICENSE.txt for additional copyright notices.
16
- //
17
- // == Contributing
18
- //
19
- // By contributing source code to SciRuby, you agree to be bound by
20
- // our Contributor Agreement:
21
- //
22
- // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
- //
24
- // == math.h
25
- //
26
- // Header file for math functions, interfacing with BLAS, etc.
27
- //
28
- // For instructions on adding CBLAS and CLAPACK functions, see the
29
- // beginning of math.cpp.
30
- //
31
- // Some of these functions are from ATLAS. Here is the license for
32
- // ATLAS:
33
- //
34
- /*
35
- * Automatically Tuned Linear Algebra Software v3.8.4
36
- * (C) Copyright 1999 R. Clint Whaley
37
- *
38
- * Redistribution and use in source and binary forms, with or without
39
- * modification, are permitted provided that the following conditions
40
- * are met:
41
- * 1. Redistributions of source code must retain the above copyright
42
- * notice, this list of conditions and the following disclaimer.
43
- * 2. Redistributions in binary form must reproduce the above copyright
44
- * notice, this list of conditions, and the following disclaimer in the
45
- * documentation and/or other materials provided with the distribution.
46
- * 3. The name of the ATLAS group or the names of its contributers may
47
- * not be used to endorse or promote products derived from this
48
- * software without specific written permission.
49
- *
50
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
51
- * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
52
- * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
53
- * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
54
- * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
55
- * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
56
- * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
57
- * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
58
- * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
59
- * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
60
- * POSSIBILITY OF SUCH DAMAGE.
61
- *
62
- */
63
-
64
- #ifndef MATH_H
65
- #define MATH_H
66
-
67
- /*
68
- * Standard Includes
69
- */
70
-
71
- extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
72
- #include <cblas.h>
73
-
74
- #ifdef HAVE_CLAPACK_H
75
- #include <clapack.h>
76
- #endif
77
- }
78
-
79
- #include <algorithm> // std::min, std::max
80
- #include <limits> // std::numeric_limits
81
-
82
- /*
83
- * Project Includes
84
- */
85
- #include "data/data.h"
86
- #include "lapack.h"
87
-
88
- /*
89
- * Macros
90
- */
91
- #define REAL_RECURSE_LIMIT 4
92
-
93
- /*
94
- * Data
95
- */
96
-
97
-
98
- extern "C" {
99
- /*
100
- * C accessors.
101
- */
102
- void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
103
- void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
104
- void nm_math_init_blas(void);
105
- }
106
-
107
-
108
- namespace nm {
109
- namespace math {
110
-
111
- /*
112
- * Types
113
- */
114
-
115
-
116
- // These allow an increase in precision for intermediate values of gemm and gemv.
117
- // See also: http://stackoverflow.com/questions/11873694/how-does-one-increase-precision-in-c-templates-in-a-template-typename-dependen
118
- template <typename DType> struct LongDType;
119
- template <> struct LongDType<uint8_t> { typedef int16_t type; };
120
- template <> struct LongDType<int8_t> { typedef int16_t type; };
121
- template <> struct LongDType<int16_t> { typedef int32_t type; };
122
- template <> struct LongDType<int32_t> { typedef int64_t type; };
123
- template <> struct LongDType<int64_t> { typedef int64_t type; };
124
- template <> struct LongDType<float> { typedef double type; };
125
- template <> struct LongDType<double> { typedef double type; };
126
- template <> struct LongDType<Complex64> { typedef Complex128 type; };
127
- template <> struct LongDType<Complex128> { typedef Complex128 type; };
128
- template <> struct LongDType<Rational32> { typedef Rational128 type; };
129
- template <> struct LongDType<Rational64> { typedef Rational128 type; };
130
- template <> struct LongDType<Rational128> { typedef Rational128 type; };
131
- template <> struct LongDType<RubyObject> { typedef RubyObject type; };
132
-
133
- /*
134
- * Functions
135
- */
136
-
137
- /* Numeric inverse -- usually just 1 / f, but a little more complicated for complex. */
138
- template <typename DType>
139
- inline DType numeric_inverse(const DType& n) {
140
- return n.inverse();
141
- }
142
- template <> inline float numeric_inverse<float>(const float& n) { return 1 / n; }
143
- template <> inline double numeric_inverse<double>(const double& n) { return 1 / n; }
144
-
145
- /*
146
- * This version of trsm doesn't do any error checks and only works on column-major matrices.
147
- *
148
- * For row major, call trsm<DType> instead. That will handle necessary changes-of-variables
149
- * and parameter checks.
150
- *
151
- * Note that some of the boundary conditions here may be incorrect. Very little has been tested!
152
- * This was converted directly from dtrsm.f using f2c, and then rewritten more cleanly.
153
- */
154
- template <typename DType>
155
- inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
156
- const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
157
- const int m, const int n, const DType alpha, const DType* a,
158
- const int lda, DType* b, const int ldb)
159
- {
160
-
161
- // (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3
162
-
163
- if (m == 0 || n == 0) return; /* Quick return if possible. */
164
-
165
- if (alpha == 0) { // Handle alpha == 0
166
- for (int j = 0; j < n; ++j) {
167
- for (int i = 0; i < m; ++i) {
168
- b[i + j * ldb] = 0;
169
- }
170
- }
171
- return;
172
- }
173
-
174
- if (side == CblasLeft) {
175
- if (trans_a == CblasNoTrans) {
176
-
177
- /* Form B := alpha*inv( A )*B. */
178
- if (uplo == CblasUpper) {
179
- for (int j = 0; j < n; ++j) {
180
- if (alpha != 1) {
181
- for (int i = 0; i < m; ++i) {
182
- b[i + j * ldb] = alpha * b[i + j * ldb];
183
- }
184
- }
185
- for (int k = m-1; k >= 0; --k) {
186
- if (b[k + j * ldb] != 0) {
187
- if (diag == CblasNonUnit) {
188
- b[k + j * ldb] /= a[k + k * lda];
189
- }
190
-
191
- for (int i = 0; i < k-1; ++i) {
192
- b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
193
- }
194
- }
195
- }
196
- }
197
- } else {
198
- for (int j = 0; j < n; ++j) {
199
- if (alpha != 1) {
200
- for (int i = 0; i < m; ++i) {
201
- b[i + j * ldb] = alpha * b[i + j * ldb];
202
- }
203
- }
204
- for (int k = 0; k < m; ++k) {
205
- if (b[k + j * ldb] != 0.) {
206
- if (diag == CblasNonUnit) {
207
- b[k + j * ldb] /= a[k + k * lda];
208
- }
209
- for (int i = k+1; i < m; ++i) {
210
- b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
211
- }
212
- }
213
- }
214
- }
215
- }
216
- } else { // CblasTrans
217
-
218
- /* Form B := alpha*inv( A**T )*B. */
219
- if (uplo == CblasUpper) {
220
- for (int j = 0; j < n; ++j) {
221
- for (int i = 0; i < m; ++i) {
222
- DType temp = alpha * b[i + j * ldb];
223
- for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
224
- temp -= a[k + i * lda] * b[k + j * ldb];
225
- }
226
- if (diag == CblasNonUnit) {
227
- temp /= a[i + i * lda];
228
- }
229
- b[i + j * ldb] = temp;
230
- }
231
- }
232
- } else {
233
- for (int j = 0; j < n; ++j) {
234
- for (int i = m-1; i >= 0; --i) {
235
- DType temp= alpha * b[i + j * ldb];
236
- for (int k = i+1; k < m; ++k) {
237
- temp -= a[k + i * lda] * b[k + j * ldb];
238
- }
239
- if (diag == CblasNonUnit) {
240
- temp /= a[i + i * lda];
241
- }
242
- b[i + j * ldb] = temp;
243
- }
244
- }
245
- }
246
- }
247
- } else { // right side
248
-
249
- if (trans_a == CblasNoTrans) {
250
-
251
- /* Form B := alpha*B*inv( A ). */
252
-
253
- if (uplo == CblasUpper) {
254
- for (int j = 0; j < n; ++j) {
255
- if (alpha != 1) {
256
- for (int i = 0; i < m; ++i) {
257
- b[i + j * ldb] = alpha * b[i + j * ldb];
258
- }
259
- }
260
- for (int k = 0; k < j-1; ++k) {
261
- if (a[k + j * lda] != 0) {
262
- for (int i = 0; i < m; ++i) {
263
- b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
264
- }
265
- }
266
- }
267
- if (diag == CblasNonUnit) {
268
- DType temp = 1 / a[j + j * lda];
269
- for (int i = 0; i < m; ++i) {
270
- b[i + j * ldb] = temp * b[i + j * ldb];
271
- }
272
- }
273
- }
274
- } else {
275
- for (int j = n-1; j >= 0; --j) {
276
- if (alpha != 1) {
277
- for (int i = 0; i < m; ++i) {
278
- b[i + j * ldb] = alpha * b[i + j * ldb];
279
- }
280
- }
281
-
282
- for (int k = j+1; k < n; ++k) {
283
- if (a[k + j * lda] != 0.) {
284
- for (int i = 0; i < m; ++i) {
285
- b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
286
- }
287
- }
288
- }
289
- if (diag == CblasNonUnit) {
290
- DType temp = 1 / a[j + j * lda];
291
-
292
- for (int i = 0; i < m; ++i) {
293
- b[i + j * ldb] = temp * b[i + j * ldb];
294
- }
295
- }
296
- }
297
- }
298
- } else { // CblasTrans
299
-
300
- /* Form B := alpha*B*inv( A**T ). */
301
-
302
- if (uplo == CblasUpper) {
303
- for (int k = n-1; k >= 0; --k) {
304
- if (diag == CblasNonUnit) {
305
- DType temp= 1 / a[k + k * lda];
306
- for (int i = 0; i < m; ++i) {
307
- b[i + k * ldb] = temp * b[i + k * ldb];
308
- }
309
- }
310
- for (int j = 0; j < k-1; ++j) {
311
- if (a[j + k * lda] != 0.) {
312
- DType temp= a[j + k * lda];
313
- for (int i = 0; i < m; ++i) {
314
- b[i + j * ldb] -= temp * b[i + k * ldb];
315
- }
316
- }
317
- }
318
- if (alpha != 1) {
319
- for (int i = 0; i < m; ++i) {
320
- b[i + k * ldb] = alpha * b[i + k * ldb];
321
- }
322
- }
323
- }
324
- } else {
325
- for (int k = 0; k < n; ++k) {
326
- if (diag == CblasNonUnit) {
327
- DType temp = 1 / a[k + k * lda];
328
- for (int i = 0; i < m; ++i) {
329
- b[i + k * ldb] = temp * b[i + k * ldb];
330
- }
331
- }
332
- for (int j = k+1; j < n; ++j) {
333
- if (a[j + k * lda] != 0.) {
334
- DType temp = a[j + k * lda];
335
- for (int i = 0; i < m; ++i) {
336
- b[i + j * ldb] -= temp * b[i + k * ldb];
337
- }
338
- }
339
- }
340
- if (alpha != 1) {
341
- for (int i = 0; i < m; ++i) {
342
- b[i + k * ldb] = alpha * b[i + k * ldb];
343
- }
344
- }
345
- }
346
- }
347
- }
348
- }
349
- }
350
-
351
-
352
- template <typename DType>
353
- inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
354
- const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
355
- rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes");
356
- }
357
-
358
- template <typename DType>
359
- inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
360
- const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
361
- rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes");
362
- }
363
-
364
- template <>
365
- inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
366
- const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) {
367
- cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
368
- }
369
-
370
- template <>
371
- inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
372
- const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) {
373
- cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
374
- }
375
-
376
- template <>
377
- inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
378
- const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
379
- cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
380
- }
381
-
382
- template <>
383
- inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
384
- const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
385
- cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
386
- }
387
-
388
-
389
- template <>
390
- inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
391
- const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
392
- cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
393
- }
394
-
395
- template <>
396
- inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
397
- const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
398
- cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
399
- }
400
-
401
-
402
- template <typename DType>
403
- inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
404
- const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha,
405
- const DType* A, const int lda, DType* B, const int ldb) {
406
- rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes");
407
- }
408
-
409
- template <>
410
- inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
411
- const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha,
412
- const float* A, const int lda, float* B, const int ldb) {
413
- cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
414
- }
415
-
416
- template <>
417
- inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
418
- const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha,
419
- const double* A, const int lda, double* B, const int ldb) {
420
- cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
421
- }
422
-
423
- template <>
424
- inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
425
- const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha,
426
- const Complex64* A, const int lda, Complex64* B, const int ldb) {
427
- cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
428
- }
429
-
430
- template <>
431
- inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
432
- const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha,
433
- const Complex128* A, const int lda, Complex128* B, const int ldb) {
434
- cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
435
- }
436
-
437
-
438
- /*
439
- * BLAS' DTRSM function, generalized.
440
- */
441
- template <typename DType, typename = typename std::enable_if<!std::is_integral<DType>::value>::type>
442
- inline void trsm(const enum CBLAS_ORDER order,
443
- const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
444
- const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
445
- const int m, const int n, const DType alpha, const DType* a,
446
- const int lda, DType* b, const int ldb)
447
- {
448
- /*using std::cerr;
449
- using std::endl;*/
450
-
451
- int num_rows_a = n;
452
- if (side == CblasLeft) num_rows_a = m;
453
-
454
- if (lda < std::max(1,num_rows_a)) {
455
- fprintf(stderr, "TRSM: num_rows_a = %d; got lda=%d\n", num_rows_a, lda);
456
- rb_raise(rb_eArgError, "TRSM: Expected lda >= max(1, num_rows_a)");
457
- }
458
-
459
- // Test the input parameters.
460
- if (order == CblasRowMajor) {
461
- if (ldb < std::max(1,n)) {
462
- fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
463
- rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,N)");
464
- }
465
-
466
- // For row major, need to switch side and uplo
467
- enum CBLAS_SIDE side_ = side == CblasLeft ? CblasRight : CblasLeft;
468
- enum CBLAS_UPLO uplo_ = uplo == CblasUpper ? CblasLower : CblasUpper;
469
-
470
- /*
471
- cerr << "(row-major) trsm: " << (side_ == CblasLeft ? "left " : "right ")
472
- << (uplo_ == CblasUpper ? "upper " : "lower ")
473
- << (trans_a == CblasTrans ? "trans " : "notrans ")
474
- << (diag == CblasNonUnit ? "nonunit " : "unit ")
475
- << n << " " << m << " " << alpha << " a " << lda << " b " << ldb << endl;
476
- */
477
- trsm_nothrow<DType>(side_, uplo_, trans_a, diag, n, m, alpha, a, lda, b, ldb);
478
-
479
- } else { // CblasColMajor
480
-
481
- if (ldb < std::max(1,m)) {
482
- fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
483
- rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,M)");
484
- }
485
- /*
486
- cerr << "(col-major) trsm: " << (side == CblasLeft ? "left " : "right ")
487
- << (uplo == CblasUpper ? "upper " : "lower ")
488
- << (trans_a == CblasTrans ? "trans " : "notrans ")
489
- << (diag == CblasNonUnit ? "nonunit " : "unit ")
490
- << m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
491
- */
492
- trsm_nothrow<DType>(side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
493
-
494
- }
495
-
496
- }
497
-
498
-
499
- template <>
500
- inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
501
- const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
502
- const int m, const int n, const float alpha, const float* a,
503
- const int lda, float* b, const int ldb)
504
- {
505
- cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
506
- }
507
-
508
- template <>
509
- inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
510
- const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
511
- const int m, const int n, const double alpha, const double* a,
512
- const int lda, double* b, const int ldb)
513
- {
514
- /* using std::cerr;
515
- using std::endl;
516
- cerr << "(row-major) dtrsm: " << (side == CblasLeft ? "left " : "right ")
517
- << (uplo == CblasUpper ? "upper " : "lower ")
518
- << (trans_a == CblasTrans ? "trans " : "notrans ")
519
- << (diag == CblasNonUnit ? "nonunit " : "unit ")
520
- << m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
521
- */
522
- cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
523
- }
524
-
525
-
526
- template <>
527
- inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
528
- const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
529
- const int m, const int n, const Complex64 alpha, const Complex64* a,
530
- const int lda, Complex64* b, const int ldb)
531
- {
532
- cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
533
- }
534
-
535
- template <>
536
- inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
537
- const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
538
- const int m, const int n, const Complex128 alpha, const Complex128* a,
539
- const int lda, Complex128* b, const int ldb)
540
- {
541
- cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
542
- }
543
-
544
-
545
- /*
546
- * ATLAS function which performs row interchanges on a general rectangular matrix. Modeled after the LAPACK LASWP function.
547
- *
548
- * This version is templated for use by template <> getrf().
549
- */
550
- template <typename DType>
551
- inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
552
- //const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
553
-
554
- int nb = N >> 5;
555
-
556
- const int mr = N - (nb<<5);
557
- const int incA = lda << 5;
558
-
559
- if (K2 < K1) return;
560
-
561
- int i1, i2;
562
- if (inci < 0) {
563
- piv -= (K2-1) * inci;
564
- i1 = K2 - 1;
565
- i2 = K1;
566
- } else {
567
- piv += K1 * inci;
568
- i1 = K1;
569
- i2 = K2-1;
570
- }
571
-
572
- if (nb) {
573
-
574
- do {
575
- const int* ipiv = piv;
576
- int i = i1;
577
- int KeepOn;
578
-
579
- do {
580
- int ip = *ipiv; ipiv += inci;
581
-
582
- if (ip != i) {
583
- DType *a0 = &(A[i]),
584
- *a1 = &(A[ip]);
585
-
586
- for (register int h = 32; h; h--) {
587
- DType r = *a0;
588
- *a0 = *a1;
589
- *a1 = r;
590
-
591
- a0 += lda;
592
- a1 += lda;
593
- }
594
-
595
- }
596
- if (inci > 0) KeepOn = (++i <= i2);
597
- else KeepOn = (--i >= i2);
598
-
599
- } while (KeepOn);
600
- A += incA;
601
- } while (--nb);
602
- }
603
-
604
- if (mr) {
605
- const int* ipiv = piv;
606
- int i = i1;
607
- int KeepOn;
608
-
609
- do {
610
- int ip = *ipiv; ipiv += inci;
611
- if (ip != i) {
612
- DType *a0 = &(A[i]),
613
- *a1 = &(A[ip]);
614
-
615
- for (register int h = mr; h; h--) {
616
- DType r = *a0;
617
- *a0 = *a1;
618
- *a1 = r;
619
-
620
- a0 += lda;
621
- a1 += lda;
622
- }
623
- }
624
-
625
- if (inci > 0) KeepOn = (++i <= i2);
626
- else KeepOn = (--i >= i2);
627
-
628
- } while (KeepOn);
629
- }
630
- }
631
-
632
-
633
- /*
634
- * GEneral Matrix Multiplication: based on dgemm.f from Netlib.
635
- *
636
- * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
637
- *
638
- * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
639
- *
640
- * This version throws no errors. Use gemm<DType> instead for error checking.
641
- */
642
- template <typename DType>
643
- inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
644
- const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
645
- {
646
-
647
- typename LongDType<DType>::type temp;
648
-
649
- // Quick return if possible
650
- if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;
651
-
652
- // For alpha = 0
653
- if (*alpha == 0) {
654
- if (*beta == 0) {
655
- for (int j = 0; j < N; ++j)
656
- for (int i = 0; i < M; ++i) {
657
- C[i+j*ldc] = 0;
658
- }
659
- } else {
660
- for (int j = 0; j < N; ++j)
661
- for (int i = 0; i < M; ++i) {
662
- C[i+j*ldc] *= *beta;
663
- }
664
- }
665
- return;
666
- }
667
-
668
- // Start the operations
669
- if (TransB == CblasNoTrans) {
670
- if (TransA == CblasNoTrans) {
671
- // C = alpha*A*B+beta*C
672
- for (int j = 0; j < N; ++j) {
673
- if (*beta == 0) {
674
- for (int i = 0; i < M; ++i) {
675
- C[i+j*ldc] = 0;
676
- }
677
- } else if (*beta != 1) {
678
- for (int i = 0; i < M; ++i) {
679
- C[i+j*ldc] *= *beta;
680
- }
681
- }
682
-
683
- for (int l = 0; l < K; ++l) {
684
- if (B[l+j*ldb] != 0) {
685
- temp = *alpha * B[l+j*ldb];
686
- for (int i = 0; i < M; ++i) {
687
- C[i+j*ldc] += A[i+l*lda] * temp;
688
- }
689
- }
690
- }
691
- }
692
-
693
- } else {
694
-
695
- // C = alpha*A**DType*B + beta*C
696
- for (int j = 0; j < N; ++j) {
697
- for (int i = 0; i < M; ++i) {
698
- temp = 0;
699
- for (int l = 0; l < K; ++l) {
700
- temp += A[l+i*lda] * B[l+j*ldb];
701
- }
702
-
703
- if (*beta == 0) {
704
- C[i+j*ldc] = *alpha*temp;
705
- } else {
706
- C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
707
- }
708
- }
709
- }
710
-
711
- }
712
-
713
- } else if (TransA == CblasNoTrans) {
714
-
715
- // C = alpha*A*B**T + beta*C
716
- for (int j = 0; j < N; ++j) {
717
- if (*beta == 0) {
718
- for (int i = 0; i < M; ++i) {
719
- C[i+j*ldc] = 0;
720
- }
721
- } else if (*beta != 1) {
722
- for (int i = 0; i < M; ++i) {
723
- C[i+j*ldc] *= *beta;
724
- }
725
- }
726
-
727
- for (int l = 0; l < K; ++l) {
728
- if (B[j+l*ldb] != 0) {
729
- temp = *alpha * B[j+l*ldb];
730
- for (int i = 0; i < M; ++i) {
731
- C[i+j*ldc] += A[i+l*lda] * temp;
732
- }
733
- }
734
- }
735
-
736
- }
737
-
738
- } else {
739
-
740
- // C = alpha*A**DType*B**T + beta*C
741
- for (int j = 0; j < N; ++j) {
742
- for (int i = 0; i < M; ++i) {
743
- temp = 0;
744
- for (int l = 0; l < K; ++l) {
745
- temp += A[l+i*lda] * B[j+l*ldb];
746
- }
747
-
748
- if (*beta == 0) {
749
- C[i+j*ldc] = *alpha*temp;
750
- } else {
751
- C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
752
- }
753
- }
754
- }
755
-
756
- }
757
-
758
- return;
759
- }
760
-
761
-
762
-
763
- template <typename DType>
764
- inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
765
- const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
766
- {
767
- if (Order == CblasRowMajor) {
768
- if (TransA == CblasNoTrans) {
769
- if (lda < std::max(K,1)) {
770
- rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
771
- }
772
- } else {
773
- if (lda < std::max(M,1)) { // && TransA == CblasTrans
774
- rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
775
- }
776
- }
777
-
778
- if (TransB == CblasNoTrans) {
779
- if (ldb < std::max(N,1)) {
780
- rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
781
- }
782
- } else {
783
- if (ldb < std::max(K,1)) {
784
- rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
785
- }
786
- }
787
-
788
- if (ldc < std::max(N,1)) {
789
- rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
790
- }
791
- } else { // CblasColMajor
792
- if (TransA == CblasNoTrans) {
793
- if (lda < std::max(M,1)) {
794
- rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
795
- }
796
- } else {
797
- if (lda < std::max(K,1)) { // && TransA == CblasTrans
798
- rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
799
- }
800
- }
801
-
802
- if (TransB == CblasNoTrans) {
803
- if (ldb < std::max(K,1)) {
804
- rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
805
- }
806
- } else {
807
- if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
808
- rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
809
- }
810
- }
811
-
812
- if (ldc < std::max(M,1)) {
813
- rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
814
- }
815
- }
816
-
817
- /*
818
- * Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
819
- * we copy C and then subtract to preserve asymmetry.
820
- */
821
-
822
- if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
823
- rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
824
- /*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
825
- syreflect(CblasUpper, N, C, ldc);
826
- */
827
- }
828
-
829
- if (Order == CblasRowMajor) gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
830
- else gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
831
-
832
- }
833
-
834
-
835
- template <>
836
- inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
837
- const float* alpha, const float* A, const int lda, const float* B, const int ldb, const float* beta, float* C, const int ldc) {
838
- cblas_sgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
839
- }
840
-
841
- template <>
842
- inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
843
- const double* alpha, const double* A, const int lda, const double* B, const int ldb, const double* beta, double* C, const int ldc) {
844
- cblas_dgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
845
- }
846
-
847
- template <>
848
- inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
849
- const Complex64* alpha, const Complex64* A, const int lda, const Complex64* B, const int ldb, const Complex64* beta, Complex64* C, const int ldc) {
850
- cblas_cgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
851
- }
852
-
853
- template <>
854
- inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
855
- const Complex128* alpha, const Complex128* A, const int lda, const Complex128* B, const int ldb, const Complex128* beta, Complex128* C, const int ldc) {
856
- cblas_zgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
857
- }
858
-
859
-
860
- /*
861
- * GEneral Matrix-Vector multiplication: based on dgemv.f from Netlib.
862
- *
863
- * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
864
- *
865
- * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
866
- */
867
- template <typename DType>
868
- inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const DType* alpha, const DType* A, const int lda,
869
- const DType* X, const int incX, const DType* beta, DType* Y, const int incY) {
870
- int lenX, lenY, i, j;
871
- int kx, ky, iy, jx, jy, ix;
872
-
873
- typename LongDType<DType>::type temp;
874
-
875
- // Test the input parameters
876
- if (Trans < 111 || Trans > 113) {
877
- rb_raise(rb_eArgError, "GEMV: TransA must be CblasNoTrans, CblasTrans, or CblasConjTrans");
878
- return false;
879
- } else if (lda < std::max(1, N)) {
880
- fprintf(stderr, "GEMV: N = %d; got lda=%d", N, lda);
881
- rb_raise(rb_eArgError, "GEMV: Expected lda >= max(1, N)");
882
- return false;
883
- } else if (incX == 0) {
884
- rb_raise(rb_eArgError, "GEMV: Expected incX != 0\n");
885
- return false;
886
- } else if (incY == 0) {
887
- rb_raise(rb_eArgError, "GEMV: Expected incY != 0\n");
888
- return false;
889
- }
890
-
891
- // Quick return if possible
892
- if (!M or !N or (*alpha == 0 and *beta == 1)) return true;
893
-
894
- if (Trans == CblasNoTrans) {
895
- lenX = N;
896
- lenY = M;
897
- } else {
898
- lenX = M;
899
- lenY = N;
900
- }
901
-
902
- if (incX > 0) kx = 0;
903
- else kx = (lenX - 1) * -incX;
904
-
905
- if (incY > 0) ky = 0;
906
- else ky = (lenY - 1) * -incY;
907
-
908
- // Start the operations. In this version, the elements of A are accessed sequentially with one pass through A.
909
- if (*beta != 1) {
910
- if (incY == 1) {
911
- if (*beta == 0) {
912
- for (i = 0; i < lenY; ++i) {
913
- Y[i] = 0;
914
- }
915
- } else {
916
- for (i = 0; i < lenY; ++i) {
917
- Y[i] *= *beta;
918
- }
919
- }
920
- } else {
921
- iy = ky;
922
- if (*beta == 0) {
923
- for (i = 0; i < lenY; ++i) {
924
- Y[iy] = 0;
925
- iy += incY;
926
- }
927
- } else {
928
- for (i = 0; i < lenY; ++i) {
929
- Y[iy] *= *beta;
930
- iy += incY;
931
- }
932
- }
933
- }
934
- }
935
-
936
- if (*alpha == 0) return false;
937
-
938
- if (Trans == CblasNoTrans) {
939
-
940
- // Form y := alpha*A*x + y.
941
- jx = kx;
942
- if (incY == 1) {
943
- for (j = 0; j < N; ++j) {
944
- if (X[jx] != 0) {
945
- temp = *alpha * X[jx];
946
- for (i = 0; i < M; ++i) {
947
- Y[i] += A[j+i*lda] * temp;
948
- }
949
- }
950
- jx += incX;
951
- }
952
- } else {
953
- for (j = 0; j < N; ++j) {
954
- if (X[jx] != 0) {
955
- temp = *alpha * X[jx];
956
- iy = ky;
957
- for (i = 0; i < M; ++i) {
958
- Y[iy] += A[j+i*lda] * temp;
959
- iy += incY;
960
- }
961
- }
962
- jx += incX;
963
- }
964
- }
965
-
966
- } else { // TODO: Check that indices are correct! They're switched for C.
967
-
968
- // Form y := alpha*A**DType*x + y.
969
- jy = ky;
970
-
971
- if (incX == 1) {
972
- for (j = 0; j < N; ++j) {
973
- temp = 0;
974
- for (i = 0; i < M; ++i) {
975
- temp += A[j+i*lda]*X[j];
976
- }
977
- Y[jy] += *alpha * temp;
978
- jy += incY;
979
- }
980
- } else {
981
- for (j = 0; j < N; ++j) {
982
- temp = 0;
983
- ix = kx;
984
- for (i = 0; i < M; ++i) {
985
- temp += A[j+i*lda] * X[ix];
986
- ix += incX;
987
- }
988
-
989
- Y[jy] += *alpha * temp;
990
- jy += incY;
991
- }
992
- }
993
- }
994
-
995
- return true;
996
- } // end of GEMV
997
-
998
- template <>
999
- inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const float* alpha, const float* A, const int lda,
1000
- const float* X, const int incX, const float* beta, float* Y, const int incY) {
1001
- cblas_sgemv(CblasRowMajor, Trans, M, N, *alpha, A, lda, X, incX, *beta, Y, incY);
1002
- return true;
1003
- }
1004
-
1005
- template <>
1006
- inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const double* alpha, const double* A, const int lda,
1007
- const double* X, const int incX, const double* beta, double* Y, const int incY) {
1008
- cblas_dgemv(CblasRowMajor, Trans, M, N, *alpha, A, lda, X, incX, *beta, Y, incY);
1009
- return true;
1010
- }
1011
-
1012
- template <>
1013
- inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const Complex64* alpha, const Complex64* A, const int lda,
1014
- const Complex64* X, const int incX, const Complex64* beta, Complex64* Y, const int incY) {
1015
- cblas_cgemv(CblasRowMajor, Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
1016
- return true;
1017
- }
1018
-
1019
- template <>
1020
- inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const Complex128* alpha, const Complex128* A, const int lda,
1021
- const Complex128* X, const int incX, const Complex128* beta, Complex128* Y, const int incY) {
1022
- cblas_zgemv(CblasRowMajor, Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
1023
- return true;
1024
- }
1025
-
1026
-
1027
- // Yale: numeric matrix multiply c=a*b
1028
- template <typename DType, typename IType>
1029
- inline void numbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const DType* a, const bool diaga,
1030
- const IType* ib, const IType* jb, const DType* b, const bool diagb, IType* ic, IType* jc, DType* c, const bool diagc) {
1031
- const unsigned int max_lmn = std::max(std::max(m, n), l);
1032
- IType next[max_lmn];
1033
- DType sums[max_lmn];
1034
-
1035
- DType v;
1036
-
1037
- IType head, length, temp, ndnz = 0;
1038
- IType minmn = std::min(m,n);
1039
- IType minlm = std::min(l,m);
1040
-
1041
- for (IType idx = 0; idx < max_lmn; ++idx) { // initialize scratch arrays
1042
- next[idx] = std::numeric_limits<IType>::max();
1043
- sums[idx] = 0;
1044
- }
1045
-
1046
- for (IType i = 0; i < n; ++i) { // walk down the rows
1047
- head = std::numeric_limits<IType>::max()-1; // head gets assigned as whichever column of B's row j we last visited
1048
- length = 0;
1049
-
1050
- for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // walk through entries in each row
1051
- IType j;
1052
-
1053
- if (jj == ia[i+1]) { // if we're in the last entry for this row:
1054
- if (!diaga || i >= minmn) continue;
1055
- j = i; // if it's a new Yale matrix, and last entry, get the diagonal position (j) and entry (ajj)
1056
- v = a[i];
1057
- } else {
1058
- j = ja[jj]; // if it's not the last entry for this row, get the column (j) and entry (ajj)
1059
- v = a[jj];
1060
- }
1061
-
1062
- for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) {
1063
-
1064
- IType k;
1065
-
1066
- if (kk == ib[j+1]) { // Get the column id for that entry
1067
- if (!diagb || j >= minlm) continue;
1068
- k = j;
1069
- sums[k] += v*b[k];
1070
- } else {
1071
- k = jb[kk];
1072
- sums[k] += v*b[kk];
1073
- }
1074
-
1075
- if (next[k] == std::numeric_limits<IType>::max()) {
1076
- next[k] = head;
1077
- head = k;
1078
- ++length;
1079
- }
1080
- } // end of kk loop
1081
- } // end of jj loop
1082
-
1083
- for (IType jj = 0; jj < length; ++jj) {
1084
- if (sums[head] != 0) {
1085
- if (diagc && head == i) {
1086
- c[head] = sums[head];
1087
- } else {
1088
- jc[n+1+ndnz] = head;
1089
- c[n+1+ndnz] = sums[head];
1090
- ++ndnz;
1091
- }
1092
- }
1093
-
1094
- temp = head;
1095
- head = next[head];
1096
-
1097
- next[temp] = std::numeric_limits<IType>::max();
1098
- sums[temp] = 0;
1099
- }
1100
-
1101
- ic[i+1] = n+1+ndnz;
1102
- }
1103
- } /* numbmm_ */
1104
-
1105
-
1106
- /*
1107
- template <typename DType, typename IType>
1108
- inline void new_yale_matrix_multiply(const unsigned int m, const IType* ija, const DType* a, const IType* ijb, const DType* b, YALE_STORAGE* c_storage) {
1109
- unsigned int n = c_storage->shape[0],
1110
- l = c_storage->shape[1];
1111
-
1112
- // Create a working vector of dimension max(m,l,n) and initial value IType::max():
1113
- std::vector<IType> mask(std::max(std::max(m,l),n), std::numeric_limits<IType>::max());
1114
-
1115
- for (IType i = 0; i < n; ++i) { // A.rows.each_index do |i|
1116
-
1117
- IType j, k;
1118
- size_t ndnz;
1119
-
1120
- for (IType jj = ija[i]; jj <= ija[i+1]; ++jj) { // walk through column pointers for row i of A
1121
- j = (jj == ija[i+1]) ? i : ija[jj]; // Get the current column index (handle diagonals last)
1122
-
1123
- if (j >= m) {
1124
- if (j == ija[jj]) rb_raise(rb_eIndexError, "ija array for left-hand matrix contains an out-of-bounds column index %u at position %u", jj, j);
1125
- else break;
1126
- }
1127
-
1128
- for (IType kk = ijb[j]; kk <= ijb[j+1]; ++kk) { // walk through column pointers for row j of B
1129
- if (j >= m) continue; // first of all, does B *have* a row j?
1130
- k = (kk == ijb[j+1]) ? j : ijb[kk]; // Get the current column index (handle diagonals last)
1131
-
1132
- if (k >= l) {
1133
- if (k == ijb[kk]) rb_raise(rb_eIndexError, "ija array for right-hand matrix contains an out-of-bounds column index %u at position %u", kk, k);
1134
- else break;
1135
- }
1136
-
1137
- if (mask[k] == )
1138
- }
1139
-
1140
- }
1141
- }
1142
- }
1143
- */
1144
-
1145
- // Yale: Symbolic matrix multiply c=a*b
1146
- template <typename IType>
1147
- inline size_t symbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const bool diaga,
1148
- const IType* ib, const IType* jb, const bool diagb, IType* ic, const bool diagc) {
1149
- unsigned int max_lmn = std::max(std::max(m,n), l);
1150
- IType mask[max_lmn]; // INDEX in the SMMP paper.
1151
- IType j, k; /* Local variables */
1152
- size_t ndnz = n;
1153
-
1154
- for (IType idx = 0; idx < max_lmn; ++idx)
1155
- mask[idx] = std::numeric_limits<IType>::max();
1156
-
1157
- if (ic) { // Only write to ic if it's supplied; otherwise, we're just counting.
1158
- if (diagc) ic[0] = n+1;
1159
- else ic[0] = 0;
1160
- }
1161
-
1162
- IType minmn = std::min(m,n);
1163
- IType minlm = std::min(l,m);
1164
-
1165
- for (IType i = 0; i < n; ++i) { // MAIN LOOP: through rows
1166
-
1167
- for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // merge row lists, walking through columns in each row
1168
-
1169
- // j <- column index given by JA[jj], or handle diagonal.
1170
- if (jj == ia[i+1]) { // Don't really do it the last time -- just handle diagonals in a new yale matrix.
1171
- if (!diaga || i >= minmn) continue;
1172
- j = i;
1173
- } else j = ja[jj];
1174
-
1175
- for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) { // Now walk through columns K of row J in matrix B.
1176
- if (kk == ib[j+1]) {
1177
- if (!diagb || j >= minlm) continue;
1178
- k = j;
1179
- } else k = jb[kk];
1180
-
1181
- if (mask[k] != i) {
1182
- mask[k] = i;
1183
- ++ndnz;
1184
- }
1185
- }
1186
- }
1187
-
1188
- if (diagc && mask[i] == std::numeric_limits<IType>::max()) --ndnz;
1189
-
1190
- if (ic) ic[i+1] = ndnz;
1191
- }
1192
-
1193
- return ndnz;
1194
- } /* symbmm_ */
1195
-
1196
-
1197
- // In-place quicksort (from Wikipedia) -- called by smmp_sort_columns, below. All functions are inclusive of left, right.
1198
- namespace smmp_sort {
1199
- const size_t THRESHOLD = 4; // switch to insertion sort for 4 elements or fewer
1200
-
1201
- template <typename DType, typename IType>
1202
- void print_array(DType* vals, IType* array, IType left, IType right) {
1203
- for (IType i = left; i <= right; ++i) {
1204
- std::cerr << array[i] << ":" << vals[i] << " ";
1205
- }
1206
- std::cerr << std::endl;
1207
- }
1208
-
1209
- template <typename DType, typename IType>
1210
- IType partition(DType* vals, IType* array, IType left, IType right, IType pivot) {
1211
- IType pivotJ = array[pivot];
1212
- DType pivotV = vals[pivot];
1213
-
1214
- // Swap pivot and right
1215
- array[pivot] = array[right];
1216
- vals[pivot] = vals[right];
1217
- array[right] = pivotJ;
1218
- vals[right] = pivotV;
1219
-
1220
- IType store = left;
1221
- for (IType idx = left; idx < right; ++idx) {
1222
- if (array[idx] <= pivotJ) {
1223
- // Swap i and store
1224
- std::swap(array[idx], array[store]);
1225
- std::swap(vals[idx], vals[store]);
1226
- ++store;
1227
- }
1228
- }
1229
-
1230
- std::swap(array[store], array[right]);
1231
- std::swap(vals[store], vals[right]);
1232
-
1233
- return store;
1234
- }
1235
-
1236
- // Recommended to use the median of left, right, and mid for the pivot.
1237
- template <typename IType>
1238
- IType median(IType a, IType b, IType c) {
1239
- if (a < b) {
1240
- if (b < c) return b; // a b c
1241
- if (a < c) return c; // a c b
1242
- return a; // c a b
1243
-
1244
- } else { // a > b
1245
- if (a < c) return a; // b a c
1246
- if (b < c) return c; // b c a
1247
- return b; // c b a
1248
- }
1249
- }
1250
-
1251
-
1252
- // Insertion sort is more efficient than quicksort for small N
1253
- template <typename DType, typename IType>
1254
- void insertion_sort(DType* vals, IType* array, IType left, IType right) {
1255
- for (IType idx = left; idx <= right; ++idx) {
1256
- IType col_to_insert = array[idx];
1257
- DType val_to_insert = vals[idx];
1258
-
1259
- IType hole_pos = idx;
1260
- for (; hole_pos > left && col_to_insert < array[hole_pos-1]; --hole_pos) {
1261
- array[hole_pos] = array[hole_pos - 1]; // shift the larger column index up
1262
- vals[hole_pos] = vals[hole_pos - 1]; // value goes along with it
1263
- }
1264
-
1265
- array[hole_pos] = col_to_insert;
1266
- vals[hole_pos] = val_to_insert;
1267
- }
1268
- }
1269
-
1270
-
1271
- template <typename DType, typename IType>
1272
- void quicksort(DType* vals, IType* array, IType left, IType right) {
1273
-
1274
- if (left < right) {
1275
- if (right - left < THRESHOLD) {
1276
- insertion_sort(vals, array, left, right);
1277
- } else {
1278
- // choose any pivot such that left < pivot < right
1279
- IType pivot = median(left, right, (IType)(((unsigned long)left + (unsigned long)right) / 2));
1280
- pivot = partition(vals, array, left, right, pivot);
1281
-
1282
- // recursively sort elements smaller than the pivot
1283
- quicksort<DType,IType>(vals, array, left, pivot-1);
1284
-
1285
- // recursively sort elements at least as big as the pivot
1286
- quicksort<DType,IType>(vals, array, pivot+1, right);
1287
- }
1288
- }
1289
- }
1290
-
1291
-
1292
- }; // end of namespace smmp_sort
1293
-
1294
-
1295
- /*
1296
- * For use following symbmm and numbmm. Sorts the matrix entries in each row according to the column index.
1297
- * This utilizes quicksort, which is an in-place unstable sort (since there are no duplicate entries, we don't care
1298
- * about stability).
1299
- *
1300
- * TODO: It might be worthwhile to do a test for free memory, and if available, use an unstable sort that isn't in-place.
1301
- *
1302
- * TODO: It's actually probably possible to write an even faster sort, since symbmm/numbmm are not producing a random
1303
- * ordering. If someone is doing a lot of Yale matrix multiplication, it might benefit them to consider even insertion
1304
- * sort.
1305
- */
1306
- template <typename DType, typename IType>
1307
- inline void smmp_sort_columns(const size_t n, const IType* ia, IType* ja, DType* a) {
1308
- for (size_t i = 0; i < n; ++i) {
1309
- if (ia[i+1] - ia[i] < 2) continue; // no need to sort rows containing only one or two elements.
1310
- else if (ia[i+1] - ia[i] <= smmp_sort::THRESHOLD) {
1311
- smmp_sort::insertion_sort<DType, IType>(a, ja, ia[i], ia[i+1]-1); // faster for small rows
1312
- } else {
1313
- smmp_sort::quicksort<DType, IType>(a, ja, ia[i], ia[i+1]-1); // faster for large rows (and may call insertion_sort as well)
1314
- }
1315
- }
1316
- }
1317
-
1318
-
1319
-
1320
- /*
1321
- * Transposes a generic Yale matrix (old or new). Specify new by setting diaga = true.
1322
- *
1323
- * Based on transp from SMMP (same as symbmm and numbmm).
1324
- *
1325
- * This is not named in the same way as most yale_storage functions because it does not act on a YALE_STORAGE
1326
- * object.
1327
- */
1328
- template <typename DType, typename IType>
1329
- void transpose_yale(const size_t n, const size_t m, const void* ia_, const void* ja_, const void* a_,
1330
- const bool diaga, void* ib_, void* jb_, void* b_, const bool move)
1331
- {
1332
- const IType *ia = reinterpret_cast<const IType*>(ia_),
1333
- *ja = reinterpret_cast<const IType*>(ja_);
1334
- const DType *a = reinterpret_cast<const DType*>(a_);
1335
-
1336
- IType *ib = reinterpret_cast<IType*>(ib_),
1337
- *jb = reinterpret_cast<IType*>(jb_);
1338
- DType *b = reinterpret_cast<DType*>(b_);
1339
-
1340
-
1341
-
1342
- size_t index;
1343
-
1344
- // Clear B
1345
- for (size_t i = 0; i < m+1; ++i) ib[i] = 0;
1346
-
1347
- if (move)
1348
- for (size_t i = 0; i < m+1; ++i) b[i] = 0;
1349
-
1350
- if (diaga) ib[0] = m + 1;
1351
- else ib[0] = 0;
1352
-
1353
- /* count indices for each column */
1354
-
1355
- for (size_t i = 0; i < n; ++i) {
1356
- for (size_t j = ia[i]; j < ia[i+1]; ++j) {
1357
- ++(ib[ja[j]+1]);
1358
- }
1359
- }
1360
-
1361
- for (size_t i = 0; i < m; ++i) {
1362
- ib[i+1] = ib[i] + ib[i+1];
1363
- }
1364
-
1365
- /* now make jb */
1366
-
1367
- for (size_t i = 0; i < n; ++i) {
1368
-
1369
- for (size_t j = ia[i]; j < ia[i+1]; ++j) {
1370
- index = ja[j];
1371
- jb[ib[index]] = i;
1372
-
1373
- if (move)
1374
- b[ib[index]] = a[j];
1375
-
1376
- ++(ib[index]);
1377
- }
1378
- }
1379
-
1380
- /* now fixup ib */
1381
-
1382
- for (size_t i = m; i >= 1; --i) {
1383
- ib[i] = ib[i-1];
1384
- }
1385
-
1386
-
1387
- if (diaga) {
1388
- if (move) {
1389
- size_t j = std::min(n,m);
1390
-
1391
- for (size_t i = 0; i < j; ++i) {
1392
- b[i] = a[i];
1393
- }
1394
- }
1395
- ib[0] = m + 1;
1396
-
1397
- } else {
1398
- ib[0] = 0;
1399
- }
1400
- }
1401
-
1402
-
1403
- /*
1404
- * Templated version of row-order and column-order getrf, derived from ATL_getrfR.c (from ATLAS 3.8.0).
1405
- *
1406
- * 1. Row-major factorization of form
1407
- * A = L * U * P
1408
- * where P is a column-permutation matrix, L is lower triangular (lower
1409
- * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
1410
- * trapazoidal if M < N). This is the recursive Level 3 BLAS version.
1411
- *
1412
- * 2. Column-major factorization of form
1413
- * A = P * L * U
1414
- * where P is a row-permutation matrix, L is lower triangular with unit diagonal
1415
- * elements (lower trapazoidal if M > N), and U is upper triangular (upper
1416
- * trapazoidal if M < N). This is the recursive Level 3 BLAS version.
1417
- *
1418
- * Template argument determines whether 1 or 2 is utilized.
1419
- */
1420
- template <bool RowMajor, typename DType>
1421
- inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int* ipiv) {
1422
- const int MN = std::min(M, N);
1423
- int ierr = 0;
1424
-
1425
- // Symbols used by ATLAS:
1426
- // Row Col Us
1427
- // Nup Nleft N_ul
1428
- // Ndown Nright N_dr
1429
- // We're going to use N_ul, N_dr
1430
-
1431
- DType neg_one = -1, one = 1;
1432
-
1433
- if (MN > 1) {
1434
- int N_ul = MN >> 1;
1435
-
1436
- // FIXME: Figure out how ATLAS #defines NB
1437
- #ifdef NB
1438
- if (N_ul > NB) N_ul = ATL_MulByNB(ATL_DivByNB(N_ul));
1439
- #endif
1440
-
1441
- int N_dr = M - N_ul;
1442
-
1443
- int i = RowMajor ? getrf_nothrow<true,DType>(N_ul, N, A, lda, ipiv) : getrf_nothrow<false,DType>(M, N_ul, A, lda, ipiv);
1444
-
1445
- if (i) if (!ierr) ierr = i;
1446
-
1447
- DType *Ar, *Ac, *An;
1448
- if (RowMajor) {
1449
- Ar = &(A[N_ul * lda]),
1450
- Ac = &(A[N_ul]);
1451
- An = &(Ar[N_ul]);
1452
-
1453
- nm::math::laswp<DType>(N_dr, Ar, lda, 0, N_ul, ipiv, 1);
1454
-
1455
- nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, N_dr, N_ul, one, A, lda, Ar, lda);
1456
- nm::math::gemm<DType>(CblasRowMajor, CblasNoTrans, CblasNoTrans, N_dr, N-N_ul, N_ul, &neg_one, Ar, lda, Ac, lda, &one, An, lda);
1457
-
1458
- i = getrf_nothrow<true,DType>(N_dr, N-N_ul, An, lda, ipiv+N_ul);
1459
- } else {
1460
- Ar = NULL;
1461
- Ac = &(A[N_ul * lda]);
1462
- An = &(Ac[N_ul]);
1463
-
1464
- nm::math::laswp<DType>(N_dr, Ac, lda, 0, N_ul, ipiv, 1);
1465
-
1466
- nm::math::trsm<DType>(CblasColMajor, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N_ul, N_dr, one, A, lda, Ac, lda);
1467
- nm::math::gemm<DType>(CblasColMajor, CblasNoTrans, CblasNoTrans, M-N_ul, N_dr, N_ul, &neg_one, An, lda, Ac, lda, &one, An, lda);
1468
-
1469
- i = getrf_nothrow<false,DType>(M-N_ul, N_dr, An, lda, ipiv+N_ul);
1470
- }
1471
-
1472
- if (i) if (!ierr) ierr = N_ul + i;
1473
-
1474
- for (i = N_ul; i != MN; i++) {
1475
- ipiv[i] += N_ul;
1476
- }
1477
-
1478
- nm::math::laswp<DType>(N_ul, A, lda, N_ul, MN, ipiv, 1); /* apply pivots */
1479
-
1480
- } else if (MN == 1) { // there's another case for the colmajor version, but i don't know that it's that critical. Calls ATLAS LU2, who knows what that does.
1481
-
1482
- int i = *ipiv = nm::math::lapack::idamax<DType>(N, A, 1); // cblas_iamax(N, A, 1);
1483
-
1484
- DType tmp = A[i];
1485
- if (tmp != 0) {
1486
-
1487
- nm::math::lapack::scal<DType>((RowMajor ? N : M), nm::math::numeric_inverse(tmp), A, 1);
1488
- A[i] = *A;
1489
- *A = tmp;
1490
-
1491
- } else ierr = 1;
1492
-
1493
- }
1494
- return(ierr);
1495
- }
1496
-
1497
- /*
1498
- * Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF.
1499
- *
1500
- * From ATLAS 3.8.0.
1501
- */
1502
- template <typename DType>
1503
- int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A,
1504
- const int lda, const int* ipiv, DType* B, const int ldb)
1505
- {
1506
- // enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
1507
-
1508
- if (!N || !NRHS) return 0;
1509
-
1510
- const DType ONE = 1;
1511
-
1512
- if (Order == CblasColMajor) {
1513
- if (Trans == CblasNoTrans) {
1514
- nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
1515
- nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
1516
- nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1517
- } else {
1518
- nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1519
- nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
1520
- nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
1521
- }
1522
- } else {
1523
- if (Trans == CblasNoTrans) {
1524
- nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1525
- nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
1526
- nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
1527
- } else {
1528
- nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
1529
- nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
1530
- nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1531
- }
1532
- }
1533
- return 0;
1534
- }
1535
-
1536
-
1537
- /*
1538
- * Solves a system of linear equations A*X = B with a symmetric positive definite matrix A using the Cholesky factorization computed by POTRF.
1539
- *
1540
- * From ATLAS 3.8.0.
1541
- */
1542
- template <typename DType, bool is_complex>
1543
- int potrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, const DType* A,
1544
- const int lda, DType* B, const int ldb)
1545
- {
1546
- // enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
1547
-
1548
- CBLAS_TRANSPOSE MyTrans = is_complex ? CblasConjTrans : CblasTrans;
1549
-
1550
- if (!N || !NRHS) return 0;
1551
-
1552
- const DType ONE = 1;
1553
-
1554
- if (Order == CblasColMajor) {
1555
- if (Uplo == CblasUpper) {
1556
- nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1557
- nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1558
- } else {
1559
- nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1560
- nm::math::trsm<DType>(Order, CblasLeft, CblasLower, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1561
- }
1562
- } else {
1563
- // There's some kind of scaling operation that normally happens here in ATLAS. Not sure what it does, so we'll only
1564
- // worry if something breaks. It probably has to do with their non-templated code and doesn't apply to us.
1565
-
1566
- if (Uplo == CblasUpper) {
1567
- nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1568
- nm::math::trsm<DType>(Order, CblasRight, CblasUpper, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1569
- } else {
1570
- nm::math::trsm<DType>(Order, CblasRight, CblasLower, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1571
- nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1572
- }
1573
- }
1574
- return 0;
1575
- }
1576
-
1577
-
1578
-
1579
- /*
1580
- * From ATLAS 3.8.0:
1581
- *
1582
- * Computes one of two LU factorizations based on the setting of the Order
1583
- * parameter, as follows:
1584
- * ----------------------------------------------------------------------------
1585
- * Order == CblasColMajor
1586
- * Column-major factorization of form
1587
- * A = P * L * U
1588
- * where P is a row-permutation matrix, L is lower triangular with unit
1589
- * diagonal elements (lower trapazoidal if M > N), and U is upper triangular
1590
- * (upper trapazoidal if M < N).
1591
- *
1592
- * ----------------------------------------------------------------------------
1593
- * Order == CblasRowMajor
1594
- * Row-major factorization of form
1595
- * A = P * L * U
1596
- * where P is a column-permutation matrix, L is lower triangular (lower
1597
- * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
1598
- * trapazoidal if M < N).
1599
- *
1600
- * ============================================================================
1601
- * Let IERR be the return value of the function:
1602
- * If IERR == 0, successful exit.
1603
- * If (IERR < 0) the -IERR argument had an illegal value
1604
- * If (IERR > 0 && Order == CblasColMajor)
1605
- * U(i-1,i-1) is exactly zero. The factorization has been completed,
1606
- * but the factor U is exactly singular, and division by zero will
1607
- * occur if it is used to solve a system of equations.
1608
- * If (IERR > 0 && Order == CblasRowMajor)
1609
- * L(i-1,i-1) is exactly zero. The factorization has been completed,
1610
- * but the factor L is exactly singular, and division by zero will
1611
- * occur if it is used to solve a system of equations.
1612
- */
1613
- template <typename DType>
1614
- inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, DType* A, int lda, int* ipiv) {
1615
- if (Order == CblasRowMajor) {
1616
- if (lda < std::max(1,N)) {
1617
- rb_raise(rb_eArgError, "GETRF: lda must be >= MAX(N,1): lda=%d N=%d", lda, N);
1618
- return -6;
1619
- }
1620
-
1621
- return getrf_nothrow<true,DType>(M, N, A, lda, ipiv);
1622
- } else {
1623
- if (lda < std::max(1,M)) {
1624
- rb_raise(rb_eArgError, "GETRF: lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
1625
- return -6;
1626
- }
1627
-
1628
- return getrf_nothrow<false,DType>(M, N, A, lda, ipiv);
1629
- //rb_raise(rb_eNotImpError, "column major getrf not implemented");
1630
- }
1631
- }
1632
-
1633
-
1634
- /*
1635
- * From ATLAS 3.8.0:
1636
- *
1637
- * Computes one of two LU factorizations based on the setting of the Order
1638
- * parameter, as follows:
1639
- * ----------------------------------------------------------------------------
1640
- * Order == CblasColMajor
1641
- * Column-major factorization of form
1642
- * A = P * L * U
1643
- * where P is a row-permutation matrix, L is lower triangular with unit
1644
- * diagonal elements (lower trapazoidal if M > N), and U is upper triangular
1645
- * (upper trapazoidal if M < N).
1646
- *
1647
- * ----------------------------------------------------------------------------
1648
- * Order == CblasRowMajor
1649
- * Row-major factorization of form
1650
- * A = P * L * U
1651
- * where P is a column-permutation matrix, L is lower triangular (lower
1652
- * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
1653
- * trapazoidal if M < N).
1654
- *
1655
- * ============================================================================
1656
- * Let IERR be the return value of the function:
1657
- * If IERR == 0, successful exit.
1658
- * If (IERR < 0) the -IERR argument had an illegal value
1659
- * If (IERR > 0 && Order == CblasColMajor)
1660
- * U(i-1,i-1) is exactly zero. The factorization has been completed,
1661
- * but the factor U is exactly singular, and division by zero will
1662
- * occur if it is used to solve a system of equations.
1663
- * If (IERR > 0 && Order == CblasRowMajor)
1664
- * L(i-1,i-1) is exactly zero. The factorization has been completed,
1665
- * but the factor L is exactly singular, and division by zero will
1666
- * occur if it is used to solve a system of equations.
1667
- */
1668
- template <typename DType>
1669
- inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
1670
- #ifdef HAVE_CLAPACK_H
1671
- rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
1672
- #else
1673
- rb_raise(rb_eNotImpError, "only LAPACK version implemented thus far");
1674
- #endif
1675
- return 0;
1676
- }
1677
-
1678
- #ifdef HAVE_CLAPACK_H
1679
- template <>
1680
- inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
1681
- return clapack_spotrf(order, uplo, N, A, lda);
1682
- }
1683
-
1684
- template <>
1685
- inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
1686
- return clapack_dpotrf(order, uplo, N, A, lda);
1687
- }
1688
-
1689
- template <>
1690
- inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
1691
- return clapack_cpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
1692
- }
1693
-
1694
- template <>
1695
- inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
1696
- return clapack_zpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
1697
- }
1698
- #endif
1699
-
1700
-
1701
- // This is the old BLAS version of this function. ATLAS has an optimized version, but
1702
- // it's going to be tough to translate.
1703
- template <typename DType>
1704
- static void swap(const int N, DType* X, const int incX, DType* Y, const int incY) {
1705
- if (N > 0) {
1706
- int ix = 0, iy = 0;
1707
- for (int i = 0; i < N; ++i) {
1708
- DType temp = X[i];
1709
- X[i] = Y[i];
1710
- Y[i] = temp;
1711
-
1712
- ix += incX;
1713
- iy += incY;
1714
- }
1715
- }
1716
- }
1717
-
1718
-
1719
- // Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
1720
- //
1721
- // From ATLAS 3.8.0.
1722
- template <typename DType>
1723
- static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
1724
-
1725
- for (int i = 0; i != M; ++i) {
1726
- for (int j = i+1; j < N; ++j) {
1727
- C[j] = U[j];
1728
- U[j] = 0;
1729
- }
1730
-
1731
- C += ldc;
1732
- U += ldu;
1733
- }
1734
- }
1735
-
1736
-
1737
- /*
1738
- * Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
1739
- * functions. This is probably really complicated.
1740
- *
1741
- * Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
1742
- *
1743
- */
1744
-
1745
- /*
1746
-
1747
- template <bool RowMajor, bool Upper, typename DType>
1748
- static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
1749
-
1750
- if (RowMajor) {
1751
- DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
1752
- DType tmp;
1753
- if (Upper) {
1754
- DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
1755
- A12 = pA1[2], A13 = pA1[3],
1756
- A23 = pA2[3];
1757
-
1758
- if (Diag == CblasNonUnit) {
1759
- pA0->inverse();
1760
- (pA1+1)->inverse();
1761
- (pA2+2)->inverse();
1762
- (pA3+3)->inverse();
1763
-
1764
- pA0[1] = -A01 * pA1[1] * pA0[0];
1765
- pA1[2] = -A12 * pA2[2] * pA1[1];
1766
- pA2[3] = -A23 * pA3[3] * pA2[2];
1767
-
1768
- pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
1769
- pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
1770
-
1771
- pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
1772
-
1773
- } else {
1774
-
1775
- pA0[1] = -A01;
1776
- pA1[2] = -A12;
1777
- pA2[3] = -A23;
1778
-
1779
- pA0[2] = -(A01 * pA1[2] + A02);
1780
- pA1[3] = -(A12 * pA2[3] + A13);
1781
-
1782
- pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
1783
- }
1784
-
1785
- } else { // Lower
1786
- DType A10 = pA1[0],
1787
- A20 = pA2[0], A21 = pA2[1],
1788
- A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
1789
- DType *B10 = pA1,
1790
- *B20 = pA2,
1791
- *B30 = pA3,
1792
- *B21 = pA2+1,
1793
- *B31 = pA3+1,
1794
- *B32 = pA3+2;
1795
-
1796
-
1797
- if (Diag == CblasNonUnit) {
1798
- pA0->inverse();
1799
- (pA1+1)->inverse();
1800
- (pA2+2)->inverse();
1801
- (pA3+3)->inverse();
1802
-
1803
- *B10 = -A10 * pA0[0] * pA1[1];
1804
- *B21 = -A21 * pA1[1] * pA2[2];
1805
- *B32 = -A32 * pA2[2] * pA3[3];
1806
- *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
1807
- *B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
1808
- *B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
1809
- } else {
1810
- *B10 = -A10;
1811
- *B21 = -A21;
1812
- *B32 = -A32;
1813
- *B20 = -(A20 + A21 * (*B10));
1814
- *B31 = -(A31 + A32 * (*B21));
1815
- *B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
1816
- }
1817
- }
1818
-
1819
- } else {
1820
- rb_raise(rb_eNotImpError, "only row-major implemented at this time");
1821
- }
1822
-
1823
- return 0;
1824
-
1825
- }
1826
-
1827
-
1828
- template <bool RowMajor, bool Upper, typename DType>
1829
- static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
1830
-
1831
- if (RowMajor) {
1832
-
1833
- DType tmp;
1834
-
1835
- if (Upper) {
1836
- DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
1837
- A12 = pA1[2], A13 = pA1[3];
1838
-
1839
- DType *B01 = pA0 + 1,
1840
- *B02 = pA0 + 2,
1841
- *B12 = pA1 + 2;
1842
-
1843
- if (Diag == CblasNonUnit) {
1844
- pA0->inverse();
1845
- (pA1+1)->inverse();
1846
- (pA2+2)->inverse();
1847
-
1848
- *B01 = -A01 * pA1[1] * pA0[0];
1849
- *B12 = -A12 * pA2[2] * pA1[1];
1850
- *B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
1851
- } else {
1852
- *B01 = -A01;
1853
- *B12 = -A12;
1854
- *B02 = -(A01 * (*B12) + A02);
1855
- }
1856
-
1857
- } else { // Lower
1858
- DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
1859
- DType A10=pA1[0],
1860
- A20=pA2[0], A21=pA2[1];
1861
-
1862
- DType *B10 = pA1,
1863
- *B20 = pA2;
1864
- *B21 = pA2+1;
1865
-
1866
- if (Diag == CblasNonUnit) {
1867
- pA0->inverse();
1868
- (pA1+1)->inverse();
1869
- (pA2+2)->inverse();
1870
- *B10 = -A10 * pA0[0] * pA1[1];
1871
- *B21 = -A21 * pA1[1] * pA2[2];
1872
- *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
1873
- } else {
1874
- *B10 = -A10;
1875
- *B21 = -A21;
1876
- *B20 = -(A20 + A21 * (*B10));
1877
- }
1878
- }
1879
-
1880
-
1881
- } else {
1882
- rb_raise(rb_eNotImpError, "only row-major implemented at this time");
1883
- }
1884
-
1885
- return 0;
1886
-
1887
- }
1888
-
1889
- template <bool RowMajor, bool Upper, bool Real, typename DType>
1890
- static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
1891
- DType *Age, *Atr;
1892
- DType tmp;
1893
- int Nleft, Nright;
1894
-
1895
- int ierr = 0;
1896
-
1897
- static const DType ONE = 1;
1898
- static const DType MONE -1;
1899
- static const DType NONE = -1;
1900
-
1901
- if (RowMajor) {
1902
-
1903
- // FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
1904
- if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
1905
- Nleft = N >> 1;
1906
- #ifdef NB
1907
- if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
1908
- #endif
1909
-
1910
- Nright = N - Nleft;
1911
-
1912
- if (Upper) {
1913
- Age = A + Nleft;
1914
- Atr = A + (Nleft * (lda+1));
1915
-
1916
- nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
1917
- Nleft, Nright, ONE, Atr, lda, Age, lda);
1918
-
1919
- nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
1920
- Nleft, Nright, MONE, A, lda, Age, lda);
1921
-
1922
- } else { // Lower
1923
- Age = A + ((Nleft*lda));
1924
- Atr = A + (Nleft * (lda+1));
1925
-
1926
- nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
1927
- Nright, Nleft, ONE, A, lda, Age, lda);
1928
- nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
1929
- Nright, Nleft, MONE, Atr, lda, Age, lda);
1930
- }
1931
-
1932
- ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
1933
- if (ierr) return ierr;
1934
-
1935
- ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
1936
- if (ierr) return ierr + Nleft;
1937
-
1938
- } else {
1939
- if (Real) {
1940
- if (N == 4) {
1941
- return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
1942
- } else if (N == 3) {
1943
- return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
1944
- } else if (N == 2) {
1945
- if (Diag == CblasNonUnit) {
1946
- A->inverse();
1947
- (A+(lda+1))->inverse();
1948
-
1949
- if (Upper) {
1950
- *(A+1) *= *A; // TRI_MUL
1951
- *(A+1) *= *(A+lda+1); // TRI_MUL
1952
- } else {
1953
- *(A+lda) *= *A; // TRI_MUL
1954
- *(A+lda) *= *(A+lda+1); // TRI_MUL
1955
- }
1956
- }
1957
-
1958
- if (Upper) *(A+1) = -*(A+1); // TRI_NEG
1959
- else *(A+lda) = -*(A+lda); // TRI_NEG
1960
- } else if (Diag == CblasNonUnit) A->inverse();
1961
- } else { // not real
1962
- if (Diag == CblasNonUnit) A->inverse();
1963
- }
1964
- }
1965
-
1966
- } else {
1967
- rb_raise(rb_eNotImpError, "only row-major implemented at this time");
1968
- }
1969
-
1970
- return ierr;
1971
- }
1972
-
1973
-
1974
- template <bool RowMajor, bool Real, typename DType>
1975
- int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
1976
-
1977
- if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
1978
-
1979
- int jb, nb, I, ndown, iret;
1980
-
1981
- const DType ONE = 1, NONE = -1;
1982
-
1983
- int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
1984
- if (!iret && N > 1) {
1985
- jb = lwrk / N;
1986
- if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
1987
- else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
1988
- else nb = jb;
1989
- if (!nb) return -6; // need at least 1 row of workspace
1990
-
1991
- // only first iteration will have partial block, unroll it
1992
-
1993
- jb = N - (N/nb) * nb;
1994
- if (!jb) jb = nb;
1995
- I = N - jb;
1996
- A += lda * I;
1997
- trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
1998
- nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
1999
- jb, N, ONE, wrk, jb, A, lda);
2000
-
2001
- if (I) {
2002
- do {
2003
- I -= nb;
2004
- A -= nb * lda;
2005
- ndown = N-I;
2006
- trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
2007
- nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
2008
- nb, N, ONE, wrk, ndown, A, lda);
2009
- } while (I);
2010
- }
2011
-
2012
- // Apply row interchanges
2013
-
2014
- for (I = N - 2; I >= 0; --I) {
2015
- jb = ipiv[I];
2016
- if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
2017
- }
2018
- }
2019
-
2020
- return iret;
2021
- }
2022
- */
2023
-
2024
-
2025
- // TODO: Test this to see if it works properly on complex. ATLAS has a separate algorithm for complex, which looks like
2026
- // TODO: it may actually be the same one.
2027
- //
2028
- // This function is called ATL_rot in ATLAS 3.8.4.
2029
- template <typename DType>
2030
- inline void rot_helper(const int N, DType* X, const int incX, DType* Y, const int incY, const DType c, const DType s) {
2031
- if (c != 1 || s != 0) {
2032
- if (incX == 1 && incY == 1) {
2033
- for (int i = 0; i != N; ++i) {
2034
- DType tmp = X[i] * c + Y[i] * s;
2035
- Y[i] = Y[i] * c - X[i] * s;
2036
- X[i] = tmp;
2037
- }
2038
- } else {
2039
- for (int i = N; i > 0; --i, Y += incY, X += incX) {
2040
- DType tmp = *X * c + *Y * s;
2041
- *Y = *Y * c - *X * s;
2042
- *X = tmp;
2043
- }
2044
- }
2045
- }
2046
- }
2047
-
2048
-
2049
- /* Givens plane rotation. From ATLAS 3.8.4. */
2050
- // FIXME: Need a specialized algorithm for Rationals. BLAS' algorithm simply will not work for most values due to the
2051
- // FIXME: sqrt.
2052
- template <typename DType>
2053
- inline void rotg(DType* a, DType* b, DType* c, DType* s) {
2054
- DType aa = std::abs(*a), ab = std::abs(*b);
2055
- DType roe = aa > ab ? *a : *b;
2056
- DType scal = aa + ab;
2057
-
2058
- if (scal == 0) {
2059
- *c = 1;
2060
- *s = *a = *b = 0;
2061
- } else {
2062
- DType t0 = aa / scal, t1 = ab / scal;
2063
- DType r = scal * std::sqrt(t0 * t0 + t1 * t1);
2064
- if (roe < 0) r = -r;
2065
- *c = *a / r;
2066
- *s = *b / r;
2067
- DType z = (*c != 0) ? (1 / *c) : DType(1);
2068
- *a = r;
2069
- *b = z;
2070
- }
2071
- }
2072
-
2073
- template <>
2074
- inline void rotg(float* a, float* b, float* c, float* s) {
2075
- cblas_srotg(a, b, c, s);
2076
- }
2077
-
2078
- template <>
2079
- inline void rotg(double* a, double* b, double* c, double* s) {
2080
- cblas_drotg(a, b, c, s);
2081
- }
2082
-
2083
- template <>
2084
- inline void rotg(Complex64* a, Complex64* b, Complex64* c, Complex64* s) {
2085
- cblas_crotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
2086
- }
2087
-
2088
- template <>
2089
- inline void rotg(Complex128* a, Complex128* b, Complex128* c, Complex128* s) {
2090
- cblas_zrotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
2091
- }
2092
-
2093
- template <typename DType>
2094
- inline void cblas_rotg(void* a, void* b, void* c, void* s) {
2095
- rotg<DType>(reinterpret_cast<DType*>(a), reinterpret_cast<DType*>(b), reinterpret_cast<DType*>(c), reinterpret_cast<DType*>(s));
2096
- }
2097
-
2098
-
2099
- /* Applies a plane rotation. From ATLAS 3.8.4. */
2100
- template <typename DType, typename CSDType>
2101
- inline void rot(const int N, DType* X, const int incX, DType* Y, const int incY, const CSDType c, const CSDType s) {
2102
- DType *x = X, *y = Y;
2103
- int incx = incX, incy = incY;
2104
-
2105
- if (N > 0) {
2106
- if (incX < 0) {
2107
- if (incY < 0) { incx = -incx; incy = -incy; }
2108
- else x += -incX * (N-1);
2109
- } else if (incY < 0) {
2110
- incy = -incy;
2111
- incx = -incx;
2112
- x += (N-1) * incX;
2113
- }
2114
- rot_helper<DType>(N, x, incx, y, incy, c, s);
2115
- }
2116
- }
2117
-
2118
- template <>
2119
- inline void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) {
2120
- cblas_srot(N, X, incX, Y, incY, (float)c, (float)s);
2121
- }
2122
-
2123
- template <>
2124
- inline void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) {
2125
- cblas_drot(N, X, incX, Y, incY, c, s);
2126
- }
2127
-
2128
- template <>
2129
- inline void rot(const int N, Complex64* X, const int incX, Complex64* Y, const int incY, const float c, const float s) {
2130
- cblas_csrot(N, X, incX, Y, incY, c, s);
2131
- }
2132
-
2133
- template <>
2134
- inline void rot(const int N, Complex128* X, const int incX, Complex128* Y, const int incY, const double c, const double s) {
2135
- cblas_zdrot(N, X, incX, Y, incY, c, s);
2136
- }
2137
-
2138
-
2139
- template <typename DType, typename CSDType>
2140
- inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int incY, const void* c, const void* s) {
2141
- rot<DType,CSDType>(N, reinterpret_cast<DType*>(X), incX, reinterpret_cast<DType*>(Y), incY,
2142
- *reinterpret_cast<const CSDType*>(c), *reinterpret_cast<const CSDType*>(s));
2143
- }
2144
-
2145
- /*
2146
- * Level 1 BLAS routine which returns the 2-norm of an n-vector x.
2147
- #
2148
- * Based on input types, these are the valid return types:
2149
- * int -> int
2150
- * float -> float or double
2151
- * double -> double
2152
- * complex64 -> float or double
2153
- * complex128 -> double
2154
- * rational -> rational
2155
- */
2156
- template <typename ReturnDType, typename DType>
2157
- ReturnDType nrm2(const int N, const DType* X, const int incX) {
2158
- const DType ONE = 1, ZERO = 0;
2159
- typename LongDType<DType>::type scale = 0, ssq = 1, absxi, temp;
2160
-
2161
-
2162
- if ((N < 1) || (incX < 1)) return ZERO;
2163
- else if (N == 1) return std::abs(X[0]);
2164
-
2165
- for (int i = 0; i < N; ++i) {
2166
- absxi = std::abs(X[i*incX]);
2167
- if (scale < absxi) {
2168
- temp = scale / absxi;
2169
- scale = absxi;
2170
- ssq = ONE + ssq * (temp * temp);
2171
- } else {
2172
- temp = absxi / scale;
2173
- ssq += temp * temp;
2174
- }
2175
- }
2176
-
2177
- return scale * std::sqrt( ssq );
2178
- }
2179
-
2180
-
2181
- #ifdef HAVE_CBLAS_H
2182
- template <>
2183
- inline float nrm2(const int N, const float* X, const int incX) {
2184
- return cblas_snrm2(N, X, incX);
2185
- }
2186
-
2187
- template <>
2188
- inline double nrm2(const int N, const double* X, const int incX) {
2189
- return cblas_dnrm2(N, X, incX);
2190
- }
2191
-
2192
- template <>
2193
- inline float nrm2(const int N, const Complex64* X, const int incX) {
2194
- return cblas_scnrm2(N, X, incX);
2195
- }
2196
-
2197
- template <>
2198
- inline double nrm2(const int N, const Complex128* X, const int incX) {
2199
- return cblas_dznrm2(N, X, incX);
2200
- }
2201
- #else
2202
- template <typename FloatDType>
2203
- static inline void nrm2_complex_helper(const FloatDType& xr, const FloatDType& xi, double& scale, double& ssq) {
2204
- double absx = std::abs(xr);
2205
- if (scale < absx) {
2206
- double temp = scale / absx;
2207
- scale = absx;
2208
- ssq = 1.0 + ssq * (temp * temp);
2209
- } else {
2210
- double temp = absx / scale;
2211
- ssq += temp * temp;
2212
- }
2213
-
2214
- absx = std::abs(xi);
2215
- if (scale < absx) {
2216
- double temp = scale / absx;
2217
- scale = absx;
2218
- ssq = 1.0 + ssq * (temp * temp);
2219
- } else {
2220
- double temp = absx / scale;
2221
- ssq += temp * temp;
2222
- }
2223
- }
2224
-
2225
- template <>
2226
- float nrm2(const int N, const Complex64* X, const int incX) {
2227
- double scale = 0, ssq = 1, temp;
2228
-
2229
- if ((N < 1) || (incX < 1)) return 0.0;
2230
-
2231
- for (int i = 0; i < N; ++i) {
2232
- nrm2_complex_helper<float>(X[i*incX].r, X[i*incX].i, scale, temp);
2233
- }
2234
-
2235
- return scale * std::sqrt( ssq );
2236
- }
2237
-
2238
- template <>
2239
- double nrm2(const int N, const Complex128* X, const int incX) {
2240
- double scale = 0, ssq = 1, temp;
2241
-
2242
- if ((N < 1) || (incX < 1)) return 0.0;
2243
-
2244
- for (int i = 0; i < N; ++i) {
2245
- nrm2_complex_helper<double>(X[i*incX].r, X[i*incX].i, scale, temp);
2246
- }
2247
-
2248
- return scale * std::sqrt( ssq );
2249
- }
2250
- #endif
2251
-
2252
- template <typename ReturnDType, typename DType>
2253
- inline void cblas_nrm2(const int N, const void* X, const int incX, void* result) {
2254
- *reinterpret_cast<ReturnDType*>( result ) = nrm2<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
2255
- }
2256
-
2257
- /*
2258
- * Level 1 BLAS routine which sums the absolute values of a vector's contents. If the vector consists of complex values,
2259
- * the routine sums the absolute values of the real and imaginary components as well.
2260
- *
2261
- * So, based on input types, these are the valid return types:
2262
- * int -> int
2263
- * float -> float or double
2264
- * double -> double
2265
- * complex64 -> float or double
2266
- * complex128 -> double
2267
- * rational -> rational
2268
- */
2269
- template <typename ReturnDType, typename DType>
2270
- inline ReturnDType asum(const int N, const DType* X, const int incX) {
2271
- ReturnDType sum = 0;
2272
- if ((N > 0) && (incX > 0)) {
2273
- for (int i = 0; i < N; ++i) {
2274
- sum += std::abs(X[i*incX]);
2275
- }
2276
- }
2277
- return sum;
2278
- }
2279
-
2280
-
2281
- #ifdef HAVE_CBLAS_H
2282
- template <>
2283
- inline float asum(const int N, const float* X, const int incX) {
2284
- return cblas_sasum(N, X, incX);
2285
- }
2286
-
2287
- template <>
2288
- inline double asum(const int N, const double* X, const int incX) {
2289
- return cblas_dasum(N, X, incX);
2290
- }
2291
-
2292
- template <>
2293
- inline float asum(const int N, const Complex64* X, const int incX) {
2294
- return cblas_scasum(N, X, incX);
2295
- }
2296
-
2297
- template <>
2298
- inline double asum(const int N, const Complex128* X, const int incX) {
2299
- return cblas_dzasum(N, X, incX);
2300
- }
2301
- #else
2302
- template <>
2303
- inline float asum(const int N, const Complex64* X, const int incX) {
2304
- float sum = 0;
2305
- if ((N > 0) && (incX > 0)) {
2306
- for (int i = 0; i < N; ++i) {
2307
- sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
2308
- }
2309
- }
2310
- return sum;
2311
- }
2312
-
2313
- template <>
2314
- inline double asum(const int N, const Complex128* X, const int incX) {
2315
- double sum = 0;
2316
- if ((N > 0) && (incX > 0)) {
2317
- for (int i = 0; i < N; ++i) {
2318
- sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
2319
- }
2320
- }
2321
- return sum;
2322
- }
2323
- #endif
2324
-
2325
-
2326
- template <typename ReturnDType, typename DType>
2327
- inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
2328
- *reinterpret_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
2329
- }
2330
-
2331
-
2332
- template <bool is_complex, typename DType>
2333
- inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
2334
-
2335
- int Nleft, Nright;
2336
- const DType ONE = 1;
2337
- DType *G, *U0 = A, *U1;
2338
-
2339
- if (N > 1) {
2340
- Nleft = N >> 1;
2341
- #ifdef NB
2342
- if (Nleft > NB) Nleft = ATL_MulByNB(ATL_DivByNB(Nleft));
2343
- #endif
2344
-
2345
- Nright = N - Nleft;
2346
-
2347
- // FIXME: There's a simpler way to write this next block, but I'm way too tired to work it out right now.
2348
- if (uplo == CblasUpper) {
2349
- if (order == CblasRowMajor) {
2350
- G = A + Nleft;
2351
- U1 = G + Nleft * lda;
2352
- } else {
2353
- G = A + Nleft * lda;
2354
- U1 = G + Nleft;
2355
- }
2356
- } else {
2357
- if (order == CblasRowMajor) {
2358
- G = A + Nleft * lda;
2359
- U1 = G + Nleft;
2360
- } else {
2361
- G = A + Nleft;
2362
- U1 = G + Nleft * lda;
2363
- }
2364
- }
2365
-
2366
- lauum<is_complex, DType>(order, uplo, Nleft, U0, lda);
2367
-
2368
- if (is_complex) {
2369
-
2370
- nm::math::herk<DType>(order, uplo,
2371
- uplo == CblasLower ? CblasConjTrans : CblasNoTrans,
2372
- Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
2373
-
2374
- nm::math::trmm<DType>(order, CblasLeft, uplo, CblasConjTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
2375
- } else {
2376
- nm::math::syrk<DType>(order, uplo,
2377
- uplo == CblasLower ? CblasTrans : CblasNoTrans,
2378
- Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
2379
-
2380
- nm::math::trmm<DType>(order, CblasLeft, uplo, CblasTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
2381
- }
2382
- lauum<is_complex, DType>(order, uplo, Nright, U1, lda);
2383
-
2384
- } else {
2385
- *A = *A * *A;
2386
- }
2387
- }
2388
-
2389
-
2390
- #ifdef HAVE_CLAPACK_H
2391
- template <bool is_complex>
2392
- inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
2393
- clapack_slauum(order, uplo, N, A, lda);
2394
- }
2395
-
2396
- template <bool is_complex>
2397
- inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
2398
- clapack_dlauum(order, uplo, N, A, lda);
2399
- }
2400
-
2401
- template <bool is_complex>
2402
- inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
2403
- clapack_clauum(order, uplo, N, A, lda);
2404
- }
2405
-
2406
- template <bool is_complex>
2407
- inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
2408
- clapack_zlauum(order, uplo, N, A, lda);
2409
- }
2410
- #endif
2411
-
2412
-
2413
- /*
2414
- * Function signature conversion for calling LAPACK's lauum functions as directly as possible.
2415
- *
2416
- * For documentation: http://www.netlib.org/lapack/double/dlauum.f
2417
- *
2418
- * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2419
- */
2420
- template <bool is_complex, typename DType>
2421
- inline int clapack_lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
2422
- if (n < 0) rb_raise(rb_eArgError, "n cannot be less than zero, is set to %d", n);
2423
- if (lda < n || lda < 1) rb_raise(rb_eArgError, "lda must be >= max(n,1); lda=%d, n=%d\n", lda, n);
2424
-
2425
- if (uplo == CblasUpper) lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
2426
- else lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
2427
-
2428
- return 0;
2429
- }
2430
-
2431
-
2432
-
2433
-
2434
- /*
2435
- * Macro for declaring LAPACK specializations of the getrf function.
2436
- *
2437
- * type is the DType; call is the specific function to call; cast_as is what the DType* should be
2438
- * cast to in order to pass it to LAPACK.
2439
- */
2440
- #define LAPACK_GETRF(type, call, cast_as) \
2441
- template <> \
2442
- inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, type * A, const int lda, int* ipiv) { \
2443
- int info = call(Order, M, N, reinterpret_cast<cast_as *>(A), lda, ipiv); \
2444
- if (!info) return info; \
2445
- else { \
2446
- rb_raise(rb_eArgError, "getrf: problem with argument %d\n", info); \
2447
- return info; \
2448
- } \
2449
- }
2450
-
2451
- /* Specialize for ATLAS types */
2452
- /*LAPACK_GETRF(float, clapack_sgetrf, float)
2453
- LAPACK_GETRF(double, clapack_dgetrf, double)
2454
- LAPACK_GETRF(Complex64, clapack_cgetrf, void)
2455
- LAPACK_GETRF(Complex128, clapack_zgetrf, void)
2456
- */
2457
-
2458
-
2459
- /*
2460
- * Function signature conversion for calling LAPACK's getrf functions as directly as possible.
2461
- *
2462
- * For documentation: http://www.netlib.org/lapack/double/dgetrf.f
2463
- *
2464
- * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2465
- */
2466
- template <typename DType>
2467
- inline int clapack_getrf(const enum CBLAS_ORDER order, const int m, const int n, void* a, const int lda, int* ipiv) {
2468
- return getrf<DType>(order, m, n, reinterpret_cast<DType*>(a), lda, ipiv);
2469
- }
2470
-
2471
-
2472
- /*
2473
- * Function signature conversion for calling LAPACK's potrf functions as directly as possible.
2474
- *
2475
- * For documentation: http://www.netlib.org/lapack/double/dpotrf.f
2476
- *
2477
- * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2478
- */
2479
- template <typename DType>
2480
- inline int clapack_potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
2481
- return potrf<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
2482
- }
2483
-
2484
-
2485
- /*
2486
- * Function signature conversion for calling LAPACK's getrs functions as directly as possible.
2487
- *
2488
- * For documentation: http://www.netlib.org/lapack/double/dgetrs.f
2489
- *
2490
- * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2491
- */
2492
- template <typename DType>
2493
- inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs,
2494
- const void* a, const int lda, const int* ipiv, void* b, const int ldb) {
2495
- return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb);
2496
- }
2497
-
2498
- /*
2499
- * Function signature conversion for calling LAPACK's potrs functions as directly as possible.
2500
- *
2501
- * For documentation: http://www.netlib.org/lapack/double/dpotrs.f
2502
- *
2503
- * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2504
- */
2505
- template <typename DType, bool is_complex>
2506
- inline int clapack_potrs(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, const int nrhs,
2507
- const void* a, const int lda, void* b, const int ldb) {
2508
- return potrs<DType,is_complex>(order, uplo, n, nrhs, reinterpret_cast<const DType*>(a), lda, reinterpret_cast<DType*>(b), ldb);
2509
- }
2510
-
2511
- template <typename DType>
2512
- inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) {
2513
- rb_raise(rb_eNotImpError, "getri not yet implemented for non-BLAS dtypes");
2514
- return 0;
2515
- }
2516
-
2517
- #ifdef HAVE_CLAPACK_H
2518
- template <>
2519
- inline int getri(const enum CBLAS_ORDER order, const int n, float* a, const int lda, const int* ipiv) {
2520
- return clapack_sgetri(order, n, a, lda, ipiv);
2521
- }
2522
-
2523
- template <>
2524
- inline int getri(const enum CBLAS_ORDER order, const int n, double* a, const int lda, const int* ipiv) {
2525
- return clapack_dgetri(order, n, a, lda, ipiv);
2526
- }
2527
-
2528
- template <>
2529
- inline int getri(const enum CBLAS_ORDER order, const int n, Complex64* a, const int lda, const int* ipiv) {
2530
- return clapack_cgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
2531
- }
2532
-
2533
- template <>
2534
- inline int getri(const enum CBLAS_ORDER order, const int n, Complex128* a, const int lda, const int* ipiv) {
2535
- return clapack_zgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
2536
- }
2537
- #endif
2538
-
2539
-
2540
- template <typename DType>
2541
- inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, DType* a, const int lda) {
2542
- rb_raise(rb_eNotImpError, "potri not yet implemented for non-BLAS dtypes");
2543
- return 0;
2544
- }
2545
-
2546
-
2547
- #ifdef HAVE_CLAPACK_H
2548
- template <>
2549
- inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, float* a, const int lda) {
2550
- return clapack_spotri(order, uplo, n, a, lda);
2551
- }
2552
-
2553
- template <>
2554
- inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, double* a, const int lda) {
2555
- return clapack_dpotri(order, uplo, n, a, lda);
2556
- }
2557
-
2558
- template <>
2559
- inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex64* a, const int lda) {
2560
- return clapack_cpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
2561
- }
2562
-
2563
- template <>
2564
- inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex128* a, const int lda) {
2565
- return clapack_zpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
2566
- }
2567
- #endif
2568
-
2569
- /*
2570
- * Function signature conversion for calling LAPACK's getri functions as directly as possible.
2571
- *
2572
- * For documentation: http://www.netlib.org/lapack/double/dgetri.f
2573
- *
2574
- * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2575
- */
2576
- template <typename DType>
2577
- inline int clapack_getri(const enum CBLAS_ORDER order, const int n, void* a, const int lda, const int* ipiv) {
2578
- return getri<DType>(order, n, reinterpret_cast<DType*>(a), lda, ipiv);
2579
- }
2580
-
2581
-
2582
- /*
2583
- * Function signature conversion for calling LAPACK's potri functions as directly as possible.
2584
- *
2585
- * For documentation: http://www.netlib.org/lapack/double/dpotri.f
2586
- *
2587
- * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2588
- */
2589
- template <typename DType>
2590
- inline int clapack_potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
2591
- return potri<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
2592
- }
2593
-
2594
-
2595
- /*
2596
- * Function signature conversion for calling LAPACK's laswp functions as directly as possible.
2597
- *
2598
- * For documentation: http://www.netlib.org/lapack/double/dlaswp.f
2599
- *
2600
- * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2601
- */
2602
- template <typename DType>
2603
- inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
2604
- laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
2605
- }
2606
-
2607
-
2608
-
2609
- }} // end namespace nm::math
2610
-
2611
-
2612
- #endif // MATH_H