nmatrix-gemv 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +29 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +14 -0
  5. data/Gemfile +7 -0
  6. data/README.md +29 -0
  7. data/Rakefile +225 -0
  8. data/ext/nmatrix_gemv/binary_format.txt +53 -0
  9. data/ext/nmatrix_gemv/data/complex.h +399 -0
  10. data/ext/nmatrix_gemv/data/data.cpp +298 -0
  11. data/ext/nmatrix_gemv/data/data.h +771 -0
  12. data/ext/nmatrix_gemv/data/meta.h +70 -0
  13. data/ext/nmatrix_gemv/data/rational.h +436 -0
  14. data/ext/nmatrix_gemv/data/ruby_object.h +471 -0
  15. data/ext/nmatrix_gemv/extconf.rb +254 -0
  16. data/ext/nmatrix_gemv/math.cpp +1639 -0
  17. data/ext/nmatrix_gemv/math/asum.h +143 -0
  18. data/ext/nmatrix_gemv/math/geev.h +82 -0
  19. data/ext/nmatrix_gemv/math/gemm.h +271 -0
  20. data/ext/nmatrix_gemv/math/gemv.h +212 -0
  21. data/ext/nmatrix_gemv/math/ger.h +96 -0
  22. data/ext/nmatrix_gemv/math/gesdd.h +80 -0
  23. data/ext/nmatrix_gemv/math/gesvd.h +78 -0
  24. data/ext/nmatrix_gemv/math/getf2.h +86 -0
  25. data/ext/nmatrix_gemv/math/getrf.h +240 -0
  26. data/ext/nmatrix_gemv/math/getri.h +108 -0
  27. data/ext/nmatrix_gemv/math/getrs.h +129 -0
  28. data/ext/nmatrix_gemv/math/idamax.h +86 -0
  29. data/ext/nmatrix_gemv/math/inc.h +47 -0
  30. data/ext/nmatrix_gemv/math/laswp.h +165 -0
  31. data/ext/nmatrix_gemv/math/long_dtype.h +52 -0
  32. data/ext/nmatrix_gemv/math/math.h +1069 -0
  33. data/ext/nmatrix_gemv/math/nrm2.h +181 -0
  34. data/ext/nmatrix_gemv/math/potrs.h +129 -0
  35. data/ext/nmatrix_gemv/math/rot.h +141 -0
  36. data/ext/nmatrix_gemv/math/rotg.h +115 -0
  37. data/ext/nmatrix_gemv/math/scal.h +73 -0
  38. data/ext/nmatrix_gemv/math/swap.h +73 -0
  39. data/ext/nmatrix_gemv/math/trsm.h +387 -0
  40. data/ext/nmatrix_gemv/nm_memory.h +60 -0
  41. data/ext/nmatrix_gemv/nmatrix_gemv.cpp +90 -0
  42. data/ext/nmatrix_gemv/nmatrix_gemv.h +374 -0
  43. data/ext/nmatrix_gemv/ruby_constants.cpp +153 -0
  44. data/ext/nmatrix_gemv/ruby_constants.h +107 -0
  45. data/ext/nmatrix_gemv/ruby_nmatrix.c +84 -0
  46. data/ext/nmatrix_gemv/ttable_helper.rb +122 -0
  47. data/ext/nmatrix_gemv/types.h +54 -0
  48. data/ext/nmatrix_gemv/util/util.h +78 -0
  49. data/lib/nmatrix-gemv.rb +43 -0
  50. data/lib/nmatrix_gemv/blas.rb +85 -0
  51. data/lib/nmatrix_gemv/nmatrix_gemv.rb +35 -0
  52. data/lib/nmatrix_gemv/rspec.rb +75 -0
  53. data/nmatrix-gemv.gemspec +31 -0
  54. data/spec/blas_spec.rb +154 -0
  55. data/spec/spec_helper.rb +128 -0
  56. metadata +186 -0
