nmatrix-atlas 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/nmatrix/data/complex.h +364 -0
- data/ext/nmatrix/data/data.h +638 -0
- data/ext/nmatrix/data/meta.h +64 -0
- data/ext/nmatrix/data/ruby_object.h +389 -0
- data/ext/nmatrix/math/asum.h +120 -0
- data/ext/nmatrix/math/cblas_enums.h +36 -0
- data/ext/nmatrix/math/cblas_templates_core.h +507 -0
- data/ext/nmatrix/math/gemm.h +241 -0
- data/ext/nmatrix/math/gemv.h +178 -0
- data/ext/nmatrix/math/getrf.h +255 -0
- data/ext/nmatrix/math/getrs.h +121 -0
- data/ext/nmatrix/math/imax.h +79 -0
- data/ext/nmatrix/math/laswp.h +165 -0
- data/ext/nmatrix/math/long_dtype.h +49 -0
- data/ext/nmatrix/math/math.h +744 -0
- data/ext/nmatrix/math/nrm2.h +160 -0
- data/ext/nmatrix/math/rot.h +117 -0
- data/ext/nmatrix/math/rotg.h +106 -0
- data/ext/nmatrix/math/scal.h +71 -0
- data/ext/nmatrix/math/trsm.h +332 -0
- data/ext/nmatrix/math/util.h +148 -0
- data/ext/nmatrix/nm_memory.h +60 -0
- data/ext/nmatrix/nmatrix.h +408 -0
- data/ext/nmatrix/ruby_constants.h +106 -0
- data/ext/nmatrix/storage/common.h +176 -0
- data/ext/nmatrix/storage/dense/dense.h +128 -0
- data/ext/nmatrix/storage/list/list.h +137 -0
- data/ext/nmatrix/storage/storage.h +98 -0
- data/ext/nmatrix/storage/yale/class.h +1139 -0
- data/ext/nmatrix/storage/yale/iterators/base.h +142 -0
- data/ext/nmatrix/storage/yale/iterators/iterator.h +130 -0
- data/ext/nmatrix/storage/yale/iterators/row.h +449 -0
- data/ext/nmatrix/storage/yale/iterators/row_stored.h +139 -0
- data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +168 -0
- data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +123 -0
- data/ext/nmatrix/storage/yale/math/transpose.h +110 -0
- data/ext/nmatrix/storage/yale/yale.h +202 -0
- data/ext/nmatrix/types.h +54 -0
- data/ext/nmatrix/util/io.h +115 -0
- data/ext/nmatrix/util/sl_list.h +143 -0
- data/ext/nmatrix/util/util.h +78 -0
- data/ext/nmatrix_atlas/extconf.rb +250 -0
- data/ext/nmatrix_atlas/math_atlas.cpp +1206 -0
- data/ext/nmatrix_atlas/math_atlas/cblas_templates_atlas.h +72 -0
- data/ext/nmatrix_atlas/math_atlas/clapack_templates.h +332 -0
- data/ext/nmatrix_atlas/math_atlas/geev.h +82 -0
- data/ext/nmatrix_atlas/math_atlas/gesdd.h +83 -0
- data/ext/nmatrix_atlas/math_atlas/gesvd.h +81 -0
- data/ext/nmatrix_atlas/math_atlas/inc.h +47 -0
- data/ext/nmatrix_atlas/nmatrix_atlas.cpp +44 -0
- data/lib/nmatrix/atlas.rb +213 -0
- data/lib/nmatrix/lapack_ext_common.rb +69 -0
- data/spec/00_nmatrix_spec.rb +730 -0
- data/spec/01_enum_spec.rb +190 -0
- data/spec/02_slice_spec.rb +389 -0
- data/spec/03_nmatrix_monkeys_spec.rb +78 -0
- data/spec/2x2_dense_double.mat +0 -0
- data/spec/4x4_sparse.mat +0 -0
- data/spec/4x5_dense.mat +0 -0
- data/spec/blas_spec.rb +193 -0
- data/spec/elementwise_spec.rb +303 -0
- data/spec/homogeneous_spec.rb +99 -0
- data/spec/io/fortran_format_spec.rb +88 -0
- data/spec/io/harwell_boeing_spec.rb +98 -0
- data/spec/io/test.rua +9 -0
- data/spec/io_spec.rb +149 -0
- data/spec/lapack_core_spec.rb +482 -0
- data/spec/leakcheck.rb +16 -0
- data/spec/math_spec.rb +730 -0
- data/spec/nmatrix_yale_resize_test_associations.yaml +2802 -0
- data/spec/nmatrix_yale_spec.rb +286 -0
- data/spec/plugins/atlas/atlas_spec.rb +242 -0
- data/spec/rspec_monkeys.rb +56 -0
- data/spec/rspec_spec.rb +34 -0
- data/spec/shortcuts_spec.rb +310 -0
- data/spec/slice_set_spec.rb +157 -0
- data/spec/spec_helper.rb +140 -0
- data/spec/stat_spec.rb +203 -0
- data/spec/test.pcd +20 -0
- data/spec/utm5940.mtx +83844 -0
- metadata +159 -0
@@ -0,0 +1,121 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == getrs.h
|
25
|
+
//
|
26
|
+
// getrs function in native C++.
|
27
|
+
//
|
28
|
+
|
29
|
+
/*
|
30
|
+
* Automatically Tuned Linear Algebra Software v3.8.4
|
31
|
+
* (C) Copyright 1999 R. Clint Whaley
|
32
|
+
*
|
33
|
+
* Redistribution and use in source and binary forms, with or without
|
34
|
+
* modification, are permitted provided that the following conditions
|
35
|
+
* are met:
|
36
|
+
* 1. Redistributions of source code must retain the above copyright
|
37
|
+
* notice, this list of conditions and the following disclaimer.
|
38
|
+
* 2. Redistributions in binary form must reproduce the above copyright
|
39
|
+
* notice, this list of conditions, and the following disclaimer in the
|
40
|
+
* documentation and/or other materials provided with the distribution.
|
41
|
+
* 3. The name of the ATLAS group or the names of its contributers may
|
42
|
+
* not be used to endorse or promote products derived from this
|
43
|
+
* software without specific written permission.
|
44
|
+
*
|
45
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
46
|
+
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
47
|
+
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
48
|
+
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
|
49
|
+
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
50
|
+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
51
|
+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
52
|
+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
53
|
+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
54
|
+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
55
|
+
* POSSIBILITY OF SUCH DAMAGE.
|
56
|
+
*
|
57
|
+
*/
|
58
|
+
|
59
|
+
#ifndef GETRS_H
|
60
|
+
#define GETRS_H
|
61
|
+
|
62
|
+
namespace nm { namespace math {
|
63
|
+
|
64
|
+
|
65
|
+
/*
|
66
|
+
* Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF.
|
67
|
+
*
|
68
|
+
* From ATLAS 3.8.0.
|
69
|
+
*/
|
70
|
+
template <typename DType>
|
71
|
+
int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A,
|
72
|
+
const int lda, const int* ipiv, DType* B, const int ldb)
|
73
|
+
{
|
74
|
+
// enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
|
75
|
+
|
76
|
+
if (!N || !NRHS) return 0;
|
77
|
+
|
78
|
+
const DType ONE = 1;
|
79
|
+
|
80
|
+
if (Order == CblasColMajor) {
|
81
|
+
if (Trans == CblasNoTrans) {
|
82
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
|
83
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
|
84
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
85
|
+
} else {
|
86
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
87
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
|
88
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
|
89
|
+
}
|
90
|
+
} else {
|
91
|
+
if (Trans == CblasNoTrans) {
|
92
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
93
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
|
94
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
|
95
|
+
} else {
|
96
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
|
97
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
|
98
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
99
|
+
}
|
100
|
+
}
|
101
|
+
return 0;
|
102
|
+
}
|
103
|
+
|
104
|
+
|
105
|
+
/*
|
106
|
+
* Function signature conversion for calling LAPACK's getrs functions as directly as possible.
|
107
|
+
*
|
108
|
+
* For documentation: http://www.netlib.org/lapack/double/dgetrs.f
|
109
|
+
*
|
110
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
111
|
+
*/
|
112
|
+
template <typename DType>
|
113
|
+
inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs,
|
114
|
+
const void* a, const int lda, const int* ipiv, void* b, const int ldb) {
|
115
|
+
return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb);
|
116
|
+
}
|
117
|
+
|
118
|
+
|
119
|
+
} } // end nm::math
|
120
|
+
|
121
|
+
#endif // GETRS_H
|
@@ -0,0 +1,79 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == imax.h
|
25
|
+
//
|
26
|
+
// BLAS level 1 function imax.
|
27
|
+
//
|
28
|
+
|
29
|
+
#ifndef IMAX_H
|
30
|
+
#define IMAX_H
|
31
|
+
|
32
|
+
namespace nm { namespace math {
|
33
|
+
|
34
|
+
template<typename DType>
|
35
|
+
inline int imax(const int n, const DType *x, const int incx) {
|
36
|
+
|
37
|
+
if (n < 1 || incx <= 0) {
|
38
|
+
return -1;
|
39
|
+
}
|
40
|
+
if (n == 1) {
|
41
|
+
return 0;
|
42
|
+
}
|
43
|
+
|
44
|
+
DType dmax;
|
45
|
+
int imax = 0;
|
46
|
+
|
47
|
+
if (incx == 1) { // if incrementing by 1
|
48
|
+
|
49
|
+
dmax = abs(x[0]);
|
50
|
+
|
51
|
+
for (int i = 1; i < n; ++i) {
|
52
|
+
if (std::abs(x[i]) > dmax) {
|
53
|
+
imax = i;
|
54
|
+
dmax = std::abs(x[i]);
|
55
|
+
}
|
56
|
+
}
|
57
|
+
|
58
|
+
} else { // if incrementing by more than 1
|
59
|
+
|
60
|
+
dmax = std::abs(x[0]);
|
61
|
+
|
62
|
+
for (int i = 1, ix = incx; i < n; ++i, ix += incx) {
|
63
|
+
if (std::abs(x[ix]) > dmax) {
|
64
|
+
imax = i;
|
65
|
+
dmax = std::abs(x[ix]);
|
66
|
+
}
|
67
|
+
}
|
68
|
+
}
|
69
|
+
return imax;
|
70
|
+
}
|
71
|
+
|
72
|
+
template<typename DType>
|
73
|
+
inline int cblas_imax(const int n, const void* x, const int incx) {
|
74
|
+
return imax<DType>(n, reinterpret_cast<const DType*>(x), incx);
|
75
|
+
}
|
76
|
+
|
77
|
+
}} // end of namespace nm::math
|
78
|
+
|
79
|
+
#endif /* IMAX_H */
|
@@ -0,0 +1,165 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == laswp.h
|
25
|
+
//
|
26
|
+
// laswp function in native C++.
|
27
|
+
//
|
28
|
+
/*
|
29
|
+
* Automatically Tuned Linear Algebra Software v3.8.4
|
30
|
+
* (C) Copyright 1999 R. Clint Whaley
|
31
|
+
*
|
32
|
+
* Redistribution and use in source and binary forms, with or without
|
33
|
+
* modification, are permitted provided that the following conditions
|
34
|
+
* are met:
|
35
|
+
* 1. Redistributions of source code must retain the above copyright
|
36
|
+
* notice, this list of conditions and the following disclaimer.
|
37
|
+
* 2. Redistributions in binary form must reproduce the above copyright
|
38
|
+
* notice, this list of conditions, and the following disclaimer in the
|
39
|
+
* documentation and/or other materials provided with the distribution.
|
40
|
+
* 3. The name of the ATLAS group or the names of its contributers may
|
41
|
+
* not be used to endorse or promote products derived from this
|
42
|
+
* software without specific written permission.
|
43
|
+
*
|
44
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
45
|
+
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
46
|
+
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
47
|
+
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
|
48
|
+
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
49
|
+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
50
|
+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
51
|
+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
52
|
+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
53
|
+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
54
|
+
* POSSIBILITY OF SUCH DAMAGE.
|
55
|
+
*
|
56
|
+
*/
|
57
|
+
|
58
|
+
#ifndef LASWP_H
|
59
|
+
#define LASWP_H
|
60
|
+
|
61
|
+
namespace nm { namespace math {
|
62
|
+
|
63
|
+
|
64
|
+
/*
|
65
|
+
* ATLAS function which performs row interchanges on a general rectangular matrix. Modeled after the LAPACK LASWP function.
|
66
|
+
*
|
67
|
+
* This version is templated for use by template <> getrf().
|
68
|
+
*/
|
69
|
+
template <typename DType>
|
70
|
+
inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
|
71
|
+
//const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
|
72
|
+
|
73
|
+
int nb = N >> 5;
|
74
|
+
|
75
|
+
const int mr = N - (nb<<5);
|
76
|
+
const int incA = lda << 5;
|
77
|
+
|
78
|
+
if (K2 < K1) return;
|
79
|
+
|
80
|
+
int i1, i2;
|
81
|
+
if (inci < 0) {
|
82
|
+
piv -= (K2-1) * inci;
|
83
|
+
i1 = K2 - 1;
|
84
|
+
i2 = K1;
|
85
|
+
} else {
|
86
|
+
piv += K1 * inci;
|
87
|
+
i1 = K1;
|
88
|
+
i2 = K2-1;
|
89
|
+
}
|
90
|
+
|
91
|
+
if (nb) {
|
92
|
+
|
93
|
+
do {
|
94
|
+
const int* ipiv = piv;
|
95
|
+
int i = i1;
|
96
|
+
int KeepOn;
|
97
|
+
|
98
|
+
do {
|
99
|
+
int ip = *ipiv; ipiv += inci;
|
100
|
+
|
101
|
+
if (ip != i) {
|
102
|
+
DType *a0 = &(A[i]),
|
103
|
+
*a1 = &(A[ip]);
|
104
|
+
|
105
|
+
for (register int h = 32; h; h--) {
|
106
|
+
DType r = *a0;
|
107
|
+
*a0 = *a1;
|
108
|
+
*a1 = r;
|
109
|
+
|
110
|
+
a0 += lda;
|
111
|
+
a1 += lda;
|
112
|
+
}
|
113
|
+
|
114
|
+
}
|
115
|
+
if (inci > 0) KeepOn = (++i <= i2);
|
116
|
+
else KeepOn = (--i >= i2);
|
117
|
+
|
118
|
+
} while (KeepOn);
|
119
|
+
A += incA;
|
120
|
+
} while (--nb);
|
121
|
+
}
|
122
|
+
|
123
|
+
if (mr) {
|
124
|
+
const int* ipiv = piv;
|
125
|
+
int i = i1;
|
126
|
+
int KeepOn;
|
127
|
+
|
128
|
+
do {
|
129
|
+
int ip = *ipiv; ipiv += inci;
|
130
|
+
if (ip != i) {
|
131
|
+
DType *a0 = &(A[i]),
|
132
|
+
*a1 = &(A[ip]);
|
133
|
+
|
134
|
+
for (register int h = mr; h; h--) {
|
135
|
+
DType r = *a0;
|
136
|
+
*a0 = *a1;
|
137
|
+
*a1 = r;
|
138
|
+
|
139
|
+
a0 += lda;
|
140
|
+
a1 += lda;
|
141
|
+
}
|
142
|
+
}
|
143
|
+
|
144
|
+
if (inci > 0) KeepOn = (++i <= i2);
|
145
|
+
else KeepOn = (--i >= i2);
|
146
|
+
|
147
|
+
} while (KeepOn);
|
148
|
+
}
|
149
|
+
}
|
150
|
+
|
151
|
+
|
152
|
+
/*
|
153
|
+
* Function signature conversion for calling LAPACK's laswp functions as directly as possible.
|
154
|
+
*
|
155
|
+
* For documentation: http://www.netlib.org/lapack/double/dlaswp.f
|
156
|
+
*
|
157
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
158
|
+
*/
|
159
|
+
template <typename DType>
|
160
|
+
inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
|
161
|
+
laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
|
162
|
+
}
|
163
|
+
|
164
|
+
} } // namespace nm::math
|
165
|
+
#endif // LASWP_H
|
@@ -0,0 +1,49 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == long_dtype.h
|
25
|
+
//
|
26
|
+
// Declarations necessary for the native versions of GEMM and GEMV.
|
27
|
+
//
|
28
|
+
|
29
|
+
#ifndef LONG_DTYPE_H
|
30
|
+
#define LONG_DTYPE_H
|
31
|
+
|
32
|
+
namespace nm { namespace math {
|
33
|
+
// These allow an increase in precision for intermediate values of gemm and gemv.
|
34
|
+
// See also: http://stackoverflow.com/questions/11873694/how-does-one-increase-precision-in-c-templates-in-a-template-typename-dependen
|
35
|
+
template <typename DType> struct LongDType;
|
36
|
+
template <> struct LongDType<uint8_t> { typedef int16_t type; };
|
37
|
+
template <> struct LongDType<int8_t> { typedef int16_t type; };
|
38
|
+
template <> struct LongDType<int16_t> { typedef int32_t type; };
|
39
|
+
template <> struct LongDType<int32_t> { typedef int64_t type; };
|
40
|
+
template <> struct LongDType<int64_t> { typedef int64_t type; };
|
41
|
+
template <> struct LongDType<float> { typedef double type; };
|
42
|
+
template <> struct LongDType<double> { typedef double type; };
|
43
|
+
template <> struct LongDType<Complex64> { typedef Complex128 type; };
|
44
|
+
template <> struct LongDType<Complex128> { typedef Complex128 type; };
|
45
|
+
template <> struct LongDType<RubyObject> { typedef RubyObject type; };
|
46
|
+
|
47
|
+
}} // end of namespace nm::math
|
48
|
+
|
49
|
+
#endif
|
@@ -0,0 +1,744 @@
|
|
1
|
+
/////////////////////////////////////////////////////////////////////
|
2
|
+
// = NMatrix
|
3
|
+
//
|
4
|
+
// A linear algebra library for scientific computation in Ruby.
|
5
|
+
// NMatrix is part of SciRuby.
|
6
|
+
//
|
7
|
+
// NMatrix was originally inspired by and derived from NArray, by
|
8
|
+
// Masahiro Tanaka: http://narray.rubyforge.org
|
9
|
+
//
|
10
|
+
// == Copyright Information
|
11
|
+
//
|
12
|
+
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
|
13
|
+
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
|
14
|
+
//
|
15
|
+
// Please see LICENSE.txt for additional copyright notices.
|
16
|
+
//
|
17
|
+
// == Contributing
|
18
|
+
//
|
19
|
+
// By contributing source code to SciRuby, you agree to be bound by
|
20
|
+
// our Contributor Agreement:
|
21
|
+
//
|
22
|
+
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
|
23
|
+
//
|
24
|
+
// == math.h
|
25
|
+
//
|
26
|
+
// Header file for math functions, interfacing with BLAS, etc.
|
27
|
+
//
|
28
|
+
// For instructions on adding CBLAS and CLAPACK functions, see the
|
29
|
+
// beginning of math.cpp.
|
30
|
+
//
|
31
|
+
// Some of these functions are from ATLAS. Here is the license for
|
32
|
+
// ATLAS:
|
33
|
+
//
|
34
|
+
/*
|
35
|
+
* Automatically Tuned Linear Algebra Software v3.8.4
|
36
|
+
* (C) Copyright 1999 R. Clint Whaley
|
37
|
+
*
|
38
|
+
* Redistribution and use in source and binary forms, with or without
|
39
|
+
* modification, are permitted provided that the following conditions
|
40
|
+
* are met:
|
41
|
+
* 1. Redistributions of source code must retain the above copyright
|
42
|
+
* notice, this list of conditions and the following disclaimer.
|
43
|
+
* 2. Redistributions in binary form must reproduce the above copyright
|
44
|
+
* notice, this list of conditions, and the following disclaimer in the
|
45
|
+
* documentation and/or other materials provided with the distribution.
|
46
|
+
* 3. The name of the ATLAS group or the names of its contributers may
|
47
|
+
* not be used to endorse or promote products derived from this
|
48
|
+
* software without specific written permission.
|
49
|
+
*
|
50
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
51
|
+
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
52
|
+
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
53
|
+
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
|
54
|
+
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
55
|
+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
56
|
+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
57
|
+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
58
|
+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
59
|
+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
60
|
+
* POSSIBILITY OF SUCH DAMAGE.
|
61
|
+
*
|
62
|
+
*/
|
63
|
+
|
64
|
+
#ifndef MATH_H
|
65
|
+
#define MATH_H
|
66
|
+
|
67
|
+
/*
|
68
|
+
* Standard Includes
|
69
|
+
*/
|
70
|
+
|
71
|
+
#include "cblas_enums.h"
|
72
|
+
|
73
|
+
#include <algorithm> // std::min, std::max
|
74
|
+
#include <limits> // std::numeric_limits
|
75
|
+
|
76
|
+
/*
|
77
|
+
* Project Includes
|
78
|
+
*/
|
79
|
+
|
80
|
+
/*
|
81
|
+
* Macros
|
82
|
+
*/
|
83
|
+
#define REAL_RECURSE_LIMIT 4
|
84
|
+
|
85
|
+
/*
|
86
|
+
* Data
|
87
|
+
*/
|
88
|
+
|
89
|
+
|
90
|
+
extern "C" {
|
91
|
+
/*
|
92
|
+
* C accessors.
|
93
|
+
*/
|
94
|
+
|
95
|
+
void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
|
96
|
+
void nm_math_init_blas(void);
|
97
|
+
|
98
|
+
/*
|
99
|
+
* Pure math implementations.
|
100
|
+
*/
|
101
|
+
void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv);
|
102
|
+
void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype);
|
103
|
+
void nm_math_hessenberg(VALUE a);
|
104
|
+
void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
|
105
|
+
void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype);
|
106
|
+
}
|
107
|
+
|
108
|
+
|
109
|
+
namespace nm {
|
110
|
+
namespace math {
|
111
|
+
|
112
|
+
/*
|
113
|
+
* Types
|
114
|
+
*/
|
115
|
+
|
116
|
+
|
117
|
+
/*
|
118
|
+
* Functions
|
119
|
+
*/
|
120
|
+
|
121
|
+
// Yale: numeric matrix multiply c=a*b
|
122
|
+
template <typename DType>
|
123
|
+
inline void numbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const DType* a, const bool diaga,
|
124
|
+
const IType* ib, const IType* jb, const DType* b, const bool diagb, IType* ic, IType* jc, DType* c, const bool diagc) {
|
125
|
+
const unsigned int max_lmn = std::max(std::max(m, n), l);
|
126
|
+
IType next[max_lmn];
|
127
|
+
DType sums[max_lmn];
|
128
|
+
|
129
|
+
DType v;
|
130
|
+
|
131
|
+
IType head, length, temp, ndnz = 0;
|
132
|
+
IType minmn = std::min(m,n);
|
133
|
+
IType minlm = std::min(l,m);
|
134
|
+
|
135
|
+
for (IType idx = 0; idx < max_lmn; ++idx) { // initialize scratch arrays
|
136
|
+
next[idx] = std::numeric_limits<IType>::max();
|
137
|
+
sums[idx] = 0;
|
138
|
+
}
|
139
|
+
|
140
|
+
for (IType i = 0; i < n; ++i) { // walk down the rows
|
141
|
+
head = std::numeric_limits<IType>::max()-1; // head gets assigned as whichever column of B's row j we last visited
|
142
|
+
length = 0;
|
143
|
+
|
144
|
+
for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // walk through entries in each row
|
145
|
+
IType j;
|
146
|
+
|
147
|
+
if (jj == ia[i+1]) { // if we're in the last entry for this row:
|
148
|
+
if (!diaga || i >= minmn) continue;
|
149
|
+
j = i; // if it's a new Yale matrix, and last entry, get the diagonal position (j) and entry (ajj)
|
150
|
+
v = a[i];
|
151
|
+
} else {
|
152
|
+
j = ja[jj]; // if it's not the last entry for this row, get the column (j) and entry (ajj)
|
153
|
+
v = a[jj];
|
154
|
+
}
|
155
|
+
|
156
|
+
for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) {
|
157
|
+
|
158
|
+
IType k;
|
159
|
+
|
160
|
+
if (kk == ib[j+1]) { // Get the column id for that entry
|
161
|
+
if (!diagb || j >= minlm) continue;
|
162
|
+
k = j;
|
163
|
+
sums[k] += v*b[k];
|
164
|
+
} else {
|
165
|
+
k = jb[kk];
|
166
|
+
sums[k] += v*b[kk];
|
167
|
+
}
|
168
|
+
|
169
|
+
if (next[k] == std::numeric_limits<IType>::max()) {
|
170
|
+
next[k] = head;
|
171
|
+
head = k;
|
172
|
+
++length;
|
173
|
+
}
|
174
|
+
} // end of kk loop
|
175
|
+
} // end of jj loop
|
176
|
+
|
177
|
+
for (IType jj = 0; jj < length; ++jj) {
|
178
|
+
if (sums[head] != 0) {
|
179
|
+
if (diagc && head == i) {
|
180
|
+
c[head] = sums[head];
|
181
|
+
} else {
|
182
|
+
jc[n+1+ndnz] = head;
|
183
|
+
c[n+1+ndnz] = sums[head];
|
184
|
+
++ndnz;
|
185
|
+
}
|
186
|
+
}
|
187
|
+
|
188
|
+
temp = head;
|
189
|
+
head = next[head];
|
190
|
+
|
191
|
+
next[temp] = std::numeric_limits<IType>::max();
|
192
|
+
sums[temp] = 0;
|
193
|
+
}
|
194
|
+
|
195
|
+
ic[i+1] = n+1+ndnz;
|
196
|
+
}
|
197
|
+
} /* numbmm_ */
|
198
|
+
|
199
|
+
|
200
|
+
/*
|
201
|
+
template <typename DType, typename IType>
|
202
|
+
inline void new_yale_matrix_multiply(const unsigned int m, const IType* ija, const DType* a, const IType* ijb, const DType* b, YALE_STORAGE* c_storage) {
|
203
|
+
unsigned int n = c_storage->shape[0],
|
204
|
+
l = c_storage->shape[1];
|
205
|
+
|
206
|
+
// Create a working vector of dimension max(m,l,n) and initial value IType::max():
|
207
|
+
std::vector<IType> mask(std::max(std::max(m,l),n), std::numeric_limits<IType>::max());
|
208
|
+
|
209
|
+
for (IType i = 0; i < n; ++i) { // A.rows.each_index do |i|
|
210
|
+
|
211
|
+
IType j, k;
|
212
|
+
size_t ndnz;
|
213
|
+
|
214
|
+
for (IType jj = ija[i]; jj <= ija[i+1]; ++jj) { // walk through column pointers for row i of A
|
215
|
+
j = (jj == ija[i+1]) ? i : ija[jj]; // Get the current column index (handle diagonals last)
|
216
|
+
|
217
|
+
if (j >= m) {
|
218
|
+
if (j == ija[jj]) rb_raise(rb_eIndexError, "ija array for left-hand matrix contains an out-of-bounds column index %u at position %u", jj, j);
|
219
|
+
else break;
|
220
|
+
}
|
221
|
+
|
222
|
+
for (IType kk = ijb[j]; kk <= ijb[j+1]; ++kk) { // walk through column pointers for row j of B
|
223
|
+
if (j >= m) continue; // first of all, does B *have* a row j?
|
224
|
+
k = (kk == ijb[j+1]) ? j : ijb[kk]; // Get the current column index (handle diagonals last)
|
225
|
+
|
226
|
+
if (k >= l) {
|
227
|
+
if (k == ijb[kk]) rb_raise(rb_eIndexError, "ija array for right-hand matrix contains an out-of-bounds column index %u at position %u", kk, k);
|
228
|
+
else break;
|
229
|
+
}
|
230
|
+
|
231
|
+
if (mask[k] == )
|
232
|
+
}
|
233
|
+
|
234
|
+
}
|
235
|
+
}
|
236
|
+
}
|
237
|
+
*/
|
238
|
+
|
239
|
+
// Yale: Symbolic matrix multiply c=a*b
|
240
|
+
inline size_t symbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const bool diaga,
|
241
|
+
const IType* ib, const IType* jb, const bool diagb, IType* ic, const bool diagc) {
|
242
|
+
unsigned int max_lmn = std::max(std::max(m,n), l);
|
243
|
+
IType mask[max_lmn]; // INDEX in the SMMP paper.
|
244
|
+
IType j, k; /* Local variables */
|
245
|
+
size_t ndnz = n;
|
246
|
+
|
247
|
+
for (IType idx = 0; idx < max_lmn; ++idx)
|
248
|
+
mask[idx] = std::numeric_limits<IType>::max();
|
249
|
+
|
250
|
+
if (ic) { // Only write to ic if it's supplied; otherwise, we're just counting.
|
251
|
+
if (diagc) ic[0] = n+1;
|
252
|
+
else ic[0] = 0;
|
253
|
+
}
|
254
|
+
|
255
|
+
IType minmn = std::min(m,n);
|
256
|
+
IType minlm = std::min(l,m);
|
257
|
+
|
258
|
+
for (IType i = 0; i < n; ++i) { // MAIN LOOP: through rows
|
259
|
+
|
260
|
+
for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // merge row lists, walking through columns in each row
|
261
|
+
|
262
|
+
// j <- column index given by JA[jj], or handle diagonal.
|
263
|
+
if (jj == ia[i+1]) { // Don't really do it the last time -- just handle diagonals in a new yale matrix.
|
264
|
+
if (!diaga || i >= minmn) continue;
|
265
|
+
j = i;
|
266
|
+
} else j = ja[jj];
|
267
|
+
|
268
|
+
for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) { // Now walk through columns K of row J in matrix B.
|
269
|
+
if (kk == ib[j+1]) {
|
270
|
+
if (!diagb || j >= minlm) continue;
|
271
|
+
k = j;
|
272
|
+
} else k = jb[kk];
|
273
|
+
|
274
|
+
if (mask[k] != i) {
|
275
|
+
mask[k] = i;
|
276
|
+
++ndnz;
|
277
|
+
}
|
278
|
+
}
|
279
|
+
}
|
280
|
+
|
281
|
+
if (diagc && mask[i] == std::numeric_limits<IType>::max()) --ndnz;
|
282
|
+
|
283
|
+
if (ic) ic[i+1] = ndnz;
|
284
|
+
}
|
285
|
+
|
286
|
+
return ndnz;
|
287
|
+
} /* symbmm_ */
|
288
|
+
|
289
|
+
|
290
|
+
// In-place quicksort (from Wikipedia) -- called by smmp_sort_columns, below. All functions are inclusive of left, right.
|
291
|
+
namespace smmp_sort {
|
292
|
+
const size_t THRESHOLD = 4; // switch to insertion sort for 4 elements or fewer
|
293
|
+
|
294
|
+
template <typename DType>
|
295
|
+
void print_array(DType* vals, IType* array, IType left, IType right) {
|
296
|
+
for (IType i = left; i <= right; ++i) {
|
297
|
+
std::cerr << array[i] << ":" << vals[i] << " ";
|
298
|
+
}
|
299
|
+
std::cerr << std::endl;
|
300
|
+
}
|
301
|
+
|
302
|
+
template <typename DType>
|
303
|
+
IType partition(DType* vals, IType* array, IType left, IType right, IType pivot) {
|
304
|
+
IType pivotJ = array[pivot];
|
305
|
+
DType pivotV = vals[pivot];
|
306
|
+
|
307
|
+
// Swap pivot and right
|
308
|
+
array[pivot] = array[right];
|
309
|
+
vals[pivot] = vals[right];
|
310
|
+
array[right] = pivotJ;
|
311
|
+
vals[right] = pivotV;
|
312
|
+
|
313
|
+
IType store = left;
|
314
|
+
for (IType idx = left; idx < right; ++idx) {
|
315
|
+
if (array[idx] <= pivotJ) {
|
316
|
+
// Swap i and store
|
317
|
+
std::swap(array[idx], array[store]);
|
318
|
+
std::swap(vals[idx], vals[store]);
|
319
|
+
++store;
|
320
|
+
}
|
321
|
+
}
|
322
|
+
|
323
|
+
std::swap(array[store], array[right]);
|
324
|
+
std::swap(vals[store], vals[right]);
|
325
|
+
|
326
|
+
return store;
|
327
|
+
}
|
328
|
+
|
329
|
+
// Recommended to use the median of left, right, and mid for the pivot.
|
330
|
+
template <typename I>
|
331
|
+
inline I median(I a, I b, I c) {
|
332
|
+
if (a < b) {
|
333
|
+
if (b < c) return b; // a b c
|
334
|
+
if (a < c) return c; // a c b
|
335
|
+
return a; // c a b
|
336
|
+
|
337
|
+
} else { // a > b
|
338
|
+
if (a < c) return a; // b a c
|
339
|
+
if (b < c) return c; // b c a
|
340
|
+
return b; // c b a
|
341
|
+
}
|
342
|
+
}
|
343
|
+
|
344
|
+
|
345
|
+
// Insertion sort is more efficient than quicksort for small N
|
346
|
+
template <typename DType>
|
347
|
+
void insertion_sort(DType* vals, IType* array, IType left, IType right) {
|
348
|
+
for (IType idx = left; idx <= right; ++idx) {
|
349
|
+
IType col_to_insert = array[idx];
|
350
|
+
DType val_to_insert = vals[idx];
|
351
|
+
|
352
|
+
IType hole_pos = idx;
|
353
|
+
for (; hole_pos > left && col_to_insert < array[hole_pos-1]; --hole_pos) {
|
354
|
+
array[hole_pos] = array[hole_pos - 1]; // shift the larger column index up
|
355
|
+
vals[hole_pos] = vals[hole_pos - 1]; // value goes along with it
|
356
|
+
}
|
357
|
+
|
358
|
+
array[hole_pos] = col_to_insert;
|
359
|
+
vals[hole_pos] = val_to_insert;
|
360
|
+
}
|
361
|
+
}
|
362
|
+
|
363
|
+
|
364
|
+
template <typename DType>
|
365
|
+
void quicksort(DType* vals, IType* array, IType left, IType right) {
|
366
|
+
|
367
|
+
if (left < right) {
|
368
|
+
if (right - left < THRESHOLD) {
|
369
|
+
insertion_sort(vals, array, left, right);
|
370
|
+
} else {
|
371
|
+
// choose any pivot such that left < pivot < right
|
372
|
+
IType pivot = median<IType>(left, right, (IType)(((unsigned long)left + (unsigned long)right) / 2));
|
373
|
+
pivot = partition(vals, array, left, right, pivot);
|
374
|
+
|
375
|
+
// recursively sort elements smaller than the pivot
|
376
|
+
quicksort<DType>(vals, array, left, pivot-1);
|
377
|
+
|
378
|
+
// recursively sort elements at least as big as the pivot
|
379
|
+
quicksort<DType>(vals, array, pivot+1, right);
|
380
|
+
}
|
381
|
+
}
|
382
|
+
}
|
383
|
+
|
384
|
+
|
385
|
+
}; // end of namespace smmp_sort
|
386
|
+
|
387
|
+
|
388
|
+
/*
|
389
|
+
* For use following symbmm and numbmm. Sorts the matrix entries in each row according to the column index.
|
390
|
+
* This utilizes quicksort, which is an in-place unstable sort (since there are no duplicate entries, we don't care
|
391
|
+
* about stability).
|
392
|
+
*
|
393
|
+
* TODO: It might be worthwhile to do a test for free memory, and if available, use an unstable sort that isn't in-place.
|
394
|
+
*
|
395
|
+
* TODO: It's actually probably possible to write an even faster sort, since symbmm/numbmm are not producing a random
|
396
|
+
* ordering. If someone is doing a lot of Yale matrix multiplication, it might benefit them to consider even insertion
|
397
|
+
* sort.
|
398
|
+
*/
|
399
|
+
template <typename DType>
|
400
|
+
inline void smmp_sort_columns(const size_t n, const IType* ia, IType* ja, DType* a) {
|
401
|
+
for (size_t i = 0; i < n; ++i) {
|
402
|
+
if (ia[i+1] - ia[i] < 2) continue; // no need to sort rows containing only one or two elements.
|
403
|
+
else if (ia[i+1] - ia[i] <= smmp_sort::THRESHOLD) {
|
404
|
+
smmp_sort::insertion_sort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for small rows
|
405
|
+
} else {
|
406
|
+
smmp_sort::quicksort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for large rows (and may call insertion_sort as well)
|
407
|
+
}
|
408
|
+
}
|
409
|
+
}
|
410
|
+
|
411
|
+
|
412
|
+
// Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
|
413
|
+
//
|
414
|
+
// From ATLAS 3.8.0.
|
415
|
+
template <typename DType>
|
416
|
+
static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
|
417
|
+
|
418
|
+
for (int i = 0; i != M; ++i) {
|
419
|
+
for (int j = i+1; j < N; ++j) {
|
420
|
+
C[j] = U[j];
|
421
|
+
U[j] = 0;
|
422
|
+
}
|
423
|
+
|
424
|
+
C += ldc;
|
425
|
+
U += ldu;
|
426
|
+
}
|
427
|
+
}
|
428
|
+
|
429
|
+
|
430
|
+
/*
|
431
|
+
* Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
|
432
|
+
* functions. This is probably really complicated.
|
433
|
+
*
|
434
|
+
* Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
|
435
|
+
*
|
436
|
+
*/
|
437
|
+
|
438
|
+
/*
|
439
|
+
|
440
|
+
template <bool RowMajor, bool Upper, typename DType>
|
441
|
+
static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
|
442
|
+
|
443
|
+
if (RowMajor) {
|
444
|
+
DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
|
445
|
+
DType tmp;
|
446
|
+
if (Upper) {
|
447
|
+
DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
|
448
|
+
A12 = pA1[2], A13 = pA1[3],
|
449
|
+
A23 = pA2[3];
|
450
|
+
|
451
|
+
if (Diag == CblasNonUnit) {
|
452
|
+
pA0->inverse();
|
453
|
+
(pA1+1)->inverse();
|
454
|
+
(pA2+2)->inverse();
|
455
|
+
(pA3+3)->inverse();
|
456
|
+
|
457
|
+
pA0[1] = -A01 * pA1[1] * pA0[0];
|
458
|
+
pA1[2] = -A12 * pA2[2] * pA1[1];
|
459
|
+
pA2[3] = -A23 * pA3[3] * pA2[2];
|
460
|
+
|
461
|
+
pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
|
462
|
+
pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
|
463
|
+
|
464
|
+
pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
|
465
|
+
|
466
|
+
} else {
|
467
|
+
|
468
|
+
pA0[1] = -A01;
|
469
|
+
pA1[2] = -A12;
|
470
|
+
pA2[3] = -A23;
|
471
|
+
|
472
|
+
pA0[2] = -(A01 * pA1[2] + A02);
|
473
|
+
pA1[3] = -(A12 * pA2[3] + A13);
|
474
|
+
|
475
|
+
pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
|
476
|
+
}
|
477
|
+
|
478
|
+
} else { // Lower
|
479
|
+
DType A10 = pA1[0],
|
480
|
+
A20 = pA2[0], A21 = pA2[1],
|
481
|
+
A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
|
482
|
+
DType *B10 = pA1,
|
483
|
+
*B20 = pA2,
|
484
|
+
*B30 = pA3,
|
485
|
+
*B21 = pA2+1,
|
486
|
+
*B31 = pA3+1,
|
487
|
+
*B32 = pA3+2;
|
488
|
+
|
489
|
+
|
490
|
+
if (Diag == CblasNonUnit) {
|
491
|
+
pA0->inverse();
|
492
|
+
(pA1+1)->inverse();
|
493
|
+
(pA2+2)->inverse();
|
494
|
+
(pA3+3)->inverse();
|
495
|
+
|
496
|
+
*B10 = -A10 * pA0[0] * pA1[1];
|
497
|
+
*B21 = -A21 * pA1[1] * pA2[2];
|
498
|
+
*B32 = -A32 * pA2[2] * pA3[3];
|
499
|
+
*B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
|
500
|
+
*B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
|
501
|
+
*B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
|
502
|
+
} else {
|
503
|
+
*B10 = -A10;
|
504
|
+
*B21 = -A21;
|
505
|
+
*B32 = -A32;
|
506
|
+
*B20 = -(A20 + A21 * (*B10));
|
507
|
+
*B31 = -(A31 + A32 * (*B21));
|
508
|
+
*B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
|
509
|
+
}
|
510
|
+
}
|
511
|
+
|
512
|
+
} else {
|
513
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
514
|
+
}
|
515
|
+
|
516
|
+
return 0;
|
517
|
+
|
518
|
+
}
|
519
|
+
|
520
|
+
|
521
|
+
template <bool RowMajor, bool Upper, typename DType>
|
522
|
+
static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
|
523
|
+
|
524
|
+
if (RowMajor) {
|
525
|
+
|
526
|
+
DType tmp;
|
527
|
+
|
528
|
+
if (Upper) {
|
529
|
+
DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
|
530
|
+
A12 = pA1[2], A13 = pA1[3];
|
531
|
+
|
532
|
+
DType *B01 = pA0 + 1,
|
533
|
+
*B02 = pA0 + 2,
|
534
|
+
*B12 = pA1 + 2;
|
535
|
+
|
536
|
+
if (Diag == CblasNonUnit) {
|
537
|
+
pA0->inverse();
|
538
|
+
(pA1+1)->inverse();
|
539
|
+
(pA2+2)->inverse();
|
540
|
+
|
541
|
+
*B01 = -A01 * pA1[1] * pA0[0];
|
542
|
+
*B12 = -A12 * pA2[2] * pA1[1];
|
543
|
+
*B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
|
544
|
+
} else {
|
545
|
+
*B01 = -A01;
|
546
|
+
*B12 = -A12;
|
547
|
+
*B02 = -(A01 * (*B12) + A02);
|
548
|
+
}
|
549
|
+
|
550
|
+
} else { // Lower
|
551
|
+
DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
|
552
|
+
DType A10=pA1[0],
|
553
|
+
A20=pA2[0], A21=pA2[1];
|
554
|
+
|
555
|
+
DType *B10 = pA1,
|
556
|
+
*B20 = pA2;
|
557
|
+
*B21 = pA2+1;
|
558
|
+
|
559
|
+
if (Diag == CblasNonUnit) {
|
560
|
+
pA0->inverse();
|
561
|
+
(pA1+1)->inverse();
|
562
|
+
(pA2+2)->inverse();
|
563
|
+
*B10 = -A10 * pA0[0] * pA1[1];
|
564
|
+
*B21 = -A21 * pA1[1] * pA2[2];
|
565
|
+
*B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
|
566
|
+
} else {
|
567
|
+
*B10 = -A10;
|
568
|
+
*B21 = -A21;
|
569
|
+
*B20 = -(A20 + A21 * (*B10));
|
570
|
+
}
|
571
|
+
}
|
572
|
+
|
573
|
+
|
574
|
+
} else {
|
575
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
576
|
+
}
|
577
|
+
|
578
|
+
return 0;
|
579
|
+
|
580
|
+
}
|
581
|
+
|
582
|
+
template <bool RowMajor, bool Upper, bool Real, typename DType>
|
583
|
+
static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
|
584
|
+
DType *Age, *Atr;
|
585
|
+
DType tmp;
|
586
|
+
int Nleft, Nright;
|
587
|
+
|
588
|
+
int ierr = 0;
|
589
|
+
|
590
|
+
static const DType ONE = 1;
|
591
|
+
static const DType MONE -1;
|
592
|
+
static const DType NONE = -1;
|
593
|
+
|
594
|
+
if (RowMajor) {
|
595
|
+
|
596
|
+
// FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
|
597
|
+
if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
|
598
|
+
Nleft = N >> 1;
|
599
|
+
#ifdef NB
|
600
|
+
if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
|
601
|
+
#endif
|
602
|
+
|
603
|
+
Nright = N - Nleft;
|
604
|
+
|
605
|
+
if (Upper) {
|
606
|
+
Age = A + Nleft;
|
607
|
+
Atr = A + (Nleft * (lda+1));
|
608
|
+
|
609
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
|
610
|
+
Nleft, Nright, ONE, Atr, lda, Age, lda);
|
611
|
+
|
612
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
|
613
|
+
Nleft, Nright, MONE, A, lda, Age, lda);
|
614
|
+
|
615
|
+
} else { // Lower
|
616
|
+
Age = A + ((Nleft*lda));
|
617
|
+
Atr = A + (Nleft * (lda+1));
|
618
|
+
|
619
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
|
620
|
+
Nright, Nleft, ONE, A, lda, Age, lda);
|
621
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
|
622
|
+
Nright, Nleft, MONE, Atr, lda, Age, lda);
|
623
|
+
}
|
624
|
+
|
625
|
+
ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
|
626
|
+
if (ierr) return ierr;
|
627
|
+
|
628
|
+
ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
|
629
|
+
if (ierr) return ierr + Nleft;
|
630
|
+
|
631
|
+
} else {
|
632
|
+
if (Real) {
|
633
|
+
if (N == 4) {
|
634
|
+
return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
|
635
|
+
} else if (N == 3) {
|
636
|
+
return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
|
637
|
+
} else if (N == 2) {
|
638
|
+
if (Diag == CblasNonUnit) {
|
639
|
+
A->inverse();
|
640
|
+
(A+(lda+1))->inverse();
|
641
|
+
|
642
|
+
if (Upper) {
|
643
|
+
*(A+1) *= *A; // TRI_MUL
|
644
|
+
*(A+1) *= *(A+lda+1); // TRI_MUL
|
645
|
+
} else {
|
646
|
+
*(A+lda) *= *A; // TRI_MUL
|
647
|
+
*(A+lda) *= *(A+lda+1); // TRI_MUL
|
648
|
+
}
|
649
|
+
}
|
650
|
+
|
651
|
+
if (Upper) *(A+1) = -*(A+1); // TRI_NEG
|
652
|
+
else *(A+lda) = -*(A+lda); // TRI_NEG
|
653
|
+
} else if (Diag == CblasNonUnit) A->inverse();
|
654
|
+
} else { // not real
|
655
|
+
if (Diag == CblasNonUnit) A->inverse();
|
656
|
+
}
|
657
|
+
}
|
658
|
+
|
659
|
+
} else {
|
660
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
661
|
+
}
|
662
|
+
|
663
|
+
return ierr;
|
664
|
+
}
|
665
|
+
|
666
|
+
|
667
|
+
template <bool RowMajor, bool Real, typename DType>
|
668
|
+
int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
|
669
|
+
|
670
|
+
if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
671
|
+
|
672
|
+
int jb, nb, I, ndown, iret;
|
673
|
+
|
674
|
+
const DType ONE = 1, NONE = -1;
|
675
|
+
|
676
|
+
int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
|
677
|
+
if (!iret && N > 1) {
|
678
|
+
jb = lwrk / N;
|
679
|
+
if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
|
680
|
+
else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
|
681
|
+
else nb = jb;
|
682
|
+
if (!nb) return -6; // need at least 1 row of workspace
|
683
|
+
|
684
|
+
// only first iteration will have partial block, unroll it
|
685
|
+
|
686
|
+
jb = N - (N/nb) * nb;
|
687
|
+
if (!jb) jb = nb;
|
688
|
+
I = N - jb;
|
689
|
+
A += lda * I;
|
690
|
+
trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
|
691
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
|
692
|
+
jb, N, ONE, wrk, jb, A, lda);
|
693
|
+
|
694
|
+
if (I) {
|
695
|
+
do {
|
696
|
+
I -= nb;
|
697
|
+
A -= nb * lda;
|
698
|
+
ndown = N-I;
|
699
|
+
trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
|
700
|
+
nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
|
701
|
+
nb, N, ONE, wrk, ndown, A, lda);
|
702
|
+
} while (I);
|
703
|
+
}
|
704
|
+
|
705
|
+
// Apply row interchanges
|
706
|
+
|
707
|
+
for (I = N - 2; I >= 0; --I) {
|
708
|
+
jb = ipiv[I];
|
709
|
+
if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
|
710
|
+
}
|
711
|
+
}
|
712
|
+
|
713
|
+
return iret;
|
714
|
+
}
|
715
|
+
*/
|
716
|
+
|
717
|
+
/*
|
718
|
+
* Macro for declaring LAPACK specializations of the getrf function.
|
719
|
+
*
|
720
|
+
* type is the DType; call is the specific function to call; cast_as is what the DType* should be
|
721
|
+
* cast to in order to pass it to LAPACK.
|
722
|
+
*/
|
723
|
+
#define LAPACK_GETRF(type, call, cast_as) \
|
724
|
+
template <> \
|
725
|
+
inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, type * A, const int lda, int* ipiv) { \
|
726
|
+
int info = call(Order, M, N, reinterpret_cast<cast_as *>(A), lda, ipiv); \
|
727
|
+
if (!info) return info; \
|
728
|
+
else { \
|
729
|
+
rb_raise(rb_eArgError, "getrf: problem with argument %d\n", info); \
|
730
|
+
return info; \
|
731
|
+
} \
|
732
|
+
}
|
733
|
+
|
734
|
+
/* Specialize for ATLAS types */
|
735
|
+
/*LAPACK_GETRF(float, clapack_sgetrf, float)
|
736
|
+
LAPACK_GETRF(double, clapack_dgetrf, double)
|
737
|
+
LAPACK_GETRF(Complex64, clapack_cgetrf, void)
|
738
|
+
LAPACK_GETRF(Complex128, clapack_zgetrf, void)
|
739
|
+
*/
|
740
|
+
|
741
|
+
}} // end namespace nm::math
|
742
|
+
|
743
|
+
|
744
|
+
#endif // MATH_H
|