nmatrix-atlas 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (82) hide show
  1. checksums.yaml +7 -0
  2. data/ext/nmatrix/data/complex.h +364 -0
  3. data/ext/nmatrix/data/data.h +638 -0
  4. data/ext/nmatrix/data/meta.h +64 -0
  5. data/ext/nmatrix/data/ruby_object.h +389 -0
  6. data/ext/nmatrix/math/asum.h +120 -0
  7. data/ext/nmatrix/math/cblas_enums.h +36 -0
  8. data/ext/nmatrix/math/cblas_templates_core.h +507 -0
  9. data/ext/nmatrix/math/gemm.h +241 -0
  10. data/ext/nmatrix/math/gemv.h +178 -0
  11. data/ext/nmatrix/math/getrf.h +255 -0
  12. data/ext/nmatrix/math/getrs.h +121 -0
  13. data/ext/nmatrix/math/imax.h +79 -0
  14. data/ext/nmatrix/math/laswp.h +165 -0
  15. data/ext/nmatrix/math/long_dtype.h +49 -0
  16. data/ext/nmatrix/math/math.h +744 -0
  17. data/ext/nmatrix/math/nrm2.h +160 -0
  18. data/ext/nmatrix/math/rot.h +117 -0
  19. data/ext/nmatrix/math/rotg.h +106 -0
  20. data/ext/nmatrix/math/scal.h +71 -0
  21. data/ext/nmatrix/math/trsm.h +332 -0
  22. data/ext/nmatrix/math/util.h +148 -0
  23. data/ext/nmatrix/nm_memory.h +60 -0
  24. data/ext/nmatrix/nmatrix.h +408 -0
  25. data/ext/nmatrix/ruby_constants.h +106 -0
  26. data/ext/nmatrix/storage/common.h +176 -0
  27. data/ext/nmatrix/storage/dense/dense.h +128 -0
  28. data/ext/nmatrix/storage/list/list.h +137 -0
  29. data/ext/nmatrix/storage/storage.h +98 -0
  30. data/ext/nmatrix/storage/yale/class.h +1139 -0
  31. data/ext/nmatrix/storage/yale/iterators/base.h +142 -0
  32. data/ext/nmatrix/storage/yale/iterators/iterator.h +130 -0
  33. data/ext/nmatrix/storage/yale/iterators/row.h +449 -0
  34. data/ext/nmatrix/storage/yale/iterators/row_stored.h +139 -0
  35. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +168 -0
  36. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +123 -0
  37. data/ext/nmatrix/storage/yale/math/transpose.h +110 -0
  38. data/ext/nmatrix/storage/yale/yale.h +202 -0
  39. data/ext/nmatrix/types.h +54 -0
  40. data/ext/nmatrix/util/io.h +115 -0
  41. data/ext/nmatrix/util/sl_list.h +143 -0
  42. data/ext/nmatrix/util/util.h +78 -0
  43. data/ext/nmatrix_atlas/extconf.rb +250 -0
  44. data/ext/nmatrix_atlas/math_atlas.cpp +1206 -0
  45. data/ext/nmatrix_atlas/math_atlas/cblas_templates_atlas.h +72 -0
  46. data/ext/nmatrix_atlas/math_atlas/clapack_templates.h +332 -0
  47. data/ext/nmatrix_atlas/math_atlas/geev.h +82 -0
  48. data/ext/nmatrix_atlas/math_atlas/gesdd.h +83 -0
  49. data/ext/nmatrix_atlas/math_atlas/gesvd.h +81 -0
  50. data/ext/nmatrix_atlas/math_atlas/inc.h +47 -0
  51. data/ext/nmatrix_atlas/nmatrix_atlas.cpp +44 -0
  52. data/lib/nmatrix/atlas.rb +213 -0
  53. data/lib/nmatrix/lapack_ext_common.rb +69 -0
  54. data/spec/00_nmatrix_spec.rb +730 -0
  55. data/spec/01_enum_spec.rb +190 -0
  56. data/spec/02_slice_spec.rb +389 -0
  57. data/spec/03_nmatrix_monkeys_spec.rb +78 -0
  58. data/spec/2x2_dense_double.mat +0 -0
  59. data/spec/4x4_sparse.mat +0 -0
  60. data/spec/4x5_dense.mat +0 -0
  61. data/spec/blas_spec.rb +193 -0
  62. data/spec/elementwise_spec.rb +303 -0
  63. data/spec/homogeneous_spec.rb +99 -0
  64. data/spec/io/fortran_format_spec.rb +88 -0
  65. data/spec/io/harwell_boeing_spec.rb +98 -0
  66. data/spec/io/test.rua +9 -0
  67. data/spec/io_spec.rb +149 -0
  68. data/spec/lapack_core_spec.rb +482 -0
  69. data/spec/leakcheck.rb +16 -0
  70. data/spec/math_spec.rb +730 -0
  71. data/spec/nmatrix_yale_resize_test_associations.yaml +2802 -0
  72. data/spec/nmatrix_yale_spec.rb +286 -0
  73. data/spec/plugins/atlas/atlas_spec.rb +242 -0
  74. data/spec/rspec_monkeys.rb +56 -0
  75. data/spec/rspec_spec.rb +34 -0
  76. data/spec/shortcuts_spec.rb +310 -0
  77. data/spec/slice_set_spec.rb +157 -0
  78. data/spec/spec_helper.rb +140 -0
  79. data/spec/stat_spec.rb +203 -0
  80. data/spec/test.pcd +20 -0
  81. data/spec/utm5940.mtx +83844 -0
  82. metadata +159 -0
