nmatrix 0.0.6 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -0
- data/Gemfile +5 -0
- data/History.txt +97 -0
- data/Manifest.txt +34 -7
- data/README.rdoc +13 -13
- data/Rakefile +36 -26
- data/ext/nmatrix/data/data.cpp +15 -2
- data/ext/nmatrix/data/data.h +4 -0
- data/ext/nmatrix/data/ruby_object.h +5 -14
- data/ext/nmatrix/extconf.rb +3 -2
- data/ext/nmatrix/{util/math.cpp → math.cpp} +296 -6
- data/ext/nmatrix/math/asum.h +143 -0
- data/ext/nmatrix/math/geev.h +82 -0
- data/ext/nmatrix/math/gemm.h +267 -0
- data/ext/nmatrix/math/gemv.h +208 -0
- data/ext/nmatrix/math/ger.h +96 -0
- data/ext/nmatrix/math/gesdd.h +80 -0
- data/ext/nmatrix/math/gesvd.h +78 -0
- data/ext/nmatrix/math/getf2.h +86 -0
- data/ext/nmatrix/math/getrf.h +240 -0
- data/ext/nmatrix/math/getri.h +107 -0
- data/ext/nmatrix/math/getrs.h +125 -0
- data/ext/nmatrix/math/idamax.h +86 -0
- data/ext/nmatrix/{util → math}/lapack.h +60 -356
- data/ext/nmatrix/math/laswp.h +165 -0
- data/ext/nmatrix/math/long_dtype.h +52 -0
- data/ext/nmatrix/math/math.h +1154 -0
- data/ext/nmatrix/math/nrm2.h +181 -0
- data/ext/nmatrix/math/potrs.h +125 -0
- data/ext/nmatrix/math/rot.h +141 -0
- data/ext/nmatrix/math/rotg.h +115 -0
- data/ext/nmatrix/math/scal.h +73 -0
- data/ext/nmatrix/math/swap.h +73 -0
- data/ext/nmatrix/math/trsm.h +383 -0
- data/ext/nmatrix/nmatrix.cpp +176 -152
- data/ext/nmatrix/nmatrix.h +1 -2
- data/ext/nmatrix/ruby_constants.cpp +9 -4
- data/ext/nmatrix/ruby_constants.h +1 -0
- data/ext/nmatrix/storage/dense.cpp +57 -41
- data/ext/nmatrix/storage/list.cpp +52 -50
- data/ext/nmatrix/storage/storage.cpp +59 -43
- data/ext/nmatrix/storage/yale.cpp +352 -333
- data/ext/nmatrix/storage/yale.h +4 -0
- data/lib/nmatrix.rb +2 -2
- data/lib/nmatrix/blas.rb +4 -4
- data/lib/nmatrix/enumerate.rb +241 -0
- data/lib/nmatrix/lapack.rb +54 -1
- data/lib/nmatrix/math.rb +462 -0
- data/lib/nmatrix/nmatrix.rb +210 -486
- data/lib/nmatrix/nvector.rb +0 -62
- data/lib/nmatrix/rspec.rb +75 -0
- data/lib/nmatrix/shortcuts.rb +136 -108
- data/lib/nmatrix/version.rb +1 -1
- data/spec/blas_spec.rb +20 -12
- data/spec/elementwise_spec.rb +22 -13
- data/spec/io_spec.rb +1 -0
- data/spec/lapack_spec.rb +197 -0
- data/spec/nmatrix_spec.rb +39 -38
- data/spec/nvector_spec.rb +3 -9
- data/spec/rspec_monkeys.rb +29 -0
- data/spec/rspec_spec.rb +34 -0
- data/spec/shortcuts_spec.rb +14 -16
- data/spec/slice_spec.rb +242 -186
- data/spec/spec_helper.rb +19 -0
- metadata +33 -5
- data/ext/nmatrix/util/math.h +0 -2612
data/spec/spec_helper.rb
CHANGED
@@ -24,6 +24,7 @@
|
|
24
24
|
#
|
25
25
|
# Common data for testing.
|
26
26
|
require "./lib/nmatrix"
|
27
|
+
require "./lib/nmatrix/rspec"
|
27
28
|
|
28
29
|
MATRIX43A_ARRAY = [14.0, 9.0, 3.0, 2.0, 11.0, 15.0, 0.0, 12.0, 17.0, 5.0, 2.0, 3.0]
|
29
30
|
MATRIX32A_ARRAY = [12.0, 25.0, 9.0, 10.0, 8.0, 5.0]
|
@@ -66,3 +67,21 @@ def create_vector(stype) #:nodoc:
|
|
66
67
|
|
67
68
|
m
|
68
69
|
end
|
70
|
+
|
71
|
+
# Stupid but independent comparison for slice_spec
|
72
|
+
def nm_eql(n, m) #:nodoc:
|
73
|
+
if n.shape != m.shape
|
74
|
+
false
|
75
|
+
else # NMatrix
|
76
|
+
n.shape[0].times do |i|
|
77
|
+
n.shape[1].times do |j|
|
78
|
+
if n[i,j] != m[i,j]
|
79
|
+
puts "n[#{i},#{j}] != m[#{i},#{j}] (#{n[i,j]} != #{m[i,j]})"
|
80
|
+
return false
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
true
|
86
|
+
end
|
87
|
+
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: nmatrix
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- John Woods
|
@@ -10,7 +10,7 @@ authors:
|
|
10
10
|
autorequire:
|
11
11
|
bindir: bin
|
12
12
|
cert_chain: []
|
13
|
-
date: 2013-08-
|
13
|
+
date: 2013-08-22 00:00:00.000000000 Z
|
14
14
|
dependencies:
|
15
15
|
- !ruby/object:Gem::Dependency
|
16
16
|
name: rdoc
|
@@ -136,6 +136,30 @@ files:
|
|
136
136
|
- ext/nmatrix/data/rational.h
|
137
137
|
- ext/nmatrix/data/ruby_object.h
|
138
138
|
- ext/nmatrix/extconf.rb
|
139
|
+
- ext/nmatrix/math.cpp
|
140
|
+
- ext/nmatrix/math/asum.h
|
141
|
+
- ext/nmatrix/math/geev.h
|
142
|
+
- ext/nmatrix/math/gemm.h
|
143
|
+
- ext/nmatrix/math/gemv.h
|
144
|
+
- ext/nmatrix/math/ger.h
|
145
|
+
- ext/nmatrix/math/gesdd.h
|
146
|
+
- ext/nmatrix/math/gesvd.h
|
147
|
+
- ext/nmatrix/math/getf2.h
|
148
|
+
- ext/nmatrix/math/getrf.h
|
149
|
+
- ext/nmatrix/math/getri.h
|
150
|
+
- ext/nmatrix/math/getrs.h
|
151
|
+
- ext/nmatrix/math/idamax.h
|
152
|
+
- ext/nmatrix/math/lapack.h
|
153
|
+
- ext/nmatrix/math/laswp.h
|
154
|
+
- ext/nmatrix/math/long_dtype.h
|
155
|
+
- ext/nmatrix/math/math.h
|
156
|
+
- ext/nmatrix/math/nrm2.h
|
157
|
+
- ext/nmatrix/math/potrs.h
|
158
|
+
- ext/nmatrix/math/rot.h
|
159
|
+
- ext/nmatrix/math/rotg.h
|
160
|
+
- ext/nmatrix/math/scal.h
|
161
|
+
- ext/nmatrix/math/swap.h
|
162
|
+
- ext/nmatrix/math/trsm.h
|
139
163
|
- ext/nmatrix/nmatrix.cpp
|
140
164
|
- ext/nmatrix/nmatrix.h
|
141
165
|
- ext/nmatrix/ruby_constants.cpp
|
@@ -154,21 +178,21 @@ files:
|
|
154
178
|
- ext/nmatrix/types.h
|
155
179
|
- ext/nmatrix/util/io.cpp
|
156
180
|
- ext/nmatrix/util/io.h
|
157
|
-
- ext/nmatrix/util/lapack.h
|
158
|
-
- ext/nmatrix/util/math.cpp
|
159
|
-
- ext/nmatrix/util/math.h
|
160
181
|
- ext/nmatrix/util/sl_list.cpp
|
161
182
|
- ext/nmatrix/util/sl_list.h
|
162
183
|
- ext/nmatrix/util/util.h
|
163
184
|
- lib/nmatrix.rb
|
164
185
|
- lib/nmatrix/blas.rb
|
186
|
+
- lib/nmatrix/enumerate.rb
|
165
187
|
- lib/nmatrix/io/market.rb
|
166
188
|
- lib/nmatrix/io/mat5_reader.rb
|
167
189
|
- lib/nmatrix/io/mat_reader.rb
|
168
190
|
- lib/nmatrix/lapack.rb
|
191
|
+
- lib/nmatrix/math.rb
|
169
192
|
- lib/nmatrix/monkeys.rb
|
170
193
|
- lib/nmatrix/nmatrix.rb
|
171
194
|
- lib/nmatrix/nvector.rb
|
195
|
+
- lib/nmatrix/rspec.rb
|
172
196
|
- lib/nmatrix/shortcuts.rb
|
173
197
|
- lib/nmatrix/version.rb
|
174
198
|
- lib/nmatrix/yale_functions.rb
|
@@ -188,6 +212,8 @@ files:
|
|
188
212
|
- spec/nmatrix_yale_resize_test_associations.yaml
|
189
213
|
- spec/nmatrix_yale_spec.rb
|
190
214
|
- spec/nvector_spec.rb
|
215
|
+
- spec/rspec_monkeys.rb
|
216
|
+
- spec/rspec_spec.rb
|
191
217
|
- spec/shortcuts_spec.rb
|
192
218
|
- spec/slice_spec.rb
|
193
219
|
- spec/spec_helper.rb
|
@@ -258,6 +284,8 @@ test_files:
|
|
258
284
|
- spec/nmatrix_yale_resize_test_associations.yaml
|
259
285
|
- spec/nmatrix_yale_spec.rb
|
260
286
|
- spec/nvector_spec.rb
|
287
|
+
- spec/rspec_monkeys.rb
|
288
|
+
- spec/rspec_spec.rb
|
261
289
|
- spec/shortcuts_spec.rb
|
262
290
|
- spec/slice_spec.rb
|
263
291
|
- spec/spec_helper.rb
|
data/ext/nmatrix/util/math.h
DELETED
@@ -1,2612 +0,0 @@
|
|
1
|
-
/////////////////////////////////////////////////////////////////////
|
2
|
-
// = NMatrix
|
3
|
-
//
|
4
|
-
// A linear algebra library for scientific computation in Ruby.
|
5
|
-
// NMatrix is part of SciRuby.
|
6
|
-
//
|
7
|
-
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
-
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
-
//
|
10
|
-
// == Copyright Information
|
11
|
-
//
|
12
|
-
// SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
|
13
|
-
// NMatrix is Copyright (c) 2013, Ruby Science Foundation
|
14
|
-
//
|
15
|
-
// Please see LICENSE.txt for additional copyright notices.
|
16
|
-
//
|
17
|
-
// == Contributing
|
18
|
-
//
|
19
|
-
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
-
// our Contributor Agreement:
|
21
|
-
//
|
22
|
-
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
-
//
|
24
|
-
// == math.h
|
25
|
-
//
|
26
|
-
// Header file for math functions, interfacing with BLAS, etc.
|
27
|
-
//
|
28
|
-
// For instructions on adding CBLAS and CLAPACK functions, see the
|
29
|
-
// beginning of math.cpp.
|
30
|
-
//
|
31
|
-
// Some of these functions are from ATLAS. Here is the license for
|
32
|
-
// ATLAS:
|
33
|
-
//
|
34
|
-
/*
|
35
|
-
* Automatically Tuned Linear Algebra Software v3.8.4
|
36
|
-
* (C) Copyright 1999 R. Clint Whaley
|
37
|
-
*
|
38
|
-
* Redistribution and use in source and binary forms, with or without
|
39
|
-
* modification, are permitted provided that the following conditions
|
40
|
-
* are met:
|
41
|
-
* 1. Redistributions of source code must retain the above copyright
|
42
|
-
* notice, this list of conditions and the following disclaimer.
|
43
|
-
* 2. Redistributions in binary form must reproduce the above copyright
|
44
|
-
* notice, this list of conditions, and the following disclaimer in the
|
45
|
-
* documentation and/or other materials provided with the distribution.
|
46
|
-
* 3. The name of the ATLAS group or the names of its contributers may
|
47
|
-
* not be used to endorse or promote products derived from this
|
48
|
-
* software without specific written permission.
|
49
|
-
*
|
50
|
-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
51
|
-
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
52
|
-
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
53
|
-
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
|
54
|
-
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
55
|
-
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
56
|
-
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
57
|
-
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
58
|
-
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
59
|
-
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
60
|
-
* POSSIBILITY OF SUCH DAMAGE.
|
61
|
-
*
|
62
|
-
*/
|
63
|
-
|
64
|
-
#ifndef MATH_H
|
65
|
-
#define MATH_H
|
66
|
-
|
67
|
-
/*
|
68
|
-
* Standard Includes
|
69
|
-
*/
|
70
|
-
|
71
|
-
extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
|
72
|
-
#include <cblas.h>
|
73
|
-
|
74
|
-
#ifdef HAVE_CLAPACK_H
|
75
|
-
#include <clapack.h>
|
76
|
-
#endif
|
77
|
-
}
|
78
|
-
|
79
|
-
#include <algorithm> // std::min, std::max
|
80
|
-
#include <limits> // std::numeric_limits
|
81
|
-
|
82
|
-
/*
|
83
|
-
* Project Includes
|
84
|
-
*/
|
85
|
-
#include "data/data.h"
|
86
|
-
#include "lapack.h"
|
87
|
-
|
88
|
-
/*
|
89
|
-
* Macros
|
90
|
-
*/
|
91
|
-
#define REAL_RECURSE_LIMIT 4
|
92
|
-
|
93
|
-
/*
|
94
|
-
* Data
|
95
|
-
*/
|
96
|
-
|
97
|
-
|
98
|
-
extern "C" {
|
99
|
-
/*
|
100
|
-
* C accessors.
|
101
|
-
*/
|
102
|
-
void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
|
103
|
-
void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
|
104
|
-
void nm_math_init_blas(void);
|
105
|
-
}
|
106
|
-
|
107
|
-
|
108
|
-
namespace nm {
|
109
|
-
namespace math {
|
110
|
-
|
111
|
-
/*
|
112
|
-
* Types
|
113
|
-
*/
|
114
|
-
|
115
|
-
|
116
|
-
// These allow an increase in precision for intermediate values of gemm and gemv.
|
117
|
-
// See also: http://stackoverflow.com/questions/11873694/how-does-one-increase-precision-in-c-templates-in-a-template-typename-dependen
|
118
|
-
template <typename DType> struct LongDType;
|
119
|
-
template <> struct LongDType<uint8_t> { typedef int16_t type; };
|
120
|
-
template <> struct LongDType<int8_t> { typedef int16_t type; };
|
121
|
-
template <> struct LongDType<int16_t> { typedef int32_t type; };
|
122
|
-
template <> struct LongDType<int32_t> { typedef int64_t type; };
|
123
|
-
template <> struct LongDType<int64_t> { typedef int64_t type; };
|
124
|
-
template <> struct LongDType<float> { typedef double type; };
|
125
|
-
template <> struct LongDType<double> { typedef double type; };
|
126
|
-
template <> struct LongDType<Complex64> { typedef Complex128 type; };
|
127
|
-
template <> struct LongDType<Complex128> { typedef Complex128 type; };
|
128
|
-
template <> struct LongDType<Rational32> { typedef Rational128 type; };
|
129
|
-
template <> struct LongDType<Rational64> { typedef Rational128 type; };
|
130
|
-
template <> struct LongDType<Rational128> { typedef Rational128 type; };
|
131
|
-
template <> struct LongDType<RubyObject> { typedef RubyObject type; };
|
132
|
-
|
133
|
-
/*
|
134
|
-
* Functions
|
135
|
-
*/
|
136
|
-
|
137
|
-
/* Numeric inverse -- usually just 1 / f, but a little more complicated for complex. */
|
138
|
-
template <typename DType>
|
139
|
-
inline DType numeric_inverse(const DType& n) {
|
140
|
-
return n.inverse();
|
141
|
-
}
|
142
|
-
template <> inline float numeric_inverse<float>(const float& n) { return 1 / n; }
|
143
|
-
template <> inline double numeric_inverse<double>(const double& n) { return 1 / n; }
|
144
|
-
|
145
|
-
/*
|
146
|
-
* This version of trsm doesn't do any error checks and only works on column-major matrices.
|
147
|
-
*
|
148
|
-
* For row major, call trsm<DType> instead. That will handle necessary changes-of-variables
|
149
|
-
* and parameter checks.
|
150
|
-
*
|
151
|
-
* Note that some of the boundary conditions here may be incorrect. Very little has been tested!
|
152
|
-
* This was converted directly from dtrsm.f using f2c, and then rewritten more cleanly.
|
153
|
-
*/
|
154
|
-
template <typename DType>
|
155
|
-
inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
156
|
-
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
157
|
-
const int m, const int n, const DType alpha, const DType* a,
|
158
|
-
const int lda, DType* b, const int ldb)
|
159
|
-
{
|
160
|
-
|
161
|
-
// (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3
|
162
|
-
|
163
|
-
if (m == 0 || n == 0) return; /* Quick return if possible. */
|
164
|
-
|
165
|
-
if (alpha == 0) { // Handle alpha == 0
|
166
|
-
for (int j = 0; j < n; ++j) {
|
167
|
-
for (int i = 0; i < m; ++i) {
|
168
|
-
b[i + j * ldb] = 0;
|
169
|
-
}
|
170
|
-
}
|
171
|
-
return;
|
172
|
-
}
|
173
|
-
|
174
|
-
if (side == CblasLeft) {
|
175
|
-
if (trans_a == CblasNoTrans) {
|
176
|
-
|
177
|
-
/* Form B := alpha*inv( A )*B. */
|
178
|
-
if (uplo == CblasUpper) {
|
179
|
-
for (int j = 0; j < n; ++j) {
|
180
|
-
if (alpha != 1) {
|
181
|
-
for (int i = 0; i < m; ++i) {
|
182
|
-
b[i + j * ldb] = alpha * b[i + j * ldb];
|
183
|
-
}
|
184
|
-
}
|
185
|
-
for (int k = m-1; k >= 0; --k) {
|
186
|
-
if (b[k + j * ldb] != 0) {
|
187
|
-
if (diag == CblasNonUnit) {
|
188
|
-
b[k + j * ldb] /= a[k + k * lda];
|
189
|
-
}
|
190
|
-
|
191
|
-
for (int i = 0; i < k-1; ++i) {
|
192
|
-
b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
|
193
|
-
}
|
194
|
-
}
|
195
|
-
}
|
196
|
-
}
|
197
|
-
} else {
|
198
|
-
for (int j = 0; j < n; ++j) {
|
199
|
-
if (alpha != 1) {
|
200
|
-
for (int i = 0; i < m; ++i) {
|
201
|
-
b[i + j * ldb] = alpha * b[i + j * ldb];
|
202
|
-
}
|
203
|
-
}
|
204
|
-
for (int k = 0; k < m; ++k) {
|
205
|
-
if (b[k + j * ldb] != 0.) {
|
206
|
-
if (diag == CblasNonUnit) {
|
207
|
-
b[k + j * ldb] /= a[k + k * lda];
|
208
|
-
}
|
209
|
-
for (int i = k+1; i < m; ++i) {
|
210
|
-
b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
|
211
|
-
}
|
212
|
-
}
|
213
|
-
}
|
214
|
-
}
|
215
|
-
}
|
216
|
-
} else { // CblasTrans
|
217
|
-
|
218
|
-
/* Form B := alpha*inv( A**T )*B. */
|
219
|
-
if (uplo == CblasUpper) {
|
220
|
-
for (int j = 0; j < n; ++j) {
|
221
|
-
for (int i = 0; i < m; ++i) {
|
222
|
-
DType temp = alpha * b[i + j * ldb];
|
223
|
-
for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
|
224
|
-
temp -= a[k + i * lda] * b[k + j * ldb];
|
225
|
-
}
|
226
|
-
if (diag == CblasNonUnit) {
|
227
|
-
temp /= a[i + i * lda];
|
228
|
-
}
|
229
|
-
b[i + j * ldb] = temp;
|
230
|
-
}
|
231
|
-
}
|
232
|
-
} else {
|
233
|
-
for (int j = 0; j < n; ++j) {
|
234
|
-
for (int i = m-1; i >= 0; --i) {
|
235
|
-
DType temp= alpha * b[i + j * ldb];
|
236
|
-
for (int k = i+1; k < m; ++k) {
|
237
|
-
temp -= a[k + i * lda] * b[k + j * ldb];
|
238
|
-
}
|
239
|
-
if (diag == CblasNonUnit) {
|
240
|
-
temp /= a[i + i * lda];
|
241
|
-
}
|
242
|
-
b[i + j * ldb] = temp;
|
243
|
-
}
|
244
|
-
}
|
245
|
-
}
|
246
|
-
}
|
247
|
-
} else { // right side
|
248
|
-
|
249
|
-
if (trans_a == CblasNoTrans) {
|
250
|
-
|
251
|
-
/* Form B := alpha*B*inv( A ). */
|
252
|
-
|
253
|
-
if (uplo == CblasUpper) {
|
254
|
-
for (int j = 0; j < n; ++j) {
|
255
|
-
if (alpha != 1) {
|
256
|
-
for (int i = 0; i < m; ++i) {
|
257
|
-
b[i + j * ldb] = alpha * b[i + j * ldb];
|
258
|
-
}
|
259
|
-
}
|
260
|
-
for (int k = 0; k < j-1; ++k) {
|
261
|
-
if (a[k + j * lda] != 0) {
|
262
|
-
for (int i = 0; i < m; ++i) {
|
263
|
-
b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
|
264
|
-
}
|
265
|
-
}
|
266
|
-
}
|
267
|
-
if (diag == CblasNonUnit) {
|
268
|
-
DType temp = 1 / a[j + j * lda];
|
269
|
-
for (int i = 0; i < m; ++i) {
|
270
|
-
b[i + j * ldb] = temp * b[i + j * ldb];
|
271
|
-
}
|
272
|
-
}
|
273
|
-
}
|
274
|
-
} else {
|
275
|
-
for (int j = n-1; j >= 0; --j) {
|
276
|
-
if (alpha != 1) {
|
277
|
-
for (int i = 0; i < m; ++i) {
|
278
|
-
b[i + j * ldb] = alpha * b[i + j * ldb];
|
279
|
-
}
|
280
|
-
}
|
281
|
-
|
282
|
-
for (int k = j+1; k < n; ++k) {
|
283
|
-
if (a[k + j * lda] != 0.) {
|
284
|
-
for (int i = 0; i < m; ++i) {
|
285
|
-
b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
|
286
|
-
}
|
287
|
-
}
|
288
|
-
}
|
289
|
-
if (diag == CblasNonUnit) {
|
290
|
-
DType temp = 1 / a[j + j * lda];
|
291
|
-
|
292
|
-
for (int i = 0; i < m; ++i) {
|
293
|
-
b[i + j * ldb] = temp * b[i + j * ldb];
|
294
|
-
}
|
295
|
-
}
|
296
|
-
}
|
297
|
-
}
|
298
|
-
} else { // CblasTrans
|
299
|
-
|
300
|
-
/* Form B := alpha*B*inv( A**T ). */
|
301
|
-
|
302
|
-
if (uplo == CblasUpper) {
|
303
|
-
for (int k = n-1; k >= 0; --k) {
|
304
|
-
if (diag == CblasNonUnit) {
|
305
|
-
DType temp= 1 / a[k + k * lda];
|
306
|
-
for (int i = 0; i < m; ++i) {
|
307
|
-
b[i + k * ldb] = temp * b[i + k * ldb];
|
308
|
-
}
|
309
|
-
}
|
310
|
-
for (int j = 0; j < k-1; ++j) {
|
311
|
-
if (a[j + k * lda] != 0.) {
|
312
|
-
DType temp= a[j + k * lda];
|
313
|
-
for (int i = 0; i < m; ++i) {
|
314
|
-
b[i + j * ldb] -= temp * b[i + k * ldb];
|
315
|
-
}
|
316
|
-
}
|
317
|
-
}
|
318
|
-
if (alpha != 1) {
|
319
|
-
for (int i = 0; i < m; ++i) {
|
320
|
-
b[i + k * ldb] = alpha * b[i + k * ldb];
|
321
|
-
}
|
322
|
-
}
|
323
|
-
}
|
324
|
-
} else {
|
325
|
-
for (int k = 0; k < n; ++k) {
|
326
|
-
if (diag == CblasNonUnit) {
|
327
|
-
DType temp = 1 / a[k + k * lda];
|
328
|
-
for (int i = 0; i < m; ++i) {
|
329
|
-
b[i + k * ldb] = temp * b[i + k * ldb];
|
330
|
-
}
|
331
|
-
}
|
332
|
-
for (int j = k+1; j < n; ++j) {
|
333
|
-
if (a[j + k * lda] != 0.) {
|
334
|
-
DType temp = a[j + k * lda];
|
335
|
-
for (int i = 0; i < m; ++i) {
|
336
|
-
b[i + j * ldb] -= temp * b[i + k * ldb];
|
337
|
-
}
|
338
|
-
}
|
339
|
-
}
|
340
|
-
if (alpha != 1) {
|
341
|
-
for (int i = 0; i < m; ++i) {
|
342
|
-
b[i + k * ldb] = alpha * b[i + k * ldb];
|
343
|
-
}
|
344
|
-
}
|
345
|
-
}
|
346
|
-
}
|
347
|
-
}
|
348
|
-
}
|
349
|
-
}
|
350
|
-
|
351
|
-
|
352
|
-
template <typename DType>
|
353
|
-
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
354
|
-
const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
|
355
|
-
rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes");
|
356
|
-
}
|
357
|
-
|
358
|
-
template <typename DType>
|
359
|
-
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
360
|
-
const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
|
361
|
-
rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes");
|
362
|
-
}
|
363
|
-
|
364
|
-
template <>
|
365
|
-
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
366
|
-
const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) {
|
367
|
-
cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
|
368
|
-
}
|
369
|
-
|
370
|
-
template <>
|
371
|
-
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
372
|
-
const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) {
|
373
|
-
cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
|
374
|
-
}
|
375
|
-
|
376
|
-
template <>
|
377
|
-
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
378
|
-
const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
|
379
|
-
cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
|
380
|
-
}
|
381
|
-
|
382
|
-
template <>
|
383
|
-
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
384
|
-
const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
|
385
|
-
cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
|
386
|
-
}
|
387
|
-
|
388
|
-
|
389
|
-
template <>
|
390
|
-
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
391
|
-
const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
|
392
|
-
cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
|
393
|
-
}
|
394
|
-
|
395
|
-
template <>
|
396
|
-
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
397
|
-
const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
|
398
|
-
cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
|
399
|
-
}
|
400
|
-
|
401
|
-
|
402
|
-
template <typename DType>
|
403
|
-
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
404
|
-
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha,
|
405
|
-
const DType* A, const int lda, DType* B, const int ldb) {
|
406
|
-
rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes");
|
407
|
-
}
|
408
|
-
|
409
|
-
template <>
|
410
|
-
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
411
|
-
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha,
|
412
|
-
const float* A, const int lda, float* B, const int ldb) {
|
413
|
-
cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
|
414
|
-
}
|
415
|
-
|
416
|
-
template <>
|
417
|
-
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
418
|
-
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha,
|
419
|
-
const double* A, const int lda, double* B, const int ldb) {
|
420
|
-
cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
|
421
|
-
}
|
422
|
-
|
423
|
-
template <>
|
424
|
-
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
425
|
-
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha,
|
426
|
-
const Complex64* A, const int lda, Complex64* B, const int ldb) {
|
427
|
-
cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
|
428
|
-
}
|
429
|
-
|
430
|
-
template <>
|
431
|
-
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
432
|
-
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha,
|
433
|
-
const Complex128* A, const int lda, Complex128* B, const int ldb) {
|
434
|
-
cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
|
435
|
-
}
|
436
|
-
|
437
|
-
|
438
|
-
/*
|
439
|
-
* BLAS' DTRSM function, generalized.
|
440
|
-
*/
|
441
|
-
template <typename DType, typename = typename std::enable_if<!std::is_integral<DType>::value>::type>
|
442
|
-
inline void trsm(const enum CBLAS_ORDER order,
|
443
|
-
const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
444
|
-
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
445
|
-
const int m, const int n, const DType alpha, const DType* a,
|
446
|
-
const int lda, DType* b, const int ldb)
|
447
|
-
{
|
448
|
-
/*using std::cerr;
|
449
|
-
using std::endl;*/
|
450
|
-
|
451
|
-
int num_rows_a = n;
|
452
|
-
if (side == CblasLeft) num_rows_a = m;
|
453
|
-
|
454
|
-
if (lda < std::max(1,num_rows_a)) {
|
455
|
-
fprintf(stderr, "TRSM: num_rows_a = %d; got lda=%d\n", num_rows_a, lda);
|
456
|
-
rb_raise(rb_eArgError, "TRSM: Expected lda >= max(1, num_rows_a)");
|
457
|
-
}
|
458
|
-
|
459
|
-
// Test the input parameters.
|
460
|
-
if (order == CblasRowMajor) {
|
461
|
-
if (ldb < std::max(1,n)) {
|
462
|
-
fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
|
463
|
-
rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,N)");
|
464
|
-
}
|
465
|
-
|
466
|
-
// For row major, need to switch side and uplo
|
467
|
-
enum CBLAS_SIDE side_ = side == CblasLeft ? CblasRight : CblasLeft;
|
468
|
-
enum CBLAS_UPLO uplo_ = uplo == CblasUpper ? CblasLower : CblasUpper;
|
469
|
-
|
470
|
-
/*
|
471
|
-
cerr << "(row-major) trsm: " << (side_ == CblasLeft ? "left " : "right ")
|
472
|
-
<< (uplo_ == CblasUpper ? "upper " : "lower ")
|
473
|
-
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
474
|
-
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
475
|
-
<< n << " " << m << " " << alpha << " a " << lda << " b " << ldb << endl;
|
476
|
-
*/
|
477
|
-
trsm_nothrow<DType>(side_, uplo_, trans_a, diag, n, m, alpha, a, lda, b, ldb);
|
478
|
-
|
479
|
-
} else { // CblasColMajor
|
480
|
-
|
481
|
-
if (ldb < std::max(1,m)) {
|
482
|
-
fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
|
483
|
-
rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,M)");
|
484
|
-
}
|
485
|
-
/*
|
486
|
-
cerr << "(col-major) trsm: " << (side == CblasLeft ? "left " : "right ")
|
487
|
-
<< (uplo == CblasUpper ? "upper " : "lower ")
|
488
|
-
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
489
|
-
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
490
|
-
<< m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
|
491
|
-
*/
|
492
|
-
trsm_nothrow<DType>(side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
493
|
-
|
494
|
-
}
|
495
|
-
|
496
|
-
}
|
497
|
-
|
498
|
-
|
499
|
-
template <>
|
500
|
-
inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
501
|
-
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
502
|
-
const int m, const int n, const float alpha, const float* a,
|
503
|
-
const int lda, float* b, const int ldb)
|
504
|
-
{
|
505
|
-
cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
506
|
-
}
|
507
|
-
|
508
|
-
template <>
|
509
|
-
inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
510
|
-
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
511
|
-
const int m, const int n, const double alpha, const double* a,
|
512
|
-
const int lda, double* b, const int ldb)
|
513
|
-
{
|
514
|
-
/* using std::cerr;
|
515
|
-
using std::endl;
|
516
|
-
cerr << "(row-major) dtrsm: " << (side == CblasLeft ? "left " : "right ")
|
517
|
-
<< (uplo == CblasUpper ? "upper " : "lower ")
|
518
|
-
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
519
|
-
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
520
|
-
<< m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
|
521
|
-
*/
|
522
|
-
cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
523
|
-
}
|
524
|
-
|
525
|
-
|
526
|
-
template <>
|
527
|
-
inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
528
|
-
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
529
|
-
const int m, const int n, const Complex64 alpha, const Complex64* a,
|
530
|
-
const int lda, Complex64* b, const int ldb)
|
531
|
-
{
|
532
|
-
cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
|
533
|
-
}
|
534
|
-
|
535
|
-
template <>
|
536
|
-
inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
537
|
-
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
538
|
-
const int m, const int n, const Complex128 alpha, const Complex128* a,
|
539
|
-
const int lda, Complex128* b, const int ldb)
|
540
|
-
{
|
541
|
-
cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
|
542
|
-
}
|
543
|
-
|
544
|
-
|
545
|
-
/*
|
546
|
-
* ATLAS function which performs row interchanges on a general rectangular matrix. Modeled after the LAPACK LASWP function.
|
547
|
-
*
|
548
|
-
* This version is templated for use by template <> getrf().
|
549
|
-
*/
|
550
|
-
template <typename DType>
|
551
|
-
inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
|
552
|
-
//const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
|
553
|
-
|
554
|
-
int nb = N >> 5;
|
555
|
-
|
556
|
-
const int mr = N - (nb<<5);
|
557
|
-
const int incA = lda << 5;
|
558
|
-
|
559
|
-
if (K2 < K1) return;
|
560
|
-
|
561
|
-
int i1, i2;
|
562
|
-
if (inci < 0) {
|
563
|
-
piv -= (K2-1) * inci;
|
564
|
-
i1 = K2 - 1;
|
565
|
-
i2 = K1;
|
566
|
-
} else {
|
567
|
-
piv += K1 * inci;
|
568
|
-
i1 = K1;
|
569
|
-
i2 = K2-1;
|
570
|
-
}
|
571
|
-
|
572
|
-
if (nb) {
|
573
|
-
|
574
|
-
do {
|
575
|
-
const int* ipiv = piv;
|
576
|
-
int i = i1;
|
577
|
-
int KeepOn;
|
578
|
-
|
579
|
-
do {
|
580
|
-
int ip = *ipiv; ipiv += inci;
|
581
|
-
|
582
|
-
if (ip != i) {
|
583
|
-
DType *a0 = &(A[i]),
|
584
|
-
*a1 = &(A[ip]);
|
585
|
-
|
586
|
-
for (register int h = 32; h; h--) {
|
587
|
-
DType r = *a0;
|
588
|
-
*a0 = *a1;
|
589
|
-
*a1 = r;
|
590
|
-
|
591
|
-
a0 += lda;
|
592
|
-
a1 += lda;
|
593
|
-
}
|
594
|
-
|
595
|
-
}
|
596
|
-
if (inci > 0) KeepOn = (++i <= i2);
|
597
|
-
else KeepOn = (--i >= i2);
|
598
|
-
|
599
|
-
} while (KeepOn);
|
600
|
-
A += incA;
|
601
|
-
} while (--nb);
|
602
|
-
}
|
603
|
-
|
604
|
-
if (mr) {
|
605
|
-
const int* ipiv = piv;
|
606
|
-
int i = i1;
|
607
|
-
int KeepOn;
|
608
|
-
|
609
|
-
do {
|
610
|
-
int ip = *ipiv; ipiv += inci;
|
611
|
-
if (ip != i) {
|
612
|
-
DType *a0 = &(A[i]),
|
613
|
-
*a1 = &(A[ip]);
|
614
|
-
|
615
|
-
for (register int h = mr; h; h--) {
|
616
|
-
DType r = *a0;
|
617
|
-
*a0 = *a1;
|
618
|
-
*a1 = r;
|
619
|
-
|
620
|
-
a0 += lda;
|
621
|
-
a1 += lda;
|
622
|
-
}
|
623
|
-
}
|
624
|
-
|
625
|
-
if (inci > 0) KeepOn = (++i <= i2);
|
626
|
-
else KeepOn = (--i >= i2);
|
627
|
-
|
628
|
-
} while (KeepOn);
|
629
|
-
}
|
630
|
-
}
|
631
|
-
|
632
|
-
|
633
|
-
/*
|
634
|
-
* GEneral Matrix Multiplication: based on dgemm.f from Netlib.
|
635
|
-
*
|
636
|
-
* This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
|
637
|
-
*
|
638
|
-
* Template parameters: LT -- long version of type T. Type T is the matrix dtype.
|
639
|
-
*
|
640
|
-
* This version throws no errors. Use gemm<DType> instead for error checking.
|
641
|
-
*/
|
642
|
-
template <typename DType>
|
643
|
-
inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
644
|
-
const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
|
645
|
-
{
|
646
|
-
|
647
|
-
typename LongDType<DType>::type temp;
|
648
|
-
|
649
|
-
// Quick return if possible
|
650
|
-
if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;
|
651
|
-
|
652
|
-
// For alpha = 0
|
653
|
-
if (*alpha == 0) {
|
654
|
-
if (*beta == 0) {
|
655
|
-
for (int j = 0; j < N; ++j)
|
656
|
-
for (int i = 0; i < M; ++i) {
|
657
|
-
C[i+j*ldc] = 0;
|
658
|
-
}
|
659
|
-
} else {
|
660
|
-
for (int j = 0; j < N; ++j)
|
661
|
-
for (int i = 0; i < M; ++i) {
|
662
|
-
C[i+j*ldc] *= *beta;
|
663
|
-
}
|
664
|
-
}
|
665
|
-
return;
|
666
|
-
}
|
667
|
-
|
668
|
-
// Start the operations
|
669
|
-
if (TransB == CblasNoTrans) {
|
670
|
-
if (TransA == CblasNoTrans) {
|
671
|
-
// C = alpha*A*B+beta*C
|
672
|
-
for (int j = 0; j < N; ++j) {
|
673
|
-
if (*beta == 0) {
|
674
|
-
for (int i = 0; i < M; ++i) {
|
675
|
-
C[i+j*ldc] = 0;
|
676
|
-
}
|
677
|
-
} else if (*beta != 1) {
|
678
|
-
for (int i = 0; i < M; ++i) {
|
679
|
-
C[i+j*ldc] *= *beta;
|
680
|
-
}
|
681
|
-
}
|
682
|
-
|
683
|
-
for (int l = 0; l < K; ++l) {
|
684
|
-
if (B[l+j*ldb] != 0) {
|
685
|
-
temp = *alpha * B[l+j*ldb];
|
686
|
-
for (int i = 0; i < M; ++i) {
|
687
|
-
C[i+j*ldc] += A[i+l*lda] * temp;
|
688
|
-
}
|
689
|
-
}
|
690
|
-
}
|
691
|
-
}
|
692
|
-
|
693
|
-
} else {
|
694
|
-
|
695
|
-
// C = alpha*A**DType*B + beta*C
|
696
|
-
for (int j = 0; j < N; ++j) {
|
697
|
-
for (int i = 0; i < M; ++i) {
|
698
|
-
temp = 0;
|
699
|
-
for (int l = 0; l < K; ++l) {
|
700
|
-
temp += A[l+i*lda] * B[l+j*ldb];
|
701
|
-
}
|
702
|
-
|
703
|
-
if (*beta == 0) {
|
704
|
-
C[i+j*ldc] = *alpha*temp;
|
705
|
-
} else {
|
706
|
-
C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
|
707
|
-
}
|
708
|
-
}
|
709
|
-
}
|
710
|
-
|
711
|
-
}
|
712
|
-
|
713
|
-
} else if (TransA == CblasNoTrans) {
|
714
|
-
|
715
|
-
// C = alpha*A*B**T + beta*C
|
716
|
-
for (int j = 0; j < N; ++j) {
|
717
|
-
if (*beta == 0) {
|
718
|
-
for (int i = 0; i < M; ++i) {
|
719
|
-
C[i+j*ldc] = 0;
|
720
|
-
}
|
721
|
-
} else if (*beta != 1) {
|
722
|
-
for (int i = 0; i < M; ++i) {
|
723
|
-
C[i+j*ldc] *= *beta;
|
724
|
-
}
|
725
|
-
}
|
726
|
-
|
727
|
-
for (int l = 0; l < K; ++l) {
|
728
|
-
if (B[j+l*ldb] != 0) {
|
729
|
-
temp = *alpha * B[j+l*ldb];
|
730
|
-
for (int i = 0; i < M; ++i) {
|
731
|
-
C[i+j*ldc] += A[i+l*lda] * temp;
|
732
|
-
}
|
733
|
-
}
|
734
|
-
}
|
735
|
-
|
736
|
-
}
|
737
|
-
|
738
|
-
} else {
|
739
|
-
|
740
|
-
// C = alpha*A**DType*B**T + beta*C
|
741
|
-
for (int j = 0; j < N; ++j) {
|
742
|
-
for (int i = 0; i < M; ++i) {
|
743
|
-
temp = 0;
|
744
|
-
for (int l = 0; l < K; ++l) {
|
745
|
-
temp += A[l+i*lda] * B[j+l*ldb];
|
746
|
-
}
|
747
|
-
|
748
|
-
if (*beta == 0) {
|
749
|
-
C[i+j*ldc] = *alpha*temp;
|
750
|
-
} else {
|
751
|
-
C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
|
752
|
-
}
|
753
|
-
}
|
754
|
-
}
|
755
|
-
|
756
|
-
}
|
757
|
-
|
758
|
-
return;
|
759
|
-
}
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
template <typename DType>
|
764
|
-
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
765
|
-
const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
|
766
|
-
{
|
767
|
-
if (Order == CblasRowMajor) {
|
768
|
-
if (TransA == CblasNoTrans) {
|
769
|
-
if (lda < std::max(K,1)) {
|
770
|
-
rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
|
771
|
-
}
|
772
|
-
} else {
|
773
|
-
if (lda < std::max(M,1)) { // && TransA == CblasTrans
|
774
|
-
rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
|
775
|
-
}
|
776
|
-
}
|
777
|
-
|
778
|
-
if (TransB == CblasNoTrans) {
|
779
|
-
if (ldb < std::max(N,1)) {
|
780
|
-
rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
|
781
|
-
}
|
782
|
-
} else {
|
783
|
-
if (ldb < std::max(K,1)) {
|
784
|
-
rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
|
785
|
-
}
|
786
|
-
}
|
787
|
-
|
788
|
-
if (ldc < std::max(N,1)) {
|
789
|
-
rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
|
790
|
-
}
|
791
|
-
} else { // CblasColMajor
|
792
|
-
if (TransA == CblasNoTrans) {
|
793
|
-
if (lda < std::max(M,1)) {
|
794
|
-
rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
|
795
|
-
}
|
796
|
-
} else {
|
797
|
-
if (lda < std::max(K,1)) { // && TransA == CblasTrans
|
798
|
-
rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
|
799
|
-
}
|
800
|
-
}
|
801
|
-
|
802
|
-
if (TransB == CblasNoTrans) {
|
803
|
-
if (ldb < std::max(K,1)) {
|
804
|
-
rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
|
805
|
-
}
|
806
|
-
} else {
|
807
|
-
if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
|
808
|
-
rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
|
809
|
-
}
|
810
|
-
}
|
811
|
-
|
812
|
-
if (ldc < std::max(M,1)) {
|
813
|
-
rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
|
814
|
-
}
|
815
|
-
}
|
816
|
-
|
817
|
-
/*
|
818
|
-
* Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
|
819
|
-
* we copy C and then subtract to preserve asymmetry.
|
820
|
-
*/
|
821
|
-
|
822
|
-
if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
|
823
|
-
rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
|
824
|
-
/*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
|
825
|
-
syreflect(CblasUpper, N, C, ldc);
|
826
|
-
*/
|
827
|
-
}
|
828
|
-
|
829
|
-
if (Order == CblasRowMajor) gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
|
830
|
-
else gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
831
|
-
|
832
|
-
}
|
833
|
-
|
834
|
-
|
835
|
-
template <>
|
836
|
-
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
837
|
-
const float* alpha, const float* A, const int lda, const float* B, const int ldb, const float* beta, float* C, const int ldc) {
|
838
|
-
cblas_sgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
|
839
|
-
}
|
840
|
-
|
841
|
-
template <>
|
842
|
-
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
843
|
-
const double* alpha, const double* A, const int lda, const double* B, const int ldb, const double* beta, double* C, const int ldc) {
|
844
|
-
cblas_dgemm(Order, TransA, TransB, M, N, K, *alpha, A, lda, B, ldb, *beta, C, ldc);
|
845
|
-
}
|
846
|
-
|
847
|
-
template <>
|
848
|
-
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
849
|
-
const Complex64* alpha, const Complex64* A, const int lda, const Complex64* B, const int ldb, const Complex64* beta, Complex64* C, const int ldc) {
|
850
|
-
cblas_cgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
851
|
-
}
|
852
|
-
|
853
|
-
template <>
|
854
|
-
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
|
855
|
-
const Complex128* alpha, const Complex128* A, const int lda, const Complex128* B, const int ldb, const Complex128* beta, Complex128* C, const int ldc) {
|
856
|
-
cblas_zgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
857
|
-
}
|
858
|
-
|
859
|
-
|
860
|
-
/*
|
861
|
-
* GEneral Matrix-Vector multiplication: based on dgemv.f from Netlib.
|
862
|
-
*
|
863
|
-
* This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
|
864
|
-
*
|
865
|
-
* Template parameters: LT -- long version of type T. Type T is the matrix dtype.
|
866
|
-
*/
|
867
|
-
template <typename DType>
|
868
|
-
inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const DType* alpha, const DType* A, const int lda,
|
869
|
-
const DType* X, const int incX, const DType* beta, DType* Y, const int incY) {
|
870
|
-
int lenX, lenY, i, j;
|
871
|
-
int kx, ky, iy, jx, jy, ix;
|
872
|
-
|
873
|
-
typename LongDType<DType>::type temp;
|
874
|
-
|
875
|
-
// Test the input parameters
|
876
|
-
if (Trans < 111 || Trans > 113) {
|
877
|
-
rb_raise(rb_eArgError, "GEMV: TransA must be CblasNoTrans, CblasTrans, or CblasConjTrans");
|
878
|
-
return false;
|
879
|
-
} else if (lda < std::max(1, N)) {
|
880
|
-
fprintf(stderr, "GEMV: N = %d; got lda=%d", N, lda);
|
881
|
-
rb_raise(rb_eArgError, "GEMV: Expected lda >= max(1, N)");
|
882
|
-
return false;
|
883
|
-
} else if (incX == 0) {
|
884
|
-
rb_raise(rb_eArgError, "GEMV: Expected incX != 0\n");
|
885
|
-
return false;
|
886
|
-
} else if (incY == 0) {
|
887
|
-
rb_raise(rb_eArgError, "GEMV: Expected incY != 0\n");
|
888
|
-
return false;
|
889
|
-
}
|
890
|
-
|
891
|
-
// Quick return if possible
|
892
|
-
if (!M or !N or (*alpha == 0 and *beta == 1)) return true;
|
893
|
-
|
894
|
-
if (Trans == CblasNoTrans) {
|
895
|
-
lenX = N;
|
896
|
-
lenY = M;
|
897
|
-
} else {
|
898
|
-
lenX = M;
|
899
|
-
lenY = N;
|
900
|
-
}
|
901
|
-
|
902
|
-
if (incX > 0) kx = 0;
|
903
|
-
else kx = (lenX - 1) * -incX;
|
904
|
-
|
905
|
-
if (incY > 0) ky = 0;
|
906
|
-
else ky = (lenY - 1) * -incY;
|
907
|
-
|
908
|
-
// Start the operations. In this version, the elements of A are accessed sequentially with one pass through A.
|
909
|
-
if (*beta != 1) {
|
910
|
-
if (incY == 1) {
|
911
|
-
if (*beta == 0) {
|
912
|
-
for (i = 0; i < lenY; ++i) {
|
913
|
-
Y[i] = 0;
|
914
|
-
}
|
915
|
-
} else {
|
916
|
-
for (i = 0; i < lenY; ++i) {
|
917
|
-
Y[i] *= *beta;
|
918
|
-
}
|
919
|
-
}
|
920
|
-
} else {
|
921
|
-
iy = ky;
|
922
|
-
if (*beta == 0) {
|
923
|
-
for (i = 0; i < lenY; ++i) {
|
924
|
-
Y[iy] = 0;
|
925
|
-
iy += incY;
|
926
|
-
}
|
927
|
-
} else {
|
928
|
-
for (i = 0; i < lenY; ++i) {
|
929
|
-
Y[iy] *= *beta;
|
930
|
-
iy += incY;
|
931
|
-
}
|
932
|
-
}
|
933
|
-
}
|
934
|
-
}
|
935
|
-
|
936
|
-
if (*alpha == 0) return false;
|
937
|
-
|
938
|
-
if (Trans == CblasNoTrans) {
|
939
|
-
|
940
|
-
// Form y := alpha*A*x + y.
|
941
|
-
jx = kx;
|
942
|
-
if (incY == 1) {
|
943
|
-
for (j = 0; j < N; ++j) {
|
944
|
-
if (X[jx] != 0) {
|
945
|
-
temp = *alpha * X[jx];
|
946
|
-
for (i = 0; i < M; ++i) {
|
947
|
-
Y[i] += A[j+i*lda] * temp;
|
948
|
-
}
|
949
|
-
}
|
950
|
-
jx += incX;
|
951
|
-
}
|
952
|
-
} else {
|
953
|
-
for (j = 0; j < N; ++j) {
|
954
|
-
if (X[jx] != 0) {
|
955
|
-
temp = *alpha * X[jx];
|
956
|
-
iy = ky;
|
957
|
-
for (i = 0; i < M; ++i) {
|
958
|
-
Y[iy] += A[j+i*lda] * temp;
|
959
|
-
iy += incY;
|
960
|
-
}
|
961
|
-
}
|
962
|
-
jx += incX;
|
963
|
-
}
|
964
|
-
}
|
965
|
-
|
966
|
-
} else { // TODO: Check that indices are correct! They're switched for C.
|
967
|
-
|
968
|
-
// Form y := alpha*A**DType*x + y.
|
969
|
-
jy = ky;
|
970
|
-
|
971
|
-
if (incX == 1) {
|
972
|
-
for (j = 0; j < N; ++j) {
|
973
|
-
temp = 0;
|
974
|
-
for (i = 0; i < M; ++i) {
|
975
|
-
temp += A[j+i*lda]*X[j];
|
976
|
-
}
|
977
|
-
Y[jy] += *alpha * temp;
|
978
|
-
jy += incY;
|
979
|
-
}
|
980
|
-
} else {
|
981
|
-
for (j = 0; j < N; ++j) {
|
982
|
-
temp = 0;
|
983
|
-
ix = kx;
|
984
|
-
for (i = 0; i < M; ++i) {
|
985
|
-
temp += A[j+i*lda] * X[ix];
|
986
|
-
ix += incX;
|
987
|
-
}
|
988
|
-
|
989
|
-
Y[jy] += *alpha * temp;
|
990
|
-
jy += incY;
|
991
|
-
}
|
992
|
-
}
|
993
|
-
}
|
994
|
-
|
995
|
-
return true;
|
996
|
-
} // end of GEMV
|
997
|
-
|
998
|
-
template <>
|
999
|
-
inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const float* alpha, const float* A, const int lda,
|
1000
|
-
const float* X, const int incX, const float* beta, float* Y, const int incY) {
|
1001
|
-
cblas_sgemv(CblasRowMajor, Trans, M, N, *alpha, A, lda, X, incX, *beta, Y, incY);
|
1002
|
-
return true;
|
1003
|
-
}
|
1004
|
-
|
1005
|
-
template <>
|
1006
|
-
inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const double* alpha, const double* A, const int lda,
|
1007
|
-
const double* X, const int incX, const double* beta, double* Y, const int incY) {
|
1008
|
-
cblas_dgemv(CblasRowMajor, Trans, M, N, *alpha, A, lda, X, incX, *beta, Y, incY);
|
1009
|
-
return true;
|
1010
|
-
}
|
1011
|
-
|
1012
|
-
template <>
|
1013
|
-
inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const Complex64* alpha, const Complex64* A, const int lda,
|
1014
|
-
const Complex64* X, const int incX, const Complex64* beta, Complex64* Y, const int incY) {
|
1015
|
-
cblas_cgemv(CblasRowMajor, Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
|
1016
|
-
return true;
|
1017
|
-
}
|
1018
|
-
|
1019
|
-
template <>
|
1020
|
-
inline bool gemv(const enum CBLAS_TRANSPOSE Trans, const int M, const int N, const Complex128* alpha, const Complex128* A, const int lda,
|
1021
|
-
const Complex128* X, const int incX, const Complex128* beta, Complex128* Y, const int incY) {
|
1022
|
-
cblas_zgemv(CblasRowMajor, Trans, M, N, alpha, A, lda, X, incX, beta, Y, incY);
|
1023
|
-
return true;
|
1024
|
-
}
|
1025
|
-
|
1026
|
-
|
1027
|
-
// Yale: numeric matrix multiply c=a*b
|
1028
|
-
template <typename DType, typename IType>
|
1029
|
-
inline void numbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const DType* a, const bool diaga,
|
1030
|
-
const IType* ib, const IType* jb, const DType* b, const bool diagb, IType* ic, IType* jc, DType* c, const bool diagc) {
|
1031
|
-
const unsigned int max_lmn = std::max(std::max(m, n), l);
|
1032
|
-
IType next[max_lmn];
|
1033
|
-
DType sums[max_lmn];
|
1034
|
-
|
1035
|
-
DType v;
|
1036
|
-
|
1037
|
-
IType head, length, temp, ndnz = 0;
|
1038
|
-
IType minmn = std::min(m,n);
|
1039
|
-
IType minlm = std::min(l,m);
|
1040
|
-
|
1041
|
-
for (IType idx = 0; idx < max_lmn; ++idx) { // initialize scratch arrays
|
1042
|
-
next[idx] = std::numeric_limits<IType>::max();
|
1043
|
-
sums[idx] = 0;
|
1044
|
-
}
|
1045
|
-
|
1046
|
-
for (IType i = 0; i < n; ++i) { // walk down the rows
|
1047
|
-
head = std::numeric_limits<IType>::max()-1; // head gets assigned as whichever column of B's row j we last visited
|
1048
|
-
length = 0;
|
1049
|
-
|
1050
|
-
for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // walk through entries in each row
|
1051
|
-
IType j;
|
1052
|
-
|
1053
|
-
if (jj == ia[i+1]) { // if we're in the last entry for this row:
|
1054
|
-
if (!diaga || i >= minmn) continue;
|
1055
|
-
j = i; // if it's a new Yale matrix, and last entry, get the diagonal position (j) and entry (ajj)
|
1056
|
-
v = a[i];
|
1057
|
-
} else {
|
1058
|
-
j = ja[jj]; // if it's not the last entry for this row, get the column (j) and entry (ajj)
|
1059
|
-
v = a[jj];
|
1060
|
-
}
|
1061
|
-
|
1062
|
-
for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) {
|
1063
|
-
|
1064
|
-
IType k;
|
1065
|
-
|
1066
|
-
if (kk == ib[j+1]) { // Get the column id for that entry
|
1067
|
-
if (!diagb || j >= minlm) continue;
|
1068
|
-
k = j;
|
1069
|
-
sums[k] += v*b[k];
|
1070
|
-
} else {
|
1071
|
-
k = jb[kk];
|
1072
|
-
sums[k] += v*b[kk];
|
1073
|
-
}
|
1074
|
-
|
1075
|
-
if (next[k] == std::numeric_limits<IType>::max()) {
|
1076
|
-
next[k] = head;
|
1077
|
-
head = k;
|
1078
|
-
++length;
|
1079
|
-
}
|
1080
|
-
} // end of kk loop
|
1081
|
-
} // end of jj loop
|
1082
|
-
|
1083
|
-
for (IType jj = 0; jj < length; ++jj) {
|
1084
|
-
if (sums[head] != 0) {
|
1085
|
-
if (diagc && head == i) {
|
1086
|
-
c[head] = sums[head];
|
1087
|
-
} else {
|
1088
|
-
jc[n+1+ndnz] = head;
|
1089
|
-
c[n+1+ndnz] = sums[head];
|
1090
|
-
++ndnz;
|
1091
|
-
}
|
1092
|
-
}
|
1093
|
-
|
1094
|
-
temp = head;
|
1095
|
-
head = next[head];
|
1096
|
-
|
1097
|
-
next[temp] = std::numeric_limits<IType>::max();
|
1098
|
-
sums[temp] = 0;
|
1099
|
-
}
|
1100
|
-
|
1101
|
-
ic[i+1] = n+1+ndnz;
|
1102
|
-
}
|
1103
|
-
} /* numbmm_ */
|
1104
|
-
|
1105
|
-
|
1106
|
-
/*
|
1107
|
-
template <typename DType, typename IType>
|
1108
|
-
inline void new_yale_matrix_multiply(const unsigned int m, const IType* ija, const DType* a, const IType* ijb, const DType* b, YALE_STORAGE* c_storage) {
|
1109
|
-
unsigned int n = c_storage->shape[0],
|
1110
|
-
l = c_storage->shape[1];
|
1111
|
-
|
1112
|
-
// Create a working vector of dimension max(m,l,n) and initial value IType::max():
|
1113
|
-
std::vector<IType> mask(std::max(std::max(m,l),n), std::numeric_limits<IType>::max());
|
1114
|
-
|
1115
|
-
for (IType i = 0; i < n; ++i) { // A.rows.each_index do |i|
|
1116
|
-
|
1117
|
-
IType j, k;
|
1118
|
-
size_t ndnz;
|
1119
|
-
|
1120
|
-
for (IType jj = ija[i]; jj <= ija[i+1]; ++jj) { // walk through column pointers for row i of A
|
1121
|
-
j = (jj == ija[i+1]) ? i : ija[jj]; // Get the current column index (handle diagonals last)
|
1122
|
-
|
1123
|
-
if (j >= m) {
|
1124
|
-
if (j == ija[jj]) rb_raise(rb_eIndexError, "ija array for left-hand matrix contains an out-of-bounds column index %u at position %u", jj, j);
|
1125
|
-
else break;
|
1126
|
-
}
|
1127
|
-
|
1128
|
-
for (IType kk = ijb[j]; kk <= ijb[j+1]; ++kk) { // walk through column pointers for row j of B
|
1129
|
-
if (j >= m) continue; // first of all, does B *have* a row j?
|
1130
|
-
k = (kk == ijb[j+1]) ? j : ijb[kk]; // Get the current column index (handle diagonals last)
|
1131
|
-
|
1132
|
-
if (k >= l) {
|
1133
|
-
if (k == ijb[kk]) rb_raise(rb_eIndexError, "ija array for right-hand matrix contains an out-of-bounds column index %u at position %u", kk, k);
|
1134
|
-
else break;
|
1135
|
-
}
|
1136
|
-
|
1137
|
-
if (mask[k] == )
|
1138
|
-
}
|
1139
|
-
|
1140
|
-
}
|
1141
|
-
}
|
1142
|
-
}
|
1143
|
-
*/
|
1144
|
-
|
1145
|
-
// Yale: Symbolic matrix multiply c=a*b
|
1146
|
-
template <typename IType>
|
1147
|
-
inline size_t symbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const bool diaga,
|
1148
|
-
const IType* ib, const IType* jb, const bool diagb, IType* ic, const bool diagc) {
|
1149
|
-
unsigned int max_lmn = std::max(std::max(m,n), l);
|
1150
|
-
IType mask[max_lmn]; // INDEX in the SMMP paper.
|
1151
|
-
IType j, k; /* Local variables */
|
1152
|
-
size_t ndnz = n;
|
1153
|
-
|
1154
|
-
for (IType idx = 0; idx < max_lmn; ++idx)
|
1155
|
-
mask[idx] = std::numeric_limits<IType>::max();
|
1156
|
-
|
1157
|
-
if (ic) { // Only write to ic if it's supplied; otherwise, we're just counting.
|
1158
|
-
if (diagc) ic[0] = n+1;
|
1159
|
-
else ic[0] = 0;
|
1160
|
-
}
|
1161
|
-
|
1162
|
-
IType minmn = std::min(m,n);
|
1163
|
-
IType minlm = std::min(l,m);
|
1164
|
-
|
1165
|
-
for (IType i = 0; i < n; ++i) { // MAIN LOOP: through rows
|
1166
|
-
|
1167
|
-
for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // merge row lists, walking through columns in each row
|
1168
|
-
|
1169
|
-
// j <- column index given by JA[jj], or handle diagonal.
|
1170
|
-
if (jj == ia[i+1]) { // Don't really do it the last time -- just handle diagonals in a new yale matrix.
|
1171
|
-
if (!diaga || i >= minmn) continue;
|
1172
|
-
j = i;
|
1173
|
-
} else j = ja[jj];
|
1174
|
-
|
1175
|
-
for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) { // Now walk through columns K of row J in matrix B.
|
1176
|
-
if (kk == ib[j+1]) {
|
1177
|
-
if (!diagb || j >= minlm) continue;
|
1178
|
-
k = j;
|
1179
|
-
} else k = jb[kk];
|
1180
|
-
|
1181
|
-
if (mask[k] != i) {
|
1182
|
-
mask[k] = i;
|
1183
|
-
++ndnz;
|
1184
|
-
}
|
1185
|
-
}
|
1186
|
-
}
|
1187
|
-
|
1188
|
-
if (diagc && mask[i] == std::numeric_limits<IType>::max()) --ndnz;
|
1189
|
-
|
1190
|
-
if (ic) ic[i+1] = ndnz;
|
1191
|
-
}
|
1192
|
-
|
1193
|
-
return ndnz;
|
1194
|
-
} /* symbmm_ */
|
1195
|
-
|
1196
|
-
|
1197
|
-
// In-place quicksort (from Wikipedia) -- called by smmp_sort_columns, below. All functions are inclusive of left, right.
|
1198
|
-
namespace smmp_sort {
|
1199
|
-
const size_t THRESHOLD = 4; // switch to insertion sort for 4 elements or fewer
|
1200
|
-
|
1201
|
-
template <typename DType, typename IType>
|
1202
|
-
void print_array(DType* vals, IType* array, IType left, IType right) {
|
1203
|
-
for (IType i = left; i <= right; ++i) {
|
1204
|
-
std::cerr << array[i] << ":" << vals[i] << " ";
|
1205
|
-
}
|
1206
|
-
std::cerr << std::endl;
|
1207
|
-
}
|
1208
|
-
|
1209
|
-
template <typename DType, typename IType>
|
1210
|
-
IType partition(DType* vals, IType* array, IType left, IType right, IType pivot) {
|
1211
|
-
IType pivotJ = array[pivot];
|
1212
|
-
DType pivotV = vals[pivot];
|
1213
|
-
|
1214
|
-
// Swap pivot and right
|
1215
|
-
array[pivot] = array[right];
|
1216
|
-
vals[pivot] = vals[right];
|
1217
|
-
array[right] = pivotJ;
|
1218
|
-
vals[right] = pivotV;
|
1219
|
-
|
1220
|
-
IType store = left;
|
1221
|
-
for (IType idx = left; idx < right; ++idx) {
|
1222
|
-
if (array[idx] <= pivotJ) {
|
1223
|
-
// Swap i and store
|
1224
|
-
std::swap(array[idx], array[store]);
|
1225
|
-
std::swap(vals[idx], vals[store]);
|
1226
|
-
++store;
|
1227
|
-
}
|
1228
|
-
}
|
1229
|
-
|
1230
|
-
std::swap(array[store], array[right]);
|
1231
|
-
std::swap(vals[store], vals[right]);
|
1232
|
-
|
1233
|
-
return store;
|
1234
|
-
}
|
1235
|
-
|
1236
|
-
// Recommended to use the median of left, right, and mid for the pivot.
|
1237
|
-
template <typename IType>
|
1238
|
-
IType median(IType a, IType b, IType c) {
|
1239
|
-
if (a < b) {
|
1240
|
-
if (b < c) return b; // a b c
|
1241
|
-
if (a < c) return c; // a c b
|
1242
|
-
return a; // c a b
|
1243
|
-
|
1244
|
-
} else { // a > b
|
1245
|
-
if (a < c) return a; // b a c
|
1246
|
-
if (b < c) return c; // b c a
|
1247
|
-
return b; // c b a
|
1248
|
-
}
|
1249
|
-
}
|
1250
|
-
|
1251
|
-
|
1252
|
-
// Insertion sort is more efficient than quicksort for small N
|
1253
|
-
template <typename DType, typename IType>
|
1254
|
-
void insertion_sort(DType* vals, IType* array, IType left, IType right) {
|
1255
|
-
for (IType idx = left; idx <= right; ++idx) {
|
1256
|
-
IType col_to_insert = array[idx];
|
1257
|
-
DType val_to_insert = vals[idx];
|
1258
|
-
|
1259
|
-
IType hole_pos = idx;
|
1260
|
-
for (; hole_pos > left && col_to_insert < array[hole_pos-1]; --hole_pos) {
|
1261
|
-
array[hole_pos] = array[hole_pos - 1]; // shift the larger column index up
|
1262
|
-
vals[hole_pos] = vals[hole_pos - 1]; // value goes along with it
|
1263
|
-
}
|
1264
|
-
|
1265
|
-
array[hole_pos] = col_to_insert;
|
1266
|
-
vals[hole_pos] = val_to_insert;
|
1267
|
-
}
|
1268
|
-
}
|
1269
|
-
|
1270
|
-
|
1271
|
-
template <typename DType, typename IType>
|
1272
|
-
void quicksort(DType* vals, IType* array, IType left, IType right) {
|
1273
|
-
|
1274
|
-
if (left < right) {
|
1275
|
-
if (right - left < THRESHOLD) {
|
1276
|
-
insertion_sort(vals, array, left, right);
|
1277
|
-
} else {
|
1278
|
-
// choose any pivot such that left < pivot < right
|
1279
|
-
IType pivot = median(left, right, (IType)(((unsigned long)left + (unsigned long)right) / 2));
|
1280
|
-
pivot = partition(vals, array, left, right, pivot);
|
1281
|
-
|
1282
|
-
// recursively sort elements smaller than the pivot
|
1283
|
-
quicksort<DType,IType>(vals, array, left, pivot-1);
|
1284
|
-
|
1285
|
-
// recursively sort elements at least as big as the pivot
|
1286
|
-
quicksort<DType,IType>(vals, array, pivot+1, right);
|
1287
|
-
}
|
1288
|
-
}
|
1289
|
-
}
|
1290
|
-
|
1291
|
-
|
1292
|
-
}; // end of namespace smmp_sort
|
1293
|
-
|
1294
|
-
|
1295
|
-
/*
|
1296
|
-
* For use following symbmm and numbmm. Sorts the matrix entries in each row according to the column index.
|
1297
|
-
* This utilizes quicksort, which is an in-place unstable sort (since there are no duplicate entries, we don't care
|
1298
|
-
* about stability).
|
1299
|
-
*
|
1300
|
-
* TODO: It might be worthwhile to do a test for free memory, and if available, use an unstable sort that isn't in-place.
|
1301
|
-
*
|
1302
|
-
* TODO: It's actually probably possible to write an even faster sort, since symbmm/numbmm are not producing a random
|
1303
|
-
* ordering. If someone is doing a lot of Yale matrix multiplication, it might benefit them to consider even insertion
|
1304
|
-
* sort.
|
1305
|
-
*/
|
1306
|
-
template <typename DType, typename IType>
|
1307
|
-
inline void smmp_sort_columns(const size_t n, const IType* ia, IType* ja, DType* a) {
|
1308
|
-
for (size_t i = 0; i < n; ++i) {
|
1309
|
-
if (ia[i+1] - ia[i] < 2) continue; // no need to sort rows containing only one or two elements.
|
1310
|
-
else if (ia[i+1] - ia[i] <= smmp_sort::THRESHOLD) {
|
1311
|
-
smmp_sort::insertion_sort<DType, IType>(a, ja, ia[i], ia[i+1]-1); // faster for small rows
|
1312
|
-
} else {
|
1313
|
-
smmp_sort::quicksort<DType, IType>(a, ja, ia[i], ia[i+1]-1); // faster for large rows (and may call insertion_sort as well)
|
1314
|
-
}
|
1315
|
-
}
|
1316
|
-
}
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
/*
|
1321
|
-
* Transposes a generic Yale matrix (old or new). Specify new by setting diaga = true.
|
1322
|
-
*
|
1323
|
-
* Based on transp from SMMP (same as symbmm and numbmm).
|
1324
|
-
*
|
1325
|
-
* This is not named in the same way as most yale_storage functions because it does not act on a YALE_STORAGE
|
1326
|
-
* object.
|
1327
|
-
*/
|
1328
|
-
template <typename DType, typename IType>
|
1329
|
-
void transpose_yale(const size_t n, const size_t m, const void* ia_, const void* ja_, const void* a_,
|
1330
|
-
const bool diaga, void* ib_, void* jb_, void* b_, const bool move)
|
1331
|
-
{
|
1332
|
-
const IType *ia = reinterpret_cast<const IType*>(ia_),
|
1333
|
-
*ja = reinterpret_cast<const IType*>(ja_);
|
1334
|
-
const DType *a = reinterpret_cast<const DType*>(a_);
|
1335
|
-
|
1336
|
-
IType *ib = reinterpret_cast<IType*>(ib_),
|
1337
|
-
*jb = reinterpret_cast<IType*>(jb_);
|
1338
|
-
DType *b = reinterpret_cast<DType*>(b_);
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
size_t index;
|
1343
|
-
|
1344
|
-
// Clear B
|
1345
|
-
for (size_t i = 0; i < m+1; ++i) ib[i] = 0;
|
1346
|
-
|
1347
|
-
if (move)
|
1348
|
-
for (size_t i = 0; i < m+1; ++i) b[i] = 0;
|
1349
|
-
|
1350
|
-
if (diaga) ib[0] = m + 1;
|
1351
|
-
else ib[0] = 0;
|
1352
|
-
|
1353
|
-
/* count indices for each column */
|
1354
|
-
|
1355
|
-
for (size_t i = 0; i < n; ++i) {
|
1356
|
-
for (size_t j = ia[i]; j < ia[i+1]; ++j) {
|
1357
|
-
++(ib[ja[j]+1]);
|
1358
|
-
}
|
1359
|
-
}
|
1360
|
-
|
1361
|
-
for (size_t i = 0; i < m; ++i) {
|
1362
|
-
ib[i+1] = ib[i] + ib[i+1];
|
1363
|
-
}
|
1364
|
-
|
1365
|
-
/* now make jb */
|
1366
|
-
|
1367
|
-
for (size_t i = 0; i < n; ++i) {
|
1368
|
-
|
1369
|
-
for (size_t j = ia[i]; j < ia[i+1]; ++j) {
|
1370
|
-
index = ja[j];
|
1371
|
-
jb[ib[index]] = i;
|
1372
|
-
|
1373
|
-
if (move)
|
1374
|
-
b[ib[index]] = a[j];
|
1375
|
-
|
1376
|
-
++(ib[index]);
|
1377
|
-
}
|
1378
|
-
}
|
1379
|
-
|
1380
|
-
/* now fixup ib */
|
1381
|
-
|
1382
|
-
for (size_t i = m; i >= 1; --i) {
|
1383
|
-
ib[i] = ib[i-1];
|
1384
|
-
}
|
1385
|
-
|
1386
|
-
|
1387
|
-
if (diaga) {
|
1388
|
-
if (move) {
|
1389
|
-
size_t j = std::min(n,m);
|
1390
|
-
|
1391
|
-
for (size_t i = 0; i < j; ++i) {
|
1392
|
-
b[i] = a[i];
|
1393
|
-
}
|
1394
|
-
}
|
1395
|
-
ib[0] = m + 1;
|
1396
|
-
|
1397
|
-
} else {
|
1398
|
-
ib[0] = 0;
|
1399
|
-
}
|
1400
|
-
}
|
1401
|
-
|
1402
|
-
|
1403
|
-
/*
|
1404
|
-
* Templated version of row-order and column-order getrf, derived from ATL_getrfR.c (from ATLAS 3.8.0).
|
1405
|
-
*
|
1406
|
-
* 1. Row-major factorization of form
|
1407
|
-
* A = L * U * P
|
1408
|
-
* where P is a column-permutation matrix, L is lower triangular (lower
|
1409
|
-
* trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
|
1410
|
-
* trapazoidal if M < N). This is the recursive Level 3 BLAS version.
|
1411
|
-
*
|
1412
|
-
* 2. Column-major factorization of form
|
1413
|
-
* A = P * L * U
|
1414
|
-
* where P is a row-permutation matrix, L is lower triangular with unit diagonal
|
1415
|
-
* elements (lower trapazoidal if M > N), and U is upper triangular (upper
|
1416
|
-
* trapazoidal if M < N). This is the recursive Level 3 BLAS version.
|
1417
|
-
*
|
1418
|
-
* Template argument determines whether 1 or 2 is utilized.
|
1419
|
-
*/
|
1420
|
-
template <bool RowMajor, typename DType>
|
1421
|
-
inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int* ipiv) {
|
1422
|
-
const int MN = std::min(M, N);
|
1423
|
-
int ierr = 0;
|
1424
|
-
|
1425
|
-
// Symbols used by ATLAS:
|
1426
|
-
// Row Col Us
|
1427
|
-
// Nup Nleft N_ul
|
1428
|
-
// Ndown Nright N_dr
|
1429
|
-
// We're going to use N_ul, N_dr
|
1430
|
-
|
1431
|
-
DType neg_one = -1, one = 1;
|
1432
|
-
|
1433
|
-
if (MN > 1) {
|
1434
|
-
int N_ul = MN >> 1;
|
1435
|
-
|
1436
|
-
// FIXME: Figure out how ATLAS #defines NB
|
1437
|
-
#ifdef NB
|
1438
|
-
if (N_ul > NB) N_ul = ATL_MulByNB(ATL_DivByNB(N_ul));
|
1439
|
-
#endif
|
1440
|
-
|
1441
|
-
int N_dr = M - N_ul;
|
1442
|
-
|
1443
|
-
int i = RowMajor ? getrf_nothrow<true,DType>(N_ul, N, A, lda, ipiv) : getrf_nothrow<false,DType>(M, N_ul, A, lda, ipiv);
|
1444
|
-
|
1445
|
-
if (i) if (!ierr) ierr = i;
|
1446
|
-
|
1447
|
-
DType *Ar, *Ac, *An;
|
1448
|
-
if (RowMajor) {
|
1449
|
-
Ar = &(A[N_ul * lda]),
|
1450
|
-
Ac = &(A[N_ul]);
|
1451
|
-
An = &(Ar[N_ul]);
|
1452
|
-
|
1453
|
-
nm::math::laswp<DType>(N_dr, Ar, lda, 0, N_ul, ipiv, 1);
|
1454
|
-
|
1455
|
-
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, N_dr, N_ul, one, A, lda, Ar, lda);
|
1456
|
-
nm::math::gemm<DType>(CblasRowMajor, CblasNoTrans, CblasNoTrans, N_dr, N-N_ul, N_ul, &neg_one, Ar, lda, Ac, lda, &one, An, lda);
|
1457
|
-
|
1458
|
-
i = getrf_nothrow<true,DType>(N_dr, N-N_ul, An, lda, ipiv+N_ul);
|
1459
|
-
} else {
|
1460
|
-
Ar = NULL;
|
1461
|
-
Ac = &(A[N_ul * lda]);
|
1462
|
-
An = &(Ac[N_ul]);
|
1463
|
-
|
1464
|
-
nm::math::laswp<DType>(N_dr, Ac, lda, 0, N_ul, ipiv, 1);
|
1465
|
-
|
1466
|
-
nm::math::trsm<DType>(CblasColMajor, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N_ul, N_dr, one, A, lda, Ac, lda);
|
1467
|
-
nm::math::gemm<DType>(CblasColMajor, CblasNoTrans, CblasNoTrans, M-N_ul, N_dr, N_ul, &neg_one, An, lda, Ac, lda, &one, An, lda);
|
1468
|
-
|
1469
|
-
i = getrf_nothrow<false,DType>(M-N_ul, N_dr, An, lda, ipiv+N_ul);
|
1470
|
-
}
|
1471
|
-
|
1472
|
-
if (i) if (!ierr) ierr = N_ul + i;
|
1473
|
-
|
1474
|
-
for (i = N_ul; i != MN; i++) {
|
1475
|
-
ipiv[i] += N_ul;
|
1476
|
-
}
|
1477
|
-
|
1478
|
-
nm::math::laswp<DType>(N_ul, A, lda, N_ul, MN, ipiv, 1); /* apply pivots */
|
1479
|
-
|
1480
|
-
} else if (MN == 1) { // there's another case for the colmajor version, but i don't know that it's that critical. Calls ATLAS LU2, who knows what that does.
|
1481
|
-
|
1482
|
-
int i = *ipiv = nm::math::lapack::idamax<DType>(N, A, 1); // cblas_iamax(N, A, 1);
|
1483
|
-
|
1484
|
-
DType tmp = A[i];
|
1485
|
-
if (tmp != 0) {
|
1486
|
-
|
1487
|
-
nm::math::lapack::scal<DType>((RowMajor ? N : M), nm::math::numeric_inverse(tmp), A, 1);
|
1488
|
-
A[i] = *A;
|
1489
|
-
*A = tmp;
|
1490
|
-
|
1491
|
-
} else ierr = 1;
|
1492
|
-
|
1493
|
-
}
|
1494
|
-
return(ierr);
|
1495
|
-
}
|
1496
|
-
|
1497
|
-
/*
|
1498
|
-
* Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF.
|
1499
|
-
*
|
1500
|
-
* From ATLAS 3.8.0.
|
1501
|
-
*/
|
1502
|
-
template <typename DType>
|
1503
|
-
int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A,
|
1504
|
-
const int lda, const int* ipiv, DType* B, const int ldb)
|
1505
|
-
{
|
1506
|
-
// enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
|
1507
|
-
|
1508
|
-
if (!N || !NRHS) return 0;
|
1509
|
-
|
1510
|
-
const DType ONE = 1;
|
1511
|
-
|
1512
|
-
if (Order == CblasColMajor) {
|
1513
|
-
if (Trans == CblasNoTrans) {
|
1514
|
-
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
|
1515
|
-
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1516
|
-
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1517
|
-
} else {
|
1518
|
-
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1519
|
-
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1520
|
-
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
|
1521
|
-
}
|
1522
|
-
} else {
|
1523
|
-
if (Trans == CblasNoTrans) {
|
1524
|
-
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1525
|
-
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1526
|
-
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
|
1527
|
-
} else {
|
1528
|
-
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
|
1529
|
-
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1530
|
-
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1531
|
-
}
|
1532
|
-
}
|
1533
|
-
return 0;
|
1534
|
-
}
|
1535
|
-
|
1536
|
-
|
1537
|
-
/*
|
1538
|
-
* Solves a system of linear equations A*X = B with a symmetric positive definite matrix A using the Cholesky factorization computed by POTRF.
|
1539
|
-
*
|
1540
|
-
* From ATLAS 3.8.0.
|
1541
|
-
*/
|
1542
|
-
template <typename DType, bool is_complex>
|
1543
|
-
int potrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, const DType* A,
|
1544
|
-
const int lda, DType* B, const int ldb)
|
1545
|
-
{
|
1546
|
-
// enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
|
1547
|
-
|
1548
|
-
CBLAS_TRANSPOSE MyTrans = is_complex ? CblasConjTrans : CblasTrans;
|
1549
|
-
|
1550
|
-
if (!N || !NRHS) return 0;
|
1551
|
-
|
1552
|
-
const DType ONE = 1;
|
1553
|
-
|
1554
|
-
if (Order == CblasColMajor) {
|
1555
|
-
if (Uplo == CblasUpper) {
|
1556
|
-
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1557
|
-
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1558
|
-
} else {
|
1559
|
-
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1560
|
-
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1561
|
-
}
|
1562
|
-
} else {
|
1563
|
-
// There's some kind of scaling operation that normally happens here in ATLAS. Not sure what it does, so we'll only
|
1564
|
-
// worry if something breaks. It probably has to do with their non-templated code and doesn't apply to us.
|
1565
|
-
|
1566
|
-
if (Uplo == CblasUpper) {
|
1567
|
-
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1568
|
-
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1569
|
-
} else {
|
1570
|
-
nm::math::trsm<DType>(Order, CblasRight, CblasLower, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1571
|
-
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1572
|
-
}
|
1573
|
-
}
|
1574
|
-
return 0;
|
1575
|
-
}
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
/*
|
1580
|
-
* From ATLAS 3.8.0:
|
1581
|
-
*
|
1582
|
-
* Computes one of two LU factorizations based on the setting of the Order
|
1583
|
-
* parameter, as follows:
|
1584
|
-
* ----------------------------------------------------------------------------
|
1585
|
-
* Order == CblasColMajor
|
1586
|
-
* Column-major factorization of form
|
1587
|
-
* A = P * L * U
|
1588
|
-
* where P is a row-permutation matrix, L is lower triangular with unit
|
1589
|
-
* diagonal elements (lower trapazoidal if M > N), and U is upper triangular
|
1590
|
-
* (upper trapazoidal if M < N).
|
1591
|
-
*
|
1592
|
-
* ----------------------------------------------------------------------------
|
1593
|
-
* Order == CblasRowMajor
|
1594
|
-
* Row-major factorization of form
|
1595
|
-
* A = P * L * U
|
1596
|
-
* where P is a column-permutation matrix, L is lower triangular (lower
|
1597
|
-
* trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
|
1598
|
-
* trapazoidal if M < N).
|
1599
|
-
*
|
1600
|
-
* ============================================================================
|
1601
|
-
* Let IERR be the return value of the function:
|
1602
|
-
* If IERR == 0, successful exit.
|
1603
|
-
* If (IERR < 0) the -IERR argument had an illegal value
|
1604
|
-
* If (IERR > 0 && Order == CblasColMajor)
|
1605
|
-
* U(i-1,i-1) is exactly zero. The factorization has been completed,
|
1606
|
-
* but the factor U is exactly singular, and division by zero will
|
1607
|
-
* occur if it is used to solve a system of equations.
|
1608
|
-
* If (IERR > 0 && Order == CblasRowMajor)
|
1609
|
-
* L(i-1,i-1) is exactly zero. The factorization has been completed,
|
1610
|
-
* but the factor L is exactly singular, and division by zero will
|
1611
|
-
* occur if it is used to solve a system of equations.
|
1612
|
-
*/
|
1613
|
-
template <typename DType>
|
1614
|
-
inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, DType* A, int lda, int* ipiv) {
|
1615
|
-
if (Order == CblasRowMajor) {
|
1616
|
-
if (lda < std::max(1,N)) {
|
1617
|
-
rb_raise(rb_eArgError, "GETRF: lda must be >= MAX(N,1): lda=%d N=%d", lda, N);
|
1618
|
-
return -6;
|
1619
|
-
}
|
1620
|
-
|
1621
|
-
return getrf_nothrow<true,DType>(M, N, A, lda, ipiv);
|
1622
|
-
} else {
|
1623
|
-
if (lda < std::max(1,M)) {
|
1624
|
-
rb_raise(rb_eArgError, "GETRF: lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
|
1625
|
-
return -6;
|
1626
|
-
}
|
1627
|
-
|
1628
|
-
return getrf_nothrow<false,DType>(M, N, A, lda, ipiv);
|
1629
|
-
//rb_raise(rb_eNotImpError, "column major getrf not implemented");
|
1630
|
-
}
|
1631
|
-
}
|
1632
|
-
|
1633
|
-
|
1634
|
-
/*
|
1635
|
-
* From ATLAS 3.8.0:
|
1636
|
-
*
|
1637
|
-
* Computes one of two LU factorizations based on the setting of the Order
|
1638
|
-
* parameter, as follows:
|
1639
|
-
* ----------------------------------------------------------------------------
|
1640
|
-
* Order == CblasColMajor
|
1641
|
-
* Column-major factorization of form
|
1642
|
-
* A = P * L * U
|
1643
|
-
* where P is a row-permutation matrix, L is lower triangular with unit
|
1644
|
-
* diagonal elements (lower trapazoidal if M > N), and U is upper triangular
|
1645
|
-
* (upper trapazoidal if M < N).
|
1646
|
-
*
|
1647
|
-
* ----------------------------------------------------------------------------
|
1648
|
-
* Order == CblasRowMajor
|
1649
|
-
* Row-major factorization of form
|
1650
|
-
* A = P * L * U
|
1651
|
-
* where P is a column-permutation matrix, L is lower triangular (lower
|
1652
|
-
* trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
|
1653
|
-
* trapazoidal if M < N).
|
1654
|
-
*
|
1655
|
-
* ============================================================================
|
1656
|
-
* Let IERR be the return value of the function:
|
1657
|
-
* If IERR == 0, successful exit.
|
1658
|
-
* If (IERR < 0) the -IERR argument had an illegal value
|
1659
|
-
* If (IERR > 0 && Order == CblasColMajor)
|
1660
|
-
* U(i-1,i-1) is exactly zero. The factorization has been completed,
|
1661
|
-
* but the factor U is exactly singular, and division by zero will
|
1662
|
-
* occur if it is used to solve a system of equations.
|
1663
|
-
* If (IERR > 0 && Order == CblasRowMajor)
|
1664
|
-
* L(i-1,i-1) is exactly zero. The factorization has been completed,
|
1665
|
-
* but the factor L is exactly singular, and division by zero will
|
1666
|
-
* occur if it is used to solve a system of equations.
|
1667
|
-
*/
|
1668
|
-
template <typename DType>
|
1669
|
-
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
|
1670
|
-
#ifdef HAVE_CLAPACK_H
|
1671
|
-
rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
|
1672
|
-
#else
|
1673
|
-
rb_raise(rb_eNotImpError, "only LAPACK version implemented thus far");
|
1674
|
-
#endif
|
1675
|
-
return 0;
|
1676
|
-
}
|
1677
|
-
|
1678
|
-
#ifdef HAVE_CLAPACK_H
|
1679
|
-
template <>
|
1680
|
-
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
|
1681
|
-
return clapack_spotrf(order, uplo, N, A, lda);
|
1682
|
-
}
|
1683
|
-
|
1684
|
-
template <>
|
1685
|
-
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
|
1686
|
-
return clapack_dpotrf(order, uplo, N, A, lda);
|
1687
|
-
}
|
1688
|
-
|
1689
|
-
template <>
|
1690
|
-
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
|
1691
|
-
return clapack_cpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
|
1692
|
-
}
|
1693
|
-
|
1694
|
-
template <>
|
1695
|
-
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
|
1696
|
-
return clapack_zpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
|
1697
|
-
}
|
1698
|
-
#endif
|
1699
|
-
|
1700
|
-
|
1701
|
-
// This is the old BLAS version of this function. ATLAS has an optimized version, but
|
1702
|
-
// it's going to be tough to translate.
|
1703
|
-
template <typename DType>
|
1704
|
-
static void swap(const int N, DType* X, const int incX, DType* Y, const int incY) {
|
1705
|
-
if (N > 0) {
|
1706
|
-
int ix = 0, iy = 0;
|
1707
|
-
for (int i = 0; i < N; ++i) {
|
1708
|
-
DType temp = X[i];
|
1709
|
-
X[i] = Y[i];
|
1710
|
-
Y[i] = temp;
|
1711
|
-
|
1712
|
-
ix += incX;
|
1713
|
-
iy += incY;
|
1714
|
-
}
|
1715
|
-
}
|
1716
|
-
}
|
1717
|
-
|
1718
|
-
|
1719
|
-
// Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
|
1720
|
-
//
|
1721
|
-
// From ATLAS 3.8.0.
|
1722
|
-
template <typename DType>
|
1723
|
-
static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
|
1724
|
-
|
1725
|
-
for (int i = 0; i != M; ++i) {
|
1726
|
-
for (int j = i+1; j < N; ++j) {
|
1727
|
-
C[j] = U[j];
|
1728
|
-
U[j] = 0;
|
1729
|
-
}
|
1730
|
-
|
1731
|
-
C += ldc;
|
1732
|
-
U += ldu;
|
1733
|
-
}
|
1734
|
-
}
|
1735
|
-
|
1736
|
-
|
1737
|
-
/*
|
1738
|
-
* Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
|
1739
|
-
* functions. This is probably really complicated.
|
1740
|
-
*
|
1741
|
-
* Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
|
1742
|
-
*
|
1743
|
-
*/
|
1744
|
-
|
1745
|
-
/*
|
1746
|
-
|
1747
|
-
template <bool RowMajor, bool Upper, typename DType>
|
1748
|
-
static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
|
1749
|
-
|
1750
|
-
if (RowMajor) {
|
1751
|
-
DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
|
1752
|
-
DType tmp;
|
1753
|
-
if (Upper) {
|
1754
|
-
DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
|
1755
|
-
A12 = pA1[2], A13 = pA1[3],
|
1756
|
-
A23 = pA2[3];
|
1757
|
-
|
1758
|
-
if (Diag == CblasNonUnit) {
|
1759
|
-
pA0->inverse();
|
1760
|
-
(pA1+1)->inverse();
|
1761
|
-
(pA2+2)->inverse();
|
1762
|
-
(pA3+3)->inverse();
|
1763
|
-
|
1764
|
-
pA0[1] = -A01 * pA1[1] * pA0[0];
|
1765
|
-
pA1[2] = -A12 * pA2[2] * pA1[1];
|
1766
|
-
pA2[3] = -A23 * pA3[3] * pA2[2];
|
1767
|
-
|
1768
|
-
pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
|
1769
|
-
pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
|
1770
|
-
|
1771
|
-
pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
|
1772
|
-
|
1773
|
-
} else {
|
1774
|
-
|
1775
|
-
pA0[1] = -A01;
|
1776
|
-
pA1[2] = -A12;
|
1777
|
-
pA2[3] = -A23;
|
1778
|
-
|
1779
|
-
pA0[2] = -(A01 * pA1[2] + A02);
|
1780
|
-
pA1[3] = -(A12 * pA2[3] + A13);
|
1781
|
-
|
1782
|
-
pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
|
1783
|
-
}
|
1784
|
-
|
1785
|
-
} else { // Lower
|
1786
|
-
DType A10 = pA1[0],
|
1787
|
-
A20 = pA2[0], A21 = pA2[1],
|
1788
|
-
A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
|
1789
|
-
DType *B10 = pA1,
|
1790
|
-
*B20 = pA2,
|
1791
|
-
*B30 = pA3,
|
1792
|
-
*B21 = pA2+1,
|
1793
|
-
*B31 = pA3+1,
|
1794
|
-
*B32 = pA3+2;
|
1795
|
-
|
1796
|
-
|
1797
|
-
if (Diag == CblasNonUnit) {
|
1798
|
-
pA0->inverse();
|
1799
|
-
(pA1+1)->inverse();
|
1800
|
-
(pA2+2)->inverse();
|
1801
|
-
(pA3+3)->inverse();
|
1802
|
-
|
1803
|
-
*B10 = -A10 * pA0[0] * pA1[1];
|
1804
|
-
*B21 = -A21 * pA1[1] * pA2[2];
|
1805
|
-
*B32 = -A32 * pA2[2] * pA3[3];
|
1806
|
-
*B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
|
1807
|
-
*B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
|
1808
|
-
*B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
|
1809
|
-
} else {
|
1810
|
-
*B10 = -A10;
|
1811
|
-
*B21 = -A21;
|
1812
|
-
*B32 = -A32;
|
1813
|
-
*B20 = -(A20 + A21 * (*B10));
|
1814
|
-
*B31 = -(A31 + A32 * (*B21));
|
1815
|
-
*B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
|
1816
|
-
}
|
1817
|
-
}
|
1818
|
-
|
1819
|
-
} else {
|
1820
|
-
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1821
|
-
}
|
1822
|
-
|
1823
|
-
return 0;
|
1824
|
-
|
1825
|
-
}
|
1826
|
-
|
1827
|
-
|
1828
|
-
template <bool RowMajor, bool Upper, typename DType>
|
1829
|
-
static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
|
1830
|
-
|
1831
|
-
if (RowMajor) {
|
1832
|
-
|
1833
|
-
DType tmp;
|
1834
|
-
|
1835
|
-
if (Upper) {
|
1836
|
-
DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
|
1837
|
-
A12 = pA1[2], A13 = pA1[3];
|
1838
|
-
|
1839
|
-
DType *B01 = pA0 + 1,
|
1840
|
-
*B02 = pA0 + 2,
|
1841
|
-
*B12 = pA1 + 2;
|
1842
|
-
|
1843
|
-
if (Diag == CblasNonUnit) {
|
1844
|
-
pA0->inverse();
|
1845
|
-
(pA1+1)->inverse();
|
1846
|
-
(pA2+2)->inverse();
|
1847
|
-
|
1848
|
-
*B01 = -A01 * pA1[1] * pA0[0];
|
1849
|
-
*B12 = -A12 * pA2[2] * pA1[1];
|
1850
|
-
*B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
|
1851
|
-
} else {
|
1852
|
-
*B01 = -A01;
|
1853
|
-
*B12 = -A12;
|
1854
|
-
*B02 = -(A01 * (*B12) + A02);
|
1855
|
-
}
|
1856
|
-
|
1857
|
-
} else { // Lower
|
1858
|
-
DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
|
1859
|
-
DType A10=pA1[0],
|
1860
|
-
A20=pA2[0], A21=pA2[1];
|
1861
|
-
|
1862
|
-
DType *B10 = pA1,
|
1863
|
-
*B20 = pA2;
|
1864
|
-
*B21 = pA2+1;
|
1865
|
-
|
1866
|
-
if (Diag == CblasNonUnit) {
|
1867
|
-
pA0->inverse();
|
1868
|
-
(pA1+1)->inverse();
|
1869
|
-
(pA2+2)->inverse();
|
1870
|
-
*B10 = -A10 * pA0[0] * pA1[1];
|
1871
|
-
*B21 = -A21 * pA1[1] * pA2[2];
|
1872
|
-
*B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
|
1873
|
-
} else {
|
1874
|
-
*B10 = -A10;
|
1875
|
-
*B21 = -A21;
|
1876
|
-
*B20 = -(A20 + A21 * (*B10));
|
1877
|
-
}
|
1878
|
-
}
|
1879
|
-
|
1880
|
-
|
1881
|
-
} else {
|
1882
|
-
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1883
|
-
}
|
1884
|
-
|
1885
|
-
return 0;
|
1886
|
-
|
1887
|
-
}
|
1888
|
-
|
1889
|
-
template <bool RowMajor, bool Upper, bool Real, typename DType>
|
1890
|
-
static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
|
1891
|
-
DType *Age, *Atr;
|
1892
|
-
DType tmp;
|
1893
|
-
int Nleft, Nright;
|
1894
|
-
|
1895
|
-
int ierr = 0;
|
1896
|
-
|
1897
|
-
static const DType ONE = 1;
|
1898
|
-
static const DType MONE -1;
|
1899
|
-
static const DType NONE = -1;
|
1900
|
-
|
1901
|
-
if (RowMajor) {
|
1902
|
-
|
1903
|
-
// FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
|
1904
|
-
if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
|
1905
|
-
Nleft = N >> 1;
|
1906
|
-
#ifdef NB
|
1907
|
-
if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
|
1908
|
-
#endif
|
1909
|
-
|
1910
|
-
Nright = N - Nleft;
|
1911
|
-
|
1912
|
-
if (Upper) {
|
1913
|
-
Age = A + Nleft;
|
1914
|
-
Atr = A + (Nleft * (lda+1));
|
1915
|
-
|
1916
|
-
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
|
1917
|
-
Nleft, Nright, ONE, Atr, lda, Age, lda);
|
1918
|
-
|
1919
|
-
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
|
1920
|
-
Nleft, Nright, MONE, A, lda, Age, lda);
|
1921
|
-
|
1922
|
-
} else { // Lower
|
1923
|
-
Age = A + ((Nleft*lda));
|
1924
|
-
Atr = A + (Nleft * (lda+1));
|
1925
|
-
|
1926
|
-
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
|
1927
|
-
Nright, Nleft, ONE, A, lda, Age, lda);
|
1928
|
-
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
|
1929
|
-
Nright, Nleft, MONE, Atr, lda, Age, lda);
|
1930
|
-
}
|
1931
|
-
|
1932
|
-
ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
|
1933
|
-
if (ierr) return ierr;
|
1934
|
-
|
1935
|
-
ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
|
1936
|
-
if (ierr) return ierr + Nleft;
|
1937
|
-
|
1938
|
-
} else {
|
1939
|
-
if (Real) {
|
1940
|
-
if (N == 4) {
|
1941
|
-
return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
|
1942
|
-
} else if (N == 3) {
|
1943
|
-
return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
|
1944
|
-
} else if (N == 2) {
|
1945
|
-
if (Diag == CblasNonUnit) {
|
1946
|
-
A->inverse();
|
1947
|
-
(A+(lda+1))->inverse();
|
1948
|
-
|
1949
|
-
if (Upper) {
|
1950
|
-
*(A+1) *= *A; // TRI_MUL
|
1951
|
-
*(A+1) *= *(A+lda+1); // TRI_MUL
|
1952
|
-
} else {
|
1953
|
-
*(A+lda) *= *A; // TRI_MUL
|
1954
|
-
*(A+lda) *= *(A+lda+1); // TRI_MUL
|
1955
|
-
}
|
1956
|
-
}
|
1957
|
-
|
1958
|
-
if (Upper) *(A+1) = -*(A+1); // TRI_NEG
|
1959
|
-
else *(A+lda) = -*(A+lda); // TRI_NEG
|
1960
|
-
} else if (Diag == CblasNonUnit) A->inverse();
|
1961
|
-
} else { // not real
|
1962
|
-
if (Diag == CblasNonUnit) A->inverse();
|
1963
|
-
}
|
1964
|
-
}
|
1965
|
-
|
1966
|
-
} else {
|
1967
|
-
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1968
|
-
}
|
1969
|
-
|
1970
|
-
return ierr;
|
1971
|
-
}
|
1972
|
-
|
1973
|
-
|
1974
|
-
template <bool RowMajor, bool Real, typename DType>
|
1975
|
-
int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
|
1976
|
-
|
1977
|
-
if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1978
|
-
|
1979
|
-
int jb, nb, I, ndown, iret;
|
1980
|
-
|
1981
|
-
const DType ONE = 1, NONE = -1;
|
1982
|
-
|
1983
|
-
int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
|
1984
|
-
if (!iret && N > 1) {
|
1985
|
-
jb = lwrk / N;
|
1986
|
-
if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
|
1987
|
-
else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
|
1988
|
-
else nb = jb;
|
1989
|
-
if (!nb) return -6; // need at least 1 row of workspace
|
1990
|
-
|
1991
|
-
// only first iteration will have partial block, unroll it
|
1992
|
-
|
1993
|
-
jb = N - (N/nb) * nb;
|
1994
|
-
if (!jb) jb = nb;
|
1995
|
-
I = N - jb;
|
1996
|
-
A += lda * I;
|
1997
|
-
trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
|
1998
|
-
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
|
1999
|
-
jb, N, ONE, wrk, jb, A, lda);
|
2000
|
-
|
2001
|
-
if (I) {
|
2002
|
-
do {
|
2003
|
-
I -= nb;
|
2004
|
-
A -= nb * lda;
|
2005
|
-
ndown = N-I;
|
2006
|
-
trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
|
2007
|
-
nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
|
2008
|
-
nb, N, ONE, wrk, ndown, A, lda);
|
2009
|
-
} while (I);
|
2010
|
-
}
|
2011
|
-
|
2012
|
-
// Apply row interchanges
|
2013
|
-
|
2014
|
-
for (I = N - 2; I >= 0; --I) {
|
2015
|
-
jb = ipiv[I];
|
2016
|
-
if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
|
2017
|
-
}
|
2018
|
-
}
|
2019
|
-
|
2020
|
-
return iret;
|
2021
|
-
}
|
2022
|
-
*/
|
2023
|
-
|
2024
|
-
|
2025
|
-
// TODO: Test this to see if it works properly on complex. ATLAS has a separate algorithm for complex, which looks like
|
2026
|
-
// TODO: it may actually be the same one.
|
2027
|
-
//
|
2028
|
-
// This function is called ATL_rot in ATLAS 3.8.4.
|
2029
|
-
template <typename DType>
|
2030
|
-
inline void rot_helper(const int N, DType* X, const int incX, DType* Y, const int incY, const DType c, const DType s) {
|
2031
|
-
if (c != 1 || s != 0) {
|
2032
|
-
if (incX == 1 && incY == 1) {
|
2033
|
-
for (int i = 0; i != N; ++i) {
|
2034
|
-
DType tmp = X[i] * c + Y[i] * s;
|
2035
|
-
Y[i] = Y[i] * c - X[i] * s;
|
2036
|
-
X[i] = tmp;
|
2037
|
-
}
|
2038
|
-
} else {
|
2039
|
-
for (int i = N; i > 0; --i, Y += incY, X += incX) {
|
2040
|
-
DType tmp = *X * c + *Y * s;
|
2041
|
-
*Y = *Y * c - *X * s;
|
2042
|
-
*X = tmp;
|
2043
|
-
}
|
2044
|
-
}
|
2045
|
-
}
|
2046
|
-
}
|
2047
|
-
|
2048
|
-
|
2049
|
-
/* Givens plane rotation. From ATLAS 3.8.4. */
|
2050
|
-
// FIXME: Need a specialized algorithm for Rationals. BLAS' algorithm simply will not work for most values due to the
|
2051
|
-
// FIXME: sqrt.
|
2052
|
-
template <typename DType>
|
2053
|
-
inline void rotg(DType* a, DType* b, DType* c, DType* s) {
|
2054
|
-
DType aa = std::abs(*a), ab = std::abs(*b);
|
2055
|
-
DType roe = aa > ab ? *a : *b;
|
2056
|
-
DType scal = aa + ab;
|
2057
|
-
|
2058
|
-
if (scal == 0) {
|
2059
|
-
*c = 1;
|
2060
|
-
*s = *a = *b = 0;
|
2061
|
-
} else {
|
2062
|
-
DType t0 = aa / scal, t1 = ab / scal;
|
2063
|
-
DType r = scal * std::sqrt(t0 * t0 + t1 * t1);
|
2064
|
-
if (roe < 0) r = -r;
|
2065
|
-
*c = *a / r;
|
2066
|
-
*s = *b / r;
|
2067
|
-
DType z = (*c != 0) ? (1 / *c) : DType(1);
|
2068
|
-
*a = r;
|
2069
|
-
*b = z;
|
2070
|
-
}
|
2071
|
-
}
|
2072
|
-
|
2073
|
-
template <>
|
2074
|
-
inline void rotg(float* a, float* b, float* c, float* s) {
|
2075
|
-
cblas_srotg(a, b, c, s);
|
2076
|
-
}
|
2077
|
-
|
2078
|
-
template <>
|
2079
|
-
inline void rotg(double* a, double* b, double* c, double* s) {
|
2080
|
-
cblas_drotg(a, b, c, s);
|
2081
|
-
}
|
2082
|
-
|
2083
|
-
template <>
|
2084
|
-
inline void rotg(Complex64* a, Complex64* b, Complex64* c, Complex64* s) {
|
2085
|
-
cblas_crotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
|
2086
|
-
}
|
2087
|
-
|
2088
|
-
template <>
|
2089
|
-
inline void rotg(Complex128* a, Complex128* b, Complex128* c, Complex128* s) {
|
2090
|
-
cblas_zrotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
|
2091
|
-
}
|
2092
|
-
|
2093
|
-
template <typename DType>
|
2094
|
-
inline void cblas_rotg(void* a, void* b, void* c, void* s) {
|
2095
|
-
rotg<DType>(reinterpret_cast<DType*>(a), reinterpret_cast<DType*>(b), reinterpret_cast<DType*>(c), reinterpret_cast<DType*>(s));
|
2096
|
-
}
|
2097
|
-
|
2098
|
-
|
2099
|
-
/* Applies a plane rotation. From ATLAS 3.8.4. */
|
2100
|
-
template <typename DType, typename CSDType>
|
2101
|
-
inline void rot(const int N, DType* X, const int incX, DType* Y, const int incY, const CSDType c, const CSDType s) {
|
2102
|
-
DType *x = X, *y = Y;
|
2103
|
-
int incx = incX, incy = incY;
|
2104
|
-
|
2105
|
-
if (N > 0) {
|
2106
|
-
if (incX < 0) {
|
2107
|
-
if (incY < 0) { incx = -incx; incy = -incy; }
|
2108
|
-
else x += -incX * (N-1);
|
2109
|
-
} else if (incY < 0) {
|
2110
|
-
incy = -incy;
|
2111
|
-
incx = -incx;
|
2112
|
-
x += (N-1) * incX;
|
2113
|
-
}
|
2114
|
-
rot_helper<DType>(N, x, incx, y, incy, c, s);
|
2115
|
-
}
|
2116
|
-
}
|
2117
|
-
|
2118
|
-
template <>
|
2119
|
-
inline void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) {
|
2120
|
-
cblas_srot(N, X, incX, Y, incY, (float)c, (float)s);
|
2121
|
-
}
|
2122
|
-
|
2123
|
-
template <>
|
2124
|
-
inline void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) {
|
2125
|
-
cblas_drot(N, X, incX, Y, incY, c, s);
|
2126
|
-
}
|
2127
|
-
|
2128
|
-
template <>
|
2129
|
-
inline void rot(const int N, Complex64* X, const int incX, Complex64* Y, const int incY, const float c, const float s) {
|
2130
|
-
cblas_csrot(N, X, incX, Y, incY, c, s);
|
2131
|
-
}
|
2132
|
-
|
2133
|
-
template <>
|
2134
|
-
inline void rot(const int N, Complex128* X, const int incX, Complex128* Y, const int incY, const double c, const double s) {
|
2135
|
-
cblas_zdrot(N, X, incX, Y, incY, c, s);
|
2136
|
-
}
|
2137
|
-
|
2138
|
-
|
2139
|
-
template <typename DType, typename CSDType>
|
2140
|
-
inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int incY, const void* c, const void* s) {
|
2141
|
-
rot<DType,CSDType>(N, reinterpret_cast<DType*>(X), incX, reinterpret_cast<DType*>(Y), incY,
|
2142
|
-
*reinterpret_cast<const CSDType*>(c), *reinterpret_cast<const CSDType*>(s));
|
2143
|
-
}
|
2144
|
-
|
2145
|
-
/*
|
2146
|
-
* Level 1 BLAS routine which returns the 2-norm of an n-vector x.
|
2147
|
-
#
|
2148
|
-
* Based on input types, these are the valid return types:
|
2149
|
-
* int -> int
|
2150
|
-
* float -> float or double
|
2151
|
-
* double -> double
|
2152
|
-
* complex64 -> float or double
|
2153
|
-
* complex128 -> double
|
2154
|
-
* rational -> rational
|
2155
|
-
*/
|
2156
|
-
template <typename ReturnDType, typename DType>
|
2157
|
-
ReturnDType nrm2(const int N, const DType* X, const int incX) {
|
2158
|
-
const DType ONE = 1, ZERO = 0;
|
2159
|
-
typename LongDType<DType>::type scale = 0, ssq = 1, absxi, temp;
|
2160
|
-
|
2161
|
-
|
2162
|
-
if ((N < 1) || (incX < 1)) return ZERO;
|
2163
|
-
else if (N == 1) return std::abs(X[0]);
|
2164
|
-
|
2165
|
-
for (int i = 0; i < N; ++i) {
|
2166
|
-
absxi = std::abs(X[i*incX]);
|
2167
|
-
if (scale < absxi) {
|
2168
|
-
temp = scale / absxi;
|
2169
|
-
scale = absxi;
|
2170
|
-
ssq = ONE + ssq * (temp * temp);
|
2171
|
-
} else {
|
2172
|
-
temp = absxi / scale;
|
2173
|
-
ssq += temp * temp;
|
2174
|
-
}
|
2175
|
-
}
|
2176
|
-
|
2177
|
-
return scale * std::sqrt( ssq );
|
2178
|
-
}
|
2179
|
-
|
2180
|
-
|
2181
|
-
#ifdef HAVE_CBLAS_H
|
2182
|
-
template <>
|
2183
|
-
inline float nrm2(const int N, const float* X, const int incX) {
|
2184
|
-
return cblas_snrm2(N, X, incX);
|
2185
|
-
}
|
2186
|
-
|
2187
|
-
template <>
|
2188
|
-
inline double nrm2(const int N, const double* X, const int incX) {
|
2189
|
-
return cblas_dnrm2(N, X, incX);
|
2190
|
-
}
|
2191
|
-
|
2192
|
-
template <>
|
2193
|
-
inline float nrm2(const int N, const Complex64* X, const int incX) {
|
2194
|
-
return cblas_scnrm2(N, X, incX);
|
2195
|
-
}
|
2196
|
-
|
2197
|
-
template <>
|
2198
|
-
inline double nrm2(const int N, const Complex128* X, const int incX) {
|
2199
|
-
return cblas_dznrm2(N, X, incX);
|
2200
|
-
}
|
2201
|
-
#else
|
2202
|
-
template <typename FloatDType>
|
2203
|
-
static inline void nrm2_complex_helper(const FloatDType& xr, const FloatDType& xi, double& scale, double& ssq) {
|
2204
|
-
double absx = std::abs(xr);
|
2205
|
-
if (scale < absx) {
|
2206
|
-
double temp = scale / absx;
|
2207
|
-
scale = absx;
|
2208
|
-
ssq = 1.0 + ssq * (temp * temp);
|
2209
|
-
} else {
|
2210
|
-
double temp = absx / scale;
|
2211
|
-
ssq += temp * temp;
|
2212
|
-
}
|
2213
|
-
|
2214
|
-
absx = std::abs(xi);
|
2215
|
-
if (scale < absx) {
|
2216
|
-
double temp = scale / absx;
|
2217
|
-
scale = absx;
|
2218
|
-
ssq = 1.0 + ssq * (temp * temp);
|
2219
|
-
} else {
|
2220
|
-
double temp = absx / scale;
|
2221
|
-
ssq += temp * temp;
|
2222
|
-
}
|
2223
|
-
}
|
2224
|
-
|
2225
|
-
template <>
|
2226
|
-
float nrm2(const int N, const Complex64* X, const int incX) {
|
2227
|
-
double scale = 0, ssq = 1, temp;
|
2228
|
-
|
2229
|
-
if ((N < 1) || (incX < 1)) return 0.0;
|
2230
|
-
|
2231
|
-
for (int i = 0; i < N; ++i) {
|
2232
|
-
nrm2_complex_helper<float>(X[i*incX].r, X[i*incX].i, scale, temp);
|
2233
|
-
}
|
2234
|
-
|
2235
|
-
return scale * std::sqrt( ssq );
|
2236
|
-
}
|
2237
|
-
|
2238
|
-
template <>
|
2239
|
-
double nrm2(const int N, const Complex128* X, const int incX) {
|
2240
|
-
double scale = 0, ssq = 1, temp;
|
2241
|
-
|
2242
|
-
if ((N < 1) || (incX < 1)) return 0.0;
|
2243
|
-
|
2244
|
-
for (int i = 0; i < N; ++i) {
|
2245
|
-
nrm2_complex_helper<double>(X[i*incX].r, X[i*incX].i, scale, temp);
|
2246
|
-
}
|
2247
|
-
|
2248
|
-
return scale * std::sqrt( ssq );
|
2249
|
-
}
|
2250
|
-
#endif
|
2251
|
-
|
2252
|
-
template <typename ReturnDType, typename DType>
|
2253
|
-
inline void cblas_nrm2(const int N, const void* X, const int incX, void* result) {
|
2254
|
-
*reinterpret_cast<ReturnDType*>( result ) = nrm2<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
|
2255
|
-
}
|
2256
|
-
|
2257
|
-
/*
|
2258
|
-
* Level 1 BLAS routine which sums the absolute values of a vector's contents. If the vector consists of complex values,
|
2259
|
-
* the routine sums the absolute values of the real and imaginary components as well.
|
2260
|
-
*
|
2261
|
-
* So, based on input types, these are the valid return types:
|
2262
|
-
* int -> int
|
2263
|
-
* float -> float or double
|
2264
|
-
* double -> double
|
2265
|
-
* complex64 -> float or double
|
2266
|
-
* complex128 -> double
|
2267
|
-
* rational -> rational
|
2268
|
-
*/
|
2269
|
-
template <typename ReturnDType, typename DType>
|
2270
|
-
inline ReturnDType asum(const int N, const DType* X, const int incX) {
|
2271
|
-
ReturnDType sum = 0;
|
2272
|
-
if ((N > 0) && (incX > 0)) {
|
2273
|
-
for (int i = 0; i < N; ++i) {
|
2274
|
-
sum += std::abs(X[i*incX]);
|
2275
|
-
}
|
2276
|
-
}
|
2277
|
-
return sum;
|
2278
|
-
}
|
2279
|
-
|
2280
|
-
|
2281
|
-
#ifdef HAVE_CBLAS_H
|
2282
|
-
template <>
|
2283
|
-
inline float asum(const int N, const float* X, const int incX) {
|
2284
|
-
return cblas_sasum(N, X, incX);
|
2285
|
-
}
|
2286
|
-
|
2287
|
-
template <>
|
2288
|
-
inline double asum(const int N, const double* X, const int incX) {
|
2289
|
-
return cblas_dasum(N, X, incX);
|
2290
|
-
}
|
2291
|
-
|
2292
|
-
template <>
|
2293
|
-
inline float asum(const int N, const Complex64* X, const int incX) {
|
2294
|
-
return cblas_scasum(N, X, incX);
|
2295
|
-
}
|
2296
|
-
|
2297
|
-
template <>
|
2298
|
-
inline double asum(const int N, const Complex128* X, const int incX) {
|
2299
|
-
return cblas_dzasum(N, X, incX);
|
2300
|
-
}
|
2301
|
-
#else
|
2302
|
-
template <>
|
2303
|
-
inline float asum(const int N, const Complex64* X, const int incX) {
|
2304
|
-
float sum = 0;
|
2305
|
-
if ((N > 0) && (incX > 0)) {
|
2306
|
-
for (int i = 0; i < N; ++i) {
|
2307
|
-
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
|
2308
|
-
}
|
2309
|
-
}
|
2310
|
-
return sum;
|
2311
|
-
}
|
2312
|
-
|
2313
|
-
template <>
|
2314
|
-
inline double asum(const int N, const Complex128* X, const int incX) {
|
2315
|
-
double sum = 0;
|
2316
|
-
if ((N > 0) && (incX > 0)) {
|
2317
|
-
for (int i = 0; i < N; ++i) {
|
2318
|
-
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
|
2319
|
-
}
|
2320
|
-
}
|
2321
|
-
return sum;
|
2322
|
-
}
|
2323
|
-
#endif
|
2324
|
-
|
2325
|
-
|
2326
|
-
template <typename ReturnDType, typename DType>
|
2327
|
-
inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
|
2328
|
-
*reinterpret_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
|
2329
|
-
}
|
2330
|
-
|
2331
|
-
|
2332
|
-
template <bool is_complex, typename DType>
|
2333
|
-
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
|
2334
|
-
|
2335
|
-
int Nleft, Nright;
|
2336
|
-
const DType ONE = 1;
|
2337
|
-
DType *G, *U0 = A, *U1;
|
2338
|
-
|
2339
|
-
if (N > 1) {
|
2340
|
-
Nleft = N >> 1;
|
2341
|
-
#ifdef NB
|
2342
|
-
if (Nleft > NB) Nleft = ATL_MulByNB(ATL_DivByNB(Nleft));
|
2343
|
-
#endif
|
2344
|
-
|
2345
|
-
Nright = N - Nleft;
|
2346
|
-
|
2347
|
-
// FIXME: There's a simpler way to write this next block, but I'm way too tired to work it out right now.
|
2348
|
-
if (uplo == CblasUpper) {
|
2349
|
-
if (order == CblasRowMajor) {
|
2350
|
-
G = A + Nleft;
|
2351
|
-
U1 = G + Nleft * lda;
|
2352
|
-
} else {
|
2353
|
-
G = A + Nleft * lda;
|
2354
|
-
U1 = G + Nleft;
|
2355
|
-
}
|
2356
|
-
} else {
|
2357
|
-
if (order == CblasRowMajor) {
|
2358
|
-
G = A + Nleft * lda;
|
2359
|
-
U1 = G + Nleft;
|
2360
|
-
} else {
|
2361
|
-
G = A + Nleft;
|
2362
|
-
U1 = G + Nleft * lda;
|
2363
|
-
}
|
2364
|
-
}
|
2365
|
-
|
2366
|
-
lauum<is_complex, DType>(order, uplo, Nleft, U0, lda);
|
2367
|
-
|
2368
|
-
if (is_complex) {
|
2369
|
-
|
2370
|
-
nm::math::herk<DType>(order, uplo,
|
2371
|
-
uplo == CblasLower ? CblasConjTrans : CblasNoTrans,
|
2372
|
-
Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
|
2373
|
-
|
2374
|
-
nm::math::trmm<DType>(order, CblasLeft, uplo, CblasConjTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
|
2375
|
-
} else {
|
2376
|
-
nm::math::syrk<DType>(order, uplo,
|
2377
|
-
uplo == CblasLower ? CblasTrans : CblasNoTrans,
|
2378
|
-
Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
|
2379
|
-
|
2380
|
-
nm::math::trmm<DType>(order, CblasLeft, uplo, CblasTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
|
2381
|
-
}
|
2382
|
-
lauum<is_complex, DType>(order, uplo, Nright, U1, lda);
|
2383
|
-
|
2384
|
-
} else {
|
2385
|
-
*A = *A * *A;
|
2386
|
-
}
|
2387
|
-
}
|
2388
|
-
|
2389
|
-
|
2390
|
-
#ifdef HAVE_CLAPACK_H
|
2391
|
-
template <bool is_complex>
|
2392
|
-
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
|
2393
|
-
clapack_slauum(order, uplo, N, A, lda);
|
2394
|
-
}
|
2395
|
-
|
2396
|
-
template <bool is_complex>
|
2397
|
-
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
|
2398
|
-
clapack_dlauum(order, uplo, N, A, lda);
|
2399
|
-
}
|
2400
|
-
|
2401
|
-
template <bool is_complex>
|
2402
|
-
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
|
2403
|
-
clapack_clauum(order, uplo, N, A, lda);
|
2404
|
-
}
|
2405
|
-
|
2406
|
-
template <bool is_complex>
|
2407
|
-
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
|
2408
|
-
clapack_zlauum(order, uplo, N, A, lda);
|
2409
|
-
}
|
2410
|
-
#endif
|
2411
|
-
|
2412
|
-
|
2413
|
-
/*
|
2414
|
-
* Function signature conversion for calling LAPACK's lauum functions as directly as possible.
|
2415
|
-
*
|
2416
|
-
* For documentation: http://www.netlib.org/lapack/double/dlauum.f
|
2417
|
-
*
|
2418
|
-
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2419
|
-
*/
|
2420
|
-
template <bool is_complex, typename DType>
|
2421
|
-
inline int clapack_lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2422
|
-
if (n < 0) rb_raise(rb_eArgError, "n cannot be less than zero, is set to %d", n);
|
2423
|
-
if (lda < n || lda < 1) rb_raise(rb_eArgError, "lda must be >= max(n,1); lda=%d, n=%d\n", lda, n);
|
2424
|
-
|
2425
|
-
if (uplo == CblasUpper) lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2426
|
-
else lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2427
|
-
|
2428
|
-
return 0;
|
2429
|
-
}
|
2430
|
-
|
2431
|
-
|
2432
|
-
|
2433
|
-
|
2434
|
-
/*
|
2435
|
-
* Macro for declaring LAPACK specializations of the getrf function.
|
2436
|
-
*
|
2437
|
-
* type is the DType; call is the specific function to call; cast_as is what the DType* should be
|
2438
|
-
* cast to in order to pass it to LAPACK.
|
2439
|
-
*/
|
2440
|
-
#define LAPACK_GETRF(type, call, cast_as) \
|
2441
|
-
template <> \
|
2442
|
-
inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, type * A, const int lda, int* ipiv) { \
|
2443
|
-
int info = call(Order, M, N, reinterpret_cast<cast_as *>(A), lda, ipiv); \
|
2444
|
-
if (!info) return info; \
|
2445
|
-
else { \
|
2446
|
-
rb_raise(rb_eArgError, "getrf: problem with argument %d\n", info); \
|
2447
|
-
return info; \
|
2448
|
-
} \
|
2449
|
-
}
|
2450
|
-
|
2451
|
-
/* Specialize for ATLAS types */
|
2452
|
-
/*LAPACK_GETRF(float, clapack_sgetrf, float)
|
2453
|
-
LAPACK_GETRF(double, clapack_dgetrf, double)
|
2454
|
-
LAPACK_GETRF(Complex64, clapack_cgetrf, void)
|
2455
|
-
LAPACK_GETRF(Complex128, clapack_zgetrf, void)
|
2456
|
-
*/
|
2457
|
-
|
2458
|
-
|
2459
|
-
/*
|
2460
|
-
* Function signature conversion for calling LAPACK's getrf functions as directly as possible.
|
2461
|
-
*
|
2462
|
-
* For documentation: http://www.netlib.org/lapack/double/dgetrf.f
|
2463
|
-
*
|
2464
|
-
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2465
|
-
*/
|
2466
|
-
template <typename DType>
|
2467
|
-
inline int clapack_getrf(const enum CBLAS_ORDER order, const int m, const int n, void* a, const int lda, int* ipiv) {
|
2468
|
-
return getrf<DType>(order, m, n, reinterpret_cast<DType*>(a), lda, ipiv);
|
2469
|
-
}
|
2470
|
-
|
2471
|
-
|
2472
|
-
/*
|
2473
|
-
* Function signature conversion for calling LAPACK's potrf functions as directly as possible.
|
2474
|
-
*
|
2475
|
-
* For documentation: http://www.netlib.org/lapack/double/dpotrf.f
|
2476
|
-
*
|
2477
|
-
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2478
|
-
*/
|
2479
|
-
template <typename DType>
|
2480
|
-
inline int clapack_potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2481
|
-
return potrf<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2482
|
-
}
|
2483
|
-
|
2484
|
-
|
2485
|
-
/*
|
2486
|
-
* Function signature conversion for calling LAPACK's getrs functions as directly as possible.
|
2487
|
-
*
|
2488
|
-
* For documentation: http://www.netlib.org/lapack/double/dgetrs.f
|
2489
|
-
*
|
2490
|
-
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2491
|
-
*/
|
2492
|
-
template <typename DType>
|
2493
|
-
inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs,
|
2494
|
-
const void* a, const int lda, const int* ipiv, void* b, const int ldb) {
|
2495
|
-
return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb);
|
2496
|
-
}
|
2497
|
-
|
2498
|
-
/*
|
2499
|
-
* Function signature conversion for calling LAPACK's potrs functions as directly as possible.
|
2500
|
-
*
|
2501
|
-
* For documentation: http://www.netlib.org/lapack/double/dpotrs.f
|
2502
|
-
*
|
2503
|
-
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2504
|
-
*/
|
2505
|
-
template <typename DType, bool is_complex>
|
2506
|
-
inline int clapack_potrs(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, const int nrhs,
|
2507
|
-
const void* a, const int lda, void* b, const int ldb) {
|
2508
|
-
return potrs<DType,is_complex>(order, uplo, n, nrhs, reinterpret_cast<const DType*>(a), lda, reinterpret_cast<DType*>(b), ldb);
|
2509
|
-
}
|
2510
|
-
|
2511
|
-
template <typename DType>
|
2512
|
-
inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) {
|
2513
|
-
rb_raise(rb_eNotImpError, "getri not yet implemented for non-BLAS dtypes");
|
2514
|
-
return 0;
|
2515
|
-
}
|
2516
|
-
|
2517
|
-
#ifdef HAVE_CLAPACK_H
|
2518
|
-
template <>
|
2519
|
-
inline int getri(const enum CBLAS_ORDER order, const int n, float* a, const int lda, const int* ipiv) {
|
2520
|
-
return clapack_sgetri(order, n, a, lda, ipiv);
|
2521
|
-
}
|
2522
|
-
|
2523
|
-
template <>
|
2524
|
-
inline int getri(const enum CBLAS_ORDER order, const int n, double* a, const int lda, const int* ipiv) {
|
2525
|
-
return clapack_dgetri(order, n, a, lda, ipiv);
|
2526
|
-
}
|
2527
|
-
|
2528
|
-
template <>
|
2529
|
-
inline int getri(const enum CBLAS_ORDER order, const int n, Complex64* a, const int lda, const int* ipiv) {
|
2530
|
-
return clapack_cgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
|
2531
|
-
}
|
2532
|
-
|
2533
|
-
template <>
|
2534
|
-
inline int getri(const enum CBLAS_ORDER order, const int n, Complex128* a, const int lda, const int* ipiv) {
|
2535
|
-
return clapack_zgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
|
2536
|
-
}
|
2537
|
-
#endif
|
2538
|
-
|
2539
|
-
|
2540
|
-
template <typename DType>
|
2541
|
-
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, DType* a, const int lda) {
|
2542
|
-
rb_raise(rb_eNotImpError, "potri not yet implemented for non-BLAS dtypes");
|
2543
|
-
return 0;
|
2544
|
-
}
|
2545
|
-
|
2546
|
-
|
2547
|
-
#ifdef HAVE_CLAPACK_H
|
2548
|
-
template <>
|
2549
|
-
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, float* a, const int lda) {
|
2550
|
-
return clapack_spotri(order, uplo, n, a, lda);
|
2551
|
-
}
|
2552
|
-
|
2553
|
-
template <>
|
2554
|
-
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, double* a, const int lda) {
|
2555
|
-
return clapack_dpotri(order, uplo, n, a, lda);
|
2556
|
-
}
|
2557
|
-
|
2558
|
-
template <>
|
2559
|
-
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex64* a, const int lda) {
|
2560
|
-
return clapack_cpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
|
2561
|
-
}
|
2562
|
-
|
2563
|
-
template <>
|
2564
|
-
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex128* a, const int lda) {
|
2565
|
-
return clapack_zpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
|
2566
|
-
}
|
2567
|
-
#endif
|
2568
|
-
|
2569
|
-
/*
|
2570
|
-
* Function signature conversion for calling LAPACK's getri functions as directly as possible.
|
2571
|
-
*
|
2572
|
-
* For documentation: http://www.netlib.org/lapack/double/dgetri.f
|
2573
|
-
*
|
2574
|
-
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2575
|
-
*/
|
2576
|
-
template <typename DType>
|
2577
|
-
inline int clapack_getri(const enum CBLAS_ORDER order, const int n, void* a, const int lda, const int* ipiv) {
|
2578
|
-
return getri<DType>(order, n, reinterpret_cast<DType*>(a), lda, ipiv);
|
2579
|
-
}
|
2580
|
-
|
2581
|
-
|
2582
|
-
/*
|
2583
|
-
* Function signature conversion for calling LAPACK's potri functions as directly as possible.
|
2584
|
-
*
|
2585
|
-
* For documentation: http://www.netlib.org/lapack/double/dpotri.f
|
2586
|
-
*
|
2587
|
-
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2588
|
-
*/
|
2589
|
-
template <typename DType>
|
2590
|
-
inline int clapack_potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2591
|
-
return potri<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2592
|
-
}
|
2593
|
-
|
2594
|
-
|
2595
|
-
/*
|
2596
|
-
* Function signature conversion for calling LAPACK's laswp functions as directly as possible.
|
2597
|
-
*
|
2598
|
-
* For documentation: http://www.netlib.org/lapack/double/dlaswp.f
|
2599
|
-
*
|
2600
|
-
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2601
|
-
*/
|
2602
|
-
template <typename DType>
|
2603
|
-
inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
|
2604
|
-
laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
|
2605
|
-
}
|
2606
|
-
|
2607
|
-
|
2608
|
-
|
2609
|
-
}} // end namespace nm::math
|
2610
|
-
|
2611
|
-
|
2612
|
-
#endif // MATH_H
|