nmatrix 0.0.6 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -0
- data/Gemfile +5 -0
- data/History.txt +97 -0
- data/Manifest.txt +34 -7
- data/README.rdoc +13 -13
- data/Rakefile +36 -26
- data/ext/nmatrix/data/data.cpp +15 -2
- data/ext/nmatrix/data/data.h +4 -0
- data/ext/nmatrix/data/ruby_object.h +5 -14
- data/ext/nmatrix/extconf.rb +3 -2
- data/ext/nmatrix/{util/math.cpp → math.cpp} +296 -6
- data/ext/nmatrix/math/asum.h +143 -0
- data/ext/nmatrix/math/geev.h +82 -0
- data/ext/nmatrix/math/gemm.h +267 -0
- data/ext/nmatrix/math/gemv.h +208 -0
- data/ext/nmatrix/math/ger.h +96 -0
- data/ext/nmatrix/math/gesdd.h +80 -0
- data/ext/nmatrix/math/gesvd.h +78 -0
- data/ext/nmatrix/math/getf2.h +86 -0
- data/ext/nmatrix/math/getrf.h +240 -0
- data/ext/nmatrix/math/getri.h +107 -0
- data/ext/nmatrix/math/getrs.h +125 -0
- data/ext/nmatrix/math/idamax.h +86 -0
- data/ext/nmatrix/{util → math}/lapack.h +60 -356
- data/ext/nmatrix/math/laswp.h +165 -0
- data/ext/nmatrix/math/long_dtype.h +52 -0
- data/ext/nmatrix/math/math.h +1154 -0
- data/ext/nmatrix/math/nrm2.h +181 -0
- data/ext/nmatrix/math/potrs.h +125 -0
- data/ext/nmatrix/math/rot.h +141 -0
- data/ext/nmatrix/math/rotg.h +115 -0
- data/ext/nmatrix/math/scal.h +73 -0
- data/ext/nmatrix/math/swap.h +73 -0
- data/ext/nmatrix/math/trsm.h +383 -0
- data/ext/nmatrix/nmatrix.cpp +176 -152
- data/ext/nmatrix/nmatrix.h +1 -2
- data/ext/nmatrix/ruby_constants.cpp +9 -4
- data/ext/nmatrix/ruby_constants.h +1 -0
- data/ext/nmatrix/storage/dense.cpp +57 -41
- data/ext/nmatrix/storage/list.cpp +52 -50
- data/ext/nmatrix/storage/storage.cpp +59 -43
- data/ext/nmatrix/storage/yale.cpp +352 -333
- data/ext/nmatrix/storage/yale.h +4 -0
- data/lib/nmatrix.rb +2 -2
- data/lib/nmatrix/blas.rb +4 -4
- data/lib/nmatrix/enumerate.rb +241 -0
- data/lib/nmatrix/lapack.rb +54 -1
- data/lib/nmatrix/math.rb +462 -0
- data/lib/nmatrix/nmatrix.rb +210 -486
- data/lib/nmatrix/nvector.rb +0 -62
- data/lib/nmatrix/rspec.rb +75 -0
- data/lib/nmatrix/shortcuts.rb +136 -108
- data/lib/nmatrix/version.rb +1 -1
- data/spec/blas_spec.rb +20 -12
- data/spec/elementwise_spec.rb +22 -13
- data/spec/io_spec.rb +1 -0
- data/spec/lapack_spec.rb +197 -0
- data/spec/nmatrix_spec.rb +39 -38
- data/spec/nvector_spec.rb +3 -9
- data/spec/rspec_monkeys.rb +29 -0
- data/spec/rspec_spec.rb +34 -0
- data/spec/shortcuts_spec.rb +14 -16
- data/spec/slice_spec.rb +242 -186
- data/spec/spec_helper.rb +19 -0
- metadata +33 -5
- data/ext/nmatrix/util/math.h +0 -2612
@@ -0,0 +1,73 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2013, Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == scal.h
|
25
|
+
//
|
26
|
+
// LAPACK scal function in native C.
|
27
|
+
//
|
28
|
+
|
29
|
+
#ifndef SCAL_H
|
30
|
+
#define SCAL_H
|
31
|
+
|
32
|
+
namespace nm { namespace math {
|
33
|
+
|
34
|
+
/* Purpose */
|
35
|
+
/* ======= */
|
36
|
+
|
37
|
+
/* DSCAL scales a vector by a constant. */
|
38
|
+
/* uses unrolled loops for increment equal to one. */
|
39
|
+
|
40
|
+
/* Further Details */
|
41
|
+
/* =============== */
|
42
|
+
|
43
|
+
/* jack dongarra, linpack, 3/11/78. */
|
44
|
+
/* modified 3/93 to return if incx .le. 0. */
|
45
|
+
/* modified 12/3/93, array(1) declarations changed to array(*) */
|
46
|
+
|
47
|
+
/* ===================================================================== */
|
48
|
+
|
49
|
+
template <typename DType>
|
50
|
+
inline void scal(const int n, const DType da, DType* dx, const int incx) {
|
51
|
+
|
52
|
+
// This used to have unrolled loops, like dswap. They were in the way.
|
53
|
+
|
54
|
+
if (n <= 0 || incx <= 0) return;
|
55
|
+
|
56
|
+
for (int i = 0; incx < 0 ? i > n*incx : i < n*incx; i += incx) {
|
57
|
+
dx[i] = da * dx[i];
|
58
|
+
}
|
59
|
+
} /* scal */
|
60
|
+
|
61
|
+
|
62
|
+
/*
|
63
|
+
* Function signature conversion for LAPACK's scal function.
|
64
|
+
*/
|
65
|
+
template <typename DType>
|
66
|
+
inline void clapack_scal(const int n, const void* da, void* dx, const int incx) {
|
67
|
+
// FIXME: See if we can call the clapack version instead of our C++ version.
|
68
|
+
scal<DType>(n, *reinterpret_cast<const DType*>(da), reinterpret_cast<DType*>(dx), incx);
|
69
|
+
}
|
70
|
+
|
71
|
+
}} // end of nm::math
|
72
|
+
|
73
|
+
#endif
|
@@ -0,0 +1,73 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2013, Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == swap.h
|
25
|
+
//
|
26
|
+
// BLAS level 2 swap function in native C++.
|
27
|
+
//
|
28
|
+
|
29
|
+
#ifndef SWAP_H
|
30
|
+
#define SWAP_H
|
31
|
+
|
32
|
+
namespace nm { namespace math {
|
33
|
+
/*
|
34
|
+
template <typename DType>
|
35
|
+
inline void swap(int n, DType *dx, int incx, DType *dy, int incy) {
|
36
|
+
|
37
|
+
if (n <= 0) return;
|
38
|
+
|
39
|
+
// For negative increments, start at the end of the array.
|
40
|
+
int ix = incx < 0 ? (-n+1)*incx : 0,
|
41
|
+
iy = incy < 0 ? (-n+1)*incy : 0;
|
42
|
+
|
43
|
+
if (incx < 0) ix = (-n + 1) * incx;
|
44
|
+
if (incy < 0) iy = (-n + 1) * incy;
|
45
|
+
|
46
|
+
for (size_t i = 0; i < n; ++i, ix += incx, iy += incy) {
|
47
|
+
DType dtemp = dx[ix];
|
48
|
+
dx[ix] = dy[iy];
|
49
|
+
dy[iy] = dtemp;
|
50
|
+
}
|
51
|
+
return;
|
52
|
+
} /* dswap */
|
53
|
+
|
54
|
+
// This is the old BLAS version of this function. ATLAS has an optimized version, but
|
55
|
+
// it's going to be tough to translate.
|
56
|
+
template <typename DType>
|
57
|
+
static void swap(const int N, DType* X, const int incX, DType* Y, const int incY) {
|
58
|
+
if (N > 0) {
|
59
|
+
int ix = 0, iy = 0;
|
60
|
+
for (int i = 0; i < N; ++i) {
|
61
|
+
DType temp = X[i];
|
62
|
+
X[i] = Y[i];
|
63
|
+
Y[i] = temp;
|
64
|
+
|
65
|
+
ix += incX;
|
66
|
+
iy += incY;
|
67
|
+
}
|
68
|
+
}
|
69
|
+
}
|
70
|
+
|
71
|
+
}} // end nm::math
|
72
|
+
|
73
|
+
#endif
|
@@ -0,0 +1,383 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2013, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2013, Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == trsm.h
|
25
|
+
//
|
26
|
+
// trsm function in native C++.
|
27
|
+
//
|
28
|
+
/*
|
29
|
+
* Automatically Tuned Linear Algebra Software v3.8.4
|
30
|
+
* (C) Copyright 1999 R. Clint Whaley
|
31
|
+
*
|
32
|
+
* Redistribution and use in source and binary forms, with or without
|
33
|
+
* modification, are permitted provided that the following conditions
|
34
|
+
* are met:
|
35
|
+
* 1. Redistributions of source code must retain the above copyright
|
36
|
+
* notice, this list of conditions and the following disclaimer.
|
37
|
+
* 2. Redistributions in binary form must reproduce the above copyright
|
38
|
+
* notice, this list of conditions, and the following disclaimer in the
|
39
|
+
* documentation and/or other materials provided with the distribution.
|
40
|
+
* 3. The name of the ATLAS group or the names of its contributers may
|
41
|
+
* not be used to endorse or promote products derived from this
|
42
|
+
* software without specific written permission.
|
43
|
+
*
|
44
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
45
|
+
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
46
|
+
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
47
|
+
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
|
48
|
+
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
49
|
+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
50
|
+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
51
|
+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
52
|
+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
53
|
+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
54
|
+
* POSSIBILITY OF SUCH DAMAGE.
|
55
|
+
*
|
56
|
+
*/
|
57
|
+
|
58
|
+
#ifndef TRSM_H
|
59
|
+
#define TRSM_H
|
60
|
+
|
61
|
+
|
62
|
+
extern "C" {
|
63
|
+
#include <cblas.h>
|
64
|
+
}
|
65
|
+
|
66
|
+
namespace nm { namespace math {
|
67
|
+
|
68
|
+
|
69
|
+
/*
|
70
|
+
* This version of trsm doesn't do any error checks and only works on column-major matrices.
|
71
|
+
*
|
72
|
+
* For row major, call trsm<DType> instead. That will handle necessary changes-of-variables
|
73
|
+
* and parameter checks.
|
74
|
+
*
|
75
|
+
* Note that some of the boundary conditions here may be incorrect. Very little has been tested!
|
76
|
+
* This was converted directly from dtrsm.f using f2c, and then rewritten more cleanly.
|
77
|
+
*/
|
78
|
+
template <typename DType>
|
79
|
+
inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
80
|
+
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
81
|
+
const int m, const int n, const DType alpha, const DType* a,
|
82
|
+
const int lda, DType* b, const int ldb)
|
83
|
+
{
|
84
|
+
|
85
|
+
// (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3
|
86
|
+
|
87
|
+
if (m == 0 || n == 0) return; /* Quick return if possible. */
|
88
|
+
|
89
|
+
if (alpha == 0) { // Handle alpha == 0
|
90
|
+
for (int j = 0; j < n; ++j) {
|
91
|
+
for (int i = 0; i < m; ++i) {
|
92
|
+
b[i + j * ldb] = 0;
|
93
|
+
}
|
94
|
+
}
|
95
|
+
return;
|
96
|
+
}
|
97
|
+
|
98
|
+
if (side == CblasLeft) {
|
99
|
+
if (trans_a == CblasNoTrans) {
|
100
|
+
|
101
|
+
/* Form B := alpha*inv( A )*B. */
|
102
|
+
if (uplo == CblasUpper) {
|
103
|
+
for (int j = 0; j < n; ++j) {
|
104
|
+
if (alpha != 1) {
|
105
|
+
for (int i = 0; i < m; ++i) {
|
106
|
+
b[i + j * ldb] = alpha * b[i + j * ldb];
|
107
|
+
}
|
108
|
+
}
|
109
|
+
for (int k = m-1; k >= 0; --k) {
|
110
|
+
if (b[k + j * ldb] != 0) {
|
111
|
+
if (diag == CblasNonUnit) {
|
112
|
+
b[k + j * ldb] /= a[k + k * lda];
|
113
|
+
}
|
114
|
+
|
115
|
+
for (int i = 0; i < k-1; ++i) {
|
116
|
+
b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
|
117
|
+
}
|
118
|
+
}
|
119
|
+
}
|
120
|
+
}
|
121
|
+
} else {
|
122
|
+
for (int j = 0; j < n; ++j) {
|
123
|
+
if (alpha != 1) {
|
124
|
+
for (int i = 0; i < m; ++i) {
|
125
|
+
b[i + j * ldb] = alpha * b[i + j * ldb];
|
126
|
+
}
|
127
|
+
}
|
128
|
+
for (int k = 0; k < m; ++k) {
|
129
|
+
if (b[k + j * ldb] != 0.) {
|
130
|
+
if (diag == CblasNonUnit) {
|
131
|
+
b[k + j * ldb] /= a[k + k * lda];
|
132
|
+
}
|
133
|
+
for (int i = k+1; i < m; ++i) {
|
134
|
+
b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
|
135
|
+
}
|
136
|
+
}
|
137
|
+
}
|
138
|
+
}
|
139
|
+
}
|
140
|
+
} else { // CblasTrans
|
141
|
+
|
142
|
+
/* Form B := alpha*inv( A**T )*B. */
|
143
|
+
if (uplo == CblasUpper) {
|
144
|
+
for (int j = 0; j < n; ++j) {
|
145
|
+
for (int i = 0; i < m; ++i) {
|
146
|
+
DType temp = alpha * b[i + j * ldb];
|
147
|
+
for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
|
148
|
+
temp -= a[k + i * lda] * b[k + j * ldb];
|
149
|
+
}
|
150
|
+
if (diag == CblasNonUnit) {
|
151
|
+
temp /= a[i + i * lda];
|
152
|
+
}
|
153
|
+
b[i + j * ldb] = temp;
|
154
|
+
}
|
155
|
+
}
|
156
|
+
} else {
|
157
|
+
for (int j = 0; j < n; ++j) {
|
158
|
+
for (int i = m-1; i >= 0; --i) {
|
159
|
+
DType temp= alpha * b[i + j * ldb];
|
160
|
+
for (int k = i+1; k < m; ++k) {
|
161
|
+
temp -= a[k + i * lda] * b[k + j * ldb];
|
162
|
+
}
|
163
|
+
if (diag == CblasNonUnit) {
|
164
|
+
temp /= a[i + i * lda];
|
165
|
+
}
|
166
|
+
b[i + j * ldb] = temp;
|
167
|
+
}
|
168
|
+
}
|
169
|
+
}
|
170
|
+
}
|
171
|
+
} else { // right side
|
172
|
+
|
173
|
+
if (trans_a == CblasNoTrans) {
|
174
|
+
|
175
|
+
/* Form B := alpha*B*inv( A ). */
|
176
|
+
|
177
|
+
if (uplo == CblasUpper) {
|
178
|
+
for (int j = 0; j < n; ++j) {
|
179
|
+
if (alpha != 1) {
|
180
|
+
for (int i = 0; i < m; ++i) {
|
181
|
+
b[i + j * ldb] = alpha * b[i + j * ldb];
|
182
|
+
}
|
183
|
+
}
|
184
|
+
for (int k = 0; k < j-1; ++k) {
|
185
|
+
if (a[k + j * lda] != 0) {
|
186
|
+
for (int i = 0; i < m; ++i) {
|
187
|
+
b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
|
188
|
+
}
|
189
|
+
}
|
190
|
+
}
|
191
|
+
if (diag == CblasNonUnit) {
|
192
|
+
DType temp = 1 / a[j + j * lda];
|
193
|
+
for (int i = 0; i < m; ++i) {
|
194
|
+
b[i + j * ldb] = temp * b[i + j * ldb];
|
195
|
+
}
|
196
|
+
}
|
197
|
+
}
|
198
|
+
} else {
|
199
|
+
for (int j = n-1; j >= 0; --j) {
|
200
|
+
if (alpha != 1) {
|
201
|
+
for (int i = 0; i < m; ++i) {
|
202
|
+
b[i + j * ldb] = alpha * b[i + j * ldb];
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
for (int k = j+1; k < n; ++k) {
|
207
|
+
if (a[k + j * lda] != 0.) {
|
208
|
+
for (int i = 0; i < m; ++i) {
|
209
|
+
b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
|
210
|
+
}
|
211
|
+
}
|
212
|
+
}
|
213
|
+
if (diag == CblasNonUnit) {
|
214
|
+
DType temp = 1 / a[j + j * lda];
|
215
|
+
|
216
|
+
for (int i = 0; i < m; ++i) {
|
217
|
+
b[i + j * ldb] = temp * b[i + j * ldb];
|
218
|
+
}
|
219
|
+
}
|
220
|
+
}
|
221
|
+
}
|
222
|
+
} else { // CblasTrans
|
223
|
+
|
224
|
+
/* Form B := alpha*B*inv( A**T ). */
|
225
|
+
|
226
|
+
if (uplo == CblasUpper) {
|
227
|
+
for (int k = n-1; k >= 0; --k) {
|
228
|
+
if (diag == CblasNonUnit) {
|
229
|
+
DType temp= 1 / a[k + k * lda];
|
230
|
+
for (int i = 0; i < m; ++i) {
|
231
|
+
b[i + k * ldb] = temp * b[i + k * ldb];
|
232
|
+
}
|
233
|
+
}
|
234
|
+
for (int j = 0; j < k-1; ++j) {
|
235
|
+
if (a[j + k * lda] != 0.) {
|
236
|
+
DType temp= a[j + k * lda];
|
237
|
+
for (int i = 0; i < m; ++i) {
|
238
|
+
b[i + j * ldb] -= temp * b[i + k * ldb];
|
239
|
+
}
|
240
|
+
}
|
241
|
+
}
|
242
|
+
if (alpha != 1) {
|
243
|
+
for (int i = 0; i < m; ++i) {
|
244
|
+
b[i + k * ldb] = alpha * b[i + k * ldb];
|
245
|
+
}
|
246
|
+
}
|
247
|
+
}
|
248
|
+
} else {
|
249
|
+
for (int k = 0; k < n; ++k) {
|
250
|
+
if (diag == CblasNonUnit) {
|
251
|
+
DType temp = 1 / a[k + k * lda];
|
252
|
+
for (int i = 0; i < m; ++i) {
|
253
|
+
b[i + k * ldb] = temp * b[i + k * ldb];
|
254
|
+
}
|
255
|
+
}
|
256
|
+
for (int j = k+1; j < n; ++j) {
|
257
|
+
if (a[j + k * lda] != 0.) {
|
258
|
+
DType temp = a[j + k * lda];
|
259
|
+
for (int i = 0; i < m; ++i) {
|
260
|
+
b[i + j * ldb] -= temp * b[i + k * ldb];
|
261
|
+
}
|
262
|
+
}
|
263
|
+
}
|
264
|
+
if (alpha != 1) {
|
265
|
+
for (int i = 0; i < m; ++i) {
|
266
|
+
b[i + k * ldb] = alpha * b[i + k * ldb];
|
267
|
+
}
|
268
|
+
}
|
269
|
+
}
|
270
|
+
}
|
271
|
+
}
|
272
|
+
}
|
273
|
+
}
|
274
|
+
|
275
|
+
/*
|
276
|
+
* BLAS' DTRSM function, generalized.
|
277
|
+
*/
|
278
|
+
template <typename DType, typename = typename std::enable_if<!std::is_integral<DType>::value>::type>
|
279
|
+
inline void trsm(const enum CBLAS_ORDER order,
|
280
|
+
const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
281
|
+
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
282
|
+
const int m, const int n, const DType alpha, const DType* a,
|
283
|
+
const int lda, DType* b, const int ldb)
|
284
|
+
{
|
285
|
+
/*using std::cerr;
|
286
|
+
using std::endl;*/
|
287
|
+
|
288
|
+
int num_rows_a = n;
|
289
|
+
if (side == CblasLeft) num_rows_a = m;
|
290
|
+
|
291
|
+
if (lda < std::max(1,num_rows_a)) {
|
292
|
+
fprintf(stderr, "TRSM: num_rows_a = %d; got lda=%d\n", num_rows_a, lda);
|
293
|
+
rb_raise(rb_eArgError, "TRSM: Expected lda >= max(1, num_rows_a)");
|
294
|
+
}
|
295
|
+
|
296
|
+
// Test the input parameters.
|
297
|
+
if (order == CblasRowMajor) {
|
298
|
+
if (ldb < std::max(1,n)) {
|
299
|
+
fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
|
300
|
+
rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,N)");
|
301
|
+
}
|
302
|
+
|
303
|
+
// For row major, need to switch side and uplo
|
304
|
+
enum CBLAS_SIDE side_ = side == CblasLeft ? CblasRight : CblasLeft;
|
305
|
+
enum CBLAS_UPLO uplo_ = uplo == CblasUpper ? CblasLower : CblasUpper;
|
306
|
+
|
307
|
+
/*
|
308
|
+
cerr << "(row-major) trsm: " << (side_ == CblasLeft ? "left " : "right ")
|
309
|
+
<< (uplo_ == CblasUpper ? "upper " : "lower ")
|
310
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
311
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
312
|
+
<< n << " " << m << " " << alpha << " a " << lda << " b " << ldb << endl;
|
313
|
+
*/
|
314
|
+
trsm_nothrow<DType>(side_, uplo_, trans_a, diag, n, m, alpha, a, lda, b, ldb);
|
315
|
+
|
316
|
+
} else { // CblasColMajor
|
317
|
+
|
318
|
+
if (ldb < std::max(1,m)) {
|
319
|
+
fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
|
320
|
+
rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,M)");
|
321
|
+
}
|
322
|
+
/*
|
323
|
+
cerr << "(col-major) trsm: " << (side == CblasLeft ? "left " : "right ")
|
324
|
+
<< (uplo == CblasUpper ? "upper " : "lower ")
|
325
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
326
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
327
|
+
<< m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
|
328
|
+
*/
|
329
|
+
trsm_nothrow<DType>(side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
330
|
+
|
331
|
+
}
|
332
|
+
|
333
|
+
}
|
334
|
+
|
335
|
+
|
336
|
+
template <>
|
337
|
+
inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
338
|
+
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
339
|
+
const int m, const int n, const float alpha, const float* a,
|
340
|
+
const int lda, float* b, const int ldb)
|
341
|
+
{
|
342
|
+
cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
343
|
+
}
|
344
|
+
|
345
|
+
template <>
|
346
|
+
inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
347
|
+
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
348
|
+
const int m, const int n, const double alpha, const double* a,
|
349
|
+
const int lda, double* b, const int ldb)
|
350
|
+
{
|
351
|
+
/* using std::cerr;
|
352
|
+
using std::endl;
|
353
|
+
cerr << "(row-major) dtrsm: " << (side == CblasLeft ? "left " : "right ")
|
354
|
+
<< (uplo == CblasUpper ? "upper " : "lower ")
|
355
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
356
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
357
|
+
<< m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
|
358
|
+
*/
|
359
|
+
cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
360
|
+
}
|
361
|
+
|
362
|
+
|
363
|
+
template <>
|
364
|
+
inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
365
|
+
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
366
|
+
const int m, const int n, const Complex64 alpha, const Complex64* a,
|
367
|
+
const int lda, Complex64* b, const int ldb)
|
368
|
+
{
|
369
|
+
cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
|
370
|
+
}
|
371
|
+
|
372
|
+
template <>
|
373
|
+
inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
374
|
+
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
|
375
|
+
const int m, const int n, const Complex128 alpha, const Complex128* a,
|
376
|
+
const int lda, Complex128* b, const int ldb)
|
377
|
+
{
|
378
|
+
cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
|
379
|
+
}
|
380
|
+
|
381
|
+
|
382
|
+
} } // namespace nm::math
|
383
|
+
#endif // TRSM_H
|