@@ -0,0 +1,121 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == getrs.h
25
+ //
26
+ // getrs function in native C++.
27
+ //
28
+
29
+ /*
30
+ * Automatically Tuned Linear Algebra Software v3.8.4
31
+ * (C) Copyright 1999 R. Clint Whaley
32
+ *
33
+ * Redistribution and use in source and binary forms, with or without
34
+ * modification, are permitted provided that the following conditions
35
+ * are met:
36
+ * 1. Redistributions of source code must retain the above copyright
37
+ * notice, this list of conditions and the following disclaimer.
38
+ * 2. Redistributions in binary form must reproduce the above copyright
39
+ * notice, this list of conditions, and the following disclaimer in the
40
+ * documentation and/or other materials provided with the distribution.
41
+ * 3. The name of the ATLAS group or the names of its contributers may
42
+ * not be used to endorse or promote products derived from this
43
+ * software without specific written permission.
44
+ *
45
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
46
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
47
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
48
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
49
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
50
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
51
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
52
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
53
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
54
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
55
+ * POSSIBILITY OF SUCH DAMAGE.
56
+ *
57
+ */
58
+
59
+ #ifndef GETRS_H
60
+ #define GETRS_H
61
+
62
+ namespace nm { namespace math {
63
+
64
+
65
+ /*
66
+ * Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF.
67
+ *
68
+ * From ATLAS 3.8.0.
69
+ */
70
+ template <typename DType>
71
+ int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A,
72
+ const int lda, const int* ipiv, DType* B, const int ldb)
73
+ {
74
+ // enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
75
+
76
+ if (!N || !NRHS) return 0;
77
+
78
+ const DType ONE = 1;
79
+
80
+ if (Order == CblasColMajor) {
81
+ if (Trans == CblasNoTrans) {
82
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
83
+ nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
84
+ nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
85
+ } else {
86
+ nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
87
+ nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
88
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
89
+ }
90
+ } else {
91
+ if (Trans == CblasNoTrans) {
92
+ nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
93
+ nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
94
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
95
+ } else {
96
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
97
+ nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
98
+ nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
99
+ }
100
+ }
101
+ return 0;
102
+ }
103
+
104
+
105
+ /*
106
+ * Function signature conversion for calling LAPACK's getrs functions as directly as possible.
107
+ *
108
+ * For documentation: http://www.netlib.org/lapack/double/dgetrs.f
109
+ *
110
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
111
+ */
112
+ template <typename DType>
113
+ inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs,
114
+ const void* a, const int lda, const int* ipiv, void* b, const int ldb) {
115
+ return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb);
116
+ }
117
+
118
+
119
+ } } // end nm::math
120
+
121
+ #endif // GETRS_H
@@ -0,0 +1,79 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == imax.h
25
+ //
26
+ // BLAS level 1 function imax.
27
+ //
28
+
29
+ #ifndef IMAX_H
30
+ #define IMAX_H
31
+
32
+ namespace nm { namespace math {
33
+
34
+ template<typename DType>
35
+ inline int imax(const int n, const DType *x, const int incx) {
36
+
37
+ if (n < 1 || incx <= 0) {
38
+ return -1;
39
+ }
40
+ if (n == 1) {
41
+ return 0;
42
+ }
43
+
44
+ DType dmax;
45
+ int imax = 0;
46
+
47
+ if (incx == 1) { // if incrementing by 1
48
+
49
+ dmax = abs(x[0]);
50
+
51
+ for (int i = 1; i < n; ++i) {
52
+ if (std::abs(x[i]) > dmax) {
53
+ imax = i;
54
+ dmax = std::abs(x[i]);
55
+ }
56
+ }
57
+
58
+ } else { // if incrementing by more than 1
59
+
60
+ dmax = std::abs(x[0]);
61
+
62
+ for (int i = 1, ix = incx; i < n; ++i, ix += incx) {
63
+ if (std::abs(x[ix]) > dmax) {
64
+ imax = i;
65
+ dmax = std::abs(x[ix]);
66
+ }
67
+ }
68
+ }
69
+ return imax;
70
+ }
71
+
72
+ template<typename DType>
73
+ inline int cblas_imax(const int n, const void* x, const int incx) {
74
+ return imax<DType>(n, reinterpret_cast<const DType*>(x), incx);
75
+ }
76
+
77
+ }} // end of namespace nm::math
78
+
79
+ #endif /* IMAX_H */
@@ -0,0 +1,165 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == laswp.h
25
+ //
26
+ // laswp function in native C++.
27
+ //
28
+ /*
29
+ * Automatically Tuned Linear Algebra Software v3.8.4
30
+ * (C) Copyright 1999 R. Clint Whaley
31
+ *
32
+ * Redistribution and use in source and binary forms, with or without
33
+ * modification, are permitted provided that the following conditions
34
+ * are met:
35
+ * 1. Redistributions of source code must retain the above copyright
36
+ * notice, this list of conditions and the following disclaimer.
37
+ * 2. Redistributions in binary form must reproduce the above copyright
38
+ * notice, this list of conditions, and the following disclaimer in the
39
+ * documentation and/or other materials provided with the distribution.
40
+ * 3. The name of the ATLAS group or the names of its contributers may
41
+ * not be used to endorse or promote products derived from this
42
+ * software without specific written permission.
43
+ *
44
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
45
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
46
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
47
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
48
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
49
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
50
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
51
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
52
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
53
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
54
+ * POSSIBILITY OF SUCH DAMAGE.
55
+ *
56
+ */
57
+
58
+ #ifndef LASWP_H
59
+ #define LASWP_H
60
+
61
+ namespace nm { namespace math {
62
+
63
+
64
+ /*
65
+ * ATLAS function which performs row interchanges on a general rectangular matrix. Modeled after the LAPACK LASWP function.
66
+ *
67
+ * This version is templated for use by template <> getrf().
68
+ */
69
+ template <typename DType>
70
+ inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
71
+ //const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
72
+
73
+ int nb = N >> 5;
74
+
75
+ const int mr = N - (nb<<5);
76
+ const int incA = lda << 5;
77
+
78
+ if (K2 < K1) return;
79
+
80
+ int i1, i2;
81
+ if (inci < 0) {
82
+ piv -= (K2-1) * inci;
83
+ i1 = K2 - 1;
84
+ i2 = K1;
85
+ } else {
86
+ piv += K1 * inci;
87
+ i1 = K1;
88
+ i2 = K2-1;
89
+ }
90
+
91
+ if (nb) {
92
+
93
+ do {
94
+ const int* ipiv = piv;
95
+ int i = i1;
96
+ int KeepOn;
97
+
98
+ do {
99
+ int ip = *ipiv; ipiv += inci;
100
+
101
+ if (ip != i) {
102
+ DType *a0 = &(A[i]),
103
+ *a1 = &(A[ip]);
104
+
105
+ for (register int h = 32; h; h--) {
106
+ DType r = *a0;
107
+ *a0 = *a1;
108
+ *a1 = r;
109
+
110
+ a0 += lda;
111
+ a1 += lda;
112
+ }
113
+
114
+ }
115
+ if (inci > 0) KeepOn = (++i <= i2);
116
+ else KeepOn = (--i >= i2);
117
+
118
+ } while (KeepOn);
119
+ A += incA;
120
+ } while (--nb);
121
+ }
122
+
123
+ if (mr) {
124
+ const int* ipiv = piv;
125
+ int i = i1;
126
+ int KeepOn;
127
+
128
+ do {
129
+ int ip = *ipiv; ipiv += inci;
130
+ if (ip != i) {
131
+ DType *a0 = &(A[i]),
132
+ *a1 = &(A[ip]);
133
+
134
+ for (register int h = mr; h; h--) {
135
+ DType r = *a0;
136
+ *a0 = *a1;
137
+ *a1 = r;
138
+
139
+ a0 += lda;
140
+ a1 += lda;
141
+ }
142
+ }
143
+
144
+ if (inci > 0) KeepOn = (++i <= i2);
145
+ else KeepOn = (--i >= i2);
146
+
147
+ } while (KeepOn);
148
+ }
149
+ }
150
+
151
+
152
+ /*
153
+ * Function signature conversion for calling LAPACK's laswp functions as directly as possible.
154
+ *
155
+ * For documentation: http://www.netlib.org/lapack/double/dlaswp.f
156
+ *
157
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
158
+ */
159
+ template <typename DType>
160
+ inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
161
+ laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
162
+ }
163
+
164
+ } } // namespace nm::math
165
+ #endif // LASWP_H
@@ -0,0 +1,49 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == long_dtype.h
25
+ //
26
+ // Declarations necessary for the native versions of GEMM and GEMV.
27
+ //
28
+
29
+ #ifndef LONG_DTYPE_H
30
+ #define LONG_DTYPE_H
31
+
32
+ namespace nm { namespace math {
33
+ // These allow an increase in precision for intermediate values of gemm and gemv.
34
+ // See also: http://stackoverflow.com/questions/11873694/how-does-one-increase-precision-in-c-templates-in-a-template-typename-dependen
35
+ template <typename DType> struct LongDType;
36
+ template <> struct LongDType<uint8_t> { typedef int16_t type; };
37
+ template <> struct LongDType<int8_t> { typedef int16_t type; };
38
+ template <> struct LongDType<int16_t> { typedef int32_t type; };
39
+ template <> struct LongDType<int32_t> { typedef int64_t type; };
40
+ template <> struct LongDType<int64_t> { typedef int64_t type; };
41
+ template <> struct LongDType<float> { typedef double type; };
42
+ template <> struct LongDType<double> { typedef double type; };
43
+ template <> struct LongDType<Complex64> { typedef Complex128 type; };
44
+ template <> struct LongDType<Complex128> { typedef Complex128 type; };
45
+ template <> struct LongDType<RubyObject> { typedef RubyObject type; };
46
+
47
+ }} // end of namespace nm::math
48
+
49
+ #endif
@@ -0,0 +1,744 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == math.h
25
+ //
26
+ // Header file for math functions, interfacing with BLAS, etc.
27
+ //
28
+ // For instructions on adding CBLAS and CLAPACK functions, see the
29
+ // beginning of math.cpp.
30
+ //
31
+ // Some of these functions are from ATLAS. Here is the license for
32
+ // ATLAS:
33
+ //
34
+ /*
35
+ * Automatically Tuned Linear Algebra Software v3.8.4
36
+ * (C) Copyright 1999 R. Clint Whaley
37
+ *
38
+ * Redistribution and use in source and binary forms, with or without
39
+ * modification, are permitted provided that the following conditions
40
+ * are met:
41
+ * 1. Redistributions of source code must retain the above copyright
42
+ * notice, this list of conditions and the following disclaimer.
43
+ * 2. Redistributions in binary form must reproduce the above copyright
44
+ * notice, this list of conditions, and the following disclaimer in the
45
+ * documentation and/or other materials provided with the distribution.
46
+ * 3. The name of the ATLAS group or the names of its contributers may
47
+ * not be used to endorse or promote products derived from this
48
+ * software without specific written permission.
49
+ *
50
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
51
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
52
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
53
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
54
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
55
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
56
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
57
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
58
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
59
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
60
+ * POSSIBILITY OF SUCH DAMAGE.
61
+ *
62
+ */
63
+
64
+ #ifndef MATH_H
65
+ #define MATH_H
66
+
67
+ /*
68
+ * Standard Includes
69
+ */
70
+
71
+ #include "cblas_enums.h"
72
+
73
+ #include <algorithm> // std::min, std::max
74
+ #include <limits> // std::numeric_limits
75
+
76
+ /*
77
+ * Project Includes
78
+ */
79
+
80
+ /*
81
+ * Macros
82
+ */
83
+ #define REAL_RECURSE_LIMIT 4
84
+
85
+ /*
86
+ * Data
87
+ */
88
+
89
+
90
+ extern "C" {
91
+ /*
92
+ * C accessors.
93
+ */
94
+
95
+ void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
96
+ void nm_math_init_blas(void);
97
+
98
+ /*
99
+ * Pure math implementations.
100
+ */
101
+ void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv);
102
+ void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype);
103
+ void nm_math_hessenberg(VALUE a);
104
+ void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
105
+ void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype);
106
+ }
107
+
108
+
109
+ namespace nm {
110
+ namespace math {
111
+
112
+ /*
113
+ * Types
114
+ */
115
+
116
+
117
+ /*
118
+ * Functions
119
+ */
120
+
121
+ // Yale: numeric matrix multiply c=a*b
122
+ template <typename DType>
123
+ inline void numbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const DType* a, const bool diaga,
124
+ const IType* ib, const IType* jb, const DType* b, const bool diagb, IType* ic, IType* jc, DType* c, const bool diagc) {
125
+ const unsigned int max_lmn = std::max(std::max(m, n), l);
126
+ IType next[max_lmn];
127
+ DType sums[max_lmn];
128
+
129
+ DType v;
130
+
131
+ IType head, length, temp, ndnz = 0;
132
+ IType minmn = std::min(m,n);
133
+ IType minlm = std::min(l,m);
134
+
135
+ for (IType idx = 0; idx < max_lmn; ++idx) { // initialize scratch arrays
136
+ next[idx] = std::numeric_limits<IType>::max();
137
+ sums[idx] = 0;
138
+ }
139
+
140
+ for (IType i = 0; i < n; ++i) { // walk down the rows
141
+ head = std::numeric_limits<IType>::max()-1; // head gets assigned as whichever column of B's row j we last visited
142
+ length = 0;
143
+
144
+ for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // walk through entries in each row
145
+ IType j;
146
+
147
+ if (jj == ia[i+1]) { // if we're in the last entry for this row:
148
+ if (!diaga || i >= minmn) continue;
149
+ j = i; // if it's a new Yale matrix, and last entry, get the diagonal position (j) and entry (ajj)
150
+ v = a[i];
151
+ } else {
152
+ j = ja[jj]; // if it's not the last entry for this row, get the column (j) and entry (ajj)
153
+ v = a[jj];
154
+ }
155
+
156
+ for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) {
157
+
158
+ IType k;
159
+
160
+ if (kk == ib[j+1]) { // Get the column id for that entry
161
+ if (!diagb || j >= minlm) continue;
162
+ k = j;
163
+ sums[k] += v*b[k];
164
+ } else {
165
+ k = jb[kk];
166
+ sums[k] += v*b[kk];
167
+ }
168
+
169
+ if (next[k] == std::numeric_limits<IType>::max()) {
170
+ next[k] = head;
171
+ head = k;
172
+ ++length;
173
+ }
174
+ } // end of kk loop
175
+ } // end of jj loop
176
+
177
+ for (IType jj = 0; jj < length; ++jj) {
178
+ if (sums[head] != 0) {
179
+ if (diagc && head == i) {
180
+ c[head] = sums[head];
181
+ } else {
182
+ jc[n+1+ndnz] = head;
183
+ c[n+1+ndnz] = sums[head];
184
+ ++ndnz;
185
+ }
186
+ }
187
+
188
+ temp = head;
189
+ head = next[head];
190
+
191
+ next[temp] = std::numeric_limits<IType>::max();
192
+ sums[temp] = 0;
193
+ }
194
+
195
+ ic[i+1] = n+1+ndnz;
196
+ }
197
+ } /* numbmm_ */
198
+
199
+
200
+ /*
201
+ template <typename DType, typename IType>
202
+ inline void new_yale_matrix_multiply(const unsigned int m, const IType* ija, const DType* a, const IType* ijb, const DType* b, YALE_STORAGE* c_storage) {
203
+ unsigned int n = c_storage->shape[0],
204
+ l = c_storage->shape[1];
205
+
206
+ // Create a working vector of dimension max(m,l,n) and initial value IType::max():
207
+ std::vector<IType> mask(std::max(std::max(m,l),n), std::numeric_limits<IType>::max());
208
+
209
+ for (IType i = 0; i < n; ++i) { // A.rows.each_index do |i|
210
+
211
+ IType j, k;
212
+ size_t ndnz;
213
+
214
+ for (IType jj = ija[i]; jj <= ija[i+1]; ++jj) { // walk through column pointers for row i of A
215
+ j = (jj == ija[i+1]) ? i : ija[jj]; // Get the current column index (handle diagonals last)
216
+
217
+ if (j >= m) {
218
+ if (j == ija[jj]) rb_raise(rb_eIndexError, "ija array for left-hand matrix contains an out-of-bounds column index %u at position %u", jj, j);
219
+ else break;
220
+ }
221
+
222
+ for (IType kk = ijb[j]; kk <= ijb[j+1]; ++kk) { // walk through column pointers for row j of B
223
+ if (j >= m) continue; // first of all, does B *have* a row j?
224
+ k = (kk == ijb[j+1]) ? j : ijb[kk]; // Get the current column index (handle diagonals last)
225
+
226
+ if (k >= l) {
227
+ if (k == ijb[kk]) rb_raise(rb_eIndexError, "ija array for right-hand matrix contains an out-of-bounds column index %u at position %u", kk, k);
228
+ else break;
229
+ }
230
+
231
+ if (mask[k] == )
232
+ }
233
+
234
+ }
235
+ }
236
+ }
237
+ */
238
+
239
+ // Yale: Symbolic matrix multiply c=a*b
240
+ inline size_t symbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const bool diaga,
241
+ const IType* ib, const IType* jb, const bool diagb, IType* ic, const bool diagc) {
242
+ unsigned int max_lmn = std::max(std::max(m,n), l);
243
+ IType mask[max_lmn]; // INDEX in the SMMP paper.
244
+ IType j, k; /* Local variables */
245
+ size_t ndnz = n;
246
+
247
+ for (IType idx = 0; idx < max_lmn; ++idx)
248
+ mask[idx] = std::numeric_limits<IType>::max();
249
+
250
+ if (ic) { // Only write to ic if it's supplied; otherwise, we're just counting.
251
+ if (diagc) ic[0] = n+1;
252
+ else ic[0] = 0;
253
+ }
254
+
255
+ IType minmn = std::min(m,n);
256
+ IType minlm = std::min(l,m);
257
+
258
+ for (IType i = 0; i < n; ++i) { // MAIN LOOP: through rows
259
+
260
+ for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // merge row lists, walking through columns in each row
261
+
262
+ // j <- column index given by JA[jj], or handle diagonal.
263
+ if (jj == ia[i+1]) { // Don't really do it the last time -- just handle diagonals in a new yale matrix.
264
+ if (!diaga || i >= minmn) continue;
265
+ j = i;
266
+ } else j = ja[jj];
267
+
268
+ for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) { // Now walk through columns K of row J in matrix B.
269
+ if (kk == ib[j+1]) {
270
+ if (!diagb || j >= minlm) continue;
271
+ k = j;
272
+ } else k = jb[kk];
273
+
274
+ if (mask[k] != i) {
275
+ mask[k] = i;
276
+ ++ndnz;
277
+ }
278
+ }
279
+ }
280
+
281
+ if (diagc && mask[i] == std::numeric_limits<IType>::max()) --ndnz;
282
+
283
+ if (ic) ic[i+1] = ndnz;
284
+ }
285
+
286
+ return ndnz;
287
+ } /* symbmm_ */
288
+
289
+
290
+ // In-place quicksort (from Wikipedia) -- called by smmp_sort_columns, below. All functions are inclusive of left, right.
291
+ namespace smmp_sort {
292
+ const size_t THRESHOLD = 4; // switch to insertion sort for 4 elements or fewer
293
+
294
+ template <typename DType>
295
+ void print_array(DType* vals, IType* array, IType left, IType right) {
296
+ for (IType i = left; i <= right; ++i) {
297
+ std::cerr << array[i] << ":" << vals[i] << " ";
298
+ }
299
+ std::cerr << std::endl;
300
+ }
301
+
302
+ template <typename DType>
303
+ IType partition(DType* vals, IType* array, IType left, IType right, IType pivot) {
304
+ IType pivotJ = array[pivot];
305
+ DType pivotV = vals[pivot];
306
+
307
+ // Swap pivot and right
308
+ array[pivot] = array[right];
309
+ vals[pivot] = vals[right];
310
+ array[right] = pivotJ;
311
+ vals[right] = pivotV;
312
+
313
+ IType store = left;
314
+ for (IType idx = left; idx < right; ++idx) {
315
+ if (array[idx] <= pivotJ) {
316
+ // Swap i and store
317
+ std::swap(array[idx], array[store]);
318
+ std::swap(vals[idx], vals[store]);
319
+ ++store;
320
+ }
321
+ }
322
+
323
+ std::swap(array[store], array[right]);
324
+ std::swap(vals[store], vals[right]);
325
+
326
+ return store;
327
+ }
328
+
329
+ // Recommended to use the median of left, right, and mid for the pivot.
330
+ template <typename I>
331
+ inline I median(I a, I b, I c) {
332
+ if (a < b) {
333
+ if (b < c) return b; // a b c
334
+ if (a < c) return c; // a c b
335
+ return a; // c a b
336
+
337
+ } else { // a > b
338
+ if (a < c) return a; // b a c
339
+ if (b < c) return c; // b c a
340
+ return b; // c b a
341
+ }
342
+ }
343
+
344
+
345
+ // Insertion sort is more efficient than quicksort for small N
346
+ template <typename DType>
347
+ void insertion_sort(DType* vals, IType* array, IType left, IType right) {
348
+ for (IType idx = left; idx <= right; ++idx) {
349
+ IType col_to_insert = array[idx];
350
+ DType val_to_insert = vals[idx];
351
+
352
+ IType hole_pos = idx;
353
+ for (; hole_pos > left && col_to_insert < array[hole_pos-1]; --hole_pos) {
354
+ array[hole_pos] = array[hole_pos - 1]; // shift the larger column index up
355
+ vals[hole_pos] = vals[hole_pos - 1]; // value goes along with it
356
+ }
357
+
358
+ array[hole_pos] = col_to_insert;
359
+ vals[hole_pos] = val_to_insert;
360
+ }
361
+ }
362
+
363
+
364
+ template <typename DType>
365
+ void quicksort(DType* vals, IType* array, IType left, IType right) {
366
+
367
+ if (left < right) {
368
+ if (right - left < THRESHOLD) {
369
+ insertion_sort(vals, array, left, right);
370
+ } else {
371
+ // choose any pivot such that left < pivot < right
372
+ IType pivot = median<IType>(left, right, (IType)(((unsigned long)left + (unsigned long)right) / 2));
373
+ pivot = partition(vals, array, left, right, pivot);
374
+
375
+ // recursively sort elements smaller than the pivot
376
+ quicksort<DType>(vals, array, left, pivot-1);
377
+
378
+ // recursively sort elements at least as big as the pivot
379
+ quicksort<DType>(vals, array, pivot+1, right);
380
+ }
381
+ }
382
+ }
383
+
384
+
385
+ }; // end of namespace smmp_sort
386
+
387
+
388
+ /*
389
+ * For use following symbmm and numbmm. Sorts the matrix entries in each row according to the column index.
390
+ * This utilizes quicksort, which is an in-place unstable sort (since there are no duplicate entries, we don't care
391
+ * about stability).
392
+ *
393
+ * TODO: It might be worthwhile to do a test for free memory, and if available, use an unstable sort that isn't in-place.
394
+ *
395
+ * TODO: It's actually probably possible to write an even faster sort, since symbmm/numbmm are not producing a random
396
+ * ordering. If someone is doing a lot of Yale matrix multiplication, it might benefit them to consider even insertion
397
+ * sort.
398
+ */
399
+ template <typename DType>
400
+ inline void smmp_sort_columns(const size_t n, const IType* ia, IType* ja, DType* a) {
401
+ for (size_t i = 0; i < n; ++i) {
402
+ if (ia[i+1] - ia[i] < 2) continue; // no need to sort rows containing only one or two elements.
403
+ else if (ia[i+1] - ia[i] <= smmp_sort::THRESHOLD) {
404
+ smmp_sort::insertion_sort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for small rows
405
+ } else {
406
+ smmp_sort::quicksort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for large rows (and may call insertion_sort as well)
407
+ }
408
+ }
409
+ }
410
+
411
+
412
+ // Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
413
+ //
414
+ // From ATLAS 3.8.0.
415
+ template <typename DType>
416
+ static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
417
+
418
+ for (int i = 0; i != M; ++i) {
419
+ for (int j = i+1; j < N; ++j) {
420
+ C[j] = U[j];
421
+ U[j] = 0;
422
+ }
423
+
424
+ C += ldc;
425
+ U += ldu;
426
+ }
427
+ }
428
+
429
+
430
+ /*
431
+ * Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
432
+ * functions. This is probably really complicated.
433
+ *
434
+ * Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
435
+ *
436
+ */
437
+
438
+ /*
439
+
440
+ template <bool RowMajor, bool Upper, typename DType>
441
+ static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
442
+
443
+ if (RowMajor) {
444
+ DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
445
+ DType tmp;
446
+ if (Upper) {
447
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
448
+ A12 = pA1[2], A13 = pA1[3],
449
+ A23 = pA2[3];
450
+
451
+ if (Diag == CblasNonUnit) {
452
+ pA0->inverse();
453
+ (pA1+1)->inverse();
454
+ (pA2+2)->inverse();
455
+ (pA3+3)->inverse();
456
+
457
+ pA0[1] = -A01 * pA1[1] * pA0[0];
458
+ pA1[2] = -A12 * pA2[2] * pA1[1];
459
+ pA2[3] = -A23 * pA3[3] * pA2[2];
460
+
461
+ pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
462
+ pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
463
+
464
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
465
+
466
+ } else {
467
+
468
+ pA0[1] = -A01;
469
+ pA1[2] = -A12;
470
+ pA2[3] = -A23;
471
+
472
+ pA0[2] = -(A01 * pA1[2] + A02);
473
+ pA1[3] = -(A12 * pA2[3] + A13);
474
+
475
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
476
+ }
477
+
478
+ } else { // Lower
479
+ DType A10 = pA1[0],
480
+ A20 = pA2[0], A21 = pA2[1],
481
+ A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
482
+ DType *B10 = pA1,
483
+ *B20 = pA2,
484
+ *B30 = pA3,
485
+ *B21 = pA2+1,
486
+ *B31 = pA3+1,
487
+ *B32 = pA3+2;
488
+
489
+
490
+ if (Diag == CblasNonUnit) {
491
+ pA0->inverse();
492
+ (pA1+1)->inverse();
493
+ (pA2+2)->inverse();
494
+ (pA3+3)->inverse();
495
+
496
+ *B10 = -A10 * pA0[0] * pA1[1];
497
+ *B21 = -A21 * pA1[1] * pA2[2];
498
+ *B32 = -A32 * pA2[2] * pA3[3];
499
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
500
+ *B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
501
+ *B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
502
+ } else {
503
+ *B10 = -A10;
504
+ *B21 = -A21;
505
+ *B32 = -A32;
506
+ *B20 = -(A20 + A21 * (*B10));
507
+ *B31 = -(A31 + A32 * (*B21));
508
+ *B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
509
+ }
510
+ }
511
+
512
+ } else {
513
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
514
+ }
515
+
516
+ return 0;
517
+
518
+ }
519
+
520
+
521
+ template <bool RowMajor, bool Upper, typename DType>
522
+ static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
523
+
524
+ if (RowMajor) {
525
+
526
+ DType tmp;
527
+
528
+ if (Upper) {
529
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
530
+ A12 = pA1[2], A13 = pA1[3];
531
+
532
+ DType *B01 = pA0 + 1,
533
+ *B02 = pA0 + 2,
534
+ *B12 = pA1 + 2;
535
+
536
+ if (Diag == CblasNonUnit) {
537
+ pA0->inverse();
538
+ (pA1+1)->inverse();
539
+ (pA2+2)->inverse();
540
+
541
+ *B01 = -A01 * pA1[1] * pA0[0];
542
+ *B12 = -A12 * pA2[2] * pA1[1];
543
+ *B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
544
+ } else {
545
+ *B01 = -A01;
546
+ *B12 = -A12;
547
+ *B02 = -(A01 * (*B12) + A02);
548
+ }
549
+
550
+ } else { // Lower
551
+ DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
552
+ DType A10=pA1[0],
553
+ A20=pA2[0], A21=pA2[1];
554
+
555
+ DType *B10 = pA1,
556
+ *B20 = pA2;
557
+ *B21 = pA2+1;
558
+
559
+ if (Diag == CblasNonUnit) {
560
+ pA0->inverse();
561
+ (pA1+1)->inverse();
562
+ (pA2+2)->inverse();
563
+ *B10 = -A10 * pA0[0] * pA1[1];
564
+ *B21 = -A21 * pA1[1] * pA2[2];
565
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
566
+ } else {
567
+ *B10 = -A10;
568
+ *B21 = -A21;
569
+ *B20 = -(A20 + A21 * (*B10));
570
+ }
571
+ }
572
+
573
+
574
+ } else {
575
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
576
+ }
577
+
578
+ return 0;
579
+
580
+ }
581
+
582
+ template <bool RowMajor, bool Upper, bool Real, typename DType>
583
+ static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
584
+ DType *Age, *Atr;
585
+ DType tmp;
586
+ int Nleft, Nright;
587
+
588
+ int ierr = 0;
589
+
590
+ static const DType ONE = 1;
591
+ static const DType MONE -1;
592
+ static const DType NONE = -1;
593
+
594
+ if (RowMajor) {
595
+
596
+ // FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
597
+ if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
598
+ Nleft = N >> 1;
599
+ #ifdef NB
600
+ if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
601
+ #endif
602
+
603
+ Nright = N - Nleft;
604
+
605
+ if (Upper) {
606
+ Age = A + Nleft;
607
+ Atr = A + (Nleft * (lda+1));
608
+
609
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
610
+ Nleft, Nright, ONE, Atr, lda, Age, lda);
611
+
612
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
613
+ Nleft, Nright, MONE, A, lda, Age, lda);
614
+
615
+ } else { // Lower
616
+ Age = A + ((Nleft*lda));
617
+ Atr = A + (Nleft * (lda+1));
618
+
619
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
620
+ Nright, Nleft, ONE, A, lda, Age, lda);
621
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
622
+ Nright, Nleft, MONE, Atr, lda, Age, lda);
623
+ }
624
+
625
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
626
+ if (ierr) return ierr;
627
+
628
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
629
+ if (ierr) return ierr + Nleft;
630
+
631
+ } else {
632
+ if (Real) {
633
+ if (N == 4) {
634
+ return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
635
+ } else if (N == 3) {
636
+ return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
637
+ } else if (N == 2) {
638
+ if (Diag == CblasNonUnit) {
639
+ A->inverse();
640
+ (A+(lda+1))->inverse();
641
+
642
+ if (Upper) {
643
+ *(A+1) *= *A; // TRI_MUL
644
+ *(A+1) *= *(A+lda+1); // TRI_MUL
645
+ } else {
646
+ *(A+lda) *= *A; // TRI_MUL
647
+ *(A+lda) *= *(A+lda+1); // TRI_MUL
648
+ }
649
+ }
650
+
651
+ if (Upper) *(A+1) = -*(A+1); // TRI_NEG
652
+ else *(A+lda) = -*(A+lda); // TRI_NEG
653
+ } else if (Diag == CblasNonUnit) A->inverse();
654
+ } else { // not real
655
+ if (Diag == CblasNonUnit) A->inverse();
656
+ }
657
+ }
658
+
659
+ } else {
660
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
661
+ }
662
+
663
+ return ierr;
664
+ }
665
+
666
+
667
+ template <bool RowMajor, bool Real, typename DType>
668
+ int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
669
+
670
+ if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
671
+
672
+ int jb, nb, I, ndown, iret;
673
+
674
+ const DType ONE = 1, NONE = -1;
675
+
676
+ int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
677
+ if (!iret && N > 1) {
678
+ jb = lwrk / N;
679
+ if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
680
+ else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
681
+ else nb = jb;
682
+ if (!nb) return -6; // need at least 1 row of workspace
683
+
684
+ // only first iteration will have partial block, unroll it
685
+
686
+ jb = N - (N/nb) * nb;
687
+ if (!jb) jb = nb;
688
+ I = N - jb;
689
+ A += lda * I;
690
+ trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
691
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
692
+ jb, N, ONE, wrk, jb, A, lda);
693
+
694
+ if (I) {
695
+ do {
696
+ I -= nb;
697
+ A -= nb * lda;
698
+ ndown = N-I;
699
+ trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
700
+ nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
701
+ nb, N, ONE, wrk, ndown, A, lda);
702
+ } while (I);
703
+ }
704
+
705
+ // Apply row interchanges
706
+
707
+ for (I = N - 2; I >= 0; --I) {
708
+ jb = ipiv[I];
709
+ if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
710
+ }
711
+ }
712
+
713
+ return iret;
714
+ }
715
+ */
716
+
717
+ /*
718
+ * Macro for declaring LAPACK specializations of the getrf function.
719
+ *
720
+ * type is the DType; call is the specific function to call; cast_as is what the DType* should be
721
+ * cast to in order to pass it to LAPACK.
722
+ */
723
+ #define LAPACK_GETRF(type, call, cast_as) \
724
+ template <> \
725
+ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, type * A, const int lda, int* ipiv) { \
726
+ int info = call(Order, M, N, reinterpret_cast<cast_as *>(A), lda, ipiv); \
727
+ if (!info) return info; \
728
+ else { \
729
+ rb_raise(rb_eArgError, "getrf: problem with argument %d\n", info); \
730
+ return info; \
731
+ } \
732
+ }
733
+
734
+ /* Specialize for ATLAS types */
735
+ /*LAPACK_GETRF(float, clapack_sgetrf, float)
736
+ LAPACK_GETRF(double, clapack_dgetrf, double)
737
+ LAPACK_GETRF(Complex64, clapack_cgetrf, void)
738
+ LAPACK_GETRF(Complex128, clapack_zgetrf, void)
739
+ */
740
+
741
+ }} // end namespace nm::math
742
+
743
+
744
+ #endif // MATH_H