nmatrix-fftw 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. checksums.yaml +7 -0
  2. data/ext/nmatrix/data/complex.h +388 -0
  3. data/ext/nmatrix/data/data.h +652 -0
  4. data/ext/nmatrix/data/meta.h +64 -0
  5. data/ext/nmatrix/data/ruby_object.h +389 -0
  6. data/ext/nmatrix/math/asum.h +120 -0
  7. data/ext/nmatrix/math/cblas_enums.h +36 -0
  8. data/ext/nmatrix/math/cblas_templates_core.h +507 -0
  9. data/ext/nmatrix/math/gemm.h +241 -0
  10. data/ext/nmatrix/math/gemv.h +178 -0
  11. data/ext/nmatrix/math/getrf.h +255 -0
  12. data/ext/nmatrix/math/getrs.h +121 -0
  13. data/ext/nmatrix/math/imax.h +79 -0
  14. data/ext/nmatrix/math/laswp.h +165 -0
  15. data/ext/nmatrix/math/long_dtype.h +49 -0
  16. data/ext/nmatrix/math/math.h +745 -0
  17. data/ext/nmatrix/math/nrm2.h +160 -0
  18. data/ext/nmatrix/math/rot.h +117 -0
  19. data/ext/nmatrix/math/rotg.h +106 -0
  20. data/ext/nmatrix/math/scal.h +71 -0
  21. data/ext/nmatrix/math/trsm.h +332 -0
  22. data/ext/nmatrix/math/util.h +148 -0
  23. data/ext/nmatrix/nm_memory.h +60 -0
  24. data/ext/nmatrix/nmatrix.h +438 -0
  25. data/ext/nmatrix/ruby_constants.h +106 -0
  26. data/ext/nmatrix/storage/common.h +177 -0
  27. data/ext/nmatrix/storage/dense/dense.h +129 -0
  28. data/ext/nmatrix/storage/list/list.h +138 -0
  29. data/ext/nmatrix/storage/storage.h +99 -0
  30. data/ext/nmatrix/storage/yale/class.h +1139 -0
  31. data/ext/nmatrix/storage/yale/iterators/base.h +143 -0
  32. data/ext/nmatrix/storage/yale/iterators/iterator.h +131 -0
  33. data/ext/nmatrix/storage/yale/iterators/row.h +450 -0
  34. data/ext/nmatrix/storage/yale/iterators/row_stored.h +140 -0
  35. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +169 -0
  36. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +124 -0
  37. data/ext/nmatrix/storage/yale/math/transpose.h +110 -0
  38. data/ext/nmatrix/storage/yale/yale.h +203 -0
  39. data/ext/nmatrix/types.h +55 -0
  40. data/ext/nmatrix/util/io.h +115 -0
  41. data/ext/nmatrix/util/sl_list.h +144 -0
  42. data/ext/nmatrix/util/util.h +78 -0
  43. data/ext/nmatrix_fftw/extconf.rb +122 -0
  44. data/ext/nmatrix_fftw/nmatrix_fftw.cpp +274 -0
  45. data/lib/nmatrix/fftw.rb +343 -0
  46. data/spec/00_nmatrix_spec.rb +736 -0
  47. data/spec/01_enum_spec.rb +190 -0
  48. data/spec/02_slice_spec.rb +389 -0
  49. data/spec/03_nmatrix_monkeys_spec.rb +78 -0
  50. data/spec/2x2_dense_double.mat +0 -0
  51. data/spec/4x4_sparse.mat +0 -0
  52. data/spec/4x5_dense.mat +0 -0
  53. data/spec/blas_spec.rb +193 -0
  54. data/spec/elementwise_spec.rb +303 -0
  55. data/spec/homogeneous_spec.rb +99 -0
  56. data/spec/io/fortran_format_spec.rb +88 -0
  57. data/spec/io/harwell_boeing_spec.rb +98 -0
  58. data/spec/io/test.rua +9 -0
  59. data/spec/io_spec.rb +149 -0
  60. data/spec/lapack_core_spec.rb +482 -0
  61. data/spec/leakcheck.rb +16 -0
  62. data/spec/math_spec.rb +807 -0
  63. data/spec/nmatrix_yale_resize_test_associations.yaml +2802 -0
  64. data/spec/nmatrix_yale_spec.rb +286 -0
  65. data/spec/plugins/fftw/fftw_spec.rb +348 -0
  66. data/spec/rspec_monkeys.rb +56 -0
  67. data/spec/rspec_spec.rb +34 -0
  68. data/spec/shortcuts_spec.rb +310 -0
  69. data/spec/slice_set_spec.rb +157 -0
  70. data/spec/spec_helper.rb +149 -0
  71. data/spec/stat_spec.rb +203 -0
  72. data/spec/test.pcd +20 -0
  73. data/spec/utm5940.mtx +83844 -0
  74. metadata +151 -0
