nmatrix-gemv 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +29 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +14 -0
  5. data/Gemfile +7 -0
  6. data/README.md +29 -0
  7. data/Rakefile +225 -0
  8. data/ext/nmatrix_gemv/binary_format.txt +53 -0
  9. data/ext/nmatrix_gemv/data/complex.h +399 -0
  10. data/ext/nmatrix_gemv/data/data.cpp +298 -0
  11. data/ext/nmatrix_gemv/data/data.h +771 -0
  12. data/ext/nmatrix_gemv/data/meta.h +70 -0
  13. data/ext/nmatrix_gemv/data/rational.h +436 -0
  14. data/ext/nmatrix_gemv/data/ruby_object.h +471 -0
  15. data/ext/nmatrix_gemv/extconf.rb +254 -0
  16. data/ext/nmatrix_gemv/math.cpp +1639 -0
  17. data/ext/nmatrix_gemv/math/asum.h +143 -0
  18. data/ext/nmatrix_gemv/math/geev.h +82 -0
  19. data/ext/nmatrix_gemv/math/gemm.h +271 -0
  20. data/ext/nmatrix_gemv/math/gemv.h +212 -0
  21. data/ext/nmatrix_gemv/math/ger.h +96 -0
  22. data/ext/nmatrix_gemv/math/gesdd.h +80 -0
  23. data/ext/nmatrix_gemv/math/gesvd.h +78 -0
  24. data/ext/nmatrix_gemv/math/getf2.h +86 -0
  25. data/ext/nmatrix_gemv/math/getrf.h +240 -0
  26. data/ext/nmatrix_gemv/math/getri.h +108 -0
  27. data/ext/nmatrix_gemv/math/getrs.h +129 -0
  28. data/ext/nmatrix_gemv/math/idamax.h +86 -0
  29. data/ext/nmatrix_gemv/math/inc.h +47 -0
  30. data/ext/nmatrix_gemv/math/laswp.h +165 -0
  31. data/ext/nmatrix_gemv/math/long_dtype.h +52 -0
  32. data/ext/nmatrix_gemv/math/math.h +1069 -0
  33. data/ext/nmatrix_gemv/math/nrm2.h +181 -0
  34. data/ext/nmatrix_gemv/math/potrs.h +129 -0
  35. data/ext/nmatrix_gemv/math/rot.h +141 -0
  36. data/ext/nmatrix_gemv/math/rotg.h +115 -0
  37. data/ext/nmatrix_gemv/math/scal.h +73 -0
  38. data/ext/nmatrix_gemv/math/swap.h +73 -0
  39. data/ext/nmatrix_gemv/math/trsm.h +387 -0
  40. data/ext/nmatrix_gemv/nm_memory.h +60 -0
  41. data/ext/nmatrix_gemv/nmatrix_gemv.cpp +90 -0
  42. data/ext/nmatrix_gemv/nmatrix_gemv.h +374 -0
  43. data/ext/nmatrix_gemv/ruby_constants.cpp +153 -0
  44. data/ext/nmatrix_gemv/ruby_constants.h +107 -0
  45. data/ext/nmatrix_gemv/ruby_nmatrix.c +84 -0
  46. data/ext/nmatrix_gemv/ttable_helper.rb +122 -0
  47. data/ext/nmatrix_gemv/types.h +54 -0
  48. data/ext/nmatrix_gemv/util/util.h +78 -0
  49. data/lib/nmatrix-gemv.rb +43 -0
  50. data/lib/nmatrix_gemv/blas.rb +85 -0
  51. data/lib/nmatrix_gemv/nmatrix_gemv.rb +35 -0
  52. data/lib/nmatrix_gemv/rspec.rb +75 -0
  53. data/nmatrix-gemv.gemspec +31 -0
  54. data/spec/blas_spec.rb +154 -0
  55. data/spec/spec_helper.rb +128 -0
  56. metadata +186 -0
