nmatrix-atlas 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (82) hide show
  1. checksums.yaml +7 -0
  2. data/ext/nmatrix/data/complex.h +364 -0
  3. data/ext/nmatrix/data/data.h +638 -0
  4. data/ext/nmatrix/data/meta.h +64 -0
  5. data/ext/nmatrix/data/ruby_object.h +389 -0
  6. data/ext/nmatrix/math/asum.h +120 -0
  7. data/ext/nmatrix/math/cblas_enums.h +36 -0
  8. data/ext/nmatrix/math/cblas_templates_core.h +507 -0
  9. data/ext/nmatrix/math/gemm.h +241 -0
  10. data/ext/nmatrix/math/gemv.h +178 -0
  11. data/ext/nmatrix/math/getrf.h +255 -0
  12. data/ext/nmatrix/math/getrs.h +121 -0
  13. data/ext/nmatrix/math/imax.h +79 -0
  14. data/ext/nmatrix/math/laswp.h +165 -0
  15. data/ext/nmatrix/math/long_dtype.h +49 -0
  16. data/ext/nmatrix/math/math.h +744 -0
  17. data/ext/nmatrix/math/nrm2.h +160 -0
  18. data/ext/nmatrix/math/rot.h +117 -0
  19. data/ext/nmatrix/math/rotg.h +106 -0
  20. data/ext/nmatrix/math/scal.h +71 -0
  21. data/ext/nmatrix/math/trsm.h +332 -0
  22. data/ext/nmatrix/math/util.h +148 -0
  23. data/ext/nmatrix/nm_memory.h +60 -0
  24. data/ext/nmatrix/nmatrix.h +408 -0
  25. data/ext/nmatrix/ruby_constants.h +106 -0
  26. data/ext/nmatrix/storage/common.h +176 -0
  27. data/ext/nmatrix/storage/dense/dense.h +128 -0
  28. data/ext/nmatrix/storage/list/list.h +137 -0
  29. data/ext/nmatrix/storage/storage.h +98 -0
  30. data/ext/nmatrix/storage/yale/class.h +1139 -0
  31. data/ext/nmatrix/storage/yale/iterators/base.h +142 -0
  32. data/ext/nmatrix/storage/yale/iterators/iterator.h +130 -0
  33. data/ext/nmatrix/storage/yale/iterators/row.h +449 -0
  34. data/ext/nmatrix/storage/yale/iterators/row_stored.h +139 -0
  35. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +168 -0
  36. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +123 -0
  37. data/ext/nmatrix/storage/yale/math/transpose.h +110 -0
  38. data/ext/nmatrix/storage/yale/yale.h +202 -0
  39. data/ext/nmatrix/types.h +54 -0
  40. data/ext/nmatrix/util/io.h +115 -0
  41. data/ext/nmatrix/util/sl_list.h +143 -0
  42. data/ext/nmatrix/util/util.h +78 -0
  43. data/ext/nmatrix_atlas/extconf.rb +250 -0
  44. data/ext/nmatrix_atlas/math_atlas.cpp +1206 -0
  45. data/ext/nmatrix_atlas/math_atlas/cblas_templates_atlas.h +72 -0
  46. data/ext/nmatrix_atlas/math_atlas/clapack_templates.h +332 -0
  47. data/ext/nmatrix_atlas/math_atlas/geev.h +82 -0
  48. data/ext/nmatrix_atlas/math_atlas/gesdd.h +83 -0
  49. data/ext/nmatrix_atlas/math_atlas/gesvd.h +81 -0
  50. data/ext/nmatrix_atlas/math_atlas/inc.h +47 -0
  51. data/ext/nmatrix_atlas/nmatrix_atlas.cpp +44 -0
  52. data/lib/nmatrix/atlas.rb +213 -0
  53. data/lib/nmatrix/lapack_ext_common.rb +69 -0
  54. data/spec/00_nmatrix_spec.rb +730 -0
  55. data/spec/01_enum_spec.rb +190 -0
  56. data/spec/02_slice_spec.rb +389 -0
  57. data/spec/03_nmatrix_monkeys_spec.rb +78 -0
  58. data/spec/2x2_dense_double.mat +0 -0
  59. data/spec/4x4_sparse.mat +0 -0
  60. data/spec/4x5_dense.mat +0 -0
  61. data/spec/blas_spec.rb +193 -0
  62. data/spec/elementwise_spec.rb +303 -0
  63. data/spec/homogeneous_spec.rb +99 -0
  64. data/spec/io/fortran_format_spec.rb +88 -0
  65. data/spec/io/harwell_boeing_spec.rb +98 -0
  66. data/spec/io/test.rua +9 -0
  67. data/spec/io_spec.rb +149 -0
  68. data/spec/lapack_core_spec.rb +482 -0
  69. data/spec/leakcheck.rb +16 -0
  70. data/spec/math_spec.rb +730 -0
  71. data/spec/nmatrix_yale_resize_test_associations.yaml +2802 -0
  72. data/spec/nmatrix_yale_spec.rb +286 -0
  73. data/spec/plugins/atlas/atlas_spec.rb +242 -0
  74. data/spec/rspec_monkeys.rb +56 -0
  75. data/spec/rspec_spec.rb +34 -0
  76. data/spec/shortcuts_spec.rb +310 -0
  77. data/spec/slice_set_spec.rb +157 -0
  78. data/spec/spec_helper.rb +140 -0
  79. data/spec/stat_spec.rb +203 -0
  80. data/spec/test.pcd +20 -0
  81. data/spec/utm5940.mtx +83844 -0
  82. metadata +159 -0
