nmatrix-lapacke 0.2.1 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. checksums.yaml +4 -4
  2. data/ext/nmatrix/data/data.h +7 -8
  3. data/ext/nmatrix/data/ruby_object.h +1 -4
  4. data/ext/nmatrix/math/asum.h +10 -31
  5. data/ext/nmatrix/math/cblas_templates_core.h +10 -10
  6. data/ext/nmatrix/math/getrf.h +2 -2
  7. data/ext/nmatrix/math/imax.h +12 -9
  8. data/ext/nmatrix/math/laswp.h +3 -3
  9. data/ext/nmatrix/math/long_dtype.h +16 -3
  10. data/ext/nmatrix/math/magnitude.h +54 -0
  11. data/ext/nmatrix/math/nrm2.h +19 -14
  12. data/ext/nmatrix/math/trsm.h +40 -36
  13. data/ext/nmatrix/math/util.h +14 -0
  14. data/ext/nmatrix/nmatrix.h +39 -1
  15. data/ext/nmatrix/storage/common.h +9 -3
  16. data/ext/nmatrix/storage/yale/class.h +1 -1
  17. data/ext/nmatrix_lapacke/extconf.rb +3 -136
  18. data/ext/nmatrix_lapacke/lapacke.cpp +104 -84
  19. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgeqrf.c +77 -0
  20. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgeqrf_work.c +89 -0
  21. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cunmqr.c +88 -0
  22. data/ext/nmatrix_lapacke/lapacke/src/lapacke_cunmqr_work.c +111 -0
  23. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgeqrf.c +75 -0
  24. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgeqrf_work.c +87 -0
  25. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dormqr.c +86 -0
  26. data/ext/nmatrix_lapacke/lapacke/src/lapacke_dormqr_work.c +109 -0
  27. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgeqrf.c +75 -0
  28. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgeqrf_work.c +87 -0
  29. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sormqr.c +86 -0
  30. data/ext/nmatrix_lapacke/lapacke/src/lapacke_sormqr_work.c +109 -0
  31. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgeqrf.c +77 -0
  32. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgeqrf_work.c +89 -0
  33. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zunmqr.c +88 -0
  34. data/ext/nmatrix_lapacke/lapacke/src/lapacke_zunmqr_work.c +111 -0
  35. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_c_nancheck.c +51 -0
  36. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_d_nancheck.c +51 -0
  37. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_s_nancheck.c +51 -0
  38. data/ext/nmatrix_lapacke/lapacke/utils/lapacke_z_nancheck.c +51 -0
  39. data/ext/nmatrix_lapacke/math_lapacke.cpp +149 -17
  40. data/ext/nmatrix_lapacke/math_lapacke/lapacke_templates.h +76 -0
  41. data/lib/nmatrix/lapacke.rb +118 -0
  42. data/spec/00_nmatrix_spec.rb +50 -1
  43. data/spec/02_slice_spec.rb +21 -21
  44. data/spec/blas_spec.rb +25 -3
  45. data/spec/math_spec.rb +233 -5
  46. data/spec/plugins/lapacke/lapacke_spec.rb +187 -0
  47. data/spec/shortcuts_spec.rb +145 -5
  48. data/spec/spec_helper.rb +24 -1
  49. metadata +38 -8
@@ -64,6 +64,82 @@ inline int lapacke_getrf(const enum CBLAS_ORDER order, const int m, const int n,
64
64
  return getrf<DType>(order, m, n, static_cast<DType*>(a), lda, ipiv);
65
65
  }
66
66
 