@@ -0,0 +1,86 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == idamax.h
25
+ //
26
+ // LAPACK idamax function in native C.
27
+ //
28
+
29
+ #ifndef IDAMAX_H
30
+ #define IDAMAX_H
31
+
32
+ namespace nm { namespace math {
33
+
34
+ /* Purpose */
35
+ /* ======= */
36
+
37
+ /* IDAMAX finds the index of element having max. absolute value. */
38
+
39
+ /* Further Details */
40
+ /* =============== */
41
+
42
+ /* jack dongarra, linpack, 3/11/78. */
43
+ /* modified 3/93 to return if incx .le. 0. */
44
+ /* modified 12/3/93, array(1) declarations changed to array(*) */
45
+
46
+ /* ===================================================================== */
47
+
48
+ template <typename DType>
49
+ inline int idamax(size_t n, DType *dx, int incx) {
50
+
51
+ /* Function Body */
52
+ if (n < 1 || incx <= 0) return -1;
53
+ if (n == 1) return 0;
54
+
55
+ DType dmax;
56
+ size_t imax = 0;
57
+
58
+ if (incx == 1) { // if incrementing by 1
59
+
60
+ dmax = abs(dx[0]);
61
+
62
+ for (size_t i = 1; i < n; ++i) {
63
+ if (std::abs(dx[i]) > dmax) {
64
+ imax = i;
65
+ dmax = std::abs(dx[i]);
66
+ }
67
+ }
68
+
69
+ } else { // if incrementing by more than 1
70
+
71
+ dmax = std::abs(dx[0]);
72
+
73
+ for (size_t i = 1, ix = incx; i < n; ++i, ix += incx) {
74
+ if (std::abs(dx[ix]) > dmax) {
75
+ imax = i;
76
+ dmax = std::abs(dx[ix]);
77
+ }
78
+ }
79
+ }
80
+ return imax;
81
+ } /* idamax_ */
82
+
83
+ }} // end of namespace nm::math
84
+
85
+ #endif
86
+
@@ -0,0 +1,47 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == inc.h
25
+ //
26
+ // Includes needed for LAPACK, CLAPACK, and CBLAS functions.
27
+ //
28
+
29
+ #ifndef INC_H
30
+ # define INC_H
31
+
32
+
33
+ extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
34
+ #if defined HAVE_CBLAS_H
35
+ #include <cblas.h>
36
+ #elif defined HAVE_ATLAS_CBLAS_H
37
+ #include <atlas/cblas.h>
38
+ #endif
39
+
40
+ #if defined HAVE_CLAPACK_H
41
+ #include <clapack.h>
42
+ #elif defined HAVE_ATLAS_CLAPACK_H
43
+ #include <atlas/clapack.h>
44
+ #endif
45
+ }
46
+
47
+ #endif // INC_H
@@ -0,0 +1,165 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == laswp.h
25
+ //
26
+ // laswp function in native C++.
27
+ //
28
+ /*
29
+ * Automatically Tuned Linear Algebra Software v3.8.4
30
+ * (C) Copyright 1999 R. Clint Whaley
31
+ *
32
+ * Redistribution and use in source and binary forms, with or without
33
+ * modification, are permitted provided that the following conditions
34
+ * are met:
35
+ * 1. Redistributions of source code must retain the above copyright
36
+ * notice, this list of conditions and the following disclaimer.
37
+ * 2. Redistributions in binary form must reproduce the above copyright
38
+ * notice, this list of conditions, and the following disclaimer in the
39
+ * documentation and/or other materials provided with the distribution.
40
+ * 3. The name of the ATLAS group or the names of its contributers may
41
+ * not be used to endorse or promote products derived from this
42
+ * software without specific written permission.
43
+ *
44
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
45
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
46
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
47
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
48
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
49
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
50
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
51
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
52
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
53
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
54
+ * POSSIBILITY OF SUCH DAMAGE.
55
+ *
56
+ */
57
+
58
+ #ifndef LASWP_H
59
+ #define LASWP_H
60
+
61
+ namespace nm { namespace math {
62
+
63
+
64
+ /*
65
+ * ATLAS function which performs row interchanges on a general rectangular matrix. Modeled after the LAPACK LASWP function.
66
+ *
67
+ * This version is templated for use by template <> getrf().
68
+ */
69
+ template <typename DType>
70
+ inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
71
+ //const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
72
+
73
+ int nb = N >> 5;
74
+
75
+ const int mr = N - (nb<<5);
76
+ const int incA = lda << 5;
77
+
78
+ if (K2 < K1) return;
79
+
80
+ int i1, i2;
81
+ if (inci < 0) {
82
+ piv -= (K2-1) * inci;
83
+ i1 = K2 - 1;
84
+ i2 = K1;
85
+ } else {
86
+ piv += K1 * inci;
87
+ i1 = K1;
88
+ i2 = K2-1;
89
+ }
90
+
91
+ if (nb) {
92
+
93
+ do {
94
+ const int* ipiv = piv;
95
+ int i = i1;
96
+ int KeepOn;
97
+
98
+ do {
99
+ int ip = *ipiv; ipiv += inci;
100
+
101
+ if (ip != i) {
102
+ DType *a0 = &(A[i]),
103
+ *a1 = &(A[ip]);
104
+
105
+ for (register int h = 32; h; h--) {
106
+ DType r = *a0;
107
+ *a0 = *a1;
108
+ *a1 = r;
109
+
110
+ a0 += lda;
111
+ a1 += lda;
112
+ }
113
+
114
+ }
115
+ if (inci > 0) KeepOn = (++i <= i2);
116
+ else KeepOn = (--i >= i2);
117
+
118
+ } while (KeepOn);
119
+ A += incA;
120
+ } while (--nb);
121
+ }
122
+
123
+ if (mr) {
124
+ const int* ipiv = piv;
125
+ int i = i1;
126
+ int KeepOn;
127
+
128
+ do {
129
+ int ip = *ipiv; ipiv += inci;
130
+ if (ip != i) {
131
+ DType *a0 = &(A[i]),
132
+ *a1 = &(A[ip]);
133
+
134
+ for (register int h = mr; h; h--) {
135
+ DType r = *a0;
136
+ *a0 = *a1;
137
+ *a1 = r;
138
+
139
+ a0 += lda;
140
+ a1 += lda;
141
+ }
142
+ }
143
+
144
+ if (inci > 0) KeepOn = (++i <= i2);
145
+ else KeepOn = (--i >= i2);
146
+
147
+ } while (KeepOn);
148
+ }
149
+ }
150
+
151
+
152
+ /*
153
+ * Function signature conversion for calling LAPACK's laswp functions as directly as possible.
154
+ *
155
+ * For documentation: http://www.netlib.org/lapack/double/dlaswp.f
156
+ *
157
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
158
+ */
159
+ template <typename DType>
160
+ inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
161
+ laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
162
+ }
163
+
164
+ } } // namespace nm::math
165
+ #endif // LASWP_H
@@ -0,0 +1,52 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == long_dtype.h
25
+ //
26
+ // Declarations necessary for the native versions of GEMM and GEMV.
27
+ //
28
+
29
+ #ifndef LONG_DTYPE_H
30
+ #define LONG_DTYPE_H
31
+
32
+ namespace nm { namespace math {
33
+ // These allow an increase in precision for intermediate values of gemm and gemv.
34
+ // See also: http://stackoverflow.com/questions/11873694/how-does-one-increase-precision-in-c-templates-in-a-template-typename-dependen
35
+ template <typename DType> struct LongDType;
36
+ template <> struct LongDType<uint8_t> { typedef int16_t type; };
37
+ template <> struct LongDType<int8_t> { typedef int16_t type; };
38
+ template <> struct LongDType<int16_t> { typedef int32_t type; };
39
+ template <> struct LongDType<int32_t> { typedef int64_t type; };
40
+ template <> struct LongDType<int64_t> { typedef int64_t type; };
41
+ template <> struct LongDType<float> { typedef double type; };
42
+ template <> struct LongDType<double> { typedef double type; };
43
+ template <> struct LongDType<Complex64> { typedef Complex128 type; };
44
+ template <> struct LongDType<Complex128> { typedef Complex128 type; };
45
+ template <> struct LongDType<Rational32> { typedef Rational128 type; };
46
+ template <> struct LongDType<Rational64> { typedef Rational128 type; };
47
+ template <> struct LongDType<Rational128> { typedef Rational128 type; };
48
+ template <> struct LongDType<RubyObject> { typedef RubyObject type; };
49
+
50
+ }} // end of namespace nm::math
51
+
52
+ #endif
@@ -0,0 +1,1069 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == math.h
25
+ //
26
+ // Header file for math functions, interfacing with BLAS, etc.
27
+ //
28
+ // For instructions on adding CBLAS and CLAPACK functions, see the
29
+ // beginning of math.cpp.
30
+ //
31
+ // Some of these functions are from ATLAS. Here is the license for
32
+ // ATLAS:
33
+ //
34
+ /*
35
+ * Automatically Tuned Linear Algebra Software v3.8.4
36
+ * (C) Copyright 1999 R. Clint Whaley
37
+ *
38
+ * Redistribution and use in source and binary forms, with or without
39
+ * modification, are permitted provided that the following conditions
40
+ * are met:
41
+ * 1. Redistributions of source code must retain the above copyright
42
+ * notice, this list of conditions and the following disclaimer.
43
+ * 2. Redistributions in binary form must reproduce the above copyright
44
+ * notice, this list of conditions, and the following disclaimer in the
45
+ * documentation and/or other materials provided with the distribution.
46
+ * 3. The name of the ATLAS group or the names of its contributers may
47
+ * not be used to endorse or promote products derived from this
48
+ * software without specific written permission.
49
+ *
50
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
51
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
52
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
53
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
54
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
55
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
56
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
57
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
58
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
59
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
60
+ * POSSIBILITY OF SUCH DAMAGE.
61
+ *
62
+ */
63
+
64
+ #ifndef MATH_H
65
+ #define MATH_H
66
+
67
+ /*
68
+ * Standard Includes
69
+ */
70
+
71
+ extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
72
+ #if defined HAVE_CBLAS_H
73
+ #include <cblas.h>
74
+ #elif defined HAVE_ATLAS_CBLAS_H
75
+ #include <atlas/cblas.h>
76
+ #endif
77
+
78
+ #if defined HAVE_CLAPACK_H
79
+ #include <clapack.h>
80
+ #elif defined HAVE_ATLAS_CLAPACK_H
81
+ #include <atlas/clapack.h>
82
+ #endif
83
+ }
84
+
85
+ #include <algorithm> // std::min, std::max
86
+ #include <limits> // std::numeric_limits
87
+
88
+ /*
89
+ * Project Includes
90
+ */
91
+
92
+ /*
93
+ * Macros
94
+ */
95
+ #define REAL_RECURSE_LIMIT 4
96
+
97
+ /*
98
+ * Data
99
+ */
100
+
101
+
102
+ extern "C" {
103
+ /*
104
+ * C accessors.
105
+ */
106
+ void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
107
+ void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
108
+ void nm_math_init_blas(void);
109
+
110
+ }
111
+
112
+
113
+ namespace nm {
114
+ namespace math {
115
+
116
+ /*
117
+ * Types
118
+ */
119
+
120
+
121
+ /*
122
+ * Functions
123
+ */
124
+
125
+
126
+ template <typename DType>
127
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
128
+ const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
129
+ rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes");
130
+ }
131
+
132
+ template <typename DType>
133
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
134
+ const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
135
+ rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes");
136
+ }
137
+
138
+ template <>
139
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
140
+ const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) {
141
+ cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
142
+ }
143
+
144
+ template <>
145
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
146
+ const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) {
147
+ cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
148
+ }
149
+
150
+ template <>
151
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
152
+ const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
153
+ cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
154
+ }
155
+
156
+ template <>
157
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
158
+ const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
159
+ cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
160
+ }
161
+
162
+
163
+ template <>
164
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
165
+ const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
166
+ cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
167
+ }
168
+
169
+ template <>
170
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
171
+ const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
172
+ cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
173
+ }
174
+
175
+
176
+ template <typename DType>
177
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
178
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha,
179
+ const DType* A, const int lda, DType* B, const int ldb) {
180
+ rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes");
181
+ }
182
+
183
+ template <>
184
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
185
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha,
186
+ const float* A, const int lda, float* B, const int ldb) {
187
+ cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
188
+ }
189
+
190
+ template <>
191
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
192
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha,
193
+ const double* A, const int lda, double* B, const int ldb) {
194
+ cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
195
+ }
196
+
197
+ template <>
198
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
199
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha,
200
+ const Complex64* A, const int lda, Complex64* B, const int ldb) {
201
+ cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
202
+ }
203
+
204
+ template <>
205
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
206
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha,
207
+ const Complex128* A, const int lda, Complex128* B, const int ldb) {
208
+ cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
209
+ }
210
+
211
+
212
+
213
+ // Yale: numeric matrix multiply c=a*b
214
+ template <typename DType>
215
+ inline void numbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const DType* a, const bool diaga,
216
+ const IType* ib, const IType* jb, const DType* b, const bool diagb, IType* ic, IType* jc, DType* c, const bool diagc) {
217
+ const unsigned int max_lmn = std::max(std::max(m, n), l);
218
+ IType next[max_lmn];
219
+ DType sums[max_lmn];
220
+
221
+ DType v;
222
+
223
+ IType head, length, temp, ndnz = 0;
224
+ IType minmn = std::min(m,n);
225
+ IType minlm = std::min(l,m);
226
+
227
+ for (IType idx = 0; idx < max_lmn; ++idx) { // initialize scratch arrays
228
+ next[idx] = std::numeric_limits<IType>::max();
229
+ sums[idx] = 0;
230
+ }
231
+
232
+ for (IType i = 0; i < n; ++i) { // walk down the rows
233
+ head = std::numeric_limits<IType>::max()-1; // head gets assigned as whichever column of B's row j we last visited
234
+ length = 0;
235
+
236
+ for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // walk through entries in each row
237
+ IType j;
238
+
239
+ if (jj == ia[i+1]) { // if we're in the last entry for this row:
240
+ if (!diaga || i >= minmn) continue;
241
+ j = i; // if it's a new Yale matrix, and last entry, get the diagonal position (j) and entry (ajj)
242
+ v = a[i];
243
+ } else {
244
+ j = ja[jj]; // if it's not the last entry for this row, get the column (j) and entry (ajj)
245
+ v = a[jj];
246
+ }
247
+
248
+ for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) {
249
+
250
+ IType k;
251
+
252
+ if (kk == ib[j+1]) { // Get the column id for that entry
253
+ if (!diagb || j >= minlm) continue;
254
+ k = j;
255
+ sums[k] += v*b[k];
256
+ } else {
257
+ k = jb[kk];
258
+ sums[k] += v*b[kk];
259
+ }
260
+
261
+ if (next[k] == std::numeric_limits<IType>::max()) {
262
+ next[k] = head;
263
+ head = k;
264
+ ++length;
265
+ }
266
+ } // end of kk loop
267
+ } // end of jj loop
268
+
269
+ for (IType jj = 0; jj < length; ++jj) {
270
+ if (sums[head] != 0) {
271
+ if (diagc && head == i) {
272
+ c[head] = sums[head];
273
+ } else {
274
+ jc[n+1+ndnz] = head;
275
+ c[n+1+ndnz] = sums[head];
276
+ ++ndnz;
277
+ }
278
+ }
279
+
280
+ temp = head;
281
+ head = next[head];
282
+
283
+ next[temp] = std::numeric_limits<IType>::max();
284
+ sums[temp] = 0;
285
+ }
286
+
287
+ ic[i+1] = n+1+ndnz;
288
+ }
289
+ } /* numbmm_ */
290
+
291
+
292
+ /*
293
+ template <typename DType, typename IType>
294
+ inline void new_yale_matrix_multiply(const unsigned int m, const IType* ija, const DType* a, const IType* ijb, const DType* b, YALE_STORAGE* c_storage) {
295
+ unsigned int n = c_storage->shape[0],
296
+ l = c_storage->shape[1];
297
+
298
+ // Create a working vector of dimension max(m,l,n) and initial value IType::max():
299
+ std::vector<IType> mask(std::max(std::max(m,l),n), std::numeric_limits<IType>::max());
300
+
301
+ for (IType i = 0; i < n; ++i) { // A.rows.each_index do |i|
302
+
303
+ IType j, k;
304
+ size_t ndnz;
305
+
306
+ for (IType jj = ija[i]; jj <= ija[i+1]; ++jj) { // walk through column pointers for row i of A
307
+ j = (jj == ija[i+1]) ? i : ija[jj]; // Get the current column index (handle diagonals last)
308
+
309
+ if (j >= m) {
310
+ if (j == ija[jj]) rb_raise(rb_eIndexError, "ija array for left-hand matrix contains an out-of-bounds column index %u at position %u", jj, j);
311
+ else break;
312
+ }
313
+
314
+ for (IType kk = ijb[j]; kk <= ijb[j+1]; ++kk) { // walk through column pointers for row j of B
315
+ if (j >= m) continue; // first of all, does B *have* a row j?
316
+ k = (kk == ijb[j+1]) ? j : ijb[kk]; // Get the current column index (handle diagonals last)
317
+
318
+ if (k >= l) {
319
+ if (k == ijb[kk]) rb_raise(rb_eIndexError, "ija array for right-hand matrix contains an out-of-bounds column index %u at position %u", kk, k);
320
+ else break;
321
+ }
322
+
323
+ if (mask[k] == )
324
+ }
325
+
326
+ }
327
+ }
328
+ }
329
+ */
330
+
331
+ // Yale: Symbolic matrix multiply c=a*b
332
+ inline size_t symbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const bool diaga,
333
+ const IType* ib, const IType* jb, const bool diagb, IType* ic, const bool diagc) {
334
+ unsigned int max_lmn = std::max(std::max(m,n), l);
335
+ IType mask[max_lmn]; // INDEX in the SMMP paper.
336
+ IType j, k; /* Local variables */
337
+ size_t ndnz = n;
338
+
339
+ for (IType idx = 0; idx < max_lmn; ++idx)
340
+ mask[idx] = std::numeric_limits<IType>::max();
341
+
342
+ if (ic) { // Only write to ic if it's supplied; otherwise, we're just counting.
343
+ if (diagc) ic[0] = n+1;
344
+ else ic[0] = 0;
345
+ }
346
+
347
+ IType minmn = std::min(m,n);
348
+ IType minlm = std::min(l,m);
349
+
350
+ for (IType i = 0; i < n; ++i) { // MAIN LOOP: through rows
351
+
352
+ for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // merge row lists, walking through columns in each row
353
+
354
+ // j <- column index given by JA[jj], or handle diagonal.
355
+ if (jj == ia[i+1]) { // Don't really do it the last time -- just handle diagonals in a new yale matrix.
356
+ if (!diaga || i >= minmn) continue;
357
+ j = i;
358
+ } else j = ja[jj];
359
+
360
+ for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) { // Now walk through columns K of row J in matrix B.
361
+ if (kk == ib[j+1]) {
362
+ if (!diagb || j >= minlm) continue;
363
+ k = j;
364
+ } else k = jb[kk];
365
+
366
+ if (mask[k] != i) {
367
+ mask[k] = i;
368
+ ++ndnz;
369
+ }
370
+ }
371
+ }
372
+
373
+ if (diagc && mask[i] == std::numeric_limits<IType>::max()) --ndnz;
374
+
375
+ if (ic) ic[i+1] = ndnz;
376
+ }
377
+
378
+ return ndnz;
379
+ } /* symbmm_ */
380
+
381
+
382
+ // In-place quicksort (from Wikipedia) -- called by smmp_sort_columns, below. All functions are inclusive of left, right.
383
+ namespace smmp_sort {
384
+ const size_t THRESHOLD = 4; // switch to insertion sort for 4 elements or fewer
385
+
386
+ template <typename DType>
387
+ void print_array(DType* vals, IType* array, IType left, IType right) {
388
+ for (IType i = left; i <= right; ++i) {
389
+ std::cerr << array[i] << ":" << vals[i] << " ";
390
+ }
391
+ std::cerr << std::endl;
392
+ }
393
+
394
+ template <typename DType>
395
+ IType partition(DType* vals, IType* array, IType left, IType right, IType pivot) {
396
+ IType pivotJ = array[pivot];
397
+ DType pivotV = vals[pivot];
398
+
399
+ // Swap pivot and right
400
+ array[pivot] = array[right];
401
+ vals[pivot] = vals[right];
402
+ array[right] = pivotJ;
403
+ vals[right] = pivotV;
404
+
405
+ IType store = left;
406
+ for (IType idx = left; idx < right; ++idx) {
407
+ if (array[idx] <= pivotJ) {
408
+ // Swap i and store
409
+ std::swap(array[idx], array[store]);
410
+ std::swap(vals[idx], vals[store]);
411
+ ++store;
412
+ }
413
+ }
414
+
415
+ std::swap(array[store], array[right]);
416
+ std::swap(vals[store], vals[right]);
417
+
418
+ return store;
419
+ }
420
+
421
+ // Recommended to use the median of left, right, and mid for the pivot.
422
+ template <typename I>
423
+ inline I median(I a, I b, I c) {
424
+ if (a < b) {
425
+ if (b < c) return b; // a b c
426
+ if (a < c) return c; // a c b
427
+ return a; // c a b
428
+
429
+ } else { // a > b
430
+ if (a < c) return a; // b a c
431
+ if (b < c) return c; // b c a
432
+ return b; // c b a
433
+ }
434
+ }
435
+
436
+
437
+ // Insertion sort is more efficient than quicksort for small N
438
+ template <typename DType>
439
+ void insertion_sort(DType* vals, IType* array, IType left, IType right) {
440
+ for (IType idx = left; idx <= right; ++idx) {
441
+ IType col_to_insert = array[idx];
442
+ DType val_to_insert = vals[idx];
443
+
444
+ IType hole_pos = idx;
445
+ for (; hole_pos > left && col_to_insert < array[hole_pos-1]; --hole_pos) {
446
+ array[hole_pos] = array[hole_pos - 1]; // shift the larger column index up
447
+ vals[hole_pos] = vals[hole_pos - 1]; // value goes along with it
448
+ }
449
+
450
+ array[hole_pos] = col_to_insert;
451
+ vals[hole_pos] = val_to_insert;
452
+ }
453
+ }
454
+
455
+
456
+ template <typename DType>
457
+ void quicksort(DType* vals, IType* array, IType left, IType right) {
458
+
459
+ if (left < right) {
460
+ if (right - left < THRESHOLD) {
461
+ insertion_sort(vals, array, left, right);
462
+ } else {
463
+ // choose any pivot such that left < pivot < right
464
+ IType pivot = median<IType>(left, right, (IType)(((unsigned long)left + (unsigned long)right) / 2));
465
+ pivot = partition(vals, array, left, right, pivot);
466
+
467
+ // recursively sort elements smaller than the pivot
468
+ quicksort<DType>(vals, array, left, pivot-1);
469
+
470
+ // recursively sort elements at least as big as the pivot
471
+ quicksort<DType>(vals, array, pivot+1, right);
472
+ }
473
+ }
474
+ }
475
+
476
+
477
+ }; // end of namespace smmp_sort
478
+
479
+
480
+ /*
481
+ * For use following symbmm and numbmm. Sorts the matrix entries in each row according to the column index.
482
+ * This utilizes quicksort, which is an in-place unstable sort (since there are no duplicate entries, we don't care
483
+ * about stability).
484
+ *
485
+ * TODO: It might be worthwhile to do a test for free memory, and if available, use an unstable sort that isn't in-place.
486
+ *
487
+ * TODO: It's actually probably possible to write an even faster sort, since symbmm/numbmm are not producing a random
488
+ * ordering. If someone is doing a lot of Yale matrix multiplication, it might benefit them to consider even insertion
489
+ * sort.
490
+ */
491
+ template <typename DType>
492
+ inline void smmp_sort_columns(const size_t n, const IType* ia, IType* ja, DType* a) {
493
+ for (size_t i = 0; i < n; ++i) {
494
+ if (ia[i+1] - ia[i] < 2) continue; // no need to sort rows containing only one or two elements.
495
+ else if (ia[i+1] - ia[i] <= smmp_sort::THRESHOLD) {
496
+ smmp_sort::insertion_sort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for small rows
497
+ } else {
498
+ smmp_sort::quicksort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for large rows (and may call insertion_sort as well)
499
+ }
500
+ }
501
+ }
502
+
503
+
504
+ /*
505
+ * From ATLAS 3.8.0:
506
+ *
507
+ * Computes one of two LU factorizations based on the setting of the Order
508
+ * parameter, as follows:
509
+ * ----------------------------------------------------------------------------
510
+ * Order == CblasColMajor
511
+ * Column-major factorization of form
512
+ * A = P * L * U
513
+ * where P is a row-permutation matrix, L is lower triangular with unit
514
+ * diagonal elements (lower trapazoidal if M > N), and U is upper triangular
515
+ * (upper trapazoidal if M < N).
516
+ *
517
+ * ----------------------------------------------------------------------------
518
+ * Order == CblasRowMajor
519
+ * Row-major factorization of form
520
+ * A = P * L * U
521
+ * where P is a column-permutation matrix, L is lower triangular (lower
522
+ * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
523
+ * trapazoidal if M < N).
524
+ *
525
+ * ============================================================================
526
+ * Let IERR be the return value of the function:
527
+ * If IERR == 0, successful exit.
528
+ * If (IERR < 0) the -IERR argument had an illegal value
529
+ * If (IERR > 0 && Order == CblasColMajor)
530
+ * U(i-1,i-1) is exactly zero. The factorization has been completed,
531
+ * but the factor U is exactly singular, and division by zero will
532
+ * occur if it is used to solve a system of equations.
533
+ * If (IERR > 0 && Order == CblasRowMajor)
534
+ * L(i-1,i-1) is exactly zero. The factorization has been completed,
535
+ * but the factor L is exactly singular, and division by zero will
536
+ * occur if it is used to solve a system of equations.
537
+ */
538
+ template <typename DType>
539
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
540
+ #if defined HAVE_CLAPACK_H || defined HAVE_ATLAS_CLAPACK_H
541
+ rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
542
+ #else
543
+ rb_raise(rb_eNotImpError, "only CLAPACK version implemented thus far");
544
+ #endif
545
+ return 0;
546
+ }
547
+
548
+ #if defined HAVE_CLAPACK_H || defined HAVE_ATLAS_CLAPACK_H
549
+ template <>
550
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
551
+ return clapack_spotrf(order, uplo, N, A, lda);
552
+ }
553
+
554
+ template <>
555
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
556
+ return clapack_dpotrf(order, uplo, N, A, lda);
557
+ }
558
+
559
+ template <>
560
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
561
+ return clapack_cpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
562
+ }
563
+
564
+ template <>
565
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
566
+ return clapack_zpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
567
+ }
568
+ #endif
569
+
570
+
571
+
572
+ // Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
573
+ //
574
+ // From ATLAS 3.8.0.
575
+ template <typename DType>
576
+ static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
577
+
578
+ for (int i = 0; i != M; ++i) {
579
+ for (int j = i+1; j < N; ++j) {
580
+ C[j] = U[j];
581
+ U[j] = 0;
582
+ }
583
+
584
+ C += ldc;
585
+ U += ldu;
586
+ }
587
+ }
588
+
589
+
590
+ /*
591
+ * Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
592
+ * functions. This is probably really complicated.
593
+ *
594
+ * Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
595
+ *
596
+ */
597
+
598
+ /*
599
+
600
+ template <bool RowMajor, bool Upper, typename DType>
601
+ static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
602
+
603
+ if (RowMajor) {
604
+ DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
605
+ DType tmp;
606
+ if (Upper) {
607
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
608
+ A12 = pA1[2], A13 = pA1[3],
609
+ A23 = pA2[3];
610
+
611
+ if (Diag == CblasNonUnit) {
612
+ pA0->inverse();
613
+ (pA1+1)->inverse();
614
+ (pA2+2)->inverse();
615
+ (pA3+3)->inverse();
616
+
617
+ pA0[1] = -A01 * pA1[1] * pA0[0];
618
+ pA1[2] = -A12 * pA2[2] * pA1[1];
619
+ pA2[3] = -A23 * pA3[3] * pA2[2];
620
+
621
+ pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
622
+ pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
623
+
624
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
625
+
626
+ } else {
627
+
628
+ pA0[1] = -A01;
629
+ pA1[2] = -A12;
630
+ pA2[3] = -A23;
631
+
632
+ pA0[2] = -(A01 * pA1[2] + A02);
633
+ pA1[3] = -(A12 * pA2[3] + A13);
634
+
635
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
636
+ }
637
+
638
+ } else { // Lower
639
+ DType A10 = pA1[0],
640
+ A20 = pA2[0], A21 = pA2[1],
641
+ A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
642
+ DType *B10 = pA1,
643
+ *B20 = pA2,
644
+ *B30 = pA3,
645
+ *B21 = pA2+1,
646
+ *B31 = pA3+1,
647
+ *B32 = pA3+2;
648
+
649
+
650
+ if (Diag == CblasNonUnit) {
651
+ pA0->inverse();
652
+ (pA1+1)->inverse();
653
+ (pA2+2)->inverse();
654
+ (pA3+3)->inverse();
655
+
656
+ *B10 = -A10 * pA0[0] * pA1[1];
657
+ *B21 = -A21 * pA1[1] * pA2[2];
658
+ *B32 = -A32 * pA2[2] * pA3[3];
659
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
660
+ *B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
661
+ *B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
662
+ } else {
663
+ *B10 = -A10;
664
+ *B21 = -A21;
665
+ *B32 = -A32;
666
+ *B20 = -(A20 + A21 * (*B10));
667
+ *B31 = -(A31 + A32 * (*B21));
668
+ *B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
669
+ }
670
+ }
671
+
672
+ } else {
673
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
674
+ }
675
+
676
+ return 0;
677
+
678
+ }
679
+
680
+
681
+ template <bool RowMajor, bool Upper, typename DType>
682
+ static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
683
+
684
+ if (RowMajor) {
685
+
686
+ DType tmp;
687
+
688
+ if (Upper) {
689
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
690
+ A12 = pA1[2], A13 = pA1[3];
691
+
692
+ DType *B01 = pA0 + 1,
693
+ *B02 = pA0 + 2,
694
+ *B12 = pA1 + 2;
695
+
696
+ if (Diag == CblasNonUnit) {
697
+ pA0->inverse();
698
+ (pA1+1)->inverse();
699
+ (pA2+2)->inverse();
700
+
701
+ *B01 = -A01 * pA1[1] * pA0[0];
702
+ *B12 = -A12 * pA2[2] * pA1[1];
703
+ *B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
704
+ } else {
705
+ *B01 = -A01;
706
+ *B12 = -A12;
707
+ *B02 = -(A01 * (*B12) + A02);
708
+ }
709
+
710
+ } else { // Lower
711
+ DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
712
+ DType A10=pA1[0],
713
+ A20=pA2[0], A21=pA2[1];
714
+
715
+ DType *B10 = pA1,
716
+ *B20 = pA2;
717
+ *B21 = pA2+1;
718
+
719
+ if (Diag == CblasNonUnit) {
720
+ pA0->inverse();
721
+ (pA1+1)->inverse();
722
+ (pA2+2)->inverse();
723
+ *B10 = -A10 * pA0[0] * pA1[1];
724
+ *B21 = -A21 * pA1[1] * pA2[2];
725
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
726
+ } else {
727
+ *B10 = -A10;
728
+ *B21 = -A21;
729
+ *B20 = -(A20 + A21 * (*B10));
730
+ }
731
+ }
732
+
733
+
734
+ } else {
735
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
736
+ }
737
+
738
+ return 0;
739
+
740
+ }
741
+
742
+ template <bool RowMajor, bool Upper, bool Real, typename DType>
743
+ static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
744
+ DType *Age, *Atr;
745
+ DType tmp;
746
+ int Nleft, Nright;
747
+
748
+ int ierr = 0;
749
+
750
+ static const DType ONE = 1;
751
+ static const DType MONE -1;
752
+ static const DType NONE = -1;
753
+
754
+ if (RowMajor) {
755
+
756
+ // FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
757
+ if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
758
+ Nleft = N >> 1;
759
+ #ifdef NB
760
+ if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
761
+ #endif
762
+
763
+ Nright = N - Nleft;
764
+
765
+ if (Upper) {
766
+ Age = A + Nleft;
767
+ Atr = A + (Nleft * (lda+1));
768
+
769
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
770
+ Nleft, Nright, ONE, Atr, lda, Age, lda);
771
+
772
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
773
+ Nleft, Nright, MONE, A, lda, Age, lda);
774
+
775
+ } else { // Lower
776
+ Age = A + ((Nleft*lda));
777
+ Atr = A + (Nleft * (lda+1));
778
+
779
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
780
+ Nright, Nleft, ONE, A, lda, Age, lda);
781
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
782
+ Nright, Nleft, MONE, Atr, lda, Age, lda);
783
+ }
784
+
785
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
786
+ if (ierr) return ierr;
787
+
788
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
789
+ if (ierr) return ierr + Nleft;
790
+
791
+ } else {
792
+ if (Real) {
793
+ if (N == 4) {
794
+ return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
795
+ } else if (N == 3) {
796
+ return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
797
+ } else if (N == 2) {
798
+ if (Diag == CblasNonUnit) {
799
+ A->inverse();
800
+ (A+(lda+1))->inverse();
801
+
802
+ if (Upper) {
803
+ *(A+1) *= *A; // TRI_MUL
804
+ *(A+1) *= *(A+lda+1); // TRI_MUL
805
+ } else {
806
+ *(A+lda) *= *A; // TRI_MUL
807
+ *(A+lda) *= *(A+lda+1); // TRI_MUL
808
+ }
809
+ }
810
+
811
+ if (Upper) *(A+1) = -*(A+1); // TRI_NEG
812
+ else *(A+lda) = -*(A+lda); // TRI_NEG
813
+ } else if (Diag == CblasNonUnit) A->inverse();
814
+ } else { // not real
815
+ if (Diag == CblasNonUnit) A->inverse();
816
+ }
817
+ }
818
+
819
+ } else {
820
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
821
+ }
822
+
823
+ return ierr;
824
+ }
825
+
826
+
827
+ template <bool RowMajor, bool Real, typename DType>
828
+ int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
829
+
830
+ if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
831
+
832
+ int jb, nb, I, ndown, iret;
833
+
834
+ const DType ONE = 1, NONE = -1;
835
+
836
+ int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
837
+ if (!iret && N > 1) {
838
+ jb = lwrk / N;
839
+ if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
840
+ else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
841
+ else nb = jb;
842
+ if (!nb) return -6; // need at least 1 row of workspace
843
+
844
+ // only first iteration will have partial block, unroll it
845
+
846
+ jb = N - (N/nb) * nb;
847
+ if (!jb) jb = nb;
848
+ I = N - jb;
849
+ A += lda * I;
850
+ trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
851
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
852
+ jb, N, ONE, wrk, jb, A, lda);
853
+
854
+ if (I) {
855
+ do {
856
+ I -= nb;
857
+ A -= nb * lda;
858
+ ndown = N-I;
859
+ trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
860
+ nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
861
+ nb, N, ONE, wrk, ndown, A, lda);
862
+ } while (I);
863
+ }
864
+
865
+ // Apply row interchanges
866
+
867
+ for (I = N - 2; I >= 0; --I) {
868
+ jb = ipiv[I];
869
+ if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
870
+ }
871
+ }
872
+
873
+ return iret;
874
+ }
875
+ */
876
+
877
+
878
+
879
+ template <bool is_complex, typename DType>
880
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
881
+
882
+ int Nleft, Nright;
883
+ const DType ONE = 1;
884
+ DType *G, *U0 = A, *U1;
885
+
886
+ if (N > 1) {
887
+ Nleft = N >> 1;
888
+ #ifdef NB
889
+ if (Nleft > NB) Nleft = ATL_MulByNB(ATL_DivByNB(Nleft));
890
+ #endif
891
+
892
+ Nright = N - Nleft;
893
+
894
+ // FIXME: There's a simpler way to write this next block, but I'm way too tired to work it out right now.
895
+ if (uplo == CblasUpper) {
896
+ if (order == CblasRowMajor) {
897
+ G = A + Nleft;
898
+ U1 = G + Nleft * lda;
899
+ } else {
900
+ G = A + Nleft * lda;
901
+ U1 = G + Nleft;
902
+ }
903
+ } else {
904
+ if (order == CblasRowMajor) {
905
+ G = A + Nleft * lda;
906
+ U1 = G + Nleft;
907
+ } else {
908
+ G = A + Nleft;
909
+ U1 = G + Nleft * lda;
910
+ }
911
+ }
912
+
913
+ lauum<is_complex, DType>(order, uplo, Nleft, U0, lda);
914
+
915
+ if (is_complex) {
916
+
917
+ nm::math::herk<DType>(order, uplo,
918
+ uplo == CblasLower ? CblasConjTrans : CblasNoTrans,
919
+ Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
920
+
921
+ nm::math::trmm<DType>(order, CblasLeft, uplo, CblasConjTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
922
+ } else {
923
+ nm::math::syrk<DType>(order, uplo,
924
+ uplo == CblasLower ? CblasTrans : CblasNoTrans,
925
+ Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
926
+
927
+ nm::math::trmm<DType>(order, CblasLeft, uplo, CblasTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
928
+ }
929
+ lauum<is_complex, DType>(order, uplo, Nright, U1, lda);
930
+
931
+ } else {
932
+ *A = *A * *A;
933
+ }
934
+ }
935
+
936
+
937
+ #if defined HAVE_CLAPACK_H || defined HAVE_ATLAS_CLAPACK_H
938
+ template <bool is_complex>
939
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
940
+ clapack_slauum(order, uplo, N, A, lda);
941
+ }
942
+
943
+ template <bool is_complex>
944
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
945
+ clapack_dlauum(order, uplo, N, A, lda);
946
+ }
947
+
948
+ template <bool is_complex>
949
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
950
+ clapack_clauum(order, uplo, N, A, lda);
951
+ }
952
+
953
+ template <bool is_complex>
954
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
955
+ clapack_zlauum(order, uplo, N, A, lda);
956
+ }
957
+ #endif
958
+
959
+
960
+ /*
961
+ * Function signature conversion for calling LAPACK's lauum functions as directly as possible.
962
+ *
963
+ * For documentation: http://www.netlib.org/lapack/double/dlauum.f
964
+ *
965
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
966
+ */
967
+ template <bool is_complex, typename DType>
968
+ inline int clapack_lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
969
+ if (n < 0) rb_raise(rb_eArgError, "n cannot be less than zero, is set to %d", n);
970
+ if (lda < n || lda < 1) rb_raise(rb_eArgError, "lda must be >= max(n,1); lda=%d, n=%d\n", lda, n);
971
+
972
+ if (uplo == CblasUpper) lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
973
+ else lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
974
+
975
+ return 0;
976
+ }
977
+
978
+
979
+
980
+
981
+ /*
982
+ * Macro for declaring LAPACK specializations of the getrf function.
983
+ *
984
+ * type is the DType; call is the specific function to call; cast_as is what the DType* should be
985
+ * cast to in order to pass it to LAPACK.
986
+ */
987
+ #define LAPACK_GETRF(type, call, cast_as) \
988
+ template <> \
989
+ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, type * A, const int lda, int* ipiv) { \
990
+ int info = call(Order, M, N, reinterpret_cast<cast_as *>(A), lda, ipiv); \
991
+ if (!info) return info; \
992
+ else { \
993
+ rb_raise(rb_eArgError, "getrf: problem with argument %d\n", info); \
994
+ return info; \
995
+ } \
996
+ }
997
+
998
+ /* Specialize for ATLAS types */
999
+ /*LAPACK_GETRF(float, clapack_sgetrf, float)
1000
+ LAPACK_GETRF(double, clapack_dgetrf, double)
1001
+ LAPACK_GETRF(Complex64, clapack_cgetrf, void)
1002
+ LAPACK_GETRF(Complex128, clapack_zgetrf, void)
1003
+ */
1004
+
1005
+
1006
+
1007
+ /*
1008
+ * Function signature conversion for calling LAPACK's potrf functions as directly as possible.
1009
+ *
1010
+ * For documentation: http://www.netlib.org/lapack/double/dpotrf.f
1011
+ *
1012
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
1013
+ */
1014
+ template <typename DType>
1015
+ inline int clapack_potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
1016
+ return potrf<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
1017
+ }
1018
+
1019
+
1020
+
1021
+ template <typename DType>
1022
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, DType* a, const int lda) {
1023
+ rb_raise(rb_eNotImpError, "potri not yet implemented for non-BLAS dtypes");
1024
+ return 0;
1025
+ }
1026
+
1027
+
1028
+ #if defined HAVE_CLAPACK_H || defined HAVE_ATLAS_CLAPACK_H
1029
+ template <>
1030
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, float* a, const int lda) {
1031
+ return clapack_spotri(order, uplo, n, a, lda);
1032
+ }
1033
+
1034
+ template <>
1035
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, double* a, const int lda) {
1036
+ return clapack_dpotri(order, uplo, n, a, lda);
1037
+ }
1038
+
1039
+ template <>
1040
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex64* a, const int lda) {
1041
+ return clapack_cpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
1042
+ }
1043
+
1044
+ template <>
1045
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex128* a, const int lda) {
1046
+ return clapack_zpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
1047
+ }
1048
+ #endif
1049
+
1050
+
1051
+ /*
1052
+ * Function signature conversion for calling LAPACK's potri functions as directly as possible.
1053
+ *
1054
+ * For documentation: http://www.netlib.org/lapack/double/dpotri.f
1055
+ *
1056
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
1057
+ */
1058
+ template <typename DType>
1059
+ inline int clapack_potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
1060
+ return potri<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
1061
+ }
1062
+
1063
+
1064
+
1065
+
1066
+ }} // end namespace nm::math
1067
+
1068
+
1069
+ #endif // MATH_H