nmatrix-lapacke 0.2.1 → 0.2.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/ext/nmatrix/data/data.h +7 -8
- data/ext/nmatrix/data/ruby_object.h +1 -4
- data/ext/nmatrix/math/asum.h +10 -31
- data/ext/nmatrix/math/cblas_templates_core.h +10 -10
- data/ext/nmatrix/math/getrf.h +2 -2
- data/ext/nmatrix/math/imax.h +12 -9
- data/ext/nmatrix/math/laswp.h +3 -3
- data/ext/nmatrix/math/long_dtype.h +16 -3
- data/ext/nmatrix/math/magnitude.h +54 -0
- data/ext/nmatrix/math/nrm2.h +19 -14
- data/ext/nmatrix/math/trsm.h +40 -36
- data/ext/nmatrix/math/util.h +14 -0
- data/ext/nmatrix/nmatrix.h +39 -1
- data/ext/nmatrix/storage/common.h +9 -3
- data/ext/nmatrix/storage/yale/class.h +1 -1
- data/ext/nmatrix_lapacke/extconf.rb +3 -136
- data/ext/nmatrix_lapacke/lapacke.cpp +104 -84
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgeqrf.c +77 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgeqrf_work.c +89 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_cunmqr.c +88 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_cunmqr_work.c +111 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgeqrf.c +75 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgeqrf_work.c +87 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_dormqr.c +86 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_dormqr_work.c +109 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgeqrf.c +75 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgeqrf_work.c +87 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_sormqr.c +86 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_sormqr_work.c +109 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgeqrf.c +77 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgeqrf_work.c +89 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_zunmqr.c +88 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_zunmqr_work.c +111 -0
- data/ext/nmatrix_lapacke/lapacke/utils/lapacke_c_nancheck.c +51 -0
- data/ext/nmatrix_lapacke/lapacke/utils/lapacke_d_nancheck.c +51 -0
- data/ext/nmatrix_lapacke/lapacke/utils/lapacke_s_nancheck.c +51 -0
- data/ext/nmatrix_lapacke/lapacke/utils/lapacke_z_nancheck.c +51 -0
- data/ext/nmatrix_lapacke/math_lapacke.cpp +149 -17
- data/ext/nmatrix_lapacke/math_lapacke/lapacke_templates.h +76 -0
- data/lib/nmatrix/lapacke.rb +118 -0
- data/spec/00_nmatrix_spec.rb +50 -1
- data/spec/02_slice_spec.rb +21 -21
- data/spec/blas_spec.rb +25 -3
- data/spec/math_spec.rb +233 -5
- data/spec/plugins/lapacke/lapacke_spec.rb +187 -0
- data/spec/shortcuts_spec.rb +145 -5
- data/spec/spec_helper.rb +24 -1
- metadata +38 -8
@@ -64,6 +64,82 @@ inline int lapacke_getrf(const enum CBLAS_ORDER order, const int m, const int n,
|
|
64
64
|
return getrf<DType>(order, m, n, static_cast<DType*>(a), lda, ipiv);
|
65
65
|
}
|
66
66
|
|
67
|
+
//geqrf
|
68
|
+
template <typename DType>
|
69
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, DType* a, const int lda, DType* tau) {
|
70
|
+
rb_raise(rb_eNotImpError, "lapacke_geqrf not implemented for non_BLAS dtypes.");
|
71
|
+
return 0;
|
72
|
+
}
|
73
|
+
|
74
|
+
template <>
|
75
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, float* a, const int lda, float* tau) {
|
76
|
+
return LAPACKE_sgeqrf(order, m, n, a, lda, tau);
|
77
|
+
}
|
78
|
+
|
79
|
+
template < >
|
80
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, double* a, const int lda, double* tau) {
|
81
|
+
return LAPACKE_dgeqrf(order, m, n, a, lda, tau);
|
82
|
+
}
|
83
|
+
|
84
|
+
template <>
|
85
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, Complex64* a, const int lda, Complex64* tau) {
|
86
|
+
return LAPACKE_cgeqrf(order, m, n, a, lda, tau);
|
87
|
+
}
|
88
|
+
|
89
|
+
template <>
|
90
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, Complex128* a, const int lda, Complex128* tau) {
|
91
|
+
return LAPACKE_zgeqrf(order, m, n, a, lda, tau);
|
92
|
+
}
|
93
|
+
|
94
|
+
template <typename DType>
|
95
|
+
inline int lapacke_geqrf(const enum CBLAS_ORDER order, const int m, const int n, void* a, const int lda, void* tau) {
|
96
|
+
return geqrf<DType>(order, m, n, static_cast<DType*>(a), lda, static_cast<DType*>(tau));
|
97
|
+
}
|
98
|
+
|
99
|
+
//ormqr
|
100
|
+
template <typename DType>
|
101
|
+
inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, DType* a, const int lda, DType* tau, DType* c, const int ldc) {
|
102
|
+
rb_raise(rb_eNotImpError, "lapacke_ormqr not implemented for non_BLAS dtypes.");
|
103
|
+
return 0;
|
104
|
+
}
|
105
|
+
|
106
|
+
template <>
|
107
|
+
inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, float* a, const int lda, float* tau, float* c, const int ldc) {
|
108
|
+
return LAPACKE_sormqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
|
109
|
+
}
|
110
|
+
|
111
|
+
template <>
|
112
|
+
inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, double* a, const int lda, double* tau, double* c, const int ldc) {
|
113
|
+
return LAPACKE_dormqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
|
114
|
+
}
|
115
|
+
|
116
|
+
template <typename DType>
|
117
|
+
inline int lapacke_ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, void* a, const int lda, void* tau, void* c, const int ldc) {
|
118
|
+
return ormqr<DType>(order, side, trans, m, n, k, static_cast<DType*>(a), lda, static_cast<DType*>(tau), static_cast<DType*>(c), ldc);
|
119
|
+
}
|
120
|
+
|
121
|
+
//unmqr
|
122
|
+
template <typename DType>
|
123
|
+
inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, DType* a, const int lda, DType* tau, DType* c, const int ldc) {
|
124
|
+
rb_raise(rb_eNotImpError, "lapacke_unmqr not implemented for non complex dtypes.");
|
125
|
+
return 0;
|
126
|
+
}
|
127
|
+
|
128
|
+
template <>
|
129
|
+
inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, Complex64* a, const int lda, Complex64* tau, Complex64* c, const int ldc) {
|
130
|
+
return LAPACKE_cunmqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
|
131
|
+
}
|
132
|
+
|
133
|
+
template <>
|
134
|
+
inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, Complex128* a, const int lda, Complex128* tau, Complex128* c, const int ldc) {
|
135
|
+
return LAPACKE_zunmqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
|
136
|
+
}
|
137
|
+
|
138
|
+
template <typename DType>
|
139
|
+
inline int lapacke_unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, void* a, const int lda, void* tau, void* c, const int ldc) {
|
140
|
+
return unmqr<DType>(order, side, trans, m, n, k, static_cast<DType*>(a), lda, static_cast<DType*>(tau), static_cast<DType*>(c), ldc);
|
141
|
+
}
|
142
|
+
|
67
143
|
//getri
|
68
144
|
template <typename DType>
|
69
145
|
inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) {
|
data/lib/nmatrix/lapacke.rb
CHANGED
@@ -232,4 +232,122 @@ class NMatrix
|
|
232
232
|
raise(ArgumentError, "#{opts[:form]} is not a valid form option")
|
233
233
|
end
|
234
234
|
end
|
235
|
+
|
236
|
+
#
|
237
|
+
# call-seq:
|
238
|
+
# geqrf! -> shape.min x 1 NMatrix
|
239
|
+
#
|
240
|
+
# QR factorization of a general M-by-N matrix +A+.
|
241
|
+
#
|
242
|
+
# The QR factorization is A = QR, where Q is orthogonal and R is Upper Triangular
|
243
|
+
# +A+ is overwritten with the elements of R and Q with Q being represented by the
|
244
|
+
# elements below A's diagonal and an array of scalar factors in the output NMatrix.
|
245
|
+
#
|
246
|
+
# The matrix Q is represented as a product of elementary reflectors
|
247
|
+
# Q = H(1) H(2) . . . H(k), where k = min(m,n).
|
248
|
+
#
|
249
|
+
# Each H(i) has the form
|
250
|
+
#
|
251
|
+
# H(i) = I - tau * v * v'
|
252
|
+
#
|
253
|
+
# http://www.netlib.org/lapack/explore-html/d3/d69/dgeqrf_8f.html
|
254
|
+
#
|
255
|
+
# Only works for dense matrices.
|
256
|
+
#
|
257
|
+
# * *Returns* :
|
258
|
+
# - Vector TAU. Q and R are stored in A. Q is represented by TAU and A
|
259
|
+
# * *Raises* :
|
260
|
+
# - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
|
261
|
+
#
|
262
|
+
def geqrf!
|
263
|
+
raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense?
|
264
|
+
|
265
|
+
tau = NMatrix.new([self.shape.min,1], dtype: self.dtype)
|
266
|
+
NMatrix::LAPACK::lapacke_geqrf(:row, self.shape[0], self.shape[1], self, self.shape[1], tau)
|
267
|
+
|
268
|
+
tau
|
269
|
+
end
|
270
|
+
|
271
|
+
#
|
272
|
+
# call-seq:
|
273
|
+
# ormqr(tau) -> NMatrix
|
274
|
+
# ormqr(tau, side, transpose, c) -> NMatrix
|
275
|
+
#
|
276
|
+
# Returns the product Q * c or c * Q after a call to geqrf! used in QR factorization.
|
277
|
+
# +c+ is overwritten with the elements of the result NMatrix if supplied. Q is the orthogonal matrix
|
278
|
+
# represented by tau and the calling NMatrix
|
279
|
+
#
|
280
|
+
# Only works on float types, use unmqr for complex types.
|
281
|
+
#
|
282
|
+
# == Arguments
|
283
|
+
#
|
284
|
+
# * +tau+ - vector containing scalar factors of elementary reflectors
|
285
|
+
# * +side+ - direction of multiplication [:left, :right]
|
286
|
+
# * +transpose+ - apply Q with or without transpose [false, :transpose]
|
287
|
+
# * +c+ - NMatrix multplication argument that is overwritten, no argument assumes c = identity
|
288
|
+
#
|
289
|
+
# * *Returns* :
|
290
|
+
#
|
291
|
+
# - Q * c or c * Q Where Q may be transposed before multiplication.
|
292
|
+
#
|
293
|
+
#
|
294
|
+
# * *Raises* :
|
295
|
+
# - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
|
296
|
+
# - +TypeError+ -> Works only on floating point matrices, use unmqr for complex types
|
297
|
+
# - +TypeError+ -> c must have the same dtype as the calling NMatrix
|
298
|
+
#
|
299
|
+
def ormqr(tau, side=:left, transpose=false, c=nil)
|
300
|
+
raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense?
|
301
|
+
raise(TypeError, "Works only on floating point matrices, use unmqr for complex types") if self.complex_dtype?
|
302
|
+
raise(TypeError, "c must have the same dtype as the calling NMatrix") if c and c.dtype != self.dtype
|
303
|
+
|
304
|
+
|
305
|
+
#Default behaviour produces Q * I = Q if c is not supplied.
|
306
|
+
result = c ? c : NMatrix.identity(self.shape[0], dtype: self.dtype)
|
307
|
+
NMatrix::LAPACK::lapacke_ormqr(:row, side, transpose, result.shape[0], result.shape[1], tau.shape[0], self, self.shape[1], tau, result, result.shape[1])
|
308
|
+
|
309
|
+
result
|
310
|
+
end
|
311
|
+
|
312
|
+
#
|
313
|
+
# call-seq:
|
314
|
+
# unmqr(tau) -> NMatrix
|
315
|
+
# unmqr(tau, side, transpose, c) -> NMatrix
|
316
|
+
#
|
317
|
+
# Returns the product Q * c or c * Q after a call to geqrf! used in QR factorization.
|
318
|
+
# +c+ is overwritten with the elements of the result NMatrix if it is supplied. Q is the orthogonal matrix
|
319
|
+
# represented by tau and the calling NMatrix
|
320
|
+
#
|
321
|
+
# Only works on complex types, use ormqr for float types.
|
322
|
+
#
|
323
|
+
# == Arguments
|
324
|
+
#
|
325
|
+
# * +tau+ - vector containing scalar factors of elementary reflectors
|
326
|
+
# * +side+ - direction of multiplication [:left, :right]
|
327
|
+
# * +transpose+ - apply Q as Q or its complex conjugate [false, :complex_conjugate]
|
328
|
+
# * +c+ - NMatrix multplication argument that is overwritten, no argument assumes c = identity
|
329
|
+
#
|
330
|
+
# * *Returns* :
|
331
|
+
#
|
332
|
+
# - Q * c or c * Q Where Q may be transformed to its complex conjugate before multiplication.
|
333
|
+
#
|
334
|
+
#
|
335
|
+
# * *Raises* :
|
336
|
+
# - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
|
337
|
+
# - +TypeError+ -> Works only on floating point matrices, use unmqr for complex types
|
338
|
+
# - +TypeError+ -> c must have the same dtype as the calling NMatrix
|
339
|
+
#
|
340
|
+
def unmqr(tau, side=:left, transpose=false, c=nil)
|
341
|
+
raise(StorageTypeError, "ATLAS functions only work on dense matrices") unless self.dense?
|
342
|
+
raise(TypeError, "Works only on complex matrices, use ormqr for normal floating point matrices") unless self.complex_dtype?
|
343
|
+
raise(TypeError, "c must have the same dtype as the calling NMatrix") if c and c.dtype != self.dtype
|
344
|
+
|
345
|
+
#Default behaviour produces Q * I = Q if c is not supplied.
|
346
|
+
result = c ? c : NMatrix.identity(self.shape[0], dtype: self.dtype)
|
347
|
+
NMatrix::LAPACK::lapacke_unmqr(:row, side, transpose, result.shape[0], result.shape[1], tau.shape[0], self, self.shape[1], tau, result, result.shape[1])
|
348
|
+
|
349
|
+
result
|
350
|
+
end
|
351
|
+
|
352
|
+
|
235
353
|
end
|
data/spec/00_nmatrix_spec.rb
CHANGED
@@ -424,6 +424,13 @@ describe 'NMatrix' do
|
|
424
424
|
expect(n.reshape!([8,2]).eql?(n)).to eq(true) # because n itself changes
|
425
425
|
end
|
426
426
|
|
427
|
+
it "should do the reshape operation in place, changing dimension" do
|
428
|
+
n = NMatrix.seq(4)
|
429
|
+
a = n.reshape!([4,2,2])
|
430
|
+
expect(n).to eq(NMatrix.seq([4,2,2]))
|
431
|
+
expect(a).to eq(NMatrix.seq([4,2,2]))
|
432
|
+
end
|
433
|
+
|
427
434
|
it "reshape and reshape! must produce same result" do
|
428
435
|
n = NMatrix.seq(4)+1
|
429
436
|
a = NMatrix.seq(4)+1
|
@@ -432,7 +439,7 @@ describe 'NMatrix' do
|
|
432
439
|
|
433
440
|
it "should prevent a resize in place" do
|
434
441
|
n = NMatrix.seq(4)+1
|
435
|
-
expect { n.reshape([5,2]) }.to raise_error(ArgumentError)
|
442
|
+
expect { n.reshape!([5,2]) }.to raise_error(ArgumentError)
|
436
443
|
end
|
437
444
|
end
|
438
445
|
|
@@ -534,6 +541,26 @@ describe 'NMatrix' do
|
|
534
541
|
n = NMatrix.new([1,3,1], [1,2,3])
|
535
542
|
expect(n.dconcat(n)).to eq(NMatrix.new([1,3,2], [1,1,2,2,3,3]))
|
536
543
|
end
|
544
|
+
|
545
|
+
it "should work on matrices with different size along concat dim" do
|
546
|
+
n = N[[1, 2, 3],
|
547
|
+
[4, 5, 6]]
|
548
|
+
m = N[[7],
|
549
|
+
[8]]
|
550
|
+
|
551
|
+
expect(n.hconcat(m)).to eq N[[1, 2, 3, 7], [4, 5, 6, 8]]
|
552
|
+
expect(m.hconcat(n)).to eq N[[7, 1, 2, 3], [8, 4, 5, 6]]
|
553
|
+
end
|
554
|
+
|
555
|
+
it "should work on matrices with different size along concat dim" do
|
556
|
+
n = N[[1, 2, 3],
|
557
|
+
[4, 5, 6]]
|
558
|
+
|
559
|
+
m = N[[7, 8, 9]]
|
560
|
+
|
561
|
+
expect(n.vconcat(m)).to eq N[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
562
|
+
expect(m.vconcat(n)).to eq N[[7, 8, 9], [1, 2, 3], [4, 5, 6]]
|
563
|
+
end
|
537
564
|
end
|
538
565
|
|
539
566
|
context "#[]" do
|
@@ -631,6 +658,23 @@ describe 'NMatrix' do
|
|
631
658
|
end
|
632
659
|
end
|
633
660
|
|
661
|
+
context "#last" do
|
662
|
+
it "returns the last element of a 1-dimensional NMatrix" do
|
663
|
+
n = NMatrix.new([1,4], [1,2,3,4])
|
664
|
+
expect(n.last).to eq(4)
|
665
|
+
end
|
666
|
+
|
667
|
+
it "returns the last element of a 2-dimensional NMatrix" do
|
668
|
+
n = NMatrix.new([2,2], [4,8,12,16])
|
669
|
+
expect(n.last).to eq(16)
|
670
|
+
end
|
671
|
+
|
672
|
+
it "returns the last element of a 3-dimensional NMatrix" do
|
673
|
+
n = NMatrix.new([2,2,2], [1,2,3,4,5,6,7,8])
|
674
|
+
expect(n.last).to eq(8)
|
675
|
+
end
|
676
|
+
end
|
677
|
+
|
634
678
|
context "#diagonal" do
|
635
679
|
ALL_DTYPES.each do |dtype|
|
636
680
|
before do
|
@@ -682,6 +726,11 @@ describe 'NMatrix' do
|
|
682
726
|
expect(@sample_matrix.repeat(2, 0)).to eq(NMatrix.new([4, 2], [1, 2, 3, 4, 1, 2, 3, 4]))
|
683
727
|
expect(@sample_matrix.repeat(2, 1)).to eq(NMatrix.new([2, 4], [1, 2, 1, 2, 3, 4, 3, 4]))
|
684
728
|
end
|
729
|
+
|
730
|
+
it "preserves dtype" do
|
731
|
+
expect(@sample_matrix.repeat(2, 0).dtype).to eq(@sample_matrix.dtype)
|
732
|
+
expect(@sample_matrix.repeat(2, 1).dtype).to eq(@sample_matrix.dtype)
|
733
|
+
end
|
685
734
|
end
|
686
735
|
|
687
736
|
context "#meshgrid" do
|
data/spec/02_slice_spec.rb
CHANGED
@@ -127,9 +127,9 @@ describe "Slice operation" do
|
|
127
127
|
it "should have #is_ref? method" do
|
128
128
|
a = stype_matrix[0..1, 0..1]
|
129
129
|
b = stype_matrix.slice(0..1, 0..1)
|
130
|
-
expect(stype_matrix.is_ref?).to
|
131
|
-
expect(a.is_ref?).to
|
132
|
-
expect(b.is_ref?).to
|
130
|
+
expect(stype_matrix.is_ref?).to be false
|
131
|
+
expect(a.is_ref?).to be true
|
132
|
+
expect(b.is_ref?).to be false
|
133
133
|
end
|
134
134
|
|
135
135
|
it "reference should compare with non-reference" do
|
@@ -141,7 +141,7 @@ describe "Slice operation" do
|
|
141
141
|
context "with copying" do
|
142
142
|
it 'should return an NMatrix' do
|
143
143
|
n = stype_matrix.slice(0..1,0..1)
|
144
|
-
expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to
|
144
|
+
expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to be true
|
145
145
|
end
|
146
146
|
|
147
147
|
it 'should return a copy of 2x2 matrix to self elements' do
|
@@ -183,19 +183,19 @@ describe "Slice operation" do
|
|
183
183
|
|
184
184
|
[:dense, :list, :yale].each do |cast_type|
|
185
185
|
it "should cast copied slice from #{stype.upcase} to #{cast_type.upcase}" do
|
186
|
-
expect(nm_eql(stype_matrix.slice(1..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(1..2,1..2))).to
|
187
|
-
expect(nm_eql(stype_matrix.slice(0..1, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..1,1..2))).to
|
188
|
-
expect(nm_eql(stype_matrix.slice(1..2, 0..1).cast(cast_type, :int32), stype_matrix.slice(1..2,0..1))).to
|
189
|
-
expect(nm_eql(stype_matrix.slice(0..1, 0..1).cast(cast_type, :int32), stype_matrix.slice(0..1,0..1))).to
|
186
|
+
expect(nm_eql(stype_matrix.slice(1..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(1..2,1..2))).to be true
|
187
|
+
expect(nm_eql(stype_matrix.slice(0..1, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..1,1..2))).to be true
|
188
|
+
expect(nm_eql(stype_matrix.slice(1..2, 0..1).cast(cast_type, :int32), stype_matrix.slice(1..2,0..1))).to be true
|
189
|
+
expect(nm_eql(stype_matrix.slice(0..1, 0..1).cast(cast_type, :int32), stype_matrix.slice(0..1,0..1))).to be true
|
190
190
|
|
191
191
|
# Non square
|
192
|
-
expect(nm_eql(stype_matrix.slice(0..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..2,1..2))).to
|
192
|
+
expect(nm_eql(stype_matrix.slice(0..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..2,1..2))).to be true
|
193
193
|
#require 'pry'
|
194
194
|
#binding.pry if cast_type == :yale
|
195
|
-
expect(nm_eql(stype_matrix.slice(1..2, 0..2).cast(cast_type, :int32), stype_matrix.slice(1..2,0..2))).to
|
195
|
+
expect(nm_eql(stype_matrix.slice(1..2, 0..2).cast(cast_type, :int32), stype_matrix.slice(1..2,0..2))).to be true
|
196
196
|
|
197
197
|
# Full
|
198
|
-
expect(nm_eql(stype_matrix.slice(0..2, 0..2).cast(cast_type, :int32), stype_matrix)).to
|
198
|
+
expect(nm_eql(stype_matrix.slice(0..2, 0..2).cast(cast_type, :int32), stype_matrix)).to be true
|
199
199
|
end
|
200
200
|
end
|
201
201
|
end
|
@@ -214,7 +214,7 @@ describe "Slice operation" do
|
|
214
214
|
context "by reference" do
|
215
215
|
it 'should return an NMatrix' do
|
216
216
|
n = stype_matrix[0..1,0..1]
|
217
|
-
expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to
|
217
|
+
expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to be true
|
218
218
|
end
|
219
219
|
|
220
220
|
it 'should return a 2x2 matrix with refs to self elements' do
|
@@ -246,7 +246,7 @@ describe "Slice operation" do
|
|
246
246
|
|
247
247
|
it 'should slice again' do
|
248
248
|
n = stype_matrix[1..2, 1..2]
|
249
|
-
expect(nm_eql(n[1,0..1], NVector.new(2, [7,8], dtype: :int32).transpose)).to
|
249
|
+
expect(nm_eql(n[1,0..1], NVector.new(2, [7,8], dtype: :int32).transpose)).to be true
|
250
250
|
end
|
251
251
|
|
252
252
|
it 'should be correct slice for range 0..2 and 0...3' do
|
@@ -320,7 +320,7 @@ describe "Slice operation" do
|
|
320
320
|
end
|
321
321
|
|
322
322
|
it "compares slices to scalars" do
|
323
|
-
(stype_matrix[1, 0..2] > 2).each { |e| expect(e != 0).to
|
323
|
+
(stype_matrix[1, 0..2] > 2).each { |e| expect(e != 0).to be true }
|
324
324
|
end
|
325
325
|
|
326
326
|
it "iterates only over elements in the slice" do
|
@@ -367,20 +367,20 @@ describe "Slice operation" do
|
|
367
367
|
|
368
368
|
[:dense, :list, :yale].each do |cast_type|
|
369
369
|
it "should cast a square reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
|
370
|
-
expect(nm_eql(stype_matrix[1..2, 1..2].cast(cast_type), stype_matrix[1..2,1..2])).to
|
371
|
-
expect(nm_eql(stype_matrix[0..1, 1..2].cast(cast_type), stype_matrix[0..1,1..2])).to
|
372
|
-
expect(nm_eql(stype_matrix[1..2, 0..1].cast(cast_type), stype_matrix[1..2,0..1])).to
|
373
|
-
expect(nm_eql(stype_matrix[0..1, 0..1].cast(cast_type), stype_matrix[0..1,0..1])).to
|
370
|
+
expect(nm_eql(stype_matrix[1..2, 1..2].cast(cast_type), stype_matrix[1..2,1..2])).to be true
|
371
|
+
expect(nm_eql(stype_matrix[0..1, 1..2].cast(cast_type), stype_matrix[0..1,1..2])).to be true
|
372
|
+
expect(nm_eql(stype_matrix[1..2, 0..1].cast(cast_type), stype_matrix[1..2,0..1])).to be true
|
373
|
+
expect(nm_eql(stype_matrix[0..1, 0..1].cast(cast_type), stype_matrix[0..1,0..1])).to be true
|
374
374
|
end
|
375
375
|
|
376
376
|
it "should cast a rectangular reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
|
377
377
|
# Non square
|
378
|
-
expect(nm_eql(stype_matrix[0..2, 1..2].cast(cast_type), stype_matrix[0..2,1..2])).to
|
379
|
-
expect(nm_eql(stype_matrix[1..2, 0..2].cast(cast_type), stype_matrix[1..2,0..2])).to
|
378
|
+
expect(nm_eql(stype_matrix[0..2, 1..2].cast(cast_type), stype_matrix[0..2,1..2])).to be true # FIXME: memory problem.
|
379
|
+
expect(nm_eql(stype_matrix[1..2, 0..2].cast(cast_type), stype_matrix[1..2,0..2])).to be true # this one is fine
|
380
380
|
end
|
381
381
|
|
382
382
|
it "should cast a square full-matrix reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
|
383
|
-
expect(nm_eql(stype_matrix[0..2, 0..2].cast(cast_type), stype_matrix)).to
|
383
|
+
expect(nm_eql(stype_matrix[0..2, 0..2].cast(cast_type), stype_matrix)).to be true
|
384
384
|
end
|
385
385
|
end
|
386
386
|
end
|
data/spec/blas_spec.rb
CHANGED
@@ -69,6 +69,18 @@ describe NMatrix::BLAS do
|
|
69
69
|
expect(b[0]).to eq(-15.0/2)
|
70
70
|
expect(b[1]).to eq(5)
|
71
71
|
expect(b[2]).to eq(-13)
|
72
|
+
|
73
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :lower, :transpose, :nounit, 3, 1, 1.0, a, 3, b, 1)
|
74
|
+
|
75
|
+
expect(b[0]).to eq(307.0/8)
|
76
|
+
expect(b[1]).to eq(57.0/2)
|
77
|
+
expect(b[2]).to eq(26.0)
|
78
|
+
|
79
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :upper, :transpose, :unit, 3, 1, 1.0, a, 3, b, 1)
|
80
|
+
|
81
|
+
expect(b[0]).to eq(307.0/8)
|
82
|
+
expect(b[1]).to eq(763.0/16)
|
83
|
+
expect(b[2]).to eq(4269.0/64)
|
72
84
|
end
|
73
85
|
|
74
86
|
# trmm multiplies two matrices, where one of the two is required to be
|
@@ -174,9 +186,17 @@ describe NMatrix::BLAS do
|
|
174
186
|
|
175
187
|
it "exposes nrm2" do
|
176
188
|
pending("broken for :object") if dtype == :object
|
177
|
-
pending("Temporarily disable because the internal implementation of nrm2 is broken -WL 2015-05-17") if dtype == :complex64 || dtype == :complex128
|
178
189
|
|
179
|
-
|
190
|
+
if dtype =~ /complex/
|
191
|
+
x = NMatrix.new([3,1], [Complex(1,2),Complex(3,4),Complex(0,6)], dtype: dtype)
|
192
|
+
y = NMatrix.new([3,1], [Complex(0,0),Complex(0,0),Complex(0,0)], dtype: dtype)
|
193
|
+
nrm2 = 8.12403840463596
|
194
|
+
else
|
195
|
+
x = NMatrix.new([4,1], [2,-4,3,5], dtype: dtype)
|
196
|
+
y = NMatrix.new([3,1], [0,0,0], dtype: dtype)
|
197
|
+
nrm2 = 5.385164807134504
|
198
|
+
end
|
199
|
+
|
180
200
|
err = case dtype
|
181
201
|
when :float32, :complex64
|
182
202
|
1e-6
|
@@ -185,7 +205,9 @@ describe NMatrix::BLAS do
|
|
185
205
|
else
|
186
206
|
1e-14
|
187
207
|
end
|
188
|
-
|
208
|
+
|
209
|
+
expect(NMatrix::BLAS.nrm2(x, 1, 3)).to be_within(err).of(nrm2)
|
210
|
+
expect(NMatrix::BLAS.nrm2(y, 1, 3)).to be_within(err).of(0)
|
189
211
|
end
|
190
212
|
|
191
213
|
end
|
data/spec/math_spec.rb
CHANGED
@@ -285,6 +285,152 @@ describe "math" do
|
|
285
285
|
end
|
286
286
|
end
|
287
287
|
|
288
|
+
NON_INTEGER_DTYPES.each do |dtype|
|
289
|
+
next if dtype == :object
|
290
|
+
context dtype do
|
291
|
+
|
292
|
+
it "calculates QR decomposition using factorize_qr for a square matrix" do
|
293
|
+
|
294
|
+
a = NMatrix.new(3, [12.0, -51.0, 4.0,
|
295
|
+
6.0, 167.0, -68.0,
|
296
|
+
-4.0, 24.0, -41.0] , dtype: dtype)
|
297
|
+
|
298
|
+
q_solution = NMatrix.new([3,3], Q_SOLUTION_ARRAY_2, dtype: dtype)
|
299
|
+
|
300
|
+
r_solution = NMatrix.new([3,3], [-14.0, -21.0, 14,
|
301
|
+
0.0, -175, 70,
|
302
|
+
0.0, 0.0, -35] , dtype: dtype)
|
303
|
+
|
304
|
+
err = case dtype
|
305
|
+
when :float32, :complex64
|
306
|
+
1e-4
|
307
|
+
when :float64, :complex128
|
308
|
+
1e-13
|
309
|
+
end
|
310
|
+
|
311
|
+
begin
|
312
|
+
q,r = a.factorize_qr
|
313
|
+
|
314
|
+
expect(q).to be_within(err).of(q_solution)
|
315
|
+
expect(r).to be_within(err).of(r_solution)
|
316
|
+
|
317
|
+
rescue NotImplementedError
|
318
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
319
|
+
end
|
320
|
+
end
|
321
|
+
|
322
|
+
it "calculates QR decomposition using factorize_qr for a tall and narrow rectangular matrix" do
|
323
|
+
|
324
|
+
a = NMatrix.new([4,2], [34.0, 21.0,
|
325
|
+
23.0, 53.0,
|
326
|
+
26.0, 346.0,
|
327
|
+
23.0, 121.0] , dtype: dtype)
|
328
|
+
|
329
|
+
q_solution = NMatrix.new([4,4], Q_SOLUTION_ARRAY_1, dtype: dtype)
|
330
|
+
|
331
|
+
r_solution = NMatrix.new([4,2], [-53.75872022286244, -255.06559574252242,
|
332
|
+
0.0, 269.34836526051555,
|
333
|
+
0.0, 0.0,
|
334
|
+
0.0, 0.0] , dtype: dtype)
|
335
|
+
|
336
|
+
err = case dtype
|
337
|
+
when :float32, :complex64
|
338
|
+
1e-4
|
339
|
+
when :float64, :complex128
|
340
|
+
1e-13
|
341
|
+
end
|
342
|
+
|
343
|
+
begin
|
344
|
+
q,r = a.factorize_qr
|
345
|
+
|
346
|
+
expect(q).to be_within(err).of(q_solution)
|
347
|
+
expect(r).to be_within(err).of(r_solution)
|
348
|
+
|
349
|
+
rescue NotImplementedError
|
350
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
351
|
+
end
|
352
|
+
end
|
353
|
+
|
354
|
+
it "calculates QR decomposition using factorize_qr for a short and wide rectangular matrix" do
|
355
|
+
|
356
|
+
a = NMatrix.new([3,4], [123,31,57,81,92,14,17,36,42,34,11,28], dtype: dtype)
|
357
|
+
|
358
|
+
q_solution = NMatrix.new([3,3], Q_SOLUTION_ARRAY_3, dtype: dtype)
|
359
|
+
|
360
|
+
r_solution = NMatrix.new([3,4], R_SOLUTION_ARRAY, dtype: dtype)
|
361
|
+
|
362
|
+
err = case dtype
|
363
|
+
when :float32, :complex64
|
364
|
+
1e-4
|
365
|
+
when :float64, :complex128
|
366
|
+
1e-13
|
367
|
+
end
|
368
|
+
|
369
|
+
begin
|
370
|
+
q,r = a.factorize_qr
|
371
|
+
|
372
|
+
expect(q).to be_within(err).of(q_solution)
|
373
|
+
expect(r).to be_within(err).of(r_solution)
|
374
|
+
|
375
|
+
rescue NotImplementedError
|
376
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
377
|
+
end
|
378
|
+
end
|
379
|
+
|
380
|
+
it "calculates QR decomposition such that A - QR ~ 0" do
|
381
|
+
|
382
|
+
a = NMatrix.new([3,3], [ 9.0, 0.0, 26.0,
|
383
|
+
12.0, 0.0, -7.0,
|
384
|
+
0.0, 4.0, 0.0] , dtype: dtype)
|
385
|
+
|
386
|
+
err = case dtype
|
387
|
+
when :float32, :complex64
|
388
|
+
1e-4
|
389
|
+
when :float64, :complex128
|
390
|
+
1e-13
|
391
|
+
end
|
392
|
+
|
393
|
+
begin
|
394
|
+
q,r = a.factorize_qr
|
395
|
+
a_expected = q.dot(r)
|
396
|
+
|
397
|
+
expect(a_expected).to be_within(err).of(a)
|
398
|
+
|
399
|
+
rescue NotImplementedError
|
400
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
401
|
+
end
|
402
|
+
end
|
403
|
+
|
404
|
+
|
405
|
+
it "calculates the orthogonal matrix Q in QR decomposition" do
|
406
|
+
|
407
|
+
a = N.new([2,2], [34.0, 21, 23, 53] , dtype: dtype)
|
408
|
+
|
409
|
+
err = case dtype
|
410
|
+
when :float32, :complex64
|
411
|
+
1e-4
|
412
|
+
when :float64, :complex128
|
413
|
+
1e-13
|
414
|
+
end
|
415
|
+
|
416
|
+
begin
|
417
|
+
q,r = a.factorize_qr
|
418
|
+
|
419
|
+
#Q is orthogonal if Q x Q.transpose = I
|
420
|
+
product = q.dot(q.transpose)
|
421
|
+
|
422
|
+
expect(product[0,0]).to be_within(err).of(1)
|
423
|
+
expect(product[1,0]).to be_within(err).of(0)
|
424
|
+
expect(product[0,1]).to be_within(err).of(0)
|
425
|
+
expect(product[1,1]).to be_within(err).of(1)
|
426
|
+
|
427
|
+
rescue NotImplementedError
|
428
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
429
|
+
end
|
430
|
+
end
|
431
|
+
end
|
432
|
+
end
|
433
|
+
|
288
434
|
ALL_DTYPES.each do |dtype|
|
289
435
|
next if dtype == :byte #doesn't work for unsigned types
|
290
436
|
next if dtype == :object
|
@@ -332,6 +478,49 @@ describe "math" do
|
|
332
478
|
end
|
333
479
|
end
|
334
480
|
|
481
|
+
ALL_DTYPES.each do |dtype|
|
482
|
+
next if dtype == :byte #doesn't work for unsigned types
|
483
|
+
next if dtype == :object
|
484
|
+
|
485
|
+
context dtype do
|
486
|
+
err = case dtype
|
487
|
+
when :float32, :complex64
|
488
|
+
1e-4
|
489
|
+
else #integer matrices will return :float64
|
490
|
+
1e-13
|
491
|
+
end
|
492
|
+
|
493
|
+
it "should correctly find adjugate a matrix in place (bang)" do
|
494
|
+
a = NMatrix.new(:dense, 2, [2, 3, 3, 5], dtype)
|
495
|
+
b = NMatrix.new(:dense, 2, [5, -3, -3, 2], dtype)
|
496
|
+
|
497
|
+
if a.integer_dtype?
|
498
|
+
expect{a.adjugate!}.to raise_error(DataTypeError)
|
499
|
+
else
|
500
|
+
#should return adjugate as well as modifying a
|
501
|
+
r = a.adjugate!
|
502
|
+
expect(a).to be_within(err).of(b)
|
503
|
+
expect(r).to be_within(err).of(b)
|
504
|
+
end
|
505
|
+
end
|
506
|
+
|
507
|
+
|
508
|
+
it "should correctly find adjugate of a matrix out-of-place" do
|
509
|
+
a = NMatrix.new(:dense, 3, [-3, 2, -5, -1, 0, -2, 3, -4, 1], dtype)
|
510
|
+
|
511
|
+
if a.integer_dtype?
|
512
|
+
b = NMatrix.new(:dense, 3, [-8, 18, -4, -5, 12, -1, 4, -6, 2], :float64)
|
513
|
+
else
|
514
|
+
b = NMatrix.new(:dense, 3, [-8, 18, -4, -5, 12, -1, 4, -6, 2], dtype)
|
515
|
+
end
|
516
|
+
|
517
|
+
expect(a.adjoint).to be_within(err).of(b)
|
518
|
+
expect(a.adjugate).to be_within(err).of(b)
|
519
|
+
end
|
520
|
+
|
521
|
+
end
|
522
|
+
end
|
523
|
+
|
335
524
|
# TODO: Get it working with ROBJ too
|
336
525
|
[:byte,:int8,:int16,:int32,:int64,:float32,:float64].each do |left_dtype|
|
337
526
|
[:byte,:int8,:int16,:int32,:int64,:float32,:float64].each do |right_dtype|
|
@@ -702,7 +891,7 @@ describe "math" do
|
|
702
891
|
360, 96, 51, -14,
|
703
892
|
448,-231,-24,-87,
|
704
893
|
-1168, 595,234, 523],
|
705
|
-
dtype: answer_dtype,
|
894
|
+
dtype: answer_dtype,
|
706
895
|
stype: stype))
|
707
896
|
end
|
708
897
|
|
@@ -757,7 +946,6 @@ describe "math" do
|
|
757
946
|
|
758
947
|
context "determinants" do
|
759
948
|
ALL_DTYPES.each do |dtype|
|
760
|
-
next if dtype == :object
|
761
949
|
context dtype do
|
762
950
|
before do
|
763
951
|
@a = NMatrix.new([2,2], [1,2,
|
@@ -779,13 +967,19 @@ describe "math" do
|
|
779
967
|
end
|
780
968
|
end
|
781
969
|
it "computes the determinant of 2x2 matrix" do
|
782
|
-
|
970
|
+
if dtype != :object
|
971
|
+
expect(@a.det).to be_within(@err).of(-2)
|
972
|
+
end
|
783
973
|
end
|
784
974
|
it "computes the determinant of 3x3 matrix" do
|
785
|
-
|
975
|
+
if dtype != :object
|
976
|
+
expect(@b.det).to be_within(@err).of(-8)
|
977
|
+
end
|
786
978
|
end
|
787
979
|
it "computes the determinant of 4x4 matrix" do
|
788
|
-
|
980
|
+
if dtype != :object
|
981
|
+
expect(@c.det).to be_within(@err).of(-18)
|
982
|
+
end
|
789
983
|
end
|
790
984
|
it "computes the exact determinant of 2x2 matrix" do
|
791
985
|
if dtype == :byte
|
@@ -804,4 +998,38 @@ describe "math" do
|
|
804
998
|
end
|
805
999
|
end
|
806
1000
|
end
|
1001
|
+
|
1002
|
+
context "#scale and #scale!" do
|
1003
|
+
[:dense,:list,:yale].each do |stype|
|
1004
|
+
ALL_DTYPES.each do |dtype|
|
1005
|
+
next if dtype == :object
|
1006
|
+
context "for #{dtype}" do
|
1007
|
+
before do
|
1008
|
+
@m = NMatrix.new([3, 3], [0, 1, 2,
|
1009
|
+
3, 4, 5,
|
1010
|
+
6, 7, 8], stype: stype, dtype: dtype)
|
1011
|
+
end
|
1012
|
+
it "scales the matrix by a given factor and return the result" do
|
1013
|
+
if integer_dtype? dtype
|
1014
|
+
expect{@m.scale 2.0}.to raise_error(DataTypeError)
|
1015
|
+
else
|
1016
|
+
expect(@m.scale 2.0).to eq(NMatrix.new([3, 3], [0, 2, 4,
|
1017
|
+
6, 8, 10,
|
1018
|
+
12, 14, 16], stype: stype, dtype: dtype))
|
1019
|
+
end
|
1020
|
+
end
|
1021
|
+
it "scales the matrix in place by a given factor" do
|
1022
|
+
if dtype == :int8
|
1023
|
+
expect{@m.scale! 2}.to raise_error(DataTypeError)
|
1024
|
+
else
|
1025
|
+
@m.scale! 2
|
1026
|
+
expect(@m).to eq(NMatrix.new([3, 3], [0, 2, 4,
|
1027
|
+
6, 8, 10,
|
1028
|
+
12, 14, 16], stype: stype, dtype: dtype))
|
1029
|
+
end
|
1030
|
+
end
|
1031
|
+
end
|
1032
|
+
end
|
1033
|
+
end
|
1034
|
+
end
|
807
1035
|
end
|