@@ -0,0 +1,121 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == getrs.h
25
+ //
26
+ // getrs function in native C++.
27
+ //
28
+
29
+ /*
30
+ * Automatically Tuned Linear Algebra Software v3.8.4
31
+ * (C) Copyright 1999 R. Clint Whaley
32
+ *
33
+ * Redistribution and use in source and binary forms, with or without
34
+ * modification, are permitted provided that the following conditions
35
+ * are met:
36
+ * 1. Redistributions of source code must retain the above copyright
37
+ * notice, this list of conditions and the following disclaimer.
38
+ * 2. Redistributions in binary form must reproduce the above copyright
39
+ * notice, this list of conditions, and the following disclaimer in the
40
+ * documentation and/or other materials provided with the distribution.
41
+ * 3. The name of the ATLAS group or the names of its contributers may
42
+ * not be used to endorse or promote products derived from this
43
+ * software without specific written permission.
44
+ *
45
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
46
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
47
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
48
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
49
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
50
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
51
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
52
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
53
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
54
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
55
+ * POSSIBILITY OF SUCH DAMAGE.
56
+ *
57
+ */
58
+
59
+ #ifndef GETRS_H
60
+ #define GETRS_H
61
+
62
+ namespace nm { namespace math {
63
+
64
+
65
+ /*
66
+ * Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF.
67
+ *
68
+ * From ATLAS 3.8.0.
69
+ */
70
+ template <typename DType>
71
+ int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A,
72
+ const int lda, const int* ipiv, DType* B, const int ldb)
73
+ {
74
+ // enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
75
+
76
+ if (!N || !NRHS) return 0;
77
+
78
+ const DType ONE = 1;
79
+
80
+ if (Order == CblasColMajor) {
81
+ if (Trans == CblasNoTrans) {
82
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
83
+ nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
84
+ nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
85
+ } else {
86
+ nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
87
+ nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
88
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
89
+ }
90
+ } else {
91
+ if (Trans == CblasNoTrans) {
92
+ nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
93
+ nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
94
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
95
+ } else {
96
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
97
+ nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
98
+ nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
99
+ }
100
+ }
101
+ return 0;
102
+ }
103
+
104
+
105
+ /*
106
+ * Function signature conversion for calling LAPACK's getrs functions as directly as possible.
107
+ *
108
+ * For documentation: http://www.netlib.org/lapack/double/dgetrs.f
109
+ *
110
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
111
+ */
112
+ template <typename DType>
113
+ inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs,
114
+ const void* a, const int lda, const int* ipiv, void* b, const int ldb) {
115
+ return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb);
116
+ }
117
+
118
+
119
+ } } // end nm::math
120
+
121
+ #endif // GETRS_H
@@ -0,0 +1,79 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == imax.h
25
+ //
26
+ // BLAS level 1 function imax.
27
+ //
28
+
29
+ #ifndef IMAX_H
30
+ #define IMAX_H
31
+
32
+ namespace nm { namespace math {
33
+
34
+ template<typename DType>
35
+ inline int imax(const int n, const DType *x, const int incx) {
36
+
37
+ if (n < 1 || incx <= 0) {
38
+ return -1;
39
+ }
40
+ if (n == 1) {
41
+ return 0;
42
+ }
43
+
44
+ DType dmax;
45
+ int imax = 0;
46
+
47
+ if (incx == 1) { // if incrementing by 1
48
+
49
+ dmax = abs(x[0]);
50
+
51
+ for (int i = 1; i < n; ++i) {
52
+ if (std::abs(x[i]) > dmax) {
53
+ imax = i;
54
+ dmax = std::abs(x[i]);
55
+ }
56
+ }
57
+
58
+ } else { // if incrementing by more than 1
59
+
60
+ dmax = std::abs(x[0]);
61
+
62
+ for (int i = 1, ix = incx; i < n; ++i, ix += incx) {
63
+ if (std::abs(x[ix]) > dmax) {
64
+ imax = i;
65
+ dmax = std::abs(x[ix]);
66
+ }
67
+ }
68
+ }
69
+ return imax;
70
+ }
71
+
72
+ template<typename DType>
73
+ inline int cblas_imax(const int n, const void* x, const int incx) {
74
+ return imax<DType>(n, reinterpret_cast<const DType*>(x), incx);
75
+ }
76
+
77
+ }} // end of namespace nm::math
78
+
79
+ #endif /* IMAX_H */
@@ -0,0 +1,165 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == laswp.h
25
+ //
26
+ // laswp function in native C++.
27
+ //
28
+ /*
29
+ * Automatically Tuned Linear Algebra Software v3.8.4
30
+ * (C) Copyright 1999 R. Clint Whaley
31
+ *
32
+ * Redistribution and use in source and binary forms, with or without
33
+ * modification, are permitted provided that the following conditions
34
+ * are met:
35
+ * 1. Redistributions of source code must retain the above copyright
36
+ * notice, this list of conditions and the following disclaimer.
37
+ * 2. Redistributions in binary form must reproduce the above copyright
38
+ * notice, this list of conditions, and the following disclaimer in the
39
+ * documentation and/or other materials provided with the distribution.
40
+ * 3. The name of the ATLAS group or the names of its contributers may
41
+ * not be used to endorse or promote products derived from this
42
+ * software without specific written permission.
43
+ *
44
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
45
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
46
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
47
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
48
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
49
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
50
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
51
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
52
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
53
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
54
+ * POSSIBILITY OF SUCH DAMAGE.
55
+ *
56
+ */
57
+
58
+ #ifndef LASWP_H
59
+ #define LASWP_H
60
+
61
+ namespace nm { namespace math {
62
+
63
+
64
+ /*
65
+ * ATLAS function which performs row interchanges on a general rectangular matrix. Modeled after the LAPACK LASWP function.
66
+ *
67
+ * This version is templated for use by template <> getrf().
68
+ */
69
+ template <typename DType>
70
+ inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
71
+ //const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
72
+
73
+ int nb = N >> 5;
74
+
75
+ const int mr = N - (nb<<5);
76
+ const int incA = lda << 5;
77
+
78
+ if (K2 < K1) return;
79
+
80
+ int i1, i2;
81
+ if (inci < 0) {
82
+ piv -= (K2-1) * inci;
83
+ i1 = K2 - 1;
84
+ i2 = K1;
85
+ } else {
86
+ piv += K1 * inci;
87
+ i1 = K1;
88
+ i2 = K2-1;
89
+ }
90
+
91
+ if (nb) {
92
+
93
+ do {
94
+ const int* ipiv = piv;
95
+ int i = i1;
96
+ int KeepOn;
97
+
98
+ do {
99
+ int ip = *ipiv; ipiv += inci;
100
+
101
+ if (ip != i) {
102
+ DType *a0 = &(A[i]),
103
+ *a1 = &(A[ip]);
104
+
105
+ for (register int h = 32; h; h--) {
106
+ DType r = *a0;
107
+ *a0 = *a1;
108
+ *a1 = r;
109
+
110
+ a0 += lda;
111
+ a1 += lda;
112
+ }
113
+
114
+ }
115
+ if (inci > 0) KeepOn = (++i <= i2);
116
+ else KeepOn = (--i >= i2);
117
+
118
+ } while (KeepOn);
119
+ A += incA;
120
+ } while (--nb);
121
+ }
122
+
123
+ if (mr) {
124
+ const int* ipiv = piv;
125
+ int i = i1;
126
+ int KeepOn;
127
+
128
+ do {
129
+ int ip = *ipiv; ipiv += inci;
130
+ if (ip != i) {
131
+ DType *a0 = &(A[i]),
132
+ *a1 = &(A[ip]);
133
+
134
+ for (register int h = mr; h; h--) {
135
+ DType r = *a0;
136
+ *a0 = *a1;
137
+ *a1 = r;
138
+
139
+ a0 += lda;
140
+ a1 += lda;
141
+ }
142
+ }
143
+
144
+ if (inci > 0) KeepOn = (++i <= i2);
145
+ else KeepOn = (--i >= i2);
146
+
147
+ } while (KeepOn);
148
+ }
149
+ }
150
+
151
+
152
+ /*
153
+ * Function signature conversion for calling LAPACK's laswp functions as directly as possible.
154
+ *
155
+ * For documentation: http://www.netlib.org/lapack/double/dlaswp.f
156
+ *
157
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
158
+ */
159
+ template <typename DType>
160
+ inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
161
+ laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
162
+ }
163
+
164
+ } } // namespace nm::math
165
+ #endif // LASWP_H
@@ -0,0 +1,49 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == long_dtype.h
25
+ //
26
+ // Declarations necessary for the native versions of GEMM and GEMV.
27
+ //
28
+
29
+ #ifndef LONG_DTYPE_H
30
+ #define LONG_DTYPE_H
31
+
32
+ namespace nm { namespace math {
33
+ // These allow an increase in precision for intermediate values of gemm and gemv.
34
+ // See also: http://stackoverflow.com/questions/11873694/how-does-one-increase-precision-in-c-templates-in-a-template-typename-dependen
35
+ template <typename DType> struct LongDType;
36
+ template <> struct LongDType<uint8_t> { typedef int16_t type; };
37
+ template <> struct LongDType<int8_t> { typedef int16_t type; };
38
+ template <> struct LongDType<int16_t> { typedef int32_t type; };
39
+ template <> struct LongDType<int32_t> { typedef int64_t type; };
40
+ template <> struct LongDType<int64_t> { typedef int64_t type; };
41
+ template <> struct LongDType<float> { typedef double type; };
42
+ template <> struct LongDType<double> { typedef double type; };
43
+ template <> struct LongDType<Complex64> { typedef Complex128 type; };
44
+ template <> struct LongDType<Complex128> { typedef Complex128 type; };
45
+ template <> struct LongDType<RubyObject> { typedef RubyObject type; };
46
+
47
+ }} // end of namespace nm::math
48
+
49
+ #endif
@@ -0,0 +1,745 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == math.h
25
+ //
26
+ // Header file for math functions, interfacing with BLAS, etc.
27
+ //
28
+ // For instructions on adding CBLAS and CLAPACK functions, see the
29
+ // beginning of math.cpp.
30
+ //
31
+ // Some of these functions are from ATLAS. Here is the license for
32
+ // ATLAS:
33
+ //
34
+ /*
35
+ * Automatically Tuned Linear Algebra Software v3.8.4
36
+ * (C) Copyright 1999 R. Clint Whaley
37
+ *
38
+ * Redistribution and use in source and binary forms, with or without
39
+ * modification, are permitted provided that the following conditions
40
+ * are met:
41
+ * 1. Redistributions of source code must retain the above copyright
42
+ * notice, this list of conditions and the following disclaimer.
43
+ * 2. Redistributions in binary form must reproduce the above copyright
44
+ * notice, this list of conditions, and the following disclaimer in the
45
+ * documentation and/or other materials provided with the distribution.
46
+ * 3. The name of the ATLAS group or the names of its contributers may
47
+ * not be used to endorse or promote products derived from this
48
+ * software without specific written permission.
49
+ *
50
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
51
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
52
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
53
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
54
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
55
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
56
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
57
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
58
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
59
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
60
+ * POSSIBILITY OF SUCH DAMAGE.
61
+ *
62
+ */
63
+
64
+ #ifndef MATH_H
65
+ #define MATH_H
66
+
67
+ /*
68
+ * Standard Includes
69
+ */
70
+
71
+ #include "cblas_enums.h"
72
+
73
+ #include <algorithm> // std::min, std::max
74
+ #include <limits> // std::numeric_limits
75
+ #include <memory> // std::unique_ptr
76
+
77
+ /*
78
+ * Project Includes
79
+ */
80
+
81
+ /*
82
+ * Macros
83
+ */
84
+ #define REAL_RECURSE_LIMIT 4
85
+
86
+ /*
87
+ * Data
88
+ */
89
+
90
+
91
+ extern "C" {
92
+ /*
93
+ * C accessors.
94
+ */
95
+
96
+ void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
97
+ void nm_math_init_blas(void);
98
+
99
+ /*
100
+ * Pure math implementations.
101
+ */
102
+ void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv);
103
+ void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype);
104
+ void nm_math_hessenberg(VALUE a);
105
+ void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
106
+ void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype);
107
+ }
108
+
109
+
110
+ namespace nm {
111
+ namespace math {
112
+
113
+ /*
114
+ * Types
115
+ */
116
+
117
+
118
+ /*
119
+ * Functions
120
+ */
121
+
122
+ // Yale: numeric matrix multiply c=a*b
123
+ template <typename DType>
124
+ inline void numbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const DType* a, const bool diaga,
125
+ const IType* ib, const IType* jb, const DType* b, const bool diagb, IType* ic, IType* jc, DType* c, const bool diagc) {
126
+ const unsigned int max_lmn = std::max(std::max(m, n), l);
127
+ std::unique_ptr<IType[]> next(new IType[max_lmn]);
128
+ std::unique_ptr<DType[]> sums(new DType[max_lmn]);
129
+
130
+ DType v;
131
+
132
+ IType head, length, temp, ndnz = 0;
133
+ IType minmn = std::min(m,n);
134
+ IType minlm = std::min(l,m);
135
+
136
+ for (IType idx = 0; idx < max_lmn; ++idx) { // initialize scratch arrays
137
+ next[idx] = std::numeric_limits<IType>::max();
138
+ sums[idx] = 0;
139
+ }
140
+
141
+ for (IType i = 0; i < n; ++i) { // walk down the rows
142
+ head = std::numeric_limits<IType>::max()-1; // head gets assigned as whichever column of B's row j we last visited
143
+ length = 0;
144
+
145
+ for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // walk through entries in each row
146
+ IType j;
147
+
148
+ if (jj == ia[i+1]) { // if we're in the last entry for this row:
149
+ if (!diaga || i >= minmn) continue;
150
+ j = i; // if it's a new Yale matrix, and last entry, get the diagonal position (j) and entry (ajj)
151
+ v = a[i];
152
+ } else {
153
+ j = ja[jj]; // if it's not the last entry for this row, get the column (j) and entry (ajj)
154
+ v = a[jj];
155
+ }
156
+
157
+ for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) {
158
+
159
+ IType k;
160
+
161
+ if (kk == ib[j+1]) { // Get the column id for that entry
162
+ if (!diagb || j >= minlm) continue;
163
+ k = j;
164
+ sums[k] += v*b[k];
165
+ } else {
166
+ k = jb[kk];
167
+ sums[k] += v*b[kk];
168
+ }
169
+
170
+ if (next[k] == std::numeric_limits<IType>::max()) {
171
+ next[k] = head;
172
+ head = k;
173
+ ++length;
174
+ }
175
+ } // end of kk loop
176
+ } // end of jj loop
177
+
178
+ for (IType jj = 0; jj < length; ++jj) {
179
+ if (sums[head] != 0) {
180
+ if (diagc && head == i) {
181
+ c[head] = sums[head];
182
+ } else {
183
+ jc[n+1+ndnz] = head;
184
+ c[n+1+ndnz] = sums[head];
185
+ ++ndnz;
186
+ }
187
+ }
188
+
189
+ temp = head;
190
+ head = next[head];
191
+
192
+ next[temp] = std::numeric_limits<IType>::max();
193
+ sums[temp] = 0;
194
+ }
195
+
196
+ ic[i+1] = n+1+ndnz;
197
+ }
198
+ } /* numbmm_ */
199
+
200
+
201
+ /*
202
+ template <typename DType, typename IType>
203
+ inline void new_yale_matrix_multiply(const unsigned int m, const IType* ija, const DType* a, const IType* ijb, const DType* b, YALE_STORAGE* c_storage) {
204
+ unsigned int n = c_storage->shape[0],
205
+ l = c_storage->shape[1];
206
+
207
+ // Create a working vector of dimension max(m,l,n) and initial value IType::max():
208
+ std::vector<IType> mask(std::max(std::max(m,l),n), std::numeric_limits<IType>::max());
209
+
210
+ for (IType i = 0; i < n; ++i) { // A.rows.each_index do |i|
211
+
212
+ IType j, k;
213
+ size_t ndnz;
214
+
215
+ for (IType jj = ija[i]; jj <= ija[i+1]; ++jj) { // walk through column pointers for row i of A
216
+ j = (jj == ija[i+1]) ? i : ija[jj]; // Get the current column index (handle diagonals last)
217
+
218
+ if (j >= m) {
219
+ if (j == ija[jj]) rb_raise(rb_eIndexError, "ija array for left-hand matrix contains an out-of-bounds column index %u at position %u", jj, j);
220
+ else break;
221
+ }
222
+
223
+ for (IType kk = ijb[j]; kk <= ijb[j+1]; ++kk) { // walk through column pointers for row j of B
224
+ if (j >= m) continue; // first of all, does B *have* a row j?
225
+ k = (kk == ijb[j+1]) ? j : ijb[kk]; // Get the current column index (handle diagonals last)
226
+
227
+ if (k >= l) {
228
+ if (k == ijb[kk]) rb_raise(rb_eIndexError, "ija array for right-hand matrix contains an out-of-bounds column index %u at position %u", kk, k);
229
+ else break;
230
+ }
231
+
232
+ if (mask[k] == )
233
+ }
234
+
235
+ }
236
+ }
237
+ }
238
+ */
239
+
240
+ // Yale: Symbolic matrix multiply c=a*b
241
+ inline size_t symbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const bool diaga,
242
+ const IType* ib, const IType* jb, const bool diagb, IType* ic, const bool diagc) {
243
+ unsigned int max_lmn = std::max(std::max(m,n), l);
244
+ IType mask[max_lmn]; // INDEX in the SMMP paper.
245
+ IType j, k; /* Local variables */
246
+ size_t ndnz = n;
247
+
248
+ for (IType idx = 0; idx < max_lmn; ++idx)
249
+ mask[idx] = std::numeric_limits<IType>::max();
250
+
251
+ if (ic) { // Only write to ic if it's supplied; otherwise, we're just counting.
252
+ if (diagc) ic[0] = n+1;
253
+ else ic[0] = 0;
254
+ }
255
+
256
+ IType minmn = std::min(m,n);
257
+ IType minlm = std::min(l,m);
258
+
259
+ for (IType i = 0; i < n; ++i) { // MAIN LOOP: through rows
260
+
261
+ for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // merge row lists, walking through columns in each row
262
+
263
+ // j <- column index given by JA[jj], or handle diagonal.
264
+ if (jj == ia[i+1]) { // Don't really do it the last time -- just handle diagonals in a new yale matrix.
265
+ if (!diaga || i >= minmn) continue;
266
+ j = i;
267
+ } else j = ja[jj];
268
+
269
+ for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) { // Now walk through columns K of row J in matrix B.
270
+ if (kk == ib[j+1]) {
271
+ if (!diagb || j >= minlm) continue;
272
+ k = j;
273
+ } else k = jb[kk];
274
+
275
+ if (mask[k] != i) {
276
+ mask[k] = i;
277
+ ++ndnz;
278
+ }
279
+ }
280
+ }
281
+
282
+ if (diagc && mask[i] == std::numeric_limits<IType>::max()) --ndnz;
283
+
284
+ if (ic) ic[i+1] = ndnz;
285
+ }
286
+
287
+ return ndnz;
288
+ } /* symbmm_ */
289
+
290
+
291
+ // In-place quicksort (from Wikipedia) -- called by smmp_sort_columns, below. All functions are inclusive of left, right.
292
+ namespace smmp_sort {
293
+ const size_t THRESHOLD = 4; // switch to insertion sort for 4 elements or fewer
294
+
295
+ template <typename DType>
296
+ void print_array(DType* vals, IType* array, IType left, IType right) {
297
+ for (IType i = left; i <= right; ++i) {
298
+ std::cerr << array[i] << ":" << vals[i] << " ";
299
+ }
300
+ std::cerr << std::endl;
301
+ }
302
+
303
+ template <typename DType>
304
+ IType partition(DType* vals, IType* array, IType left, IType right, IType pivot) {
305
+ IType pivotJ = array[pivot];
306
+ DType pivotV = vals[pivot];
307
+
308
+ // Swap pivot and right
309
+ array[pivot] = array[right];
310
+ vals[pivot] = vals[right];
311
+ array[right] = pivotJ;
312
+ vals[right] = pivotV;
313
+
314
+ IType store = left;
315
+ for (IType idx = left; idx < right; ++idx) {
316
+ if (array[idx] <= pivotJ) {
317
+ // Swap i and store
318
+ std::swap(array[idx], array[store]);
319
+ std::swap(vals[idx], vals[store]);
320
+ ++store;
321
+ }
322
+ }
323
+
324
+ std::swap(array[store], array[right]);
325
+ std::swap(vals[store], vals[right]);
326
+
327
+ return store;
328
+ }
329
+
330
+ // Recommended to use the median of left, right, and mid for the pivot.
331
+ template <typename I>
332
+ inline I median(I a, I b, I c) {
333
+ if (a < b) {
334
+ if (b < c) return b; // a b c
335
+ if (a < c) return c; // a c b
336
+ return a; // c a b
337
+
338
+ } else { // a > b
339
+ if (a < c) return a; // b a c
340
+ if (b < c) return c; // b c a
341
+ return b; // c b a
342
+ }
343
+ }
344
+
345
+
346
+ // Insertion sort is more efficient than quicksort for small N
347
+ template <typename DType>
348
+ void insertion_sort(DType* vals, IType* array, IType left, IType right) {
349
+ for (IType idx = left; idx <= right; ++idx) {
350
+ IType col_to_insert = array[idx];
351
+ DType val_to_insert = vals[idx];
352
+
353
+ IType hole_pos = idx;
354
+ for (; hole_pos > left && col_to_insert < array[hole_pos-1]; --hole_pos) {
355
+ array[hole_pos] = array[hole_pos - 1]; // shift the larger column index up
356
+ vals[hole_pos] = vals[hole_pos - 1]; // value goes along with it
357
+ }
358
+
359
+ array[hole_pos] = col_to_insert;
360
+ vals[hole_pos] = val_to_insert;
361
+ }
362
+ }
363
+
364
+
365
+ template <typename DType>
366
+ void quicksort(DType* vals, IType* array, IType left, IType right) {
367
+
368
+ if (left < right) {
369
+ if (right - left < THRESHOLD) {
370
+ insertion_sort(vals, array, left, right);
371
+ } else {
372
+ // choose any pivot such that left < pivot < right
373
+ IType pivot = median<IType>(left, right, (IType)(((unsigned long)left + (unsigned long)right) / 2));
374
+ pivot = partition(vals, array, left, right, pivot);
375
+
376
+ // recursively sort elements smaller than the pivot
377
+ quicksort<DType>(vals, array, left, pivot-1);
378
+
379
+ // recursively sort elements at least as big as the pivot
380
+ quicksort<DType>(vals, array, pivot+1, right);
381
+ }
382
+ }
383
+ }
384
+
385
+
386
+ }; // end of namespace smmp_sort
387
+
388
+
389
+ /*
390
+ * For use following symbmm and numbmm. Sorts the matrix entries in each row according to the column index.
391
+ * This utilizes quicksort, which is an in-place unstable sort (since there are no duplicate entries, we don't care
392
+ * about stability).
393
+ *
394
+ * TODO: It might be worthwhile to do a test for free memory, and if available, use an unstable sort that isn't in-place.
395
+ *
396
+ * TODO: It's actually probably possible to write an even faster sort, since symbmm/numbmm are not producing a random
397
+ * ordering. If someone is doing a lot of Yale matrix multiplication, it might benefit them to consider even insertion
398
+ * sort.
399
+ */
400
+ template <typename DType>
401
+ inline void smmp_sort_columns(const size_t n, const IType* ia, IType* ja, DType* a) {
402
+ for (size_t i = 0; i < n; ++i) {
403
+ if (ia[i+1] - ia[i] < 2) continue; // no need to sort rows containing only one or two elements.
404
+ else if (ia[i+1] - ia[i] <= smmp_sort::THRESHOLD) {
405
+ smmp_sort::insertion_sort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for small rows
406
+ } else {
407
+ smmp_sort::quicksort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for large rows (and may call insertion_sort as well)
408
+ }
409
+ }
410
+ }
411
+
412
+
413
+ // Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
414
+ //
415
+ // From ATLAS 3.8.0.
416
+ template <typename DType>
417
+ static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
418
+
419
+ for (int i = 0; i != M; ++i) {
420
+ for (int j = i+1; j < N; ++j) {
421
+ C[j] = U[j];
422
+ U[j] = 0;
423
+ }
424
+
425
+ C += ldc;
426
+ U += ldu;
427
+ }
428
+ }
429
+
430
+
431
+ /*
432
+ * Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
433
+ * functions. This is probably really complicated.
434
+ *
435
+ * Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
436
+ *
437
+ */
438
+
439
+ /*
440
+
441
+ template <bool RowMajor, bool Upper, typename DType>
442
+ static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
443
+
444
+ if (RowMajor) {
445
+ DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
446
+ DType tmp;
447
+ if (Upper) {
448
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
449
+ A12 = pA1[2], A13 = pA1[3],
450
+ A23 = pA2[3];
451
+
452
+ if (Diag == CblasNonUnit) {
453
+ pA0->inverse();
454
+ (pA1+1)->inverse();
455
+ (pA2+2)->inverse();
456
+ (pA3+3)->inverse();
457
+
458
+ pA0[1] = -A01 * pA1[1] * pA0[0];
459
+ pA1[2] = -A12 * pA2[2] * pA1[1];
460
+ pA2[3] = -A23 * pA3[3] * pA2[2];
461
+
462
+ pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
463
+ pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
464
+
465
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
466
+
467
+ } else {
468
+
469
+ pA0[1] = -A01;
470
+ pA1[2] = -A12;
471
+ pA2[3] = -A23;
472
+
473
+ pA0[2] = -(A01 * pA1[2] + A02);
474
+ pA1[3] = -(A12 * pA2[3] + A13);
475
+
476
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
477
+ }
478
+
479
+ } else { // Lower
480
+ DType A10 = pA1[0],
481
+ A20 = pA2[0], A21 = pA2[1],
482
+ A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
483
+ DType *B10 = pA1,
484
+ *B20 = pA2,
485
+ *B30 = pA3,
486
+ *B21 = pA2+1,
487
+ *B31 = pA3+1,
488
+ *B32 = pA3+2;
489
+
490
+
491
+ if (Diag == CblasNonUnit) {
492
+ pA0->inverse();
493
+ (pA1+1)->inverse();
494
+ (pA2+2)->inverse();
495
+ (pA3+3)->inverse();
496
+
497
+ *B10 = -A10 * pA0[0] * pA1[1];
498
+ *B21 = -A21 * pA1[1] * pA2[2];
499
+ *B32 = -A32 * pA2[2] * pA3[3];
500
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
501
+ *B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
502
+ *B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
503
+ } else {
504
+ *B10 = -A10;
505
+ *B21 = -A21;
506
+ *B32 = -A32;
507
+ *B20 = -(A20 + A21 * (*B10));
508
+ *B31 = -(A31 + A32 * (*B21));
509
+ *B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
510
+ }
511
+ }
512
+
513
+ } else {
514
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
515
+ }
516
+
517
+ return 0;
518
+
519
+ }
520
+
521
+
522
+ template <bool RowMajor, bool Upper, typename DType>
523
+ static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
524
+
525
+ if (RowMajor) {
526
+
527
+ DType tmp;
528
+
529
+ if (Upper) {
530
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
531
+ A12 = pA1[2], A13 = pA1[3];
532
+
533
+ DType *B01 = pA0 + 1,
534
+ *B02 = pA0 + 2,
535
+ *B12 = pA1 + 2;
536
+
537
+ if (Diag == CblasNonUnit) {
538
+ pA0->inverse();
539
+ (pA1+1)->inverse();
540
+ (pA2+2)->inverse();
541
+
542
+ *B01 = -A01 * pA1[1] * pA0[0];
543
+ *B12 = -A12 * pA2[2] * pA1[1];
544
+ *B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
545
+ } else {
546
+ *B01 = -A01;
547
+ *B12 = -A12;
548
+ *B02 = -(A01 * (*B12) + A02);
549
+ }
550
+
551
+ } else { // Lower
552
+ DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
553
+ DType A10=pA1[0],
554
+ A20=pA2[0], A21=pA2[1];
555
+
556
+ DType *B10 = pA1,
557
+ *B20 = pA2;
558
+ *B21 = pA2+1;
559
+
560
+ if (Diag == CblasNonUnit) {
561
+ pA0->inverse();
562
+ (pA1+1)->inverse();
563
+ (pA2+2)->inverse();
564
+ *B10 = -A10 * pA0[0] * pA1[1];
565
+ *B21 = -A21 * pA1[1] * pA2[2];
566
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
567
+ } else {
568
+ *B10 = -A10;
569
+ *B21 = -A21;
570
+ *B20 = -(A20 + A21 * (*B10));
571
+ }
572
+ }
573
+
574
+
575
+ } else {
576
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
577
+ }
578
+
579
+ return 0;
580
+
581
+ }
582
+
583
+ template <bool RowMajor, bool Upper, bool Real, typename DType>
584
+ static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
585
+ DType *Age, *Atr;
586
+ DType tmp;
587
+ int Nleft, Nright;
588
+
589
+ int ierr = 0;
590
+
591
+ static const DType ONE = 1;
592
+ static const DType MONE -1;
593
+ static const DType NONE = -1;
594
+
595
+ if (RowMajor) {
596
+
597
+ // FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
598
+ if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
599
+ Nleft = N >> 1;
600
+ #ifdef NB
601
+ if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
602
+ #endif
603
+
604
+ Nright = N - Nleft;
605
+
606
+ if (Upper) {
607
+ Age = A + Nleft;
608
+ Atr = A + (Nleft * (lda+1));
609
+
610
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
611
+ Nleft, Nright, ONE, Atr, lda, Age, lda);
612
+
613
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
614
+ Nleft, Nright, MONE, A, lda, Age, lda);
615
+
616
+ } else { // Lower
617
+ Age = A + ((Nleft*lda));
618
+ Atr = A + (Nleft * (lda+1));
619
+
620
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
621
+ Nright, Nleft, ONE, A, lda, Age, lda);
622
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
623
+ Nright, Nleft, MONE, Atr, lda, Age, lda);
624
+ }
625
+
626
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
627
+ if (ierr) return ierr;
628
+
629
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
630
+ if (ierr) return ierr + Nleft;
631
+
632
+ } else {
633
+ if (Real) {
634
+ if (N == 4) {
635
+ return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
636
+ } else if (N == 3) {
637
+ return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
638
+ } else if (N == 2) {
639
+ if (Diag == CblasNonUnit) {
640
+ A->inverse();
641
+ (A+(lda+1))->inverse();
642
+
643
+ if (Upper) {
644
+ *(A+1) *= *A; // TRI_MUL
645
+ *(A+1) *= *(A+lda+1); // TRI_MUL
646
+ } else {
647
+ *(A+lda) *= *A; // TRI_MUL
648
+ *(A+lda) *= *(A+lda+1); // TRI_MUL
649
+ }
650
+ }
651
+
652
+ if (Upper) *(A+1) = -*(A+1); // TRI_NEG
653
+ else *(A+lda) = -*(A+lda); // TRI_NEG
654
+ } else if (Diag == CblasNonUnit) A->inverse();
655
+ } else { // not real
656
+ if (Diag == CblasNonUnit) A->inverse();
657
+ }
658
+ }
659
+
660
+ } else {
661
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
662
+ }
663
+
664
+ return ierr;
665
+ }
666
+
667
+
668
+ template <bool RowMajor, bool Real, typename DType>
669
+ int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
670
+
671
+ if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
672
+
673
+ int jb, nb, I, ndown, iret;
674
+
675
+ const DType ONE = 1, NONE = -1;
676
+
677
+ int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
678
+ if (!iret && N > 1) {
679
+ jb = lwrk / N;
680
+ if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
681
+ else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
682
+ else nb = jb;
683
+ if (!nb) return -6; // need at least 1 row of workspace
684
+
685
+ // only first iteration will have partial block, unroll it
686
+
687
+ jb = N - (N/nb) * nb;
688
+ if (!jb) jb = nb;
689
+ I = N - jb;
690
+ A += lda * I;
691
+ trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
692
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
693
+ jb, N, ONE, wrk, jb, A, lda);
694
+
695
+ if (I) {
696
+ do {
697
+ I -= nb;
698
+ A -= nb * lda;
699
+ ndown = N-I;
700
+ trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
701
+ nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
702
+ nb, N, ONE, wrk, ndown, A, lda);
703
+ } while (I);
704
+ }
705
+
706
+ // Apply row interchanges
707
+
708
+ for (I = N - 2; I >= 0; --I) {
709
+ jb = ipiv[I];
710
+ if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
711
+ }
712
+ }
713
+
714
+ return iret;
715
+ }
716
+ */
717
+
718
+ /*
719
+ * Macro for declaring LAPACK specializations of the getrf function.
720
+ *
721
+ * type is the DType; call is the specific function to call; cast_as is what the DType* should be
722
+ * cast to in order to pass it to LAPACK.
723
+ */
724
+ #define LAPACK_GETRF(type, call, cast_as) \
725
+ template <> \
726
+ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, type * A, const int lda, int* ipiv) { \
727
+ int info = call(Order, M, N, reinterpret_cast<cast_as *>(A), lda, ipiv); \
728
+ if (!info) return info; \
729
+ else { \
730
+ rb_raise(rb_eArgError, "getrf: problem with argument %d\n", info); \
731
+ return info; \
732
+ } \
733
+ }
734
+
735
+ /* Specialize for ATLAS types */
736
+ /*LAPACK_GETRF(float, clapack_sgetrf, float)
737
+ LAPACK_GETRF(double, clapack_dgetrf, double)
738
+ LAPACK_GETRF(Complex64, clapack_cgetrf, void)
739
+ LAPACK_GETRF(Complex128, clapack_zgetrf, void)
740
+ */
741
+
742
+ }} // end namespace nm::math
743
+
744
+
745
+ #endif // MATH_H