@@ -0,0 +1,115 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == rotg.h
25
+ //
26
+ // BLAS rotg function in native C++.
27
+ //
28
+
29
+ /*
30
+ * Automatically Tuned Linear Algebra Software v3.8.4
31
+ * (C) Copyright 1999 R. Clint Whaley
32
+ *
33
+ * Redistribution and use in source and binary forms, with or without
34
+ * modification, are permitted provided that the following conditions
35
+ * are met:
36
+ * 1. Redistributions of source code must retain the above copyright
37
+ * notice, this list of conditions and the following disclaimer.
38
+ * 2. Redistributions in binary form must reproduce the above copyright
39
+ * notice, this list of conditions, and the following disclaimer in the
40
+ * documentation and/or other materials provided with the distribution.
41
+ * 3. The name of the ATLAS group or the names of its contributers may
42
+ * not be used to endorse or promote products derived from this
43
+ * software without specific written permission.
44
+ *
45
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
46
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
47
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
48
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
49
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
50
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
51
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
52
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
53
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
54
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
55
+ * POSSIBILITY OF SUCH DAMAGE.
56
+ *
57
+ */
58
+
59
+ #ifndef ROTG_H
60
+ # define ROTG_H
61
+
62
+ namespace nm { namespace math {
63
+
64
+ /* Givens plane rotation. From ATLAS 3.8.4. */
65
+ // FIXME: Not working properly for Ruby objects.
66
+ template <typename DType>
67
+ inline void rotg(DType* a, DType* b, DType* c, DType* s) {
68
+ DType aa = std::abs(*a), ab = std::abs(*b);
69
+ DType roe = aa > ab ? *a : *b;
70
+ DType scal = aa + ab;
71
+
72
+ if (scal == 0) {
73
+ *c = 1;
74
+ *s = *a = *b = 0;
75
+ } else {
76
+ DType t0 = aa / scal, t1 = ab / scal;
77
+ DType r = scal * std::sqrt(t0 * t0 + t1 * t1);
78
+ if (roe < 0) r = -r;
79
+ *c = *a / r;
80
+ *s = *b / r;
81
+ DType z = (*c != 0) ? (1 / *c) : DType(1);
82
+ *a = r;
83
+ *b = z;
84
+ }
85
+ }
86
+
87
+ template <>
88
+ inline void rotg(float* a, float* b, float* c, float* s) {
89
+ cblas_srotg(a, b, c, s);
90
+ }
91
+
92
+ template <>
93
+ inline void rotg(double* a, double* b, double* c, double* s) {
94
+ cblas_drotg(a, b, c, s);
95
+ }
96
+
97
+ template <>
98
+ inline void rotg(Complex64* a, Complex64* b, Complex64* c, Complex64* s) {
99
+ cblas_crotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
100
+ }
101
+
102
+ template <>
103
+ inline void rotg(Complex128* a, Complex128* b, Complex128* c, Complex128* s) {
104
+ cblas_zrotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
105
+ }
106
+
107
+ template <typename DType>
108
+ inline void cblas_rotg(void* a, void* b, void* c, void* s) {
109
+ rotg<DType>(reinterpret_cast<DType*>(a), reinterpret_cast<DType*>(b), reinterpret_cast<DType*>(c), reinterpret_cast<DType*>(s));
110
+ }
111
+
112
+
113
+ } } //nm::math
114
+
115
+ #endif // ROTG_H
@@ -0,0 +1,73 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == scal.h
25
+ //
26
+ // LAPACK scal function in native C.
27
+ //
28
+
29
+ #ifndef SCAL_H
30
+ #define SCAL_H
31
+
32
+ namespace nm { namespace math {
33
+
34
+ /* Purpose */
35
+ /* ======= */
36
+
37
+ /* DSCAL scales a vector by a constant. */
38
+ /* uses unrolled loops for increment equal to one. */
39
+
40
+ /* Further Details */
41
+ /* =============== */
42
+
43
+ /* jack dongarra, linpack, 3/11/78. */
44
+ /* modified 3/93 to return if incx .le. 0. */
45
+ /* modified 12/3/93, array(1) declarations changed to array(*) */
46
+
47
+ /* ===================================================================== */
48
+
49
+ template <typename DType>
50
+ inline void scal(const int n, const DType da, DType* dx, const int incx) {
51
+
52
+ // This used to have unrolled loops, like dswap. They were in the way.
53
+
54
+ if (n <= 0 || incx <= 0) return;
55
+
56
+ for (int i = 0; incx < 0 ? i > n*incx : i < n*incx; i += incx) {
57
+ dx[i] = da * dx[i];
58
+ }
59
+ } /* scal */
60
+
61
+
62
+ /*
63
+ * Function signature conversion for LAPACK's scal function.
64
+ */
65
+ template <typename DType>
66
+ inline void clapack_scal(const int n, const void* da, void* dx, const int incx) {
67
+ // FIXME: See if we can call the clapack version instead of our C++ version.
68
+ scal<DType>(n, *reinterpret_cast<const DType*>(da), reinterpret_cast<DType*>(dx), incx);
69
+ }
70
+
71
+ }} // end of nm::math
72
+
73
+ #endif
@@ -0,0 +1,73 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == swap.h
25
+ //
26
+ // BLAS level 2 swap function in native C++.
27
+ //
28
+
29
+ #ifndef SWAP_H
30
+ #define SWAP_H
31
+
32
+ namespace nm { namespace math {
33
+ /*
34
+ template <typename DType>
35
+ inline void swap(int n, DType *dx, int incx, DType *dy, int incy) {
36
+
37
+ if (n <= 0) return;
38
+
39
+ // For negative increments, start at the end of the array.
40
+ int ix = incx < 0 ? (-n+1)*incx : 0,
41
+ iy = incy < 0 ? (-n+1)*incy : 0;
42
+
43
+ if (incx < 0) ix = (-n + 1) * incx;
44
+ if (incy < 0) iy = (-n + 1) * incy;
45
+
46
+ for (size_t i = 0; i < n; ++i, ix += incx, iy += incy) {
47
+ DType dtemp = dx[ix];
48
+ dx[ix] = dy[iy];
49
+ dy[iy] = dtemp;
50
+ }
51
+ return;
52
+ } /* dswap */
53
+
54
+ // This is the old BLAS version of this function. ATLAS has an optimized version, but
55
+ // it's going to be tough to translate.
56
+ template <typename DType>
57
+ static void swap(const int N, DType* X, const int incX, DType* Y, const int incY) {
58
+ if (N > 0) {
59
+ int ix = 0, iy = 0;
60
+ for (int i = 0; i < N; ++i) {
61
+ DType temp = X[i];
62
+ X[i] = Y[i];
63
+ Y[i] = temp;
64
+
65
+ ix += incX;
66
+ iy += incY;
67
+ }
68
+ }
69
+ }
70
+
71
+ }} // end nm::math
72
+
73
+ #endif
@@ -0,0 +1,387 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == trsm.h
25
+ //
26
+ // trsm function in native C++.
27
+ //
28
+ /*
29
+ * Automatically Tuned Linear Algebra Software v3.8.4
30
+ * (C) Copyright 1999 R. Clint Whaley
31
+ *
32
+ * Redistribution and use in source and binary forms, with or without
33
+ * modification, are permitted provided that the following conditions
34
+ * are met:
35
+ * 1. Redistributions of source code must retain the above copyright
36
+ * notice, this list of conditions and the following disclaimer.
37
+ * 2. Redistributions in binary form must reproduce the above copyright
38
+ * notice, this list of conditions, and the following disclaimer in the
39
+ * documentation and/or other materials provided with the distribution.
40
+ * 3. The name of the ATLAS group or the names of its contributers may
41
+ * not be used to endorse or promote products derived from this
42
+ * software without specific written permission.
43
+ *
44
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
45
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
46
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
47
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
48
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
49
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
50
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
51
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
52
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
53
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
54
+ * POSSIBILITY OF SUCH DAMAGE.
55
+ *
56
+ */
57
+
58
+ #ifndef TRSM_H
59
+ #define TRSM_H
60
+
61
+
62
+ extern "C" {
63
+ #if defined HAVE_CBLAS_H
64
+ #include <cblas.h>
65
+ #elif defined HAVE_ATLAS_CBLAS_H
66
+ #include <atlas/cblas.h>
67
+ #endif
68
+ }
69
+
70
+ namespace nm { namespace math {
71
+
72
+
73
+ /*
74
+ * This version of trsm doesn't do any error checks and only works on column-major matrices.
75
+ *
76
+ * For row major, call trsm<DType> instead. That will handle necessary changes-of-variables
77
+ * and parameter checks.
78
+ *
79
+ * Note that some of the boundary conditions here may be incorrect. Very little has been tested!
80
+ * This was converted directly from dtrsm.f using f2c, and then rewritten more cleanly.
81
+ */
82
+ template <typename DType>
83
+ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
84
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
85
+ const int m, const int n, const DType alpha, const DType* a,
86
+ const int lda, DType* b, const int ldb)
87
+ {
88
+
89
+ // (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3
90
+
91
+ if (m == 0 || n == 0) return; /* Quick return if possible. */
92
+
93
+ if (alpha == 0) { // Handle alpha == 0
94
+ for (int j = 0; j < n; ++j) {
95
+ for (int i = 0; i < m; ++i) {
96
+ b[i + j * ldb] = 0;
97
+ }
98
+ }
99
+ return;
100
+ }
101
+
102
+ if (side == CblasLeft) {
103
+ if (trans_a == CblasNoTrans) {
104
+
105
+ /* Form B := alpha*inv( A )*B. */
106
+ if (uplo == CblasUpper) {
107
+ for (int j = 0; j < n; ++j) {
108
+ if (alpha != 1) {
109
+ for (int i = 0; i < m; ++i) {
110
+ b[i + j * ldb] = alpha * b[i + j * ldb];
111
+ }
112
+ }
113
+ for (int k = m-1; k >= 0; --k) {
114
+ if (b[k + j * ldb] != 0) {
115
+ if (diag == CblasNonUnit) {
116
+ b[k + j * ldb] /= a[k + k * lda];
117
+ }
118
+
119
+ for (int i = 0; i < k-1; ++i) {
120
+ b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
121
+ }
122
+ }
123
+ }
124
+ }
125
+ } else {
126
+ for (int j = 0; j < n; ++j) {
127
+ if (alpha != 1) {
128
+ for (int i = 0; i < m; ++i) {
129
+ b[i + j * ldb] = alpha * b[i + j * ldb];
130
+ }
131
+ }
132
+ for (int k = 0; k < m; ++k) {
133
+ if (b[k + j * ldb] != 0.) {
134
+ if (diag == CblasNonUnit) {
135
+ b[k + j * ldb] /= a[k + k * lda];
136
+ }
137
+ for (int i = k+1; i < m; ++i) {
138
+ b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
139
+ }
140
+ }
141
+ }
142
+ }
143
+ }
144
+ } else { // CblasTrans
145
+
146
+ /* Form B := alpha*inv( A**T )*B. */
147
+ if (uplo == CblasUpper) {
148
+ for (int j = 0; j < n; ++j) {
149
+ for (int i = 0; i < m; ++i) {
150
+ DType temp = alpha * b[i + j * ldb];
151
+ for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
152
+ temp -= a[k + i * lda] * b[k + j * ldb];
153
+ }
154
+ if (diag == CblasNonUnit) {
155
+ temp /= a[i + i * lda];
156
+ }
157
+ b[i + j * ldb] = temp;
158
+ }
159
+ }
160
+ } else {
161
+ for (int j = 0; j < n; ++j) {
162
+ for (int i = m-1; i >= 0; --i) {
163
+ DType temp= alpha * b[i + j * ldb];
164
+ for (int k = i+1; k < m; ++k) {
165
+ temp -= a[k + i * lda] * b[k + j * ldb];
166
+ }
167
+ if (diag == CblasNonUnit) {
168
+ temp /= a[i + i * lda];
169
+ }
170
+ b[i + j * ldb] = temp;
171
+ }
172
+ }
173
+ }
174
+ }
175
+ } else { // right side
176
+
177
+ if (trans_a == CblasNoTrans) {
178
+
179
+ /* Form B := alpha*B*inv( A ). */
180
+
181
+ if (uplo == CblasUpper) {
182
+ for (int j = 0; j < n; ++j) {
183
+ if (alpha != 1) {
184
+ for (int i = 0; i < m; ++i) {
185
+ b[i + j * ldb] = alpha * b[i + j * ldb];
186
+ }
187
+ }
188
+ for (int k = 0; k < j-1; ++k) {
189
+ if (a[k + j * lda] != 0) {
190
+ for (int i = 0; i < m; ++i) {
191
+ b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
192
+ }
193
+ }
194
+ }
195
+ if (diag == CblasNonUnit) {
196
+ DType temp = 1 / a[j + j * lda];
197
+ for (int i = 0; i < m; ++i) {
198
+ b[i + j * ldb] = temp * b[i + j * ldb];
199
+ }
200
+ }
201
+ }
202
+ } else {
203
+ for (int j = n-1; j >= 0; --j) {
204
+ if (alpha != 1) {
205
+ for (int i = 0; i < m; ++i) {
206
+ b[i + j * ldb] = alpha * b[i + j * ldb];
207
+ }
208
+ }
209
+
210
+ for (int k = j+1; k < n; ++k) {
211
+ if (a[k + j * lda] != 0.) {
212
+ for (int i = 0; i < m; ++i) {
213
+ b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
214
+ }
215
+ }
216
+ }
217
+ if (diag == CblasNonUnit) {
218
+ DType temp = 1 / a[j + j * lda];
219
+
220
+ for (int i = 0; i < m; ++i) {
221
+ b[i + j * ldb] = temp * b[i + j * ldb];
222
+ }
223
+ }
224
+ }
225
+ }
226
+ } else { // CblasTrans
227
+
228
+ /* Form B := alpha*B*inv( A**T ). */
229
+
230
+ if (uplo == CblasUpper) {
231
+ for (int k = n-1; k >= 0; --k) {
232
+ if (diag == CblasNonUnit) {
233
+ DType temp= 1 / a[k + k * lda];
234
+ for (int i = 0; i < m; ++i) {
235
+ b[i + k * ldb] = temp * b[i + k * ldb];
236
+ }
237
+ }
238
+ for (int j = 0; j < k-1; ++j) {
239
+ if (a[j + k * lda] != 0.) {
240
+ DType temp= a[j + k * lda];
241
+ for (int i = 0; i < m; ++i) {
242
+ b[i + j * ldb] -= temp * b[i + k * ldb];
243
+ }
244
+ }
245
+ }
246
+ if (alpha != 1) {
247
+ for (int i = 0; i < m; ++i) {
248
+ b[i + k * ldb] = alpha * b[i + k * ldb];
249
+ }
250
+ }
251
+ }
252
+ } else {
253
+ for (int k = 0; k < n; ++k) {
254
+ if (diag == CblasNonUnit) {
255
+ DType temp = 1 / a[k + k * lda];
256
+ for (int i = 0; i < m; ++i) {
257
+ b[i + k * ldb] = temp * b[i + k * ldb];
258
+ }
259
+ }
260
+ for (int j = k+1; j < n; ++j) {
261
+ if (a[j + k * lda] != 0.) {
262
+ DType temp = a[j + k * lda];
263
+ for (int i = 0; i < m; ++i) {
264
+ b[i + j * ldb] -= temp * b[i + k * ldb];
265
+ }
266
+ }
267
+ }
268
+ if (alpha != 1) {
269
+ for (int i = 0; i < m; ++i) {
270
+ b[i + k * ldb] = alpha * b[i + k * ldb];
271
+ }
272
+ }
273
+ }
274
+ }
275
+ }
276
+ }
277
+ }
278
+
279
+ /*
280
+ * BLAS' DTRSM function, generalized.
281
+ */
282
+ template <typename DType, typename = typename std::enable_if<!std::is_integral<DType>::value>::type>
283
+ inline void trsm(const enum CBLAS_ORDER order,
284
+ const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
285
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
286
+ const int m, const int n, const DType alpha, const DType* a,
287
+ const int lda, DType* b, const int ldb)
288
+ {
289
+ /*using std::cerr;
290
+ using std::endl;*/
291
+
292
+ int num_rows_a = n;
293
+ if (side == CblasLeft) num_rows_a = m;
294
+
295
+ if (lda < std::max(1,num_rows_a)) {
296
+ fprintf(stderr, "TRSM: num_rows_a = %d; got lda=%d\n", num_rows_a, lda);
297
+ rb_raise(rb_eArgError, "TRSM: Expected lda >= max(1, num_rows_a)");
298
+ }
299
+
300
+ // Test the input parameters.
301
+ if (order == CblasRowMajor) {
302
+ if (ldb < std::max(1,n)) {
303
+ fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
304
+ rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,N)");
305
+ }
306
+
307
+ // For row major, need to switch side and uplo
308
+ enum CBLAS_SIDE side_ = side == CblasLeft ? CblasRight : CblasLeft;
309
+ enum CBLAS_UPLO uplo_ = uplo == CblasUpper ? CblasLower : CblasUpper;
310
+
311
+ /*
312
+ cerr << "(row-major) trsm: " << (side_ == CblasLeft ? "left " : "right ")
313
+ << (uplo_ == CblasUpper ? "upper " : "lower ")
314
+ << (trans_a == CblasTrans ? "trans " : "notrans ")
315
+ << (diag == CblasNonUnit ? "nonunit " : "unit ")
316
+ << n << " " << m << " " << alpha << " a " << lda << " b " << ldb << endl;
317
+ */
318
+ trsm_nothrow<DType>(side_, uplo_, trans_a, diag, n, m, alpha, a, lda, b, ldb);
319
+
320
+ } else { // CblasColMajor
321
+
322
+ if (ldb < std::max(1,m)) {
323
+ fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
324
+ rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,M)");
325
+ }
326
+ /*
327
+ cerr << "(col-major) trsm: " << (side == CblasLeft ? "left " : "right ")
328
+ << (uplo == CblasUpper ? "upper " : "lower ")
329
+ << (trans_a == CblasTrans ? "trans " : "notrans ")
330
+ << (diag == CblasNonUnit ? "nonunit " : "unit ")
331
+ << m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
332
+ */
333
+ trsm_nothrow<DType>(side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
334
+
335
+ }
336
+
337
+ }
338
+
339
+
340
+ template <>
341
+ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
342
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
343
+ const int m, const int n, const float alpha, const float* a,
344
+ const int lda, float* b, const int ldb)
345
+ {
346
+ cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
347
+ }
348
+
349
+ template <>
350
+ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
351
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
352
+ const int m, const int n, const double alpha, const double* a,
353
+ const int lda, double* b, const int ldb)
354
+ {
355
+ /* using std::cerr;
356
+ using std::endl;
357
+ cerr << "(row-major) dtrsm: " << (side == CblasLeft ? "left " : "right ")
358
+ << (uplo == CblasUpper ? "upper " : "lower ")
359
+ << (trans_a == CblasTrans ? "trans " : "notrans ")
360
+ << (diag == CblasNonUnit ? "nonunit " : "unit ")
361
+ << m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
362
+ */
363
+ cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
364
+ }
365
+
366
+
367
+ template <>
368
+ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
369
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
370
+ const int m, const int n, const Complex64 alpha, const Complex64* a,
371
+ const int lda, Complex64* b, const int ldb)
372
+ {
373
+ cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
374
+ }
375
+
376
+ template <>
377
+ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
378
+ const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag,
379
+ const int m, const int n, const Complex128 alpha, const Complex128* a,
380
+ const int lda, Complex128* b, const int ldb)
381
+ {
382
+ cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
383
+ }
384
+
385
+
386
+ } } // namespace nm::math
387
+ #endif // TRSM_H