67
+ //geqrf
68
+ template <typename DType>
69
+ inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, DType* a, const int lda, DType* tau) {
70
+ rb_raise(rb_eNotImpError, "lapacke_geqrf not implemented for non_BLAS dtypes.");
71
+ return 0;
72
+ }
73
+
74
+ template <>
75
+ inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, float* a, const int lda, float* tau) {
76
+ return LAPACKE_sgeqrf(order, m, n, a, lda, tau);
77
+ }
78
+
79
+ template < >
80
+ inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, double* a, const int lda, double* tau) {
81
+ return LAPACKE_dgeqrf(order, m, n, a, lda, tau);
82
+ }
83
+
84
+ template <>
85
+ inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, Complex64* a, const int lda, Complex64* tau) {
86
+ return LAPACKE_cgeqrf(order, m, n, a, lda, tau);
87
+ }
88
+
89
+ template <>
90
+ inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, Complex128* a, const int lda, Complex128* tau) {
91
+ return LAPACKE_zgeqrf(order, m, n, a, lda, tau);
92
+ }
93
+
94
+ template <typename DType>
95
+ inline int lapacke_geqrf(const enum CBLAS_ORDER order, const int m, const int n, void* a, const int lda, void* tau) {
96
+ return geqrf<DType>(order, m, n, static_cast<DType*>(a), lda, static_cast<DType*>(tau));
97
+ }
98
+
99
+ //ormqr
100
+ template <typename DType>
101
+ inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, DType* a, const int lda, DType* tau, DType* c, const int ldc) {
102
+ rb_raise(rb_eNotImpError, "lapacke_ormqr not implemented for non_BLAS dtypes.");
103
+ return 0;
104
+ }
105
+
106
+ template <>
107
+ inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, float* a, const int lda, float* tau, float* c, const int ldc) {
108
+ return LAPACKE_sormqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
109
+ }
110
+
111
+ template <>
112
+ inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, double* a, const int lda, double* tau, double* c, const int ldc) {
113
+ return LAPACKE_dormqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
114
+ }
115
+
116
+ template <typename DType>
117
+ inline int lapacke_ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, void* a, const int lda, void* tau, void* c, const int ldc) {
118
+ return ormqr<DType>(order, side, trans, m, n, k, static_cast<DType*>(a), lda, static_cast<DType*>(tau), static_cast<DType*>(c), ldc);
119
+ }
120
+
121
+ //unmqr
122
+ template <typename DType>
123
+ inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, DType* a, const int lda, DType* tau, DType* c, const int ldc) {
124
+ rb_raise(rb_eNotImpError, "lapacke_unmqr not implemented for non complex dtypes.");
125
+ return 0;
126
+ }
127
+
128
+ template <>
129
+ inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, Complex64* a, const int lda, Complex64* tau, Complex64* c, const int ldc) {
130
+ return LAPACKE_cunmqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
131
+ }
132
+
133
+ template <>
134
+ inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, Complex128* a, const int lda, Complex128* tau, Complex128* c, const int ldc) {
135
+ return LAPACKE_zunmqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
136
+ }
137
+
138
+ template <typename DType>
139
+ inline int lapacke_unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, void* a, const int lda, void* tau, void* c, const int ldc) {
140
+ return unmqr<DType>(order, side, trans, m, n, k, static_cast<DType*>(a), lda, static_cast<DType*>(tau), static_cast<DType*>(c), ldc);
141
+ }
142
+
67
143
  //getri
68
144
  template <typename DType>