@@ -0,0 +1,83 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gesdd.h
25
+ //
26
+ // Header file for interface with LAPACK's xGESDD functions.
27
+ //
28
+
29
+ #ifndef GESDD_H
30
+ # define GESDD_H
31
+
32
+ extern "C" {
33
+
34
+ void sgesdd_(char*, int*, int*, float*, int*, float*, float*, int*, float*, int*, float*, int*, int*, int*);
35
+ void dgesdd_(char*, int*, int*, double*, int*, double*, double*, int*, double*, int*, double*, int*, int*, int*);
36
+ //the argument s is an array of real values and is returned as array of float/double
37
+ void cgesdd_(char*, int*, int*, nm::Complex64*, int*, float* s, nm::Complex64*, int*, nm::Complex64*, int*, nm::Complex64*, int*, float*, int*, int*);
38
+ void zgesdd_(char*, int*, int*, nm::Complex128*, int*, double* s, nm::Complex128*, int*, nm::Complex128*, int*, nm::Complex128*, int*, double*, int*, int*);
39
+ }
40
+
41
+ namespace nm {
42
+ namespace math {
43
+ namespace atlas {
44
+
45
+ template <typename DType, typename CType>
46
+ inline int gesdd(char jobz, int m, int n, DType* a, int lda, CType* s, DType* u, int ldu, DType* vt, int ldvt, DType* work, int lwork, int* iwork, CType* rwork) {
47
+ rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
48
+ return -1;
49
+ }
50
+
51
+ template <>
52
+ inline int gesdd(char jobz, int m, int n, float* a, int lda, float* s, float* u, int ldu, float* vt, int ldvt, float* work, int lwork, int* iwork, float* rwork) {
53
+ int info;
54
+ sgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info);
55
+ return info;
56
+ }
57
+
58
+ template <>
59
+ inline int gesdd(char jobz, int m, int n, double* a, int lda, double* s, double* u, int ldu, double* vt, int ldvt, double* work, int lwork, int* iwork, double* rwork) {
60
+ int info;
61
+ dgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info);
62
+ return info;
63
+ }
64
+
65
+ template <>
66
+ inline int gesdd(char jobz, int m, int n, nm::Complex64* a, int lda, float* s, nm::Complex64* u, int ldu, nm::Complex64* vt, int ldvt, nm::Complex64* work, int lwork, int* iwork, float* rwork) {
67
+ int info;
68
+ cgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, iwork, &info);
69
+ return info;
70
+ }
71
+
72
+ template <>
73
+ inline int gesdd(char jobz, int m, int n, nm::Complex128* a, int lda, double* s, nm::Complex128* u, int ldu, nm::Complex128* vt, int ldvt, nm::Complex128* work, int lwork, int* iwork, double* rwork) {
74
+ int info;
75
+ zgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, iwork, &info);
76
+ return info;
77
+ }
78
+
79
+ } // end of namespace atlas
80
+ } // end of namespace math
81
+ } // end of namespace nm
82
+
83
+ #endif // GESDD_H
@@ -0,0 +1,81 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == gesvd.h
25
+ //
26
+ // Header file for interface with LAPACK's xGESVD functions.
27
+ //
28
+
29
+ #ifndef GESVD_H
30
+ # define GESVD_H
31
+
32
+ extern "C" {
33
+ void sgesvd_(char*, char*, int*, int*, float*, int*, float*, float*, int*, float*, int*, float*, int*, int*);
34
+ void dgesvd_(char*, char*, int*, int*, double*, int*, double*, double*, int*, double*, int*, double*, int*, int*);
35
+ //the argument s is an array of real values and is returned as array of float/double
36
+ void cgesvd_(char*, char*, int*, int*, nm::Complex64*, int*, float* s, nm::Complex64*, int*, nm::Complex64*, int*, nm::Complex64*, int*, float*, int*);
37
+ void zgesvd_(char*, char*, int*, int*, nm::Complex128*, int*, double* s, nm::Complex128*, int*, nm::Complex128*, int*, nm::Complex128*, int*, double*, int*);
38
+ }
39
+
40
+ namespace nm {
41
+ namespace math {
42
+ namespace atlas {
43
+
44
+ template <typename DType, typename CType>
45
+ inline int gesvd(char jobu, char jobvt, int m, int n, DType* a, int lda, CType* s, DType* u, int ldu, DType* vt, int ldvt, DType* work, int lwork, CType* rwork) {
46
+ rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
47
+ return -1;
48
+ }
49
+
50
+ template <>
51
+ inline int gesvd(char jobu, char jobvt, int m, int n, float* a, int lda, float* s, float* u, int ldu, float* vt, int ldvt, float* work, int lwork, float* rwork) {
52
+ int info;
53
+ sgesvd_(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, &info);
54
+ return info;
55
+ }
56
+
57
+ template <>
58
+ inline int gesvd(char jobu, char jobvt, int m, int n, double* a, int lda, double* s, double* u, int ldu, double* vt, int ldvt, double* work, int lwork, double* rwork) {
59
+ int info;
60
+ dgesvd_(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, &info);
61
+ return info;
62
+ }
63
+
64
+ template <>
65
+ inline int gesvd(char jobu, char jobvt, int m, int n, nm::Complex64* a, int lda, float* s, nm::Complex64* u, int ldu, nm::Complex64* vt, int ldvt, nm::Complex64* work, int lwork, float* rwork) {
66
+ int info;
67
+ cgesvd_(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, &info);
68
+ return info;
69
+ }
70
+
71
+ template <>
72
+ inline int gesvd(char jobu, char jobvt, int m, int n, nm::Complex128* a, int lda, double* s, nm::Complex128* u, int ldu, nm::Complex128* vt, int ldvt, nm::Complex128* work, int lwork, double* rwork) {
73
+ int info;
74
+ zgesvd_(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, &info);
75
+ return info;
76
+ }
77
+
78
+ } // end of namespace atlas
79
+ } // end of namespace math
80
+ } // end of namespace nm
81
+ #endif // GESVD_H
@@ -0,0 +1,47 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == inc.h
25
+ //
26
+ // Includes needed for LAPACK, CLAPACK, and CBLAS functions.
27
+ //
28
+
29
+ #ifndef INC_H
30
+ # define INC_H
31
+
32
+
33
+ extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
34
+ #if defined HAVE_CBLAS_H
35
+ #include <cblas.h>
36
+ #elif defined HAVE_ATLAS_CBLAS_H
37
+ #include <atlas/cblas.h>
38
+ #endif
39
+
40
+ #if defined HAVE_CLAPACK_H
41
+ #include <clapack.h>
42
+ #elif defined HAVE_ATLAS_CLAPACK_H
43
+ #include <atlas/clapack.h>
44
+ #endif
45
+ }
46
+
47
+ #endif // INC_H
@@ -0,0 +1,44 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == nmatrix_atlas.cpp
25
+ //
26
+ // Main file for nmatrix_atlas extension
27
+ //
28
+
29
+ #include <ruby.h>
30
+
31
+ #include "nmatrix.h"
32
+
33
+ #include "math_atlas/inc.h"
34
+
35
+ #include "data/data.h"
36
+
37
+ extern "C" {
38
+ void nm_math_init_atlas();
39
+
40
+ void Init_nmatrix_atlas() {
41
+ nm_math_init_atlas();
42
+ }
43
+
44
+ }
@@ -0,0 +1,213 @@
1
+ #--
2
+ # = NMatrix
3
+ #
4
+ # A linear algebra library for scientific computation in Ruby.
5
+ # NMatrix is part of SciRuby.
6
+ #
7
+ # NMatrix was originally inspired by and derived from NArray, by
8
+ # Masahiro Tanaka: http://narray.rubyforge.org
9
+ #
10
+ # == Copyright Information
11
+ #
12
+ # SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ # NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ #
15
+ # Please see LICENSE.txt for additional copyright notices.
16
+ #
17
+ # == Contributing
18
+ #
19
+ # By contributing source code to SciRuby, you agree to be bound by
20
+ # our Contributor Agreement:
21
+ #
22
+ # * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ #
24
+ # == atlas.rb
25
+ #
26
+ # ruby file for the nmatrix-atlas gem. Loads the C extension and defines
27
+ # nice ruby interfaces for ATLAS functions.
28
+ #++
29
+
30
+ require 'nmatrix/nmatrix.rb' #need to have nmatrix required first or else bad things will happen
31
+ require_relative 'lapack_ext_common'
32
+
33
+ NMatrix.register_lapack_extension("nmatrix-atlas")
34
+
35
+ require "nmatrix_atlas.so"
36
+
37
+ class NMatrix
38
+
39
+ #Add functions from the ATLAS C extension to the main LAPACK and BLAS modules.
40
+ #This will overwrite the original functions where applicable.
41
+ module LAPACK
42
+ class << self
43
+ NMatrix::ATLAS::LAPACK.singleton_methods.each do |m|
44
+ define_method m, NMatrix::ATLAS::LAPACK.method(m).to_proc
45
+ end
46
+ end
47
+ end
48
+
49
+ module BLAS
50
+ class << self
51
+ NMatrix::ATLAS::BLAS.singleton_methods.each do |m|
52
+ define_method m, NMatrix::ATLAS::BLAS.method(m).to_proc
53
+ end
54
+ end
55
+ end
56
+
57
+ module LAPACK
58
+ class << self
59
+ def posv(uplo, a, b)
60
+ raise(ShapeError, "a must be square") unless a.dim == 2 && a.shape[0] == a.shape[1]
61
+ raise(ShapeError, "number of rows of b must equal number of cols of a") unless a.shape[1] == b.shape[0]
62
+ raise(StorageTypeError, "only works with dense matrices") unless a.stype == :dense && b.stype == :dense
63
+ raise(DataTypeError, "only works for non-integer, non-object dtypes") if
64
+ a.integer_dtype? || a.object_dtype? || b.integer_dtype? || b.object_dtype?
65
+
66
+ x = b.clone
67
+ clone = a.clone
68
+ n = a.shape[0]
69
+ nrhs = b.shape[1]
70
+ clapack_potrf(:row, uplo, n, clone, n)
71
+ # Must transpose b before and after: http://math-atlas.sourceforge.net/faq.html#RowSolve
72
+ x = x.transpose
73
+ clapack_potrs(:row, uplo, n, nrhs, clone, n, x, n)
74
+ x.transpose
75
+ end
76
+
77
+ def geev(matrix, which=:both)
78
+ raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless matrix.dense?
79
+ raise(ShapeError, "eigenvalues can only be computed for square matrices") unless matrix.dim == 2 && matrix.shape[0] == matrix.shape[1]
80
+
81
+ jobvl = (which == :both || which == :left) ? :t : false
82
+ jobvr = (which == :both || which == :right) ? :t : false
83
+
84
+ n = matrix.shape[0]
85
+
86
+ # Outputs
87
+ eigenvalues = NMatrix.new([n, 1], dtype: matrix.dtype) # For real dtypes this holds only the real part of the eigenvalues.
88
+ imag_eigenvalues = matrix.complex_dtype? ? nil : NMatrix.new([n, 1], dtype: matrix.dtype) # For complex dtypes, this is unused.
89
+ left_output = jobvl ? matrix.clone_structure : nil
90
+ right_output = jobvr ? matrix.clone_structure : nil
91
+
92
+ # lapack_geev is a pure LAPACK routine so it expects column-major matrices,
93
+ # so we need to transpose the input as well as the output.
94
+ temporary_matrix = matrix.transpose
95
+ NMatrix::LAPACK::lapack_geev(jobvl, # compute left eigenvectors of A?
96
+ jobvr, # compute right eigenvectors of A? (left eigenvectors of A**T)
97
+ n, # order of the matrix
98
+ temporary_matrix,# input matrix (used as work)
99
+ n, # leading dimension of matrix
100
+ eigenvalues,# real part of computed eigenvalues
101
+ imag_eigenvalues,# imag part of computed eigenvalues
102
+ left_output, # left eigenvectors, if applicable
103
+ n, # leading dimension of left_output
104
+ right_output, # right eigenvectors, if applicable
105
+ n, # leading dimension of right_output
106
+ 2*n)
107
+ left_output = left_output.transpose if jobvl
108
+ right_output = right_output.transpose if jobvr
109
+
110
+
111
+ # For real dtypes, transform left_output and right_output into correct forms.
112
+ # If the j'th and the (j+1)'th eigenvalues form a complex conjugate
113
+ # pair, then the j'th and (j+1)'th columns of the matrix are
114
+ # the real and imag parts of the eigenvector corresponding
115
+ # to the j'th eigenvalue.
116
+ if !matrix.complex_dtype?
117
+ complex_indices = []
118
+ n.times do |i|
119
+ complex_indices << i if imag_eigenvalues[i] != 0.0
120
+ end
121
+
122
+ if !complex_indices.empty?
123
+ # For real dtypes, put the real and imaginary parts together
124
+ eigenvalues = eigenvalues + imag_eigenvalues*Complex(0.0,1.0)
125
+ left_output = left_output.cast(dtype: NMatrix.upcast(:complex64, matrix.dtype)) if left_output
126
+ right_output = right_output.cast(dtype: NMatrix.upcast(:complex64, matrix.dtype)) if right_output
127
+ end
128
+
129
+ complex_indices.each_slice(2) do |i, _|
130
+ if right_output
131
+ right_output[0...n,i] = right_output[0...n,i] + right_output[0...n,i+1]*Complex(0.0,1.0)
132
+ right_output[0...n,i+1] = right_output[0...n,i].complex_conjugate
133
+ end
134
+
135
+ if left_output
136
+ left_output[0...n,i] = left_output[0...n,i] + left_output[0...n,i+1]*Complex(0.0,1.0)
137
+ left_output[0...n,i+1] = left_output[0...n,i].complex_conjugate
138
+ end
139
+ end
140
+ end
141
+
142
+ if which == :both
143
+ return [eigenvalues, left_output, right_output]
144
+ elsif which == :left
145
+ return [eigenvalues, left_output]
146
+ else
147
+ return [eigenvalues, right_output]
148
+ end
149
+ end
150
+
151
+ def gesvd(matrix, workspace_size=1)
152
+ result = alloc_svd_result(matrix)
153
+
154
+ m = matrix.shape[0]
155
+ n = matrix.shape[1]
156
+
157
+ # This is a pure LAPACK function so it expects column-major functions.
158
+ # So we need to transpose the input as well as the output.
159
+ matrix = matrix.transpose
160
+ NMatrix::LAPACK::lapack_gesvd(:a, :a, m, n, matrix, m, result[1], result[0], m, result[2], n, workspace_size)
161
+ result[0] = result[0].transpose
162
+ result[2] = result[2].transpose
163
+ result
164
+ end
165
+
166
+ def gesdd(matrix, workspace_size=nil)
167
+ min_workspace_size = matrix.shape.min * (6 + 4 * matrix.shape.min) + matrix.shape.max
168
+ workspace_size = min_workspace_size if workspace_size.nil? || workspace_size < min_workspace_size
169
+
170
+ result = alloc_svd_result(matrix)
171
+
172
+ m = matrix.shape[0]
173
+ n = matrix.shape[1]
174
+
175
+ # This is a pure LAPACK function so it expects column-major functions.
176
+ # So we need to transpose the input as well as the output.
177
+ matrix = matrix.transpose
178
+ NMatrix::LAPACK::lapack_gesdd(:a, m, n, matrix, m, result[1], result[0], m, result[2], n, workspace_size)
179
+ result[0] = result[0].transpose
180
+ result[2] = result[2].transpose
181
+ result
182
+ end
183
+ end
184
+ end
185
+
186
+ def invert!
187
+ raise(StorageTypeError, "invert only works on dense matrices currently") unless self.dense?
188
+ raise(ShapeError, "Cannot invert non-square matrix") unless shape[0] == shape[1]
189
+ raise(DataTypeError, "Cannot invert an integer matrix in-place") if self.integer_dtype?
190
+
191
+ # Even though we are using the ATLAS plugin, we still might be missing
192
+ # CLAPACK (and thus clapack_getri) if we are on OS X.
193
+ if NMatrix.has_clapack?
194
+ # Get the pivot array; factor the matrix
195
+ # We can't used getrf! here since it doesn't have the clapack behavior,
196
+ # so it doesn't play nicely with clapack_getri
197
+ n = self.shape[0]
198
+ pivot = NMatrix::LAPACK::clapack_getrf(:row, n, n, self, n)
199
+ # Now calculate the inverse using the pivot array
200
+ NMatrix::LAPACK::clapack_getri(:row, n, self, n, pivot)
201
+ self
202
+ else
203
+ __inverse__(self,true)
204
+ end
205
+ end
206
+
207
+ def potrf!(which)
208
+ raise(StorageTypeError, "ATLAS functions only work on dense matrices") unless self.dense?
209
+ raise(ShapeError, "Cholesky decomposition only valid for square matrices") unless self.dim == 2 && self.shape[0] == self.shape[1]
210
+
211
+ NMatrix::LAPACK::clapack_potrf(:row, which, self.shape[0], self, self.shape[1])
212
+ end
213
+ end