nmatrix-lapacke 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (185) hide show
  1. checksums.yaml +7 -0
  2. data/ext/nmatrix/data/complex.h +364 -0
  3. data/ext/nmatrix/data/data.h +638 -0
  4. data/ext/nmatrix/data/meta.h +64 -0
  5. data/ext/nmatrix/data/ruby_object.h +389 -0
  6. data/ext/nmatrix/math/asum.h +120 -0
  7. data/ext/nmatrix/math/cblas_enums.h +36 -0
  8. data/ext/nmatrix/math/cblas_templates_core.h +507 -0
  9. data/ext/nmatrix/math/gemm.h +241 -0
  10. data/ext/nmatrix/math/gemv.h +178 -0
  11. data/ext/nmatrix/math/getrf.h +255 -0
  12. data/ext/nmatrix/math/getrs.h +121 -0
  13. data/ext/nmatrix/math/imax.h +79 -0
  14. data/ext/nmatrix/math/laswp.h +165 -0
  15. data/ext/nmatrix/math/long_dtype.h +49 -0
  16. data/ext/nmatrix/math/math.h +744 -0
  17. data/ext/nmatrix/math/nrm2.h +160 -0
  18. data/ext/nmatrix/math/rot.h +117 -0
  19. data/ext/nmatrix/math/rotg.h +106 -0
  20. data/ext/nmatrix/math/scal.h +71 -0
  21. data/ext/nmatrix/math/trsm.h +332 -0
  22. data/ext/nmatrix/math/util.h +148 -0
  23. data/ext/nmatrix/nm_memory.h +60 -0
  24. data/ext/nmatrix/nmatrix.h +408 -0
  25. data/ext/nmatrix/ruby_constants.h +106 -0
  26. data/ext/nmatrix/storage/common.h +176 -0
  27. data/ext/nmatrix/storage/dense/dense.h +128 -0
  28. data/ext/nmatrix/storage/list/list.h +137 -0
  29. data/ext/nmatrix/storage/storage.h +98 -0
  30. data/ext/nmatrix/storage/yale/class.h +1139 -0
  31. data/ext/nmatrix/storage/yale/iterators/base.h +142 -0
  32. data/ext/nmatrix/storage/yale/iterators/iterator.h +130 -0
  33. data/ext/nmatrix/storage/yale/iterators/row.h +449 -0
  34. data/ext/nmatrix/storage/yale/iterators/row_stored.h +139 -0
  35. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +168 -0
  36. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +123 -0
  37. data/ext/nmatrix/storage/yale/math/transpose.h +110 -0
  38. data/ext/nmatrix/storage/yale/yale.h +202 -0
  39. data/ext/nmatrix/types.h +54 -0
  40. data/ext/nmatrix/util/io.h +115 -0
  41. data/ext/nmatrix/util/sl_list.h +143 -0
  42. data/ext/nmatrix/util/util.h +78 -0
  43. data/ext/nmatrix_lapacke/extconf.rb +200 -0
  44. data/ext/nmatrix_lapacke/lapacke.cpp +100 -0
  45. data/ext/nmatrix_lapacke/lapacke/include/lapacke.h +16445 -0
  46. data/ext/nmatrix_lapacke/lapacke/include/lapacke_config.h +119 -0
  47. data/ext/nmatrix_lapacke/lapacke/include/lapacke_mangling.h +17 -0
  48. data/ext/nmatrix_lapacke/lapacke/include/lapacke_mangling_with_flags.h +17 -0
  49. data/ext/nmatrix_lapacke/lapacke/include/lapacke_utils.h +579 -0
  50. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgeev.c +89 -0
  51. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgeev_work.c +141 -0
  52. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgesdd.c +106 -0
  53. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgesdd_work.c +158 -0
  54. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgesvd.c +94 -0
  55. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgesvd_work.c +149 -0
  56. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgetrf.c +51 -0
  57. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgetrf_work.c +83 -0
  58. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgetri.c +77 -0
  59. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgetri_work.c +89 -0
  60. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgetrs.c +56 -0
  61. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgetrs_work.c +102 -0
  62. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cpotrf.c +50 -0
  63. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cpotrf_work.c +82 -0
  64. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cpotri.c +50 -0
  65. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cpotri_work.c +82 -0
  66. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cpotrs.c +55 -0
  67. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cpotrs_work.c +101 -0
  68. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgeev.c +78 -0
  69. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgeev_work.c +136 -0
  70. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgesdd.c +88 -0
  71. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgesdd_work.c +153 -0
  72. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgesvd.c +83 -0
  73. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgesvd_work.c +144 -0
  74. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgetrf.c +50 -0
  75. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgetrf_work.c +81 -0
  76. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgetri.c +75 -0
  77. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgetri_work.c +87 -0
  78. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgetrs.c +55 -0
  79. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgetrs_work.c +99 -0
  80. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dpotrf.c +50 -0
  81. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dpotrf_work.c +81 -0
  82. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dpotri.c +50 -0
  83. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dpotri_work.c +81 -0
  84. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dpotrs.c +54 -0
  85. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dpotrs_work.c +97 -0
  86. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgeev.c +78 -0
  87. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgeev_work.c +134 -0
  88. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgesdd.c +88 -0
  89. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgesdd_work.c +152 -0
  90. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgesvd.c +83 -0
  91. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgesvd_work.c +143 -0
  92. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgetrf.c +50 -0
  93. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgetrf_work.c +81 -0
  94. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgetri.c +75 -0
  95. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgetri_work.c +87 -0
  96. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgetrs.c +55 -0
  97. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgetrs_work.c +99 -0
  98. data/ext/nmatrix_lapacke/lapacke/src/lapacke_spotrf.c +50 -0
  99. data/ext/nmatrix_lapacke/lapacke/src/lapacke_spotrf_work.c +81 -0
  100. data/ext/nmatrix_lapacke/lapacke/src/lapacke_spotri.c +50 -0
  101. data/ext/nmatrix_lapacke/lapacke/src/lapacke_spotri_work.c +81 -0
  102. data/ext/nmatrix_lapacke/lapacke/src/lapacke_spotrs.c +54 -0
  103. data/ext/nmatrix_lapacke/lapacke/src/lapacke_spotrs_work.c +97 -0
  104. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgeev.c +89 -0
  105. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgeev_work.c +141 -0
  106. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgesdd.c +106 -0
  107. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgesdd_work.c +158 -0
  108. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgesvd.c +94 -0
  109. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgesvd_work.c +149 -0
  110. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgetrf.c +51 -0
  111. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgetrf_work.c +83 -0
  112. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgetri.c +77 -0
  113. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgetri_work.c +89 -0
  114. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgetrs.c +56 -0
  115. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgetrs_work.c +102 -0
  116. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zpotrf.c +50 -0
  117. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zpotrf_work.c +82 -0
  118. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zpotri.c +50 -0
  119. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zpotri_work.c +82 -0
  120. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zpotrs.c +55 -0
  121. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zpotrs_work.c +101 -0
  122. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_cge_nancheck.c +62 -0
  123. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_cge_trans.c +65 -0
  124. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_cpo_nancheck.c +43 -0
  125. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_cpo_trans.c +45 -0
  126. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_ctr_nancheck.c +85 -0
  127. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_ctr_trans.c +85 -0
  128. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_dge_nancheck.c +62 -0
  129. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_dge_trans.c +65 -0
  130. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_dpo_nancheck.c +43 -0
  131. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_dpo_trans.c +45 -0
  132. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_dtr_nancheck.c +85 -0
  133. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_dtr_trans.c +85 -0
  134. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_lsame.c +41 -0
  135. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_sge_nancheck.c +62 -0
  136. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_sge_trans.c +65 -0
  137. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_spo_nancheck.c +43 -0
  138. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_spo_trans.c +45 -0
  139. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_str_nancheck.c +85 -0
  140. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_str_trans.c +85 -0
  141. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_xerbla.c +46 -0
  142. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_zge_nancheck.c +62 -0
  143. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_zge_trans.c +65 -0
  144. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_zpo_nancheck.c +43 -0
  145. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_zpo_trans.c +45 -0
  146. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_ztr_nancheck.c +85 -0
  147. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_ztr_trans.c +85 -0
  148. data/ext/nmatrix_lapacke/lapacke_nmatrix.h +16 -0
  149. data/ext/nmatrix_lapacke/make_lapacke_cpp.rb +9 -0
  150. data/ext/nmatrix_lapacke/math_lapacke.cpp +967 -0
  151. data/ext/nmatrix_lapacke/math_lapacke/cblas_local.h +576 -0
  152. data/ext/nmatrix_lapacke/math_lapacke/cblas_templates_lapacke.h +51 -0
  153. data/ext/nmatrix_lapacke/math_lapacke/lapacke_templates.h +356 -0
  154. data/ext/nmatrix_lapacke/nmatrix_lapacke.cpp +42 -0
  155. data/lib/nmatrix/lapack_ext_common.rb +69 -0
  156. data/lib/nmatrix/lapacke.rb +213 -0
  157. data/spec/00_nmatrix_spec.rb +730 -0
  158. data/spec/01_enum_spec.rb +190 -0
  159. data/spec/02_slice_spec.rb +389 -0
  160. data/spec/03_nmatrix_monkeys_spec.rb +78 -0
  161. data/spec/2x2_dense_double.mat +0 -0
  162. data/spec/4x4_sparse.mat +0 -0
  163. data/spec/4x5_dense.mat +0 -0
  164. data/spec/blas_spec.rb +193 -0
  165. data/spec/elementwise_spec.rb +303 -0
  166. data/spec/homogeneous_spec.rb +99 -0
  167. data/spec/io/fortran_format_spec.rb +88 -0
  168. data/spec/io/harwell_boeing_spec.rb +98 -0
  169. data/spec/io/test.rua +9 -0
  170. data/spec/io_spec.rb +149 -0
  171. data/spec/lapack_core_spec.rb +482 -0
  172. data/spec/leakcheck.rb +16 -0
  173. data/spec/math_spec.rb +730 -0
  174. data/spec/nmatrix_yale_resize_test_associations.yaml +2802 -0
  175. data/spec/nmatrix_yale_spec.rb +286 -0
  176. data/spec/plugins/lapacke/lapacke_spec.rb +303 -0
  177. data/spec/rspec_monkeys.rb +56 -0
  178. data/spec/rspec_spec.rb +34 -0
  179. data/spec/shortcuts_spec.rb +310 -0
  180. data/spec/slice_set_spec.rb +157 -0
  181. data/spec/spec_helper.rb +140 -0
  182. data/spec/stat_spec.rb +203 -0
  183. data/spec/test.pcd +20 -0
  184. data/spec/utm5940.mtx +83844 -0
  185. metadata +262 -0