69
145
  inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) {
@@ -232,4 +232,122 @@ class NMatrix
232
232
  raise(ArgumentError, "#{opts[:form]} is not a valid form option")
233
233
  end
234
234
  end
235
+
236
+ #
237
+ # call-seq:
238
+ # geqrf! -> shape.min x 1 NMatrix
239
+ #
240
+ # QR factorization of a general M-by-N matrix +A+.
241
+ #
242
+ # The QR factorization is A = QR, where Q is orthogonal and R is Upper Triangular
243
+ # +A+ is overwritten with the elements of R and Q with Q being represented by the
244
+ # elements below A's diagonal and an array of scalar factors in the output NMatrix.
245
+ #
246
+ # The matrix Q is represented as a product of elementary reflectors
247
+ # Q = H(1) H(2) . . . H(k), where k = min(m,n).
248
+ #
249
+ # Each H(i) has the form
250
+ #
251
+ # H(i) = I - tau * v * v'
252
+ #
253
+ # http://www.netlib.org/lapack/explore-html/d3/d69/dgeqrf_8f.html
254
+ #
255
+ # Only works for dense matrices.
256
+ #
257
+ # * *Returns* :
258
+ # - Vector TAU. Q and R are stored in A. Q is represented by TAU and A
259
+ # * *Raises* :
260
+ # - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
261
+ #
262
+ def geqrf!
263
+ raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense?
264
+
265
+ tau = NMatrix.new([self.shape.min,1], dtype: self.dtype)
266
+ NMatrix::LAPACK::lapacke_geqrf(:row, self.shape[0], self.shape[1], self, self.shape[1], tau)
267
+
268
+ tau
269
+ end
270
+
271
+ #
272
+ # call-seq:
273
+ # ormqr(tau) -> NMatrix
274
+ # ormqr(tau, side, transpose, c) -> NMatrix
275
+ #
276
+ # Returns the product Q * c or c * Q after a call to geqrf! used in QR factorization.
277
+ # +c+ is overwritten with the elements of the result NMatrix if supplied. Q is the orthogonal matrix
278
+ # represented by tau and the calling NMatrix
279
+ #
280
+ # Only works on float types, use unmqr for complex types.
281
+ #
282
+ # == Arguments
283
+ #
284
+ # * +tau+ - vector containing scalar factors of elementary reflectors
285
+ # * +side+ - direction of multiplication [:left, :right]
286
+ # * +transpose+ - apply Q with or without transpose [false, :transpose]
287
+ # * +c+ - NMatrix multplication argument that is overwritten, no argument assumes c = identity
288
+ #
289
+ # * *Returns* :
290
+ #
291
+ # - Q * c or c * Q Where Q may be transposed before multiplication.
292
+ #
293
+ #
294
+ # * *Raises* :
295
+ # - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
296
+ # - +TypeError+ -> Works only on floating point matrices, use unmqr for complex types
297
+ # - +TypeError+ -> c must have the same dtype as the calling NMatrix
298
+ #
299
+ def ormqr(tau, side=:left, transpose=false, c=nil)
300
+ raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense?
301
+ raise(TypeError, "Works only on floating point matrices, use unmqr for complex types") if self.complex_dtype?
302
+ raise(TypeError, "c must have the same dtype as the calling NMatrix") if c and c.dtype != self.dtype
303
+
304
+
305
+ #Default behaviour produces Q * I = Q if c is not supplied.
306
+ result = c ? c : NMatrix.identity(self.shape[0], dtype: self.dtype)
307
+ NMatrix::LAPACK::lapacke_ormqr(:row, side, transpose, result.shape[0], result.shape[1], tau.shape[0], self, self.shape[1], tau, result, result.shape[1])
308
+
309
+ result
310
+ end
311
+
312
+ #
313
+ # call-seq:
314
+ # unmqr(tau) -> NMatrix
315
+ # unmqr(tau, side, transpose, c) -> NMatrix
316
+ #
317
+ # Returns the product Q * c or c * Q after a call to geqrf! used in QR factorization.
318
+ # +c+ is overwritten with the elements of the result NMatrix if it is supplied. Q is the orthogonal matrix
319
+ # represented by tau and the calling NMatrix
320
+ #
321
+ # Only works on complex types, use ormqr for float types.
322
+ #
323
+ # == Arguments
324
+ #
325
+ # * +tau+ - vector containing scalar factors of elementary reflectors
326
+ # * +side+ - direction of multiplication [:left, :right]
327
+ # * +transpose+ - apply Q as Q or its complex conjugate [false, :complex_conjugate]
328
+ # * +c+ - NMatrix multplication argument that is overwritten, no argument assumes c = identity
329
+ #
330
+ # * *Returns* :
331
+ #
332
+ # - Q * c or c * Q Where Q may be transformed to its complex conjugate before multiplication.
333
+ #
334
+ #
335
+ # * *Raises* :
336
+ # - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
337
+ # - +TypeError+ -> Works only on floating point matrices, use unmqr for complex types
338
+ # - +TypeError+ -> c must have the same dtype as the calling NMatrix
339
+ #
340
+ def unmqr(tau, side=:left, transpose=false, c=nil)
341
+ raise(StorageTypeError, "ATLAS functions only work on dense matrices") unless self.dense?
342
+ raise(TypeError, "Works only on complex matrices, use ormqr for normal floating point matrices") unless self.complex_dtype?
343
+ raise(TypeError, "c must have the same dtype as the calling NMatrix") if c and c.dtype != self.dtype
344
+
345
+ #Default behaviour produces Q * I = Q if c is not supplied.
346
+ result = c ? c : NMatrix.identity(self.shape[0], dtype: self.dtype)
347
+ NMatrix::LAPACK::lapacke_unmqr(:row, side, transpose, result.shape[0], result.shape[1], tau.shape[0], self, self.shape[1], tau, result, result.shape[1])
348
+
349
+ result
350
+ end
351
+
352
+
235
353
  end
@@ -424,6 +424,13 @@ describe 'NMatrix' do
424
424
  expect(n.reshape!([8,2]).eql?(n)).to eq(true) # because n itself changes
425
425
  end
426
426
 
427
+ it "should do the reshape operation in place, changing dimension" do
428
+ n = NMatrix.seq(4)
429
+ a = n.reshape!([4,2,2])
430
+ expect(n).to eq(NMatrix.seq([4,2,2]))
431
+ expect(a).to eq(NMatrix.seq([4,2,2]))
432
+ end
433
+
427
434
  it "reshape and reshape! must produce same result" do
428
435
  n = NMatrix.seq(4)+1
429
436
  a = NMatrix.seq(4)+1
@@ -432,7 +439,7 @@ describe 'NMatrix' do
432
439
 
433
440
  it "should prevent a resize in place" do
434
441
  n = NMatrix.seq(4)+1
435
- expect { n.reshape([5,2]) }.to raise_error(ArgumentError)
442
+ expect { n.reshape!([5,2]) }.to raise_error(ArgumentError)
436
443
  end
437
444
  end
438
445
 
@@ -534,6 +541,26 @@ describe 'NMatrix' do
534
541
  n = NMatrix.new([1,3,1], [1,2,3])
535
542
  expect(n.dconcat(n)).to eq(NMatrix.new([1,3,2], [1,1,2,2,3,3]))
536
543
  end
544
+
545
+ it "should work on matrices with different size along concat dim" do
546
+ n = N[[1, 2, 3],
547
+ [4, 5, 6]]
548
+ m = N[[7],
549
+ [8]]
550
+
551
+ expect(n.hconcat(m)).to eq N[[1, 2, 3, 7], [4, 5, 6, 8]]
552
+ expect(m.hconcat(n)).to eq N[[7, 1, 2, 3], [8, 4, 5, 6]]
553
+ end
554
+
555
+ it "should work on matrices with different size along concat dim" do
556
+ n = N[[1, 2, 3],
557
+ [4, 5, 6]]
558
+
559
+ m = N[[7, 8, 9]]
560
+
561
+ expect(n.vconcat(m)).to eq N[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
562
+ expect(m.vconcat(n)).to eq N[[7, 8, 9], [1, 2, 3], [4, 5, 6]]
563
+ end
537
564
  end
538
565
 
539
566
  context "#[]" do
@@ -631,6 +658,23 @@ describe 'NMatrix' do
631
658
  end
632
659
  end
633
660
 
661
+ context "#last" do
662
+ it "returns the last element of a 1-dimensional NMatrix" do
663
+ n = NMatrix.new([1,4], [1,2,3,4])
664
+ expect(n.last).to eq(4)
665
+ end
666
+
667
+ it "returns the last element of a 2-dimensional NMatrix" do
668
+ n = NMatrix.new([2,2], [4,8,12,16])
669
+ expect(n.last).to eq(16)
670
+ end
671
+
672
+ it "returns the last element of a 3-dimensional NMatrix" do
673
+ n = NMatrix.new([2,2,2], [1,2,3,4,5,6,7,8])
674
+ expect(n.last).to eq(8)
675
+ end
676
+ end
677
+
634
678
  context "#diagonal" do
635
679
  ALL_DTYPES.each do |dtype|
636
680
  before do
@@ -682,6 +726,11 @@ describe 'NMatrix' do
682
726
  expect(@sample_matrix.repeat(2, 0)).to eq(NMatrix.new([4, 2], [1, 2, 3, 4, 1, 2, 3, 4]))
683
727
  expect(@sample_matrix.repeat(2, 1)).to eq(NMatrix.new([2, 4], [1, 2, 1, 2, 3, 4, 3, 4]))
684
728
  end
729
+
730
+ it "preserves dtype" do
731
+ expect(@sample_matrix.repeat(2, 0).dtype).to eq(@sample_matrix.dtype)
732
+ expect(@sample_matrix.repeat(2, 1).dtype).to eq(@sample_matrix.dtype)
733
+ end
685
734
  end
686
735
 
687
736
  context "#meshgrid" do
@@ -127,9 +127,9 @@ describe "Slice operation" do
127
127
  it "should have #is_ref? method" do
128
128
  a = stype_matrix[0..1, 0..1]
129
129
  b = stype_matrix.slice(0..1, 0..1)
130
- expect(stype_matrix.is_ref?).to be_false
131
- expect(a.is_ref?).to be_true
132
- expect(b.is_ref?).to be_false
130
+ expect(stype_matrix.is_ref?).to be false
131
+ expect(a.is_ref?).to be true
132
+ expect(b.is_ref?).to be false
133
133
  end
134
134
 
135
135
  it "reference should compare with non-reference" do
@@ -141,7 +141,7 @@ describe "Slice operation" do
141
141
  context "with copying" do
142
142
  it 'should return an NMatrix' do
143
143
  n = stype_matrix.slice(0..1,0..1)
144
- expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to be_true
144
+ expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to be true
145
145
  end
146
146
 
147
147
  it 'should return a copy of 2x2 matrix to self elements' do
@@ -183,19 +183,19 @@ describe "Slice operation" do
183
183
 
184
184
  [:dense, :list, :yale].each do |cast_type|
185
185
  it "should cast copied slice from #{stype.upcase} to #{cast_type.upcase}" do
186
- expect(nm_eql(stype_matrix.slice(1..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(1..2,1..2))).to be_true
187
- expect(nm_eql(stype_matrix.slice(0..1, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..1,1..2))).to be_true
188
- expect(nm_eql(stype_matrix.slice(1..2, 0..1).cast(cast_type, :int32), stype_matrix.slice(1..2,0..1))).to be_true
189
- expect(nm_eql(stype_matrix.slice(0..1, 0..1).cast(cast_type, :int32), stype_matrix.slice(0..1,0..1))).to be_true
186
+ expect(nm_eql(stype_matrix.slice(1..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(1..2,1..2))).to be true
187
+ expect(nm_eql(stype_matrix.slice(0..1, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..1,1..2))).to be true
188
+ expect(nm_eql(stype_matrix.slice(1..2, 0..1).cast(cast_type, :int32), stype_matrix.slice(1..2,0..1))).to be true
189
+ expect(nm_eql(stype_matrix.slice(0..1, 0..1).cast(cast_type, :int32), stype_matrix.slice(0..1,0..1))).to be true
190
190
 
191
191
  # Non square
192
- expect(nm_eql(stype_matrix.slice(0..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..2,1..2))).to be_true
192
+ expect(nm_eql(stype_matrix.slice(0..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..2,1..2))).to be true
193
193
  #require 'pry'
194
194
  #binding.pry if cast_type == :yale
195
- expect(nm_eql(stype_matrix.slice(1..2, 0..2).cast(cast_type, :int32), stype_matrix.slice(1..2,0..2))).to be_true
195
+ expect(nm_eql(stype_matrix.slice(1..2, 0..2).cast(cast_type, :int32), stype_matrix.slice(1..2,0..2))).to be true
196
196
 
197
197
  # Full
198
- expect(nm_eql(stype_matrix.slice(0..2, 0..2).cast(cast_type, :int32), stype_matrix)).to be_true
198
+ expect(nm_eql(stype_matrix.slice(0..2, 0..2).cast(cast_type, :int32), stype_matrix)).to be true
199
199
  end
200
200
  end
201
201
  end
@@ -214,7 +214,7 @@ describe "Slice operation" do
214
214
  context "by reference" do
215
215
  it 'should return an NMatrix' do
216
216
  n = stype_matrix[0..1,0..1]
217
- expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to be_true
217
+ expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to be true
218
218
  end
219
219
 
220
220
  it 'should return a 2x2 matrix with refs to self elements' do
@@ -246,7 +246,7 @@ describe "Slice operation" do
246
246
 
247
247
  it 'should slice again' do
248
248
  n = stype_matrix[1..2, 1..2]
249
- expect(nm_eql(n[1,0..1], NVector.new(2, [7,8], dtype: :int32).transpose)).to be_true
249
+ expect(nm_eql(n[1,0..1], NVector.new(2, [7,8], dtype: :int32).transpose)).to be true
250
250
  end
251
251
 
252
252
  it 'should be correct slice for range 0..2 and 0...3' do
@@ -320,7 +320,7 @@ describe "Slice operation" do
320
320
  end
321
321
 
322
322
  it "compares slices to scalars" do
323
- (stype_matrix[1, 0..2] > 2).each { |e| expect(e != 0).to be_true }
323
+ (stype_matrix[1, 0..2] > 2).each { |e| expect(e != 0).to be true }
324
324
  end
325
325
 
326
326
  it "iterates only over elements in the slice" do
@@ -367,20 +367,20 @@ describe "Slice operation" do
367
367
 
368
368
  [:dense, :list, :yale].each do |cast_type|
369
369
  it "should cast a square reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
370
- expect(nm_eql(stype_matrix[1..2, 1..2].cast(cast_type), stype_matrix[1..2,1..2])).to be_true
371
- expect(nm_eql(stype_matrix[0..1, 1..2].cast(cast_type), stype_matrix[0..1,1..2])).to be_true
372
- expect(nm_eql(stype_matrix[1..2, 0..1].cast(cast_type), stype_matrix[1..2,0..1])).to be_true
373
- expect(nm_eql(stype_matrix[0..1, 0..1].cast(cast_type), stype_matrix[0..1,0..1])).to be_true
370
+ expect(nm_eql(stype_matrix[1..2, 1..2].cast(cast_type), stype_matrix[1..2,1..2])).to be true
371
+ expect(nm_eql(stype_matrix[0..1, 1..2].cast(cast_type), stype_matrix[0..1,1..2])).to be true
372
+ expect(nm_eql(stype_matrix[1..2, 0..1].cast(cast_type), stype_matrix[1..2,0..1])).to be true
373
+ expect(nm_eql(stype_matrix[0..1, 0..1].cast(cast_type), stype_matrix[0..1,0..1])).to be true
374
374
  end
375
375
 
376
376
  it "should cast a rectangular reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
377
377
  # Non square
378
- expect(nm_eql(stype_matrix[0..2, 1..2].cast(cast_type), stype_matrix[0..2,1..2])).to be_true # FIXME: memory problem.
379
- expect(nm_eql(stype_matrix[1..2, 0..2].cast(cast_type), stype_matrix[1..2,0..2])).to be_true # this one is fine
378
+ expect(nm_eql(stype_matrix[0..2, 1..2].cast(cast_type), stype_matrix[0..2,1..2])).to be true # FIXME: memory problem.
379
+ expect(nm_eql(stype_matrix[1..2, 0..2].cast(cast_type), stype_matrix[1..2,0..2])).to be true # this one is fine
380
380
  end
381
381
 
382
382
  it "should cast a square full-matrix reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
383
- expect(nm_eql(stype_matrix[0..2, 0..2].cast(cast_type), stype_matrix)).to be_true
383
+ expect(nm_eql(stype_matrix[0..2, 0..2].cast(cast_type), stype_matrix)).to be true
384
384
  end
385
385
  end
386
386
  end
data/spec/blas_spec.rb CHANGED
@@ -69,6 +69,18 @@ describe NMatrix::BLAS do
69
69
  expect(b[0]).to eq(-15.0/2)
70
70
  expect(b[1]).to eq(5)
71
71
  expect(b[2]).to eq(-13)
72
+
73
+ NMatrix::BLAS::cblas_trsm(:row, :left, :lower, :transpose, :nounit, 3, 1, 1.0, a, 3, b, 1)
74
+
75
+ expect(b[0]).to eq(307.0/8)
76
+ expect(b[1]).to eq(57.0/2)
77
+ expect(b[2]).to eq(26.0)
78
+
79
+ NMatrix::BLAS::cblas_trsm(:row, :left, :upper, :transpose, :unit, 3, 1, 1.0, a, 3, b, 1)
80
+
81
+ expect(b[0]).to eq(307.0/8)
82
+ expect(b[1]).to eq(763.0/16)
83
+ expect(b[2]).to eq(4269.0/64)
72
84
  end
73
85
 
74
86
  # trmm multiplies two matrices, where one of the two is required to be
@@ -174,9 +186,17 @@ describe NMatrix::BLAS do
174
186
 
175
187
  it "exposes nrm2" do
176
188
  pending("broken for :object") if dtype == :object
177
- pending("Temporarily disable because the internal implementation of nrm2 is broken -WL 2015-05-17") if dtype == :complex64 || dtype == :complex128
178
189
 
179
- x = NMatrix.new([4,1], [2,-4,3,5], dtype: dtype)
190
+ if dtype =~ /complex/
191
+ x = NMatrix.new([3,1], [Complex(1,2),Complex(3,4),Complex(0,6)], dtype: dtype)
192
+ y = NMatrix.new([3,1], [Complex(0,0),Complex(0,0),Complex(0,0)], dtype: dtype)
193
+ nrm2 = 8.12403840463596
194
+ else
195
+ x = NMatrix.new([4,1], [2,-4,3,5], dtype: dtype)
196
+ y = NMatrix.new([3,1], [0,0,0], dtype: dtype)
197
+ nrm2 = 5.385164807134504
198
+ end
199
+
180
200
  err = case dtype
181
201
  when :float32, :complex64
182
202
  1e-6
@@ -185,7 +205,9 @@ describe NMatrix::BLAS do
185
205
  else
186
206
  1e-14
187
207
  end
188
- expect(NMatrix::BLAS.nrm2(x, 1, 3)).to be_within(err).of(5.385164807134504)
208
+
209
+ expect(NMatrix::BLAS.nrm2(x, 1, 3)).to be_within(err).of(nrm2)
210
+ expect(NMatrix::BLAS.nrm2(y, 1, 3)).to be_within(err).of(0)
189
211
  end
190
212
 
191
213
  end
data/spec/math_spec.rb CHANGED
@@ -285,6 +285,152 @@ describe "math" do
285
285
  end
286
286
  end
287
287
 
288
+ NON_INTEGER_DTYPES.each do |dtype|
289
+ next if dtype == :object
290
+ context dtype do
291
+
292
+ it "calculates QR decomposition using factorize_qr for a square matrix" do
293
+
294
+ a = NMatrix.new(3, [12.0, -51.0, 4.0,
295
+ 6.0, 167.0, -68.0,
296
+ -4.0, 24.0, -41.0] , dtype: dtype)
297
+
298
+ q_solution = NMatrix.new([3,3], Q_SOLUTION_ARRAY_2, dtype: dtype)
299
+
300
+ r_solution = NMatrix.new([3,3], [-14.0, -21.0, 14,
301
+ 0.0, -175, 70,
302
+ 0.0, 0.0, -35] , dtype: dtype)
303
+
304
+ err = case dtype
305
+ when :float32, :complex64
306
+ 1e-4
307
+ when :float64, :complex128
308
+ 1e-13
309
+ end
310
+
311
+ begin
312
+ q,r = a.factorize_qr
313
+
314
+ expect(q).to be_within(err).of(q_solution)
315
+ expect(r).to be_within(err).of(r_solution)
316
+
317
+ rescue NotImplementedError
318
+ pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
319
+ end
320
+ end
321
+
322
+ it "calculates QR decomposition using factorize_qr for a tall and narrow rectangular matrix" do
323
+
324
+ a = NMatrix.new([4,2], [34.0, 21.0,
325
+ 23.0, 53.0,
326
+ 26.0, 346.0,
327
+ 23.0, 121.0] , dtype: dtype)
328
+
329
+ q_solution = NMatrix.new([4,4], Q_SOLUTION_ARRAY_1, dtype: dtype)
330
+
331
+ r_solution = NMatrix.new([4,2], [-53.75872022286244, -255.06559574252242,
332
+ 0.0, 269.34836526051555,
333
+ 0.0, 0.0,
334
+ 0.0, 0.0] , dtype: dtype)
335
+
336
+ err = case dtype
337
+ when :float32, :complex64
338
+ 1e-4
339
+ when :float64, :complex128
340
+ 1e-13
341
+ end
342
+
343
+ begin
344
+ q,r = a.factorize_qr
345
+
346
+ expect(q).to be_within(err).of(q_solution)
347
+ expect(r).to be_within(err).of(r_solution)
348
+
349
+ rescue NotImplementedError
350
+ pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
351
+ end
352
+ end
353
+
354
+ it "calculates QR decomposition using factorize_qr for a short and wide rectangular matrix" do
355
+
356
+ a = NMatrix.new([3,4], [123,31,57,81,92,14,17,36,42,34,11,28], dtype: dtype)
357
+
358
+ q_solution = NMatrix.new([3,3], Q_SOLUTION_ARRAY_3, dtype: dtype)
359
+
360
+ r_solution = NMatrix.new([3,4], R_SOLUTION_ARRAY, dtype: dtype)
361
+
362
+ err = case dtype
363
+ when :float32, :complex64
364
+ 1e-4
365
+ when :float64, :complex128
366
+ 1e-13
367
+ end
368
+
369
+ begin
370
+ q,r = a.factorize_qr
371
+
372
+ expect(q).to be_within(err).of(q_solution)
373
+ expect(r).to be_within(err).of(r_solution)
374
+
375
+ rescue NotImplementedError
376
+ pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
377
+ end
378
+ end
379
+
380
+ it "calculates QR decomposition such that A - QR ~ 0" do
381
+
382
+ a = NMatrix.new([3,3], [ 9.0, 0.0, 26.0,
383
+ 12.0, 0.0, -7.0,
384
+ 0.0, 4.0, 0.0] , dtype: dtype)
385
+
386
+ err = case dtype
387
+ when :float32, :complex64
388
+ 1e-4
389
+ when :float64, :complex128
390
+ 1e-13
391
+ end
392
+
393
+ begin
394
+ q,r = a.factorize_qr
395
+ a_expected = q.dot(r)
396
+
397
+ expect(a_expected).to be_within(err).of(a)
398
+
399
+ rescue NotImplementedError
400
+ pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
401
+ end
402
+ end
403
+
404
+
405
+ it "calculates the orthogonal matrix Q in QR decomposition" do
406
+
407
+ a = N.new([2,2], [34.0, 21, 23, 53] , dtype: dtype)
408
+
409
+ err = case dtype
410
+ when :float32, :complex64
411
+ 1e-4
412
+ when :float64, :complex128
413
+ 1e-13
414
+ end
415
+
416
+ begin
417
+ q,r = a.factorize_qr
418
+
419
+ #Q is orthogonal if Q x Q.transpose = I
420
+ product = q.dot(q.transpose)
421
+
422
+ expect(product[0,0]).to be_within(err).of(1)
423
+ expect(product[1,0]).to be_within(err).of(0)
424
+ expect(product[0,1]).to be_within(err).of(0)
425
+ expect(product[1,1]).to be_within(err).of(1)
426
+
427
+ rescue NotImplementedError
428
+ pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
429
+ end
430
+ end
431
+ end
432
+ end
433
+
288
434
  ALL_DTYPES.each do |dtype|
289
435
  next if dtype == :byte #doesn't work for unsigned types
290
436
  next if dtype == :object
@@ -332,6 +478,49 @@ describe "math" do
332
478
  end
333
479
  end
334
480
 
481
+ ALL_DTYPES.each do |dtype|
482
+ next if dtype == :byte #doesn't work for unsigned types
483
+ next if dtype == :object
484
+
485
+ context dtype do
486
+ err = case dtype
487
+ when :float32, :complex64
488
+ 1e-4
489
+ else #integer matrices will return :float64
490
+ 1e-13
491
+ end
492
+
493
+ it "should correctly find adjugate a matrix in place (bang)" do
494
+ a = NMatrix.new(:dense, 2, [2, 3, 3, 5], dtype)
495
+ b = NMatrix.new(:dense, 2, [5, -3, -3, 2], dtype)
496
+
497
+ if a.integer_dtype?
498
+ expect{a.adjugate!}.to raise_error(DataTypeError)
499
+ else
500
+ #should return adjugate as well as modifying a
501
+ r = a.adjugate!
502
+ expect(a).to be_within(err).of(b)
503
+ expect(r).to be_within(err).of(b)
504
+ end
505
+ end
506
+
507
+
508
+ it "should correctly find adjugate of a matrix out-of-place" do
509
+ a = NMatrix.new(:dense, 3, [-3, 2, -5, -1, 0, -2, 3, -4, 1], dtype)
510
+
511
+ if a.integer_dtype?
512
+ b = NMatrix.new(:dense, 3, [-8, 18, -4, -5, 12, -1, 4, -6, 2], :float64)
513
+ else
514
+ b = NMatrix.new(:dense, 3, [-8, 18, -4, -5, 12, -1, 4, -6, 2], dtype)
515
+ end
516
+
517
+ expect(a.adjoint).to be_within(err).of(b)
518
+ expect(a.adjugate).to be_within(err).of(b)
519
+ end
520
+
521
+ end
522
+ end
523
+
335
524
  # TODO: Get it working with ROBJ too
336
525
  [:byte,:int8,:int16,:int32,:int64,:float32,:float64].each do |left_dtype|
337
526
  [:byte,:int8,:int16,:int32,:int64,:float32,:float64].each do |right_dtype|
@@ -702,7 +891,7 @@ describe "math" do
702
891
  360, 96, 51, -14,
703
892
  448,-231,-24,-87,
704
893
  -1168, 595,234, 523],
705
- dtype: answer_dtype,
894
+ dtype: answer_dtype,
706
895
  stype: stype))
707
896
  end
708
897
 
@@ -757,7 +946,6 @@ describe "math" do
757
946
 
758
947
  context "determinants" do
759
948
  ALL_DTYPES.each do |dtype|
760
- next if dtype == :object
761
949
  context dtype do
762
950
  before do
763
951
  @a = NMatrix.new([2,2], [1,2,
@@ -779,13 +967,19 @@ describe "math" do
779
967
  end
780
968
  end
781
969
  it "computes the determinant of 2x2 matrix" do
782
- expect(@a.det).to be_within(@err).of(-2)
970
+ if dtype != :object
971
+ expect(@a.det).to be_within(@err).of(-2)
972
+ end
783
973
  end
784
974
  it "computes the determinant of 3x3 matrix" do
785
- expect(@b.det).to be_within(@err).of(-8)
975
+ if dtype != :object
976
+ expect(@b.det).to be_within(@err).of(-8)
977
+ end
786
978
  end
787
979
  it "computes the determinant of 4x4 matrix" do
788
- expect(@c.det).to be_within(@err).of(-18)
980
+ if dtype != :object
981
+ expect(@c.det).to be_within(@err).of(-18)
982
+ end
789
983
  end
790
984
  it "computes the exact determinant of 2x2 matrix" do
791
985
  if dtype == :byte
@@ -804,4 +998,38 @@ describe "math" do
804
998
  end
805
999
  end
806
1000
  end
1001
+
1002
+ context "#scale and #scale!" do
1003
+ [:dense,:list,:yale].each do |stype|
1004
+ ALL_DTYPES.each do |dtype|
1005
+ next if dtype == :object
1006
+ context "for #{dtype}" do
1007
+ before do
1008
+ @m = NMatrix.new([3, 3], [0, 1, 2,
1009
+ 3, 4, 5,
1010
+ 6, 7, 8], stype: stype, dtype: dtype)
1011
+ end
1012
+ it "scales the matrix by a given factor and return the result" do
1013
+ if integer_dtype? dtype
1014
+ expect{@m.scale 2.0}.to raise_error(DataTypeError)
1015
+ else
1016
+ expect(@m.scale 2.0).to eq(NMatrix.new([3, 3], [0, 2, 4,
1017
+ 6, 8, 10,
1018
+ 12, 14, 16], stype: stype, dtype: dtype))
1019
+ end
1020
+ end
1021
+ it "scales the matrix in place by a given factor" do
1022
+ if dtype == :int8
1023
+ expect{@m.scale! 2}.to raise_error(DataTypeError)
1024
+ else
1025
+ @m.scale! 2
1026
+ expect(@m).to eq(NMatrix.new([3, 3], [0, 2, 4,
1027
+ 6, 8, 10,
1028
+ 12, 14, 16], stype: stype, dtype: dtype))
1029
+ end
1030
+ end
1031
+ end
1032
+ end
1033
+ end
1034
+ end
807
1035
  end