@@ -0,0 +1,69 @@
1
+ #--
2
+ # = NMatrix
3
+ #
4
+ # A linear algebra library for scientific computation in Ruby.
5
+ # NMatrix is part of SciRuby.
6
+ #
7
+ # NMatrix was originally inspired by and derived from NArray, by
8
+ # Masahiro Tanaka: http://narray.rubyforge.org
9
+ #
10
+ # == Copyright Information
11
+ #
12
+ # SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ # NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ #
15
+ # Please see LICENSE.txt for additional copyright notices.
16
+ #
17
+ # == Contributing
18
+ #
19
+ # By contributing source code to SciRuby, you agree to be bound by
20
+ # our Contributor Agreement:
21
+ #
22
+ # * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ #
24
+ # == lapack_ext_common.rb
25
+ #
26
+ # Contains functions shared by nmatrix-atlas and nmatrix-lapacke gems.
27
+ #++
28
+
29
+ class NMatrix
30
+ def NMatrix.register_lapack_extension(name)
31
+ if (defined? @@lapack_extension)
32
+ raise "Attempting to load #{name} when #{@@lapack_extension} is already loaded. You can only load one LAPACK extension."
33
+ end
34
+
35
+ @@lapack_extension = name
36
+ end
37
+
38
+ alias_method :internal_dot, :dot
39
+
40
+ def dot(right_v)
41
+ if (right_v.is_a?(NMatrix) && self.stype == :dense && right_v.stype == :dense &&
42
+ self.dim == 2 && right_v.dim == 2 && self.shape[1] == right_v.shape[0])
43
+
44
+ result_dtype = NMatrix.upcast(self.dtype,right_v.dtype)
45
+ left = self.dtype == result_dtype ? self : self.cast(dtype: result_dtype)
46
+ right = right_v.dtype == result_dtype ? right_v : right_v.cast(dtype: result_dtype)
47
+
48
+ left = left.clone if left.is_ref?
49
+ right = right.clone if right.is_ref?
50
+
51
+ result_m = left.shape[0]
52
+ result_n = right.shape[1]
53
+ left_n = left.shape[1]
54
+ vector = result_n == 1
55
+ result = NMatrix.new([result_m,result_n], dtype: result_dtype)
56
+
57
+ if vector
58
+ NMatrix::BLAS.cblas_gemv(false, result_m, left_n, 1, left, left_n, right, 1, 0, result, 1)
59
+ else
60
+ NMatrix::BLAS.cblas_gemm(:row, false, false, result_m, result_n, left_n, 1, left, left_n, right, result_n, 0, result, result_n)
61
+ end
62
+ return result
63
+ else
64
+ #internal_dot will handle non-dense matrices (and also dot-products for NMatrix's with dim=1),
65
+ #and also all error-handling if the input is not valid
66
+ self.internal_dot(right_v)
67
+ end
68
+ end
69
+ end
@@ -0,0 +1,213 @@
1
+ #--
2
+ # = NMatrix
3
+ #
4
+ # A linear algebra library for scientific computation in Ruby.
5
+ # NMatrix is part of SciRuby.
6
+ #
7
+ # NMatrix was originally inspired by and derived from NArray, by
8
+ # Masahiro Tanaka: http://narray.rubyforge.org
9
+ #
10
+ # == Copyright Information
11
+ #
12
+ # SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ # NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ #
15
+ # Please see LICENSE.txt for additional copyright notices.
16
+ #
17
+ # == Contributing
18
+ #
19
+ # By contributing source code to SciRuby, you agree to be bound by
20
+ # our Contributor Agreement:
21
+ #
22
+ # * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ #
24
+ # == lapacke.rb
25
+ #
26
+ # ruby file for the nmatrix-lapacke gem. Loads the C extension and defines
27
+ # nice ruby interfaces for LAPACK functions.
28
+ #++
29
+
30
+ require 'nmatrix/nmatrix.rb' #need to have nmatrix required first or else bad things will happen
31
+ require_relative 'lapack_ext_common'
32
+
33
+ NMatrix.register_lapack_extension("nmatrix-lapacke")
34
+
35
+ require "nmatrix_lapacke.so"
36
+
37
+ class NMatrix
38
+ #Add functions from the LAPACKE C extension to the main LAPACK and BLAS modules.
39
+ #This will overwrite the original functions where applicable.
40
+ module LAPACK
41
+ class << self
42
+ NMatrix::LAPACKE::LAPACK.singleton_methods.each do |m|
43
+ define_method m, NMatrix::LAPACKE::LAPACK.method(m).to_proc
44
+ end
45
+ end
46
+ end
47
+
48
+ module BLAS
49
+ class << self
50
+ NMatrix::LAPACKE::BLAS.singleton_methods.each do |m|
51
+ define_method m, NMatrix::LAPACKE::BLAS.method(m).to_proc
52
+ end
53
+ end
54
+ end
55
+
56
+ module LAPACK
57
+ class << self
58
+ def posv(uplo, a, b)
59
+ raise(ShapeError, "a must be square") unless a.dim == 2 && a.shape[0] == a.shape[1]
60
+ raise(ShapeError, "number of rows of b must equal number of cols of a") unless a.shape[1] == b.shape[0]
61
+ raise(StorageTypeError, "only works with dense matrices") unless a.stype == :dense && b.stype == :dense
62
+ raise(DataTypeError, "only works for non-integer, non-object dtypes") if
63
+ a.integer_dtype? || a.object_dtype? || b.integer_dtype? || b.object_dtype?
64
+
65
+ x = b.clone
66
+ clone = a.clone
67
+ n = a.shape[0]
68
+ nrhs = b.shape[1]
69
+ lapacke_potrf(:row, uplo, n, clone, n)
70
+ lapacke_potrs(:row, uplo, n, nrhs, clone, n, x, b.shape[1])
71
+ x
72
+ end
73
+
74
+ def geev(matrix, which=:both)
75
+ raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless matrix.dense?
76
+ raise(ShapeError, "eigenvalues can only be computed for square matrices") unless matrix.dim == 2 && matrix.shape[0] == matrix.shape[1]
77
+
78
+ jobvl = (which == :both || which == :left) ? :t : false
79
+ jobvr = (which == :both || which == :right) ? :t : false
80
+
81
+ # Copy the matrix so it doesn't get overwritten.
82
+ temporary_matrix = matrix.clone
83
+ n = matrix.shape[0]
84
+
85
+ # Outputs
86
+ eigenvalues = NMatrix.new([n, 1], dtype: matrix.dtype) # For real dtypes this holds only the real part of the eigenvalues.
87
+ imag_eigenvalues = matrix.complex_dtype? ? nil : NMatrix.new([n, 1], dtype: matrix.dtype) # For complex dtypes, this is unused.
88
+ left_output = jobvl ? matrix.clone_structure : nil
89
+ right_output = jobvr ? matrix.clone_structure : nil
90
+
91
+ NMatrix::LAPACK::lapacke_geev(:row,
92
+ jobvl, # compute left eigenvectors of A?
93
+ jobvr, # compute right eigenvectors of A? (left eigenvectors of A**T)
94
+ n, # order of the matrix
95
+ temporary_matrix,# input matrix (used as work)
96
+ n, # leading dimension of matrix
97
+ eigenvalues,# real part of computed eigenvalues
98
+ imag_eigenvalues,# imag part of computed eigenvalues
99
+ left_output, # left eigenvectors, if applicable
100
+ n, # leading dimension of left_output
101
+ right_output, # right eigenvectors, if applicable
102
+ n) # leading dimension of right_output
103
+
104
+
105
+ # For real dtypes, transform left_output and right_output into correct forms.
106
+ # If the j'th and the (j+1)'th eigenvalues form a complex conjugate
107
+ # pair, then the j'th and (j+1)'th columns of the matrix are
108
+ # the real and imag parts of the eigenvector corresponding
109
+ # to the j'th eigenvalue.
110
+ if !matrix.complex_dtype?
111
+ complex_indices = []
112
+ n.times do |i|
113
+ complex_indices << i if imag_eigenvalues[i] != 0.0
114
+ end
115
+
116
+ if !complex_indices.empty?
117
+ # For real dtypes, put the real and imaginary parts together
118
+ eigenvalues = eigenvalues + imag_eigenvalues*Complex(0.0,1.0)
119
+ left_output = left_output.cast(dtype: NMatrix.upcast(:complex64, matrix.dtype)) if left_output
120
+ right_output = right_output.cast(dtype: NMatrix.upcast(:complex64, matrix.dtype)) if right_output
121
+ end
122
+
123
+ complex_indices.each_slice(2) do |i, _|
124
+ if right_output
125
+ right_output[0...n,i] = right_output[0...n,i] + right_output[0...n,i+1]*Complex(0.0,1.0)
126
+ right_output[0...n,i+1] = right_output[0...n,i].complex_conjugate
127
+ end
128
+
129
+ if left_output
130
+ left_output[0...n,i] = left_output[0...n,i] + left_output[0...n,i+1]*Complex(0.0,1.0)
131
+ left_output[0...n,i+1] = left_output[0...n,i].complex_conjugate
132
+ end
133
+ end
134
+ end
135
+
136
+ if which == :both
137
+ return [eigenvalues, left_output, right_output]
138
+ elsif which == :left
139
+ return [eigenvalues, left_output]
140
+ else
141
+ return [eigenvalues, right_output]
142
+ end
143
+ end
144
+
145
+ def gesvd(matrix, workspace_size=1)
146
+ result = alloc_svd_result(matrix)
147
+
148
+ m = matrix.shape[0]
149
+ n = matrix.shape[1]
150
+
151
+ superb = NMatrix.new([[m,n].min], dtype: matrix.abs_dtype)
152
+
153
+ NMatrix::LAPACK::lapacke_gesvd(:row, :a, :a, m, n, matrix, n, result[1], result[0], m, result[2], n, superb)
154
+ result
155
+ end
156
+
157
+ def gesdd(matrix, workspace_size=nil)
158
+ result = alloc_svd_result(matrix)
159
+
160
+ m = matrix.shape[0]
161
+ n = matrix.shape[1]
162
+
163
+ NMatrix::LAPACK::lapacke_gesdd(:row, :a, m, n, matrix, n, result[1], result[0], m, result[2], n)
164
+ result
165
+ end
166
+ end
167
+ end
168
+
169
+ def getrf!
170
+ raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense?
171
+
172
+ ipiv = NMatrix::LAPACK::lapacke_getrf(:row, self.shape[0], self.shape[1], self, self.shape[1])
173
+
174
+ return ipiv
175
+ end
176
+
177
+ def invert!
178
+ raise(StorageTypeError, "invert only works on dense matrices currently") unless self.dense?
179
+ raise(ShapeError, "Cannot invert non-square matrix") unless shape[0] == shape[1]
180
+ raise(DataTypeError, "Cannot invert an integer matrix in-place") if self.integer_dtype?
181
+
182
+ # Get the pivot array; factor the matrix
183
+ n = self.shape[0]
184
+ pivot = NMatrix::LAPACK::lapacke_getrf(:row, n, n, self, n)
185
+ # Now calculate the inverse using the pivot array
186
+ NMatrix::LAPACK::lapacke_getri(:row, n, self, n, pivot)
187
+
188
+ self
189
+ end
190
+
191
+ def potrf!(which)
192
+ raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense?
193
+ raise(ShapeError, "Cholesky decomposition only valid for square matrices") unless self.dim == 2 && self.shape[0] == self.shape[1]
194
+
195
+ NMatrix::LAPACK::lapacke_potrf(:row, which, self.shape[0], self, self.shape[1])
196
+ end
197
+
198
+ def solve b
199
+ raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
200
+ raise(ShapeError, "number of rows of b must equal number of cols of self") if
201
+ self.shape[1] != b.shape[0]
202
+ raise ArgumentError, "only works with dense matrices" if self.stype != :dense
203
+ raise ArgumentError, "only works for non-integer, non-object dtypes" if
204
+ integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
205
+
206
+ x = b.clone
207
+ clone = self.clone
208
+ n = self.shape[0]
209
+ ipiv = NMatrix::LAPACK.lapacke_getrf(:row, n, n, clone, n)
210
+ NMatrix::LAPACK.lapacke_getrs(:row, :no_transpose, n, b.shape[1], clone, n, ipiv, x, b.shape[1])
211
+ x
212
+ end
213
+ end
@@ -0,0 +1,730 @@
1
+ # = NMatrix
2
+ #
3
+ # A linear algebra library for scientific computation in Ruby.
4
+ # NMatrix is part of SciRuby.
5
+ #
6
+ # NMatrix was originally inspired by and derived from NArray, by
7
+ # Masahiro Tanaka: http://narray.rubyforge.org
8
+ #
9
+ # == Copyright Information
10
+ #
11
+ # SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
12
+ # NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
13
+ #
14
+ # Please see LICENSE.txt for additional copyright notices.
15
+ #
16
+ # == Contributing
17
+ #
18
+ # By contributing source code to SciRuby, you agree to be bound by
19
+ # our Contributor Agreement:
20
+ #
21
+ # * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
22
+ #
23
+ # == 00_nmatrix_spec.rb
24
+ #
25
+ # Basic tests for NMatrix. These should load first, as they're
26
+ # essential to NMatrix operation.
27
+ #
28
+ require 'spec_helper'
29
+
30
+ describe NMatrix do
31
+ it "creates a matrix with the new constructor" do
32
+ n = NMatrix.new([2,2], [0,1,2,3], dtype: :int64)
33
+ expect(n.shape).to eq([2,2])
34
+ expect(n.entries).to eq([0,1,2,3])
35
+ expect(n.dtype).to eq(:int64)
36
+ end
37
+
38
+ it "adequately requires information to access a single entry of a dense matrix" do
39
+ n = NMatrix.new(:dense, 4, [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], :float64)
40
+ expect(n[0,0]).to eq(0)
41
+ expect { n[0] }.to raise_error(ArgumentError)
42
+ end
43
+
44
+ it "calculates exact determinants on small square matrices" do
45
+ expect(NMatrix.new(2, [1,2,3,4], stype: :dense, dtype: :int64).det_exact).to eq(-2)
46
+ end
47
+
48
+ it "calculates determinants" do
49
+ expect(NMatrix.new(3, [-2,2,3,-1,1,3,2,0,-1], stype: :dense, dtype: :int64).det).to eq(6)
50
+ end
51
+
52
+ it "allows casting to Ruby objects" do
53
+ m = NMatrix.new([3,3], [0,0,1,0,2,0,3,4,5], dtype: :int64, stype: :dense)
54
+ n = m.cast(:dense, :object)
55
+ expect(n).to eq(m)
56
+ end
57
+
58
+ it "allows casting from Ruby objects" do
59
+ m = NMatrix.new(:dense, [3,3], [0,0,1,0,2,0,3,4,5], :object)
60
+ n = m.cast(:dense, :int64)
61
+ expect(m).to eq(n)
62
+ end
63
+
64
+ it "allows stype casting of a dim 2 matrix between dense, sparse, and list (different dtypes)" do
65
+ m = NMatrix.new(:dense, [3,3], [0,0,1,0,2,0,3,4,5], :int64).
66
+ cast(:yale, :int32).
67
+ cast(:dense, :float64).
68
+ cast(:list, :object).
69
+ cast(:dense, :int16).
70
+ cast(:list, :int32).
71
+ cast(:yale, :int64) #.
72
+ #cast(:list, :int32).
73
+ #cast(:dense, :int16)
74
+ #m.should.equal?(original)
75
+ # For some reason this causes some weird garbage collector problems when we uncomment these. The above lines won't
76
+ # work at all in IRB, but work fine when run in a regular Ruby session.
77
+ end
78
+
79
+ it "fills dense Ruby object matrix with nil" do
80
+ n = NMatrix.new([4,3], dtype: :object)
81
+ expect(n[0,0]).to eq(nil)
82
+ end
83
+
84
+ it "fills dense with individual assignments" do
85
+ n = NMatrix.new([4,3], dtype: :float64)
86
+ n[0,0] = 14.0
87
+ n[0,1] = 9.0
88
+ n[0,2] = 3.0
89
+ n[1,0] = 2.0
90
+ n[1,1] = 11.0
91
+ n[1,2] = 15.0
92
+ n[2,0] = 0.0
93
+ n[2,1] = 12.0
94
+ n[2,2] = 17.0
95
+ n[3,0] = 5.0
96
+ n[3,1] = 2.0
97
+ n[3,2] = 3.0
98
+
99
+ expect(n[0,0]).to eq(14.0)
100
+ expect(n[0,1]).to eq(9.0)
101
+ expect(n[0,2]).to eq(3.0)
102
+ expect(n[1,0]).to eq(2.0)
103
+ expect(n[1,1]).to eq(11.0)
104
+ expect(n[1,2]).to eq(15.0)
105
+ expect(n[2,0]).to eq(0.0)
106
+ expect(n[2,1]).to eq(12.0)
107
+ expect(n[2,2]).to eq(17.0)
108
+ expect(n[3,0]).to eq(5.0)
109
+ expect(n[3,1]).to eq(2.0)
110
+ expect(n[3,2]).to eq(3.0)
111
+ end
112
+
113
+ it "fills dense with a single mass assignment" do
114
+ n = NMatrix.new([4,3], [14.0, 9.0, 3.0, 2.0, 11.0, 15.0, 0.0, 12.0, 17.0, 5.0, 2.0, 3.0])
115
+
116
+ expect(n[0,0]).to eq(14.0)
117
+ expect(n[0,1]).to eq(9.0)
118
+ expect(n[0,2]).to eq(3.0)
119
+ expect(n[1,0]).to eq(2.0)
120
+ expect(n[1,1]).to eq(11.0)
121
+ expect(n[1,2]).to eq(15.0)
122
+ expect(n[2,0]).to eq(0.0)
123
+ expect(n[2,1]).to eq(12.0)
124
+ expect(n[2,2]).to eq(17.0)
125
+ expect(n[3,0]).to eq(5.0)
126
+ expect(n[3,1]).to eq(2.0)
127
+ expect(n[3,2]).to eq(3.0)
128
+ end
129
+
130
+ it "fills dense with a single mass assignment, with dtype specified" do
131
+ m = NMatrix.new([4,3], [14.0, 9.0, 3.0, 2.0, 11.0, 15.0, 0.0, 12.0, 17.0, 5.0, 2.0, 3.0], dtype: :float32)
132
+
133
+ expect(m[0,0]).to eq(14.0)
134
+ expect(m[0,1]).to eq(9.0)
135
+ expect(m[0,2]).to eq(3.0)
136
+ expect(m[1,0]).to eq(2.0)
137
+ expect(m[1,1]).to eq(11.0)
138
+ expect(m[1,2]).to eq(15.0)
139
+ expect(m[2,0]).to eq(0.0)
140
+ expect(m[2,1]).to eq(12.0)
141
+ expect(m[2,2]).to eq(17.0)
142
+ expect(m[3,0]).to eq(5.0)
143
+ expect(m[3,1]).to eq(2.0)
144
+ expect(m[3,2]).to eq(3.0)
145
+ end
146
+
147
+ it "dense handles missing initialization value" do
148
+ n = NMatrix.new(3, dtype: :int8)
149
+ expect(n.stype).to eq(:dense)
150
+ expect(n.dtype).to eq(:int8)
151
+
152
+ m = NMatrix.new(4, dtype: :float64)
153
+ expect(m.stype).to eq(:dense)
154
+ expect(m.dtype).to eq(:float64)
155
+ end
156
+
157
+ [:dense, :list, :yale].each do |storage_type|
158
+ context storage_type do
159
+ it "can be duplicated" do
160
+ n = NMatrix.new([2,3], 1.1, stype: storage_type, dtype: :float64)
161
+ expect(n.stype).to eq(storage_type)
162
+
163
+ n[0,0] = 0.0
164
+ n[0,1] = 0.1
165
+ n[1,0] = 1.0
166
+
167
+ m = n.dup
168
+ expect(m.shape).to eq(n.shape)
169
+ expect(m.dim).to eq(n.dim)
170
+ expect(m.object_id).not_to eq(n.object_id)
171
+ expect(m.stype).to eq(storage_type)
172
+ expect(m[0,0]).to eq(n[0,0])
173
+ m[0,0] = 3.0
174
+ expect(m[0,0]).not_to eq(n[0,0])
175
+ end
176
+
177
+ it "enforces shape boundaries" do
178
+ expect { NMatrix.new([1,10], 0, dtype: :int8, stype: storage_type, default: 0)[1,0] }.to raise_error(RangeError)
179
+ expect { NMatrix.new([1,10], 0, dtype: :int8, stype: storage_type, default: 0)[0,10] }.to raise_error(RangeError)
180
+ end
181
+
182
+ it "sets and gets" do
183
+ n = NMatrix.new(2, 0, stype: storage_type, dtype: :int8)
184
+ n[0,1] = 1
185
+ expect(n[0,0]).to eq(0)
186
+ expect(n[1,0]).to eq(0)
187
+ expect(n[0,1]).to eq(1)
188
+ expect(n[1,1]).to eq(0)
189
+ end
190
+
191
+ it "sets and gets references" do
192
+ n = NMatrix.new(2, stype: storage_type, dtype: :int8, default: 0)
193
+ expect(n[0,1] = 1).to eq(1)
194
+ expect(n[0,1]).to eq(1)
195
+ end
196
+
197
+ # Tests Ruby object versus any C dtype (in this case we use :int64)
198
+ [:object, :int64].each do |dtype|
199
+ c = dtype == :object ? "Ruby object" : "non-Ruby object"
200
+ context c do
201
+ it "allows iteration of matrices" do
202
+ n = nil
203
+ if storage_type == :dense
204
+ n = NMatrix.new(:dense, [3,3], [1,2,3,4,5,6,7,8,9], dtype)
205
+ else
206
+ n = NMatrix.new([3,4], 0, stype: storage_type, dtype: dtype)
207
+ n[0,0] = 1
208
+ n[0,1] = 2
209
+ n[2,3] = 4
210
+ n[2,0] = 3
211
+ end
212
+
213
+ ary = []
214
+ n.each do |x|
215
+ ary << x
216
+ end
217
+
218
+ if storage_type == :dense
219
+ expect(ary).to eq([1,2,3,4,5,6,7,8,9])
220
+ else
221
+ expect(ary).to eq([1,2,0,0,0,0,0,0,3,0,0,4])
222
+ end
223
+ end
224
+
225
+ it "allows storage-based iteration of matrices" do
226
+ STDERR.puts storage_type.inspect
227
+ STDERR.puts dtype.inspect
228
+ n = NMatrix.new([3,3], 0, stype: storage_type, dtype: dtype)
229
+ n[0,0] = 1
230
+ n[0,1] = 2
231
+ n[2,0] = 5 if storage_type == :yale
232
+ n[2,1] = 4
233
+ n[2,2] = 3
234
+
235
+ values = []
236
+ is = []
237
+ js = []
238
+
239
+ n.each_stored_with_indices do |v,i,j|
240
+ values << v
241
+ is << i
242
+ js << j
243
+ end
244
+
245
+ if storage_type == :yale
246
+ expect(is).to eq([0,1,2,0,2,2])
247
+ expect(js).to eq([0,1,2,1,0,1])
248
+ expect(values).to eq([1,0,3,2,5,4])
249
+ elsif storage_type == :list
250
+ expect(values).to eq([1,2,4,3])
251
+ expect(is).to eq([0,0,2,2])
252
+ expect(js).to eq([0,1,1,2])
253
+ elsif storage_type == :dense
254
+ expect(values).to eq([1,2,0,0,0,0,0,4,3])
255
+ expect(is).to eq([0,0,0,1,1,1,2,2,2])
256
+ expect(js).to eq([0,1,2,0,1,2,0,1,2])
257
+ end
258
+ end
259
+ end
260
+ end
261
+ end
262
+
263
+ # dense and list, not yale
264
+ context "(storage: #{storage_type})" do
265
+ it "gets default value" do
266
+ expect(NMatrix.new(3, 0, stype: storage_type)[1,1]).to eq(0)
267
+ expect(NMatrix.new(3, 0.1, stype: storage_type)[1,1]).to eq(0.1)
268
+ expect(NMatrix.new(3, 1, stype: storage_type)[1,1]).to eq(1)
269
+
270
+ end
271
+ it "returns shape and dim" do
272
+ expect(NMatrix.new([3,2,8], 0, stype: storage_type).shape).to eq([3,2,8])
273
+ expect(NMatrix.new([3,2,8], 0, stype: storage_type).dim).to eq(3)
274
+ end
275
+
276
+ it "returns number of rows and columns" do
277
+ expect(NMatrix.new([7, 4], 3, stype: storage_type).rows).to eq(7)
278
+ expect(NMatrix.new([7, 4], 3, stype: storage_type).cols).to eq(4)
279
+ end
280
+ end unless storage_type == :yale
281
+ end
282
+
283
+
284
+ it "handles dense construction" do
285
+ expect(NMatrix.new(3,0)[1,1]).to eq(0)
286
+ expect(lambda { NMatrix.new(3,dtype: :int8)[1,1] }).to_not raise_error
287
+ end
288
+
289
+ it "converts from list to yale properly" do
290
+ m = NMatrix.new(3, 0, stype: :list)
291
+ m[0,2] = 333
292
+ m[2,2] = 777
293
+ n = m.cast(:yale, :int32)
294
+ #puts n.capacity
295
+ #n.extend NMatrix::YaleFunctions
296
+ #puts n.yale_ija.inspect
297
+ #puts n.yale_a.inspect
298
+
299
+ expect(n[0,0]).to eq(0)
300
+ expect(n[0,1]).to eq(0)
301
+ expect(n[0,2]).to eq(333)
302
+ expect(n[1,0]).to eq(0)
303
+ expect(n[1,1]).to eq(0)
304
+ expect(n[1,2]).to eq(0)
305
+ expect(n[2,0]).to eq(0)
306
+ expect(n[2,1]).to eq(0)
307
+ expect(n[2,2]).to eq(777)
308
+ end
309
+
310
+ it "should return an enumerator when each is called without a block" do
311
+ a = NMatrix.new(2, 1)
312
+ b = NMatrix.new(2, [-1,0,1,0])
313
+ enums = [a.each, b.each]
314
+
315
+ begin
316
+ atans = []
317
+ atans << Math.atan2(*enums.map(&:next)) while true
318
+ rescue StopIteration
319
+ end
320
+ end
321
+
322
+ context "dense" do
323
+ it "should return the matrix being iterated over when each is called with a block" do
324
+ a = NMatrix.new(2, 1)
325
+ val = (a.each { })
326
+ expect(val).to eq(a)
327
+ end
328
+
329
+ it "should return the matrix being iterated over when each_stored_with_indices is called with a block" do
330
+ a = NMatrix.new(2,1)
331
+ val = (a.each_stored_with_indices { })
332
+ expect(val).to eq(a)
333
+ end
334
+ end
335
+
336
+ [:list, :yale].each do |storage_type|
337
+ context storage_type do
338
+ it "should return the matrix being iterated over when each_stored_with_indices is called with a block" do
339
+ n = NMatrix.new([2,3], 1.1, stype: storage_type, dtype: :float64, default: 0)
340
+ val = (n.each_stored_with_indices { })
341
+ expect(val).to eq(n)
342
+ end
343
+
344
+ it "should return an enumerator when each_stored_with_indices is called without a block" do
345
+ n = NMatrix.new([2,3], 1.1, stype: storage_type, dtype: :float64, default: 0)
346
+ val = n.each_stored_with_indices
347
+ expect(val).to be_a Enumerator
348
+ end
349
+ end
350
+ end
351
+
352
+ it "should iterate through element 256 without a segfault" do
353
+ t = NVector.random(256)
354
+ t.each { |x| x + 0 }
355
+ end
356
+ end
357
+
358
+
359
+ describe 'NMatrix' do
360
+ context "#upper_triangle" do
361
+ it "should create a copy with the lower corner set to zero" do
362
+ n = NMatrix.seq(4)+1
363
+ expect(n.upper_triangle).to eq(NMatrix.new(4, [1,2,3,4,0,6,7,8,0,0,11,12,0,0,0,16]))
364
+ expect(n.upper_triangle(2)).to eq(NMatrix.new(4, [1,2,3,4,5,6,7,8,9,10,11,12,0,14,15,16]))
365
+ end
366
+ end
367
+
368
+ context "#lower_triangle" do
369
+ it "should create a copy with the lower corner set to zero" do
370
+ n = NMatrix.seq(4)+1
371
+ expect(n.lower_triangle).to eq(NMatrix.new(4, [1,0,0,0,5,6,0,0,9,10,11,0,13,14,15,16]))
372
+ expect(n.lower_triangle(2)).to eq(NMatrix.new(4, [1,2,3,0,5,6,7,8,9,10,11,12,13,14,15,16]))
373
+ end
374
+ end
375
+
376
+ context "#upper_triangle!" do
377
+ it "should create a copy with the lower corner set to zero" do
378
+ n = NMatrix.seq(4)+1
379
+ expect(n.upper_triangle!).to eq(NMatrix.new(4, [1,2,3,4,0,6,7,8,0,0,11,12,0,0,0,16]))
380
+ n = NMatrix.seq(4)+1
381
+ expect(n.upper_triangle!(2)).to eq(NMatrix.new(4, [1,2,3,4,5,6,7,8,9,10,11,12,0,14,15,16]))
382
+ end
383
+ end
384
+
385
+ context "#lower_triangle!" do
386
+ it "should create a copy with the lower corner set to zero" do
387
+ n = NMatrix.seq(4)+1
388
+ expect(n.lower_triangle!).to eq(NMatrix.new(4, [1,0,0,0,5,6,0,0,9,10,11,0,13,14,15,16]))
389
+ n = NMatrix.seq(4)+1
390
+ expect(n.lower_triangle!(2)).to eq(NMatrix.new(4, [1,2,3,0,5,6,7,8,9,10,11,12,13,14,15,16]))
391
+ end
392
+ end
393
+
394
+ context "#rank" do
395
+ it "should get the rank of a 2-dimensional matrix" do
396
+ n = NMatrix.seq([2,3])
397
+ expect(n.rank(0, 0)).to eq(N[[0,1,2]])
398
+ end
399
+
400
+ it "should raise an error when the rank is out of bounds" do
401
+ n = NMatrix.seq([2,3])
402
+ expect { n.rank(2, 0) }.to raise_error(RangeError)
403
+ end
404
+ end
405
+
406
+ context "#reshape" do
407
+ it "should change the shape of a matrix without the contents changing" do
408
+ n = NMatrix.seq(4)+1
409
+ expect(n.reshape([8,2]).to_flat_array).to eq(n.to_flat_array)
410
+ end
411
+
412
+ it "should permit a change of dimensionality" do
413
+ n = NMatrix.seq(4)+1
414
+ expect(n.reshape([8,1,2]).to_flat_array).to eq(n.to_flat_array)
415
+ end
416
+
417
+ it "should prevent a resize" do
418
+ n = NMatrix.seq(4)+1
419
+ expect { n.reshape([5,2]) }.to raise_error(ArgumentError)
420
+ end
421
+
422
+ it "should do the reshape operation in place" do
423
+ n = NMatrix.seq(4)+1
424
+ expect(n.reshape!([8,2]).eql?(n)).to eq(true) # because n itself changes
425
+ end
426
+
427
+ it "reshape and reshape! must produce same result" do
428
+ n = NMatrix.seq(4)+1
429
+ a = NMatrix.seq(4)+1
430
+ expect(n.reshape!([8,2])==a.reshape(8,2)).to eq(true) # because n itself changes
431
+ end
432
+
433
+ it "should prevent a resize in place" do
434
+ n = NMatrix.seq(4)+1
435
+ expect { n.reshape([5,2]) }.to raise_error(ArgumentError)
436
+ end
437
+ end
438
+
439
+ context "#transpose" do
440
+ [:dense, :list, :yale].each do |stype|
441
+ context(stype) do
442
+ it "should transpose a #{stype} matrix (2-dimensional)" do
443
+ n = NMatrix.seq(4, stype: stype)
444
+ expect(n.transpose.to_a.flatten).to eq([0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15])
445
+ end
446
+ end
447
+ end
448
+
449
+ [:dense, :list].each do |stype|
450
+ context(stype) do
451
+ it "should transpose a #{stype} matrix (3-dimensional)" do
452
+ n = NMatrix.new([4,4,1], [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], stype: stype)
453
+ expect(n.transpose([2,1,0]).to_flat_array).to eq([0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15])
454
+ expect(n.transpose([1,0,2]).to_flat_array).to eq([0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15])
455
+ expect(n.transpose([0,2,1]).to_flat_array).to eq(n.to_flat_array) # for dense, make this reshape!
456
+ end
457
+ end
458
+
459
+ it "should just copy a 1-dimensional #{stype} matrix" do
460
+ n = NMatrix.new([3], [1,2,3], stype: stype)
461
+ expect(n.transpose).to eq n
462
+ expect(n.transpose).not_to be n
463
+ end
464
+ end
465
+ end
466
+
467
+ context "#dot_product" do
468
+ [:dense].each do |stype| # list storage transpose not yet implemented
469
+ context(stype) do # yale support only 2-dim matrix
470
+ it "should work like vector product on a #{stype} (1-dimensional)" do
471
+ m = NMatrix.new([3], [1,2,3], stype: stype)
472
+ expect(m.dot(m)).to eq (NMatrix.new([1],[14]))
473
+ end
474
+ end
475
+ end
476
+ end
477
+
478
+ context "#==" do
479
+ [:dense, :list, :yale].each do |left|
480
+ [:dense, :list, :yale].each do |right|
481
+ context ("#{left}?#{right}") do
482
+ it "tests equality of two equal matrices" do
483
+ n = NMatrix.new([3,4], [0,0,1,2,0,0,3,4,0,0,0,0], stype: left)
484
+ m = NMatrix.new([3,4], [0,0,1,2,0,0,3,4,0,0,0,0], stype: right)
485
+
486
+ expect(n==m).to eq(true)
487
+ end
488
+
489
+ it "tests equality of two unequal matrices" do
490
+ n = NMatrix.new([3,4], [0,0,1,2,0,0,3,4,0,0,0,1], stype: left)
491
+ m = NMatrix.new([3,4], [0,0,1,2,0,0,3,4,0,0,0,0], stype: right)
492
+
493
+ expect(n==m).to eq(false)
494
+ end
495
+
496
+ it "tests equality of matrices with different shapes" do
497
+ n = NMatrix.new([2,2], [1,2, 3,4], stype: left)
498
+ m = NMatrix.new([2,3], [1,2, 3,4, 5,6], stype: right)
499
+ x = NMatrix.new([1,4], [1,2, 3,4], stype: right)
500
+
501
+ expect{n==m}.to raise_error(ShapeError)
502
+ expect{n==x}.to raise_error(ShapeError)
503
+ end
504
+
505
+ it "tests equality of matrices with different dimension" do
506
+ n = NMatrix.new([2,1], [1,2], stype: left)
507
+ m = NMatrix.new([2], [1,2], stype: right)
508
+
509
+ expect{n==m}.to raise_error(ShapeError)
510
+ end if left != :yale && right != :yale # yale must have dimension 2
511
+ end
512
+ end
513
+ end
514
+ end
515
+
516
+ context "#concat" do
517
+ it "should default to horizontal concatenation" do
518
+ n = NMatrix.new([1,3], [1,2,3])
519
+ expect(n.concat(n)).to eq(NMatrix.new([1,6], [1,2,3,1,2,3]))
520
+ end
521
+
522
+ it "should permit vertical concatenation" do
523
+ n = NMatrix.new([1,3], [1,2,3])
524
+ expect(n.vconcat(n)).to eq(NMatrix.new([2,3], [1,2,3]))
525
+ end
526
+
527
+ it "should permit depth concatenation on tensors" do
528
+ n = NMatrix.new([1,3,1], [1,2,3])
529
+ expect(n.dconcat(n)).to eq(NMatrix.new([1,3,2], [1,1,2,2,3,3]))
530
+ end
531
+ end
532
+
533
+ context "#[]" do
534
+ it "should return values based on indices" do
535
+ n = NMatrix.new([2,5], [1,2,3,4,5,6,7,8,9,0])
536
+ expect(n[1,0]).to eq 6
537
+ expect(n[1,0..3]).to eq NMatrix.new([1,4],[6,7,8,9])
538
+ end
539
+
540
+ it "should work for negative indices" do
541
+ n = NMatrix.new([1,5], [1,2,3,4,5])
542
+ expect(n[-1]).to eq(5)
543
+ expect(n[0,0..-2]).to eq(NMatrix.new([1,4],[1,2,3,4]))
544
+ end
545
+ end
546
+
547
+ context "#complex_conjugate!" do
548
+ [:dense, :yale, :list].each do |stype|
549
+ context(stype) do
550
+ it "should work in-place for complex dtypes" do
551
+ pending("not yet implemented for list stype") if stype == :list
552
+ n = NMatrix.new([2,3], [Complex(2,3)], stype: stype, dtype: :complex128)
553
+ n.complex_conjugate!
554
+ expect(n).to eq(NMatrix.new([2,3], [Complex(2,-3)], stype: stype, dtype: :complex128))
555
+ end
556
+
557
+ [:object, :int64].each do |dtype|
558
+ it "should work in-place for non-complex dtypes" do
559
+ pending("not yet implemented for list stype") if stype == :list
560
+ n = NMatrix.new([2,3], 1, stype: stype, dtype: dtype)
561
+ n.complex_conjugate!
562
+ expect(n).to eq(NMatrix.new([2,3], [1], stype: stype, dtype: dtype))
563
+ end
564
+ end
565
+ end
566
+ end
567
+ end
568
+
569
+ context "#complex_conjugate" do
570
+ [:dense, :yale, :list].each do |stype|
571
+ context(stype) do
572
+ it "should work out-of-place for complex dtypes" do
573
+ pending("not yet implemented for list stype") if stype == :list
574
+ n = NMatrix.new([2,3], [Complex(2,3)], stype: stype, dtype: :complex128)
575
+ expect(n.complex_conjugate).to eq(NMatrix.new([2,3], [Complex(2,-3)], stype: stype, dtype: :complex128))
576
+ end
577
+
578
+ [:object, :int64].each do |dtype|
579
+ it "should work out-of-place for non-complex dtypes" do
580
+ pending("not yet implemented for list stype") if stype == :list
581
+ n = NMatrix.new([2,3], 1, stype: stype, dtype: dtype)
582
+ expect(n.complex_conjugate).to eq(NMatrix.new([2,3], [1], stype: stype, dtype: dtype))
583
+ end
584
+ end
585
+ end
586
+ end
587
+ end
588
+
589
+ context "#inject" do
590
+ it "should sum columns of yale matrix correctly" do
591
+ n = NMatrix.new([4, 3], stype: :yale, default: 0)
592
+ n[0,0] = 1
593
+ n[1,1] = 2
594
+ n[2,2] = 4
595
+ n[3,2] = 8
596
+ column_sums = []
597
+ n.cols.times do |i|
598
+ column_sums << n.col(i).inject(:+)
599
+ end
600
+ expect(column_sums).to eq([1, 2, 12])
601
+ end
602
+ end
603
+
604
+ context "#index" do
605
+ it "returns index of first occurence of an element for a vector" do
606
+ n = NMatrix.new([5], [0,22,22,11,11])
607
+
608
+ expect(n.index(22)).to eq([1])
609
+ end
610
+
611
+ it "returns index of first occurence of an element for 2-D matrix" do
612
+ n = NMatrix.new([3,3], [23,11,23,
613
+ 44, 2, 0,
614
+ 33, 0, 32])
615
+
616
+ expect(n.index(0)).to eq([1,2])
617
+ end
618
+
619
+ it "returns index of first occerence of an element for N-D matrix" do
620
+ n = NMatrix.new([3,3,3], [23,11,23, 44, 2, 0, 33, 0, 32,
621
+ 23,11,23, 44, 2, 0, 33, 0, 32,
622
+ 23,11,23, 44, 2, 0, 33, 0, 32])
623
+
624
+ expect(n.index(44)).to eq([0,1,0])
625
+ end
626
+ end
627
+
628
+ context "#diagonal" do
629
+ ALL_DTYPES.each do |dtype|
630
+ before do
631
+ @square_matrix = NMatrix.new([3,3], [
632
+ 23,11,23,
633
+ 44, 2, 0,
634
+ 33, 0, 32
635
+ ], dtype: dtype
636
+ )
637
+
638
+ @rect_matrix = NMatrix.new([4,3], [
639
+ 23,11,23,
640
+ 44, 2, 0,
641
+ 33, 0,32,
642
+ 11,22,33
643
+ ], dtype: dtype
644
+ )
645
+ end
646
+
647
+ it "returns main diagonal for square matrix" do
648
+ expect(@square_matrix.diagonal).to eq(NMatrix.new [3], [23,2,32])
649
+ end
650
+
651
+ it "returns main diagonal for rectangular matrix" do
652
+ expect(@rect_matrix.diagonal).to eq(NMatrix.new [3], [23,2,32])
653
+ end
654
+
655
+ it "returns anti-diagonal for square matrix" do
656
+ expect(@square_matrix.diagonal(false)).to eq(NMatrix.new [3], [23,2,33])
657
+ end
658
+
659
+ it "returns anti-diagonal for rectangular matrix" do
660
+ expect(@square_matrix.diagonal(false)).to eq(NMatrix.new [3], [23,2,33])
661
+ end
662
+ end
663
+ end
664
+
665
+ context "#repeat" do
666
+ before do
667
+ @sample_matrix = NMatrix.new([2, 2], [1, 2, 3, 4])
668
+ end
669
+
670
+ it "checks count argument" do
671
+ expect{@sample_matrix.repeat(1, 0)}.to raise_error(ArgumentError)
672
+ expect{@sample_matrix.repeat(-2, 0)}.to raise_error(ArgumentError)
673
+ end
674
+
675
+ it "returns repeated matrix" do
676
+ expect(@sample_matrix.repeat(2, 0)).to eq(NMatrix.new([4, 2], [1, 2, 3, 4, 1, 2, 3, 4]))
677
+ expect(@sample_matrix.repeat(2, 1)).to eq(NMatrix.new([2, 4], [1, 2, 1, 2, 3, 4, 3, 4]))
678
+ end
679
+ end
680
+
681
+ context "#meshgrid" do
682
+ before do
683
+ @x, @y, @z = [1, 2, 3], NMatrix.new([2, 1], [4, 5]), [6, 7]
684
+ @two_dim = NMatrix.new([2, 2], [1, 2, 3, 4])
685
+ @two_dim_array = [[4], [5]]
686
+ @expected_result = [NMatrix.new([2, 3], [1, 2, 3, 1, 2, 3]), NMatrix.new([2, 3], [4, 4, 4, 5, 5, 5])]
687
+ @expected_for_ij = [NMatrix.new([3, 2], [1, 1, 2, 2, 3, 3]), NMatrix.new([3, 2], [4, 5, 4, 5, 4, 5])]
688
+ @expected_for_sparse = [NMatrix.new([1, 3], [1, 2, 3]), NMatrix.new([2, 1], [4, 5])]
689
+ @expected_for_sparse_ij = [NMatrix.new([3, 1], [1, 2, 3]), NMatrix.new([1, 2], [4, 5])]
690
+ @expected_3dim = [NMatrix.new([1, 3, 1], [1, 2, 3]).repeat(2, 0).repeat(2, 2),
691
+ NMatrix.new([2, 1, 1], [4, 5]).repeat(3, 1).repeat(2, 2),
692
+ NMatrix.new([1, 1, 2], [6, 7]).repeat(2, 0).repeat(3, 1)]
693
+ @expected_3dim_sparse_ij = [NMatrix.new([3, 1, 1], [1, 2, 3]),
694
+ NMatrix.new([1, 2, 1], [4, 5]),
695
+ NMatrix.new([1, 1, 2], [6, 7])]
696
+ end
697
+
698
+ it "checks arrays count" do
699
+ expect{NMatrix.meshgrid([@x])}.to raise_error(ArgumentError)
700
+ expect{NMatrix.meshgrid([])}.to raise_error(ArgumentError)
701
+ end
702
+
703
+ it "flattens input arrays before use" do
704
+ expect(NMatrix.meshgrid([@two_dim, @two_dim_array])).to eq(NMatrix.meshgrid([@two_dim.to_flat_array, @two_dim_array.flatten]))
705
+ end
706
+
707
+ it "returns new NMatrixes" do
708
+ expect(NMatrix.meshgrid([@x, @y])).to eq(@expected_result)
709
+ end
710
+
711
+ it "has option :sparse" do
712
+ expect(NMatrix.meshgrid([@x, @y], sparse: true)).to eq(@expected_for_sparse)
713
+ end
714
+
715
+ it "has option :indexing" do
716
+ expect(NMatrix.meshgrid([@x, @y], indexing: :ij)).to eq(@expected_for_ij)
717
+ expect(NMatrix.meshgrid([@x, @y], indexing: :xy)).to eq(@expected_result)
718
+ expect{NMatrix.meshgrid([@x, @y], indexing: :not_ij_not_xy)}.to raise_error(ArgumentError)
719
+ end
720
+
721
+ it "works well with both options set" do
722
+ expect(NMatrix.meshgrid([@x, @y], sparse: true, indexing: :ij)).to eq(@expected_for_sparse_ij)
723
+ end
724
+
725
+ it "is able to take more than two arrays as arguments and works well with options" do
726
+ expect(NMatrix.meshgrid([@x, @y, @z])).to eq(@expected_3dim)
727
+ expect(NMatrix.meshgrid([@x, @y, @z], sparse: true, indexing: :ij)).to eq(@expected_3dim_sparse_ij)
728
+ end
729
+ end
730
+ end