nmatrix-lapacke 0.2.1 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/nmatrix/data/data.h +7 -8
- data/ext/nmatrix/data/ruby_object.h +1 -4
- data/ext/nmatrix/math/asum.h +10 -31
- data/ext/nmatrix/math/cblas_templates_core.h +10 -10
- data/ext/nmatrix/math/getrf.h +2 -2
- data/ext/nmatrix/math/imax.h +12 -9
- data/ext/nmatrix/math/laswp.h +3 -3
- data/ext/nmatrix/math/long_dtype.h +16 -3
- data/ext/nmatrix/math/magnitude.h +54 -0
- data/ext/nmatrix/math/nrm2.h +19 -14
- data/ext/nmatrix/math/trsm.h +40 -36
- data/ext/nmatrix/math/util.h +14 -0
- data/ext/nmatrix/nmatrix.h +39 -1
- data/ext/nmatrix/storage/common.h +9 -3
- data/ext/nmatrix/storage/yale/class.h +1 -1
- data/ext/nmatrix_lapacke/extconf.rb +3 -136
- data/ext/nmatrix_lapacke/lapacke.cpp +104 -84
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgeqrf.c +77 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_cgeqrf_work.c +89 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_cunmqr.c +88 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_cunmqr_work.c +111 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgeqrf.c +75 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_dgeqrf_work.c +87 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_dormqr.c +86 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_dormqr_work.c +109 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgeqrf.c +75 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_sgeqrf_work.c +87 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_sormqr.c +86 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_sormqr_work.c +109 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgeqrf.c +77 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_zgeqrf_work.c +89 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_zunmqr.c +88 -0
- data/ext/nmatrix_lapacke/lapacke/src/lapacke_zunmqr_work.c +111 -0
- data/ext/nmatrix_lapacke/lapacke/utils/lapacke_c_nancheck.c +51 -0
- data/ext/nmatrix_lapacke/lapacke/utils/lapacke_d_nancheck.c +51 -0
- data/ext/nmatrix_lapacke/lapacke/utils/lapacke_s_nancheck.c +51 -0
- data/ext/nmatrix_lapacke/lapacke/utils/lapacke_z_nancheck.c +51 -0
- data/ext/nmatrix_lapacke/math_lapacke.cpp +149 -17
- data/ext/nmatrix_lapacke/math_lapacke/lapacke_templates.h +76 -0
- data/lib/nmatrix/lapacke.rb +118 -0
- data/spec/00_nmatrix_spec.rb +50 -1
- data/spec/02_slice_spec.rb +21 -21
- data/spec/blas_spec.rb +25 -3
- data/spec/math_spec.rb +233 -5
- data/spec/plugins/lapacke/lapacke_spec.rb +187 -0
- data/spec/shortcuts_spec.rb +145 -5
- data/spec/spec_helper.rb +24 -1
- metadata +38 -8
@@ -64,6 +64,82 @@ inline int lapacke_getrf(const enum CBLAS_ORDER order, const int m, const int n,
|
|
64
64
|
return getrf<DType>(order, m, n, static_cast<DType*>(a), lda, ipiv);
|
65
65
|
}
|
66
66
|
|
67
|
+
//geqrf
|
68
|
+
template <typename DType>
|
69
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, DType* a, const int lda, DType* tau) {
|
70
|
+
rb_raise(rb_eNotImpError, "lapacke_geqrf not implemented for non_BLAS dtypes.");
|
71
|
+
return 0;
|
72
|
+
}
|
73
|
+
|
74
|
+
template <>
|
75
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, float* a, const int lda, float* tau) {
|
76
|
+
return LAPACKE_sgeqrf(order, m, n, a, lda, tau);
|
77
|
+
}
|
78
|
+
|
79
|
+
template < >
|
80
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, double* a, const int lda, double* tau) {
|
81
|
+
return LAPACKE_dgeqrf(order, m, n, a, lda, tau);
|
82
|
+
}
|
83
|
+
|
84
|
+
template <>
|
85
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, Complex64* a, const int lda, Complex64* tau) {
|
86
|
+
return LAPACKE_cgeqrf(order, m, n, a, lda, tau);
|
87
|
+
}
|
88
|
+
|
89
|
+
template <>
|
90
|
+
inline int geqrf(const enum CBLAS_ORDER order, const int m, const int n, Complex128* a, const int lda, Complex128* tau) {
|
91
|
+
return LAPACKE_zgeqrf(order, m, n, a, lda, tau);
|
92
|
+
}
|
93
|
+
|
94
|
+
template <typename DType>
|
95
|
+
inline int lapacke_geqrf(const enum CBLAS_ORDER order, const int m, const int n, void* a, const int lda, void* tau) {
|
96
|
+
return geqrf<DType>(order, m, n, static_cast<DType*>(a), lda, static_cast<DType*>(tau));
|
97
|
+
}
|
98
|
+
|
99
|
+
//ormqr
|
100
|
+
template <typename DType>
|
101
|
+
inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, DType* a, const int lda, DType* tau, DType* c, const int ldc) {
|
102
|
+
rb_raise(rb_eNotImpError, "lapacke_ormqr not implemented for non_BLAS dtypes.");
|
103
|
+
return 0;
|
104
|
+
}
|
105
|
+
|
106
|
+
template <>
|
107
|
+
inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, float* a, const int lda, float* tau, float* c, const int ldc) {
|
108
|
+
return LAPACKE_sormqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
|
109
|
+
}
|
110
|
+
|
111
|
+
template <>
|
112
|
+
inline int ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, double* a, const int lda, double* tau, double* c, const int ldc) {
|
113
|
+
return LAPACKE_dormqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
|
114
|
+
}
|
115
|
+
|
116
|
+
template <typename DType>
|
117
|
+
inline int lapacke_ormqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, void* a, const int lda, void* tau, void* c, const int ldc) {
|
118
|
+
return ormqr<DType>(order, side, trans, m, n, k, static_cast<DType*>(a), lda, static_cast<DType*>(tau), static_cast<DType*>(c), ldc);
|
119
|
+
}
|
120
|
+
|
121
|
+
//unmqr
|
122
|
+
template <typename DType>
|
123
|
+
inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, DType* a, const int lda, DType* tau, DType* c, const int ldc) {
|
124
|
+
rb_raise(rb_eNotImpError, "lapacke_unmqr not implemented for non complex dtypes.");
|
125
|
+
return 0;
|
126
|
+
}
|
127
|
+
|
128
|
+
template <>
|
129
|
+
inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, Complex64* a, const int lda, Complex64* tau, Complex64* c, const int ldc) {
|
130
|
+
return LAPACKE_cunmqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
|
131
|
+
}
|
132
|
+
|
133
|
+
template <>
|
134
|
+
inline int unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, Complex128* a, const int lda, Complex128* tau, Complex128* c, const int ldc) {
|
135
|
+
return LAPACKE_zunmqr(order, side, trans, m, n, k, a, lda, tau, c, ldc);
|
136
|
+
}
|
137
|
+
|
138
|
+
template <typename DType>
|
139
|
+
inline int lapacke_unmqr(const enum CBLAS_ORDER order, char side, char trans, const int m, const int n, const int k, void* a, const int lda, void* tau, void* c, const int ldc) {
|
140
|
+
return unmqr<DType>(order, side, trans, m, n, k, static_cast<DType*>(a), lda, static_cast<DType*>(tau), static_cast<DType*>(c), ldc);
|
141
|
+
}
|
142
|
+
|
67
143
|
//getri
|
68
144
|
template <typename DType>
|
69
145
|
inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) {
|
data/lib/nmatrix/lapacke.rb
CHANGED
@@ -232,4 +232,122 @@ class NMatrix
|
|
232
232
|
raise(ArgumentError, "#{opts[:form]} is not a valid form option")
|
233
233
|
end
|
234
234
|
end
|
235
|
+
|
236
|
+
#
|
237
|
+
# call-seq:
|
238
|
+
# geqrf! -> shape.min x 1 NMatrix
|
239
|
+
#
|
240
|
+
# QR factorization of a general M-by-N matrix +A+.
|
241
|
+
#
|
242
|
+
# The QR factorization is A = QR, where Q is orthogonal and R is Upper Triangular
|
243
|
+
# +A+ is overwritten with the elements of R and Q with Q being represented by the
|
244
|
+
# elements below A's diagonal and an array of scalar factors in the output NMatrix.
|
245
|
+
#
|
246
|
+
# The matrix Q is represented as a product of elementary reflectors
|
247
|
+
# Q = H(1) H(2) . . . H(k), where k = min(m,n).
|
248
|
+
#
|
249
|
+
# Each H(i) has the form
|
250
|
+
#
|
251
|
+
# H(i) = I - tau * v * v'
|
252
|
+
#
|
253
|
+
# http://www.netlib.org/lapack/explore-html/d3/d69/dgeqrf_8f.html
|
254
|
+
#
|
255
|
+
# Only works for dense matrices.
|
256
|
+
#
|
257
|
+
# * *Returns* :
|
258
|
+
# - Vector TAU. Q and R are stored in A. Q is represented by TAU and A
|
259
|
+
# * *Raises* :
|
260
|
+
# - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
|
261
|
+
#
|
262
|
+
def geqrf!
|
263
|
+
raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense?
|
264
|
+
|
265
|
+
tau = NMatrix.new([self.shape.min,1], dtype: self.dtype)
|
266
|
+
NMatrix::LAPACK::lapacke_geqrf(:row, self.shape[0], self.shape[1], self, self.shape[1], tau)
|
267
|
+
|
268
|
+
tau
|
269
|
+
end
|
270
|
+
|
271
|
+
#
|
272
|
+
# call-seq:
|
273
|
+
# ormqr(tau) -> NMatrix
|
274
|
+
# ormqr(tau, side, transpose, c) -> NMatrix
|
275
|
+
#
|
276
|
+
# Returns the product Q * c or c * Q after a call to geqrf! used in QR factorization.
|
277
|
+
# +c+ is overwritten with the elements of the result NMatrix if supplied. Q is the orthogonal matrix
|
278
|
+
# represented by tau and the calling NMatrix
|
279
|
+
#
|
280
|
+
# Only works on float types, use unmqr for complex types.
|
281
|
+
#
|
282
|
+
# == Arguments
|
283
|
+
#
|
284
|
+
# * +tau+ - vector containing scalar factors of elementary reflectors
|
285
|
+
# * +side+ - direction of multiplication [:left, :right]
|
286
|
+
# * +transpose+ - apply Q with or without transpose [false, :transpose]
|
287
|
+
# * +c+ - NMatrix multplication argument that is overwritten, no argument assumes c = identity
|
288
|
+
#
|
289
|
+
# * *Returns* :
|
290
|
+
#
|
291
|
+
# - Q * c or c * Q Where Q may be transposed before multiplication.
|
292
|
+
#
|
293
|
+
#
|
294
|
+
# * *Raises* :
|
295
|
+
# - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
|
296
|
+
# - +TypeError+ -> Works only on floating point matrices, use unmqr for complex types
|
297
|
+
# - +TypeError+ -> c must have the same dtype as the calling NMatrix
|
298
|
+
#
|
299
|
+
def ormqr(tau, side=:left, transpose=false, c=nil)
|
300
|
+
raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense?
|
301
|
+
raise(TypeError, "Works only on floating point matrices, use unmqr for complex types") if self.complex_dtype?
|
302
|
+
raise(TypeError, "c must have the same dtype as the calling NMatrix") if c and c.dtype != self.dtype
|
303
|
+
|
304
|
+
|
305
|
+
#Default behaviour produces Q * I = Q if c is not supplied.
|
306
|
+
result = c ? c : NMatrix.identity(self.shape[0], dtype: self.dtype)
|
307
|
+
NMatrix::LAPACK::lapacke_ormqr(:row, side, transpose, result.shape[0], result.shape[1], tau.shape[0], self, self.shape[1], tau, result, result.shape[1])
|
308
|
+
|
309
|
+
result
|
310
|
+
end
|
311
|
+
|
312
|
+
#
|
313
|
+
# call-seq:
|
314
|
+
# unmqr(tau) -> NMatrix
|
315
|
+
# unmqr(tau, side, transpose, c) -> NMatrix
|
316
|
+
#
|
317
|
+
# Returns the product Q * c or c * Q after a call to geqrf! used in QR factorization.
|
318
|
+
# +c+ is overwritten with the elements of the result NMatrix if it is supplied. Q is the orthogonal matrix
|
319
|
+
# represented by tau and the calling NMatrix
|
320
|
+
#
|
321
|
+
# Only works on complex types, use ormqr for float types.
|
322
|
+
#
|
323
|
+
# == Arguments
|
324
|
+
#
|
325
|
+
# * +tau+ - vector containing scalar factors of elementary reflectors
|
326
|
+
# * +side+ - direction of multiplication [:left, :right]
|
327
|
+
# * +transpose+ - apply Q as Q or its complex conjugate [false, :complex_conjugate]
|
328
|
+
# * +c+ - NMatrix multplication argument that is overwritten, no argument assumes c = identity
|
329
|
+
#
|
330
|
+
# * *Returns* :
|
331
|
+
#
|
332
|
+
# - Q * c or c * Q Where Q may be transformed to its complex conjugate before multiplication.
|
333
|
+
#
|
334
|
+
#
|
335
|
+
# * *Raises* :
|
336
|
+
# - +StorageTypeError+ -> LAPACK functions only work on dense matrices.
|
337
|
+
# - +TypeError+ -> Works only on floating point matrices, use unmqr for complex types
|
338
|
+
# - +TypeError+ -> c must have the same dtype as the calling NMatrix
|
339
|
+
#
|
340
|
+
def unmqr(tau, side=:left, transpose=false, c=nil)
|
341
|
+
raise(StorageTypeError, "ATLAS functions only work on dense matrices") unless self.dense?
|
342
|
+
raise(TypeError, "Works only on complex matrices, use ormqr for normal floating point matrices") unless self.complex_dtype?
|
343
|
+
raise(TypeError, "c must have the same dtype as the calling NMatrix") if c and c.dtype != self.dtype
|
344
|
+
|
345
|
+
#Default behaviour produces Q * I = Q if c is not supplied.
|
346
|
+
result = c ? c : NMatrix.identity(self.shape[0], dtype: self.dtype)
|
347
|
+
NMatrix::LAPACK::lapacke_unmqr(:row, side, transpose, result.shape[0], result.shape[1], tau.shape[0], self, self.shape[1], tau, result, result.shape[1])
|
348
|
+
|
349
|
+
result
|
350
|
+
end
|
351
|
+
|
352
|
+
|
235
353
|
end
|
data/spec/00_nmatrix_spec.rb
CHANGED
@@ -424,6 +424,13 @@ describe 'NMatrix' do
|
|
424
424
|
expect(n.reshape!([8,2]).eql?(n)).to eq(true) # because n itself changes
|
425
425
|
end
|
426
426
|
|
427
|
+
it "should do the reshape operation in place, changing dimension" do
|
428
|
+
n = NMatrix.seq(4)
|
429
|
+
a = n.reshape!([4,2,2])
|
430
|
+
expect(n).to eq(NMatrix.seq([4,2,2]))
|
431
|
+
expect(a).to eq(NMatrix.seq([4,2,2]))
|
432
|
+
end
|
433
|
+
|
427
434
|
it "reshape and reshape! must produce same result" do
|
428
435
|
n = NMatrix.seq(4)+1
|
429
436
|
a = NMatrix.seq(4)+1
|
@@ -432,7 +439,7 @@ describe 'NMatrix' do
|
|
432
439
|
|
433
440
|
it "should prevent a resize in place" do
|
434
441
|
n = NMatrix.seq(4)+1
|
435
|
-
expect { n.reshape([5,2]) }.to raise_error(ArgumentError)
|
442
|
+
expect { n.reshape!([5,2]) }.to raise_error(ArgumentError)
|
436
443
|
end
|
437
444
|
end
|
438
445
|
|
@@ -534,6 +541,26 @@ describe 'NMatrix' do
|
|
534
541
|
n = NMatrix.new([1,3,1], [1,2,3])
|
535
542
|
expect(n.dconcat(n)).to eq(NMatrix.new([1,3,2], [1,1,2,2,3,3]))
|
536
543
|
end
|
544
|
+
|
545
|
+
it "should work on matrices with different size along concat dim" do
|
546
|
+
n = N[[1, 2, 3],
|
547
|
+
[4, 5, 6]]
|
548
|
+
m = N[[7],
|
549
|
+
[8]]
|
550
|
+
|
551
|
+
expect(n.hconcat(m)).to eq N[[1, 2, 3, 7], [4, 5, 6, 8]]
|
552
|
+
expect(m.hconcat(n)).to eq N[[7, 1, 2, 3], [8, 4, 5, 6]]
|
553
|
+
end
|
554
|
+
|
555
|
+
it "should work on matrices with different size along concat dim" do
|
556
|
+
n = N[[1, 2, 3],
|
557
|
+
[4, 5, 6]]
|
558
|
+
|
559
|
+
m = N[[7, 8, 9]]
|
560
|
+
|
561
|
+
expect(n.vconcat(m)).to eq N[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
562
|
+
expect(m.vconcat(n)).to eq N[[7, 8, 9], [1, 2, 3], [4, 5, 6]]
|
563
|
+
end
|
537
564
|
end
|
538
565
|
|
539
566
|
context "#[]" do
|
@@ -631,6 +658,23 @@ describe 'NMatrix' do
|
|
631
658
|
end
|
632
659
|
end
|
633
660
|
|
661
|
+
context "#last" do
|
662
|
+
it "returns the last element of a 1-dimensional NMatrix" do
|
663
|
+
n = NMatrix.new([1,4], [1,2,3,4])
|
664
|
+
expect(n.last).to eq(4)
|
665
|
+
end
|
666
|
+
|
667
|
+
it "returns the last element of a 2-dimensional NMatrix" do
|
668
|
+
n = NMatrix.new([2,2], [4,8,12,16])
|
669
|
+
expect(n.last).to eq(16)
|
670
|
+
end
|
671
|
+
|
672
|
+
it "returns the last element of a 3-dimensional NMatrix" do
|
673
|
+
n = NMatrix.new([2,2,2], [1,2,3,4,5,6,7,8])
|
674
|
+
expect(n.last).to eq(8)
|
675
|
+
end
|
676
|
+
end
|
677
|
+
|
634
678
|
context "#diagonal" do
|
635
679
|
ALL_DTYPES.each do |dtype|
|
636
680
|
before do
|
@@ -682,6 +726,11 @@ describe 'NMatrix' do
|
|
682
726
|
expect(@sample_matrix.repeat(2, 0)).to eq(NMatrix.new([4, 2], [1, 2, 3, 4, 1, 2, 3, 4]))
|
683
727
|
expect(@sample_matrix.repeat(2, 1)).to eq(NMatrix.new([2, 4], [1, 2, 1, 2, 3, 4, 3, 4]))
|
684
728
|
end
|
729
|
+
|
730
|
+
it "preserves dtype" do
|
731
|
+
expect(@sample_matrix.repeat(2, 0).dtype).to eq(@sample_matrix.dtype)
|
732
|
+
expect(@sample_matrix.repeat(2, 1).dtype).to eq(@sample_matrix.dtype)
|
733
|
+
end
|
685
734
|
end
|
686
735
|
|
687
736
|
context "#meshgrid" do
|
data/spec/02_slice_spec.rb
CHANGED
@@ -127,9 +127,9 @@ describe "Slice operation" do
|
|
127
127
|
it "should have #is_ref? method" do
|
128
128
|
a = stype_matrix[0..1, 0..1]
|
129
129
|
b = stype_matrix.slice(0..1, 0..1)
|
130
|
-
expect(stype_matrix.is_ref?).to
|
131
|
-
expect(a.is_ref?).to
|
132
|
-
expect(b.is_ref?).to
|
130
|
+
expect(stype_matrix.is_ref?).to be false
|
131
|
+
expect(a.is_ref?).to be true
|
132
|
+
expect(b.is_ref?).to be false
|
133
133
|
end
|
134
134
|
|
135
135
|
it "reference should compare with non-reference" do
|
@@ -141,7 +141,7 @@ describe "Slice operation" do
|
|
141
141
|
context "with copying" do
|
142
142
|
it 'should return an NMatrix' do
|
143
143
|
n = stype_matrix.slice(0..1,0..1)
|
144
|
-
expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to
|
144
|
+
expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to be true
|
145
145
|
end
|
146
146
|
|
147
147
|
it 'should return a copy of 2x2 matrix to self elements' do
|
@@ -183,19 +183,19 @@ describe "Slice operation" do
|
|
183
183
|
|
184
184
|
[:dense, :list, :yale].each do |cast_type|
|
185
185
|
it "should cast copied slice from #{stype.upcase} to #{cast_type.upcase}" do
|
186
|
-
expect(nm_eql(stype_matrix.slice(1..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(1..2,1..2))).to
|
187
|
-
expect(nm_eql(stype_matrix.slice(0..1, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..1,1..2))).to
|
188
|
-
expect(nm_eql(stype_matrix.slice(1..2, 0..1).cast(cast_type, :int32), stype_matrix.slice(1..2,0..1))).to
|
189
|
-
expect(nm_eql(stype_matrix.slice(0..1, 0..1).cast(cast_type, :int32), stype_matrix.slice(0..1,0..1))).to
|
186
|
+
expect(nm_eql(stype_matrix.slice(1..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(1..2,1..2))).to be true
|
187
|
+
expect(nm_eql(stype_matrix.slice(0..1, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..1,1..2))).to be true
|
188
|
+
expect(nm_eql(stype_matrix.slice(1..2, 0..1).cast(cast_type, :int32), stype_matrix.slice(1..2,0..1))).to be true
|
189
|
+
expect(nm_eql(stype_matrix.slice(0..1, 0..1).cast(cast_type, :int32), stype_matrix.slice(0..1,0..1))).to be true
|
190
190
|
|
191
191
|
# Non square
|
192
|
-
expect(nm_eql(stype_matrix.slice(0..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..2,1..2))).to
|
192
|
+
expect(nm_eql(stype_matrix.slice(0..2, 1..2).cast(cast_type, :int32), stype_matrix.slice(0..2,1..2))).to be true
|
193
193
|
#require 'pry'
|
194
194
|
#binding.pry if cast_type == :yale
|
195
|
-
expect(nm_eql(stype_matrix.slice(1..2, 0..2).cast(cast_type, :int32), stype_matrix.slice(1..2,0..2))).to
|
195
|
+
expect(nm_eql(stype_matrix.slice(1..2, 0..2).cast(cast_type, :int32), stype_matrix.slice(1..2,0..2))).to be true
|
196
196
|
|
197
197
|
# Full
|
198
|
-
expect(nm_eql(stype_matrix.slice(0..2, 0..2).cast(cast_type, :int32), stype_matrix)).to
|
198
|
+
expect(nm_eql(stype_matrix.slice(0..2, 0..2).cast(cast_type, :int32), stype_matrix)).to be true
|
199
199
|
end
|
200
200
|
end
|
201
201
|
end
|
@@ -214,7 +214,7 @@ describe "Slice operation" do
|
|
214
214
|
context "by reference" do
|
215
215
|
it 'should return an NMatrix' do
|
216
216
|
n = stype_matrix[0..1,0..1]
|
217
|
-
expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to
|
217
|
+
expect(nm_eql(n, NMatrix.new([2,2], [0,1,3,4], dtype: :int32))).to be true
|
218
218
|
end
|
219
219
|
|
220
220
|
it 'should return a 2x2 matrix with refs to self elements' do
|
@@ -246,7 +246,7 @@ describe "Slice operation" do
|
|
246
246
|
|
247
247
|
it 'should slice again' do
|
248
248
|
n = stype_matrix[1..2, 1..2]
|
249
|
-
expect(nm_eql(n[1,0..1], NVector.new(2, [7,8], dtype: :int32).transpose)).to
|
249
|
+
expect(nm_eql(n[1,0..1], NVector.new(2, [7,8], dtype: :int32).transpose)).to be true
|
250
250
|
end
|
251
251
|
|
252
252
|
it 'should be correct slice for range 0..2 and 0...3' do
|
@@ -320,7 +320,7 @@ describe "Slice operation" do
|
|
320
320
|
end
|
321
321
|
|
322
322
|
it "compares slices to scalars" do
|
323
|
-
(stype_matrix[1, 0..2] > 2).each { |e| expect(e != 0).to
|
323
|
+
(stype_matrix[1, 0..2] > 2).each { |e| expect(e != 0).to be true }
|
324
324
|
end
|
325
325
|
|
326
326
|
it "iterates only over elements in the slice" do
|
@@ -367,20 +367,20 @@ describe "Slice operation" do
|
|
367
367
|
|
368
368
|
[:dense, :list, :yale].each do |cast_type|
|
369
369
|
it "should cast a square reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
|
370
|
-
expect(nm_eql(stype_matrix[1..2, 1..2].cast(cast_type), stype_matrix[1..2,1..2])).to
|
371
|
-
expect(nm_eql(stype_matrix[0..1, 1..2].cast(cast_type), stype_matrix[0..1,1..2])).to
|
372
|
-
expect(nm_eql(stype_matrix[1..2, 0..1].cast(cast_type), stype_matrix[1..2,0..1])).to
|
373
|
-
expect(nm_eql(stype_matrix[0..1, 0..1].cast(cast_type), stype_matrix[0..1,0..1])).to
|
370
|
+
expect(nm_eql(stype_matrix[1..2, 1..2].cast(cast_type), stype_matrix[1..2,1..2])).to be true
|
371
|
+
expect(nm_eql(stype_matrix[0..1, 1..2].cast(cast_type), stype_matrix[0..1,1..2])).to be true
|
372
|
+
expect(nm_eql(stype_matrix[1..2, 0..1].cast(cast_type), stype_matrix[1..2,0..1])).to be true
|
373
|
+
expect(nm_eql(stype_matrix[0..1, 0..1].cast(cast_type), stype_matrix[0..1,0..1])).to be true
|
374
374
|
end
|
375
375
|
|
376
376
|
it "should cast a rectangular reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
|
377
377
|
# Non square
|
378
|
-
expect(nm_eql(stype_matrix[0..2, 1..2].cast(cast_type), stype_matrix[0..2,1..2])).to
|
379
|
-
expect(nm_eql(stype_matrix[1..2, 0..2].cast(cast_type), stype_matrix[1..2,0..2])).to
|
378
|
+
expect(nm_eql(stype_matrix[0..2, 1..2].cast(cast_type), stype_matrix[0..2,1..2])).to be true # FIXME: memory problem.
|
379
|
+
expect(nm_eql(stype_matrix[1..2, 0..2].cast(cast_type), stype_matrix[1..2,0..2])).to be true # this one is fine
|
380
380
|
end
|
381
381
|
|
382
382
|
it "should cast a square full-matrix reference-slice from #{stype.upcase} to #{cast_type.upcase}" do
|
383
|
-
expect(nm_eql(stype_matrix[0..2, 0..2].cast(cast_type), stype_matrix)).to
|
383
|
+
expect(nm_eql(stype_matrix[0..2, 0..2].cast(cast_type), stype_matrix)).to be true
|
384
384
|
end
|
385
385
|
end
|
386
386
|
end
|
data/spec/blas_spec.rb
CHANGED
@@ -69,6 +69,18 @@ describe NMatrix::BLAS do
|
|
69
69
|
expect(b[0]).to eq(-15.0/2)
|
70
70
|
expect(b[1]).to eq(5)
|
71
71
|
expect(b[2]).to eq(-13)
|
72
|
+
|
73
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :lower, :transpose, :nounit, 3, 1, 1.0, a, 3, b, 1)
|
74
|
+
|
75
|
+
expect(b[0]).to eq(307.0/8)
|
76
|
+
expect(b[1]).to eq(57.0/2)
|
77
|
+
expect(b[2]).to eq(26.0)
|
78
|
+
|
79
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :upper, :transpose, :unit, 3, 1, 1.0, a, 3, b, 1)
|
80
|
+
|
81
|
+
expect(b[0]).to eq(307.0/8)
|
82
|
+
expect(b[1]).to eq(763.0/16)
|
83
|
+
expect(b[2]).to eq(4269.0/64)
|
72
84
|
end
|
73
85
|
|
74
86
|
# trmm multiplies two matrices, where one of the two is required to be
|
@@ -174,9 +186,17 @@ describe NMatrix::BLAS do
|
|
174
186
|
|
175
187
|
it "exposes nrm2" do
|
176
188
|
pending("broken for :object") if dtype == :object
|
177
|
-
pending("Temporarily disable because the internal implementation of nrm2 is broken -WL 2015-05-17") if dtype == :complex64 || dtype == :complex128
|
178
189
|
|
179
|
-
|
190
|
+
if dtype =~ /complex/
|
191
|
+
x = NMatrix.new([3,1], [Complex(1,2),Complex(3,4),Complex(0,6)], dtype: dtype)
|
192
|
+
y = NMatrix.new([3,1], [Complex(0,0),Complex(0,0),Complex(0,0)], dtype: dtype)
|
193
|
+
nrm2 = 8.12403840463596
|
194
|
+
else
|
195
|
+
x = NMatrix.new([4,1], [2,-4,3,5], dtype: dtype)
|
196
|
+
y = NMatrix.new([3,1], [0,0,0], dtype: dtype)
|
197
|
+
nrm2 = 5.385164807134504
|
198
|
+
end
|
199
|
+
|
180
200
|
err = case dtype
|
181
201
|
when :float32, :complex64
|
182
202
|
1e-6
|
@@ -185,7 +205,9 @@ describe NMatrix::BLAS do
|
|
185
205
|
else
|
186
206
|
1e-14
|
187
207
|
end
|
188
|
-
|
208
|
+
|
209
|
+
expect(NMatrix::BLAS.nrm2(x, 1, 3)).to be_within(err).of(nrm2)
|
210
|
+
expect(NMatrix::BLAS.nrm2(y, 1, 3)).to be_within(err).of(0)
|
189
211
|
end
|
190
212
|
|
191
213
|
end
|
data/spec/math_spec.rb
CHANGED
@@ -285,6 +285,152 @@ describe "math" do
|
|
285
285
|
end
|
286
286
|
end
|
287
287
|
|
288
|
+
NON_INTEGER_DTYPES.each do |dtype|
|
289
|
+
next if dtype == :object
|
290
|
+
context dtype do
|
291
|
+
|
292
|
+
it "calculates QR decomposition using factorize_qr for a square matrix" do
|
293
|
+
|
294
|
+
a = NMatrix.new(3, [12.0, -51.0, 4.0,
|
295
|
+
6.0, 167.0, -68.0,
|
296
|
+
-4.0, 24.0, -41.0] , dtype: dtype)
|
297
|
+
|
298
|
+
q_solution = NMatrix.new([3,3], Q_SOLUTION_ARRAY_2, dtype: dtype)
|
299
|
+
|
300
|
+
r_solution = NMatrix.new([3,3], [-14.0, -21.0, 14,
|
301
|
+
0.0, -175, 70,
|
302
|
+
0.0, 0.0, -35] , dtype: dtype)
|
303
|
+
|
304
|
+
err = case dtype
|
305
|
+
when :float32, :complex64
|
306
|
+
1e-4
|
307
|
+
when :float64, :complex128
|
308
|
+
1e-13
|
309
|
+
end
|
310
|
+
|
311
|
+
begin
|
312
|
+
q,r = a.factorize_qr
|
313
|
+
|
314
|
+
expect(q).to be_within(err).of(q_solution)
|
315
|
+
expect(r).to be_within(err).of(r_solution)
|
316
|
+
|
317
|
+
rescue NotImplementedError
|
318
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
319
|
+
end
|
320
|
+
end
|
321
|
+
|
322
|
+
it "calculates QR decomposition using factorize_qr for a tall and narrow rectangular matrix" do
|
323
|
+
|
324
|
+
a = NMatrix.new([4,2], [34.0, 21.0,
|
325
|
+
23.0, 53.0,
|
326
|
+
26.0, 346.0,
|
327
|
+
23.0, 121.0] , dtype: dtype)
|
328
|
+
|
329
|
+
q_solution = NMatrix.new([4,4], Q_SOLUTION_ARRAY_1, dtype: dtype)
|
330
|
+
|
331
|
+
r_solution = NMatrix.new([4,2], [-53.75872022286244, -255.06559574252242,
|
332
|
+
0.0, 269.34836526051555,
|
333
|
+
0.0, 0.0,
|
334
|
+
0.0, 0.0] , dtype: dtype)
|
335
|
+
|
336
|
+
err = case dtype
|
337
|
+
when :float32, :complex64
|
338
|
+
1e-4
|
339
|
+
when :float64, :complex128
|
340
|
+
1e-13
|
341
|
+
end
|
342
|
+
|
343
|
+
begin
|
344
|
+
q,r = a.factorize_qr
|
345
|
+
|
346
|
+
expect(q).to be_within(err).of(q_solution)
|
347
|
+
expect(r).to be_within(err).of(r_solution)
|
348
|
+
|
349
|
+
rescue NotImplementedError
|
350
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
351
|
+
end
|
352
|
+
end
|
353
|
+
|
354
|
+
it "calculates QR decomposition using factorize_qr for a short and wide rectangular matrix" do
|
355
|
+
|
356
|
+
a = NMatrix.new([3,4], [123,31,57,81,92,14,17,36,42,34,11,28], dtype: dtype)
|
357
|
+
|
358
|
+
q_solution = NMatrix.new([3,3], Q_SOLUTION_ARRAY_3, dtype: dtype)
|
359
|
+
|
360
|
+
r_solution = NMatrix.new([3,4], R_SOLUTION_ARRAY, dtype: dtype)
|
361
|
+
|
362
|
+
err = case dtype
|
363
|
+
when :float32, :complex64
|
364
|
+
1e-4
|
365
|
+
when :float64, :complex128
|
366
|
+
1e-13
|
367
|
+
end
|
368
|
+
|
369
|
+
begin
|
370
|
+
q,r = a.factorize_qr
|
371
|
+
|
372
|
+
expect(q).to be_within(err).of(q_solution)
|
373
|
+
expect(r).to be_within(err).of(r_solution)
|
374
|
+
|
375
|
+
rescue NotImplementedError
|
376
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
377
|
+
end
|
378
|
+
end
|
379
|
+
|
380
|
+
it "calculates QR decomposition such that A - QR ~ 0" do
|
381
|
+
|
382
|
+
a = NMatrix.new([3,3], [ 9.0, 0.0, 26.0,
|
383
|
+
12.0, 0.0, -7.0,
|
384
|
+
0.0, 4.0, 0.0] , dtype: dtype)
|
385
|
+
|
386
|
+
err = case dtype
|
387
|
+
when :float32, :complex64
|
388
|
+
1e-4
|
389
|
+
when :float64, :complex128
|
390
|
+
1e-13
|
391
|
+
end
|
392
|
+
|
393
|
+
begin
|
394
|
+
q,r = a.factorize_qr
|
395
|
+
a_expected = q.dot(r)
|
396
|
+
|
397
|
+
expect(a_expected).to be_within(err).of(a)
|
398
|
+
|
399
|
+
rescue NotImplementedError
|
400
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
401
|
+
end
|
402
|
+
end
|
403
|
+
|
404
|
+
|
405
|
+
it "calculates the orthogonal matrix Q in QR decomposition" do
|
406
|
+
|
407
|
+
a = N.new([2,2], [34.0, 21, 23, 53] , dtype: dtype)
|
408
|
+
|
409
|
+
err = case dtype
|
410
|
+
when :float32, :complex64
|
411
|
+
1e-4
|
412
|
+
when :float64, :complex128
|
413
|
+
1e-13
|
414
|
+
end
|
415
|
+
|
416
|
+
begin
|
417
|
+
q,r = a.factorize_qr
|
418
|
+
|
419
|
+
#Q is orthogonal if Q x Q.transpose = I
|
420
|
+
product = q.dot(q.transpose)
|
421
|
+
|
422
|
+
expect(product[0,0]).to be_within(err).of(1)
|
423
|
+
expect(product[1,0]).to be_within(err).of(0)
|
424
|
+
expect(product[0,1]).to be_within(err).of(0)
|
425
|
+
expect(product[1,1]).to be_within(err).of(1)
|
426
|
+
|
427
|
+
rescue NotImplementedError
|
428
|
+
pending "Suppressing a NotImplementedError when the lapacke plugin is not available"
|
429
|
+
end
|
430
|
+
end
|
431
|
+
end
|
432
|
+
end
|
433
|
+
|
288
434
|
ALL_DTYPES.each do |dtype|
|
289
435
|
next if dtype == :byte #doesn't work for unsigned types
|
290
436
|
next if dtype == :object
|
@@ -332,6 +478,49 @@ describe "math" do
|
|
332
478
|
end
|
333
479
|
end
|
334
480
|
|
481
|
+
ALL_DTYPES.each do |dtype|
|
482
|
+
next if dtype == :byte #doesn't work for unsigned types
|
483
|
+
next if dtype == :object
|
484
|
+
|
485
|
+
context dtype do
|
486
|
+
err = case dtype
|
487
|
+
when :float32, :complex64
|
488
|
+
1e-4
|
489
|
+
else #integer matrices will return :float64
|
490
|
+
1e-13
|
491
|
+
end
|
492
|
+
|
493
|
+
it "should correctly find adjugate a matrix in place (bang)" do
|
494
|
+
a = NMatrix.new(:dense, 2, [2, 3, 3, 5], dtype)
|
495
|
+
b = NMatrix.new(:dense, 2, [5, -3, -3, 2], dtype)
|
496
|
+
|
497
|
+
if a.integer_dtype?
|
498
|
+
expect{a.adjugate!}.to raise_error(DataTypeError)
|
499
|
+
else
|
500
|
+
#should return adjugate as well as modifying a
|
501
|
+
r = a.adjugate!
|
502
|
+
expect(a).to be_within(err).of(b)
|
503
|
+
expect(r).to be_within(err).of(b)
|
504
|
+
end
|
505
|
+
end
|
506
|
+
|
507
|
+
|
508
|
+
it "should correctly find adjugate of a matrix out-of-place" do
|
509
|
+
a = NMatrix.new(:dense, 3, [-3, 2, -5, -1, 0, -2, 3, -4, 1], dtype)
|
510
|
+
|
511
|
+
if a.integer_dtype?
|
512
|
+
b = NMatrix.new(:dense, 3, [-8, 18, -4, -5, 12, -1, 4, -6, 2], :float64)
|
513
|
+
else
|
514
|
+
b = NMatrix.new(:dense, 3, [-8, 18, -4, -5, 12, -1, 4, -6, 2], dtype)
|
515
|
+
end
|
516
|
+
|
517
|
+
expect(a.adjoint).to be_within(err).of(b)
|
518
|
+
expect(a.adjugate).to be_within(err).of(b)
|
519
|
+
end
|
520
|
+
|
521
|
+
end
|
522
|
+
end
|
523
|
+
|
335
524
|
# TODO: Get it working with ROBJ too
|
336
525
|
[:byte,:int8,:int16,:int32,:int64,:float32,:float64].each do |left_dtype|
|
337
526
|
[:byte,:int8,:int16,:int32,:int64,:float32,:float64].each do |right_dtype|
|
@@ -702,7 +891,7 @@ describe "math" do
|
|
702
891
|
360, 96, 51, -14,
|
703
892
|
448,-231,-24,-87,
|
704
893
|
-1168, 595,234, 523],
|
705
|
-
dtype: answer_dtype,
|
894
|
+
dtype: answer_dtype,
|
706
895
|
stype: stype))
|
707
896
|
end
|
708
897
|
|
@@ -757,7 +946,6 @@ describe "math" do
|
|
757
946
|
|
758
947
|
context "determinants" do
|
759
948
|
ALL_DTYPES.each do |dtype|
|
760
|
-
next if dtype == :object
|
761
949
|
context dtype do
|
762
950
|
before do
|
763
951
|
@a = NMatrix.new([2,2], [1,2,
|
@@ -779,13 +967,19 @@ describe "math" do
|
|
779
967
|
end
|
780
968
|
end
|
781
969
|
it "computes the determinant of 2x2 matrix" do
|
782
|
-
|
970
|
+
if dtype != :object
|
971
|
+
expect(@a.det).to be_within(@err).of(-2)
|
972
|
+
end
|
783
973
|
end
|
784
974
|
it "computes the determinant of 3x3 matrix" do
|
785
|
-
|
975
|
+
if dtype != :object
|
976
|
+
expect(@b.det).to be_within(@err).of(-8)
|
977
|
+
end
|
786
978
|
end
|
787
979
|
it "computes the determinant of 4x4 matrix" do
|
788
|
-
|
980
|
+
if dtype != :object
|
981
|
+
expect(@c.det).to be_within(@err).of(-18)
|
982
|
+
end
|
789
983
|
end
|
790
984
|
it "computes the exact determinant of 2x2 matrix" do
|
791
985
|
if dtype == :byte
|
@@ -804,4 +998,38 @@ describe "math" do
|
|
804
998
|
end
|
805
999
|
end
|
806
1000
|
end
|
1001
|
+
|
1002
|
+
context "#scale and #scale!" do
|
1003
|
+
[:dense,:list,:yale].each do |stype|
|
1004
|
+
ALL_DTYPES.each do |dtype|
|
1005
|
+
next if dtype == :object
|
1006
|
+
context "for #{dtype}" do
|
1007
|
+
before do
|
1008
|
+
@m = NMatrix.new([3, 3], [0, 1, 2,
|
1009
|
+
3, 4, 5,
|
1010
|
+
6, 7, 8], stype: stype, dtype: dtype)
|
1011
|
+
end
|
1012
|
+
it "scales the matrix by a given factor and return the result" do
|
1013
|
+
if integer_dtype? dtype
|
1014
|
+
expect{@m.scale 2.0}.to raise_error(DataTypeError)
|
1015
|
+
else
|
1016
|
+
expect(@m.scale 2.0).to eq(NMatrix.new([3, 3], [0, 2, 4,
|
1017
|
+
6, 8, 10,
|
1018
|
+
12, 14, 16], stype: stype, dtype: dtype))
|
1019
|
+
end
|
1020
|
+
end
|
1021
|
+
it "scales the matrix in place by a given factor" do
|
1022
|
+
if dtype == :int8
|
1023
|
+
expect{@m.scale! 2}.to raise_error(DataTypeError)
|
1024
|
+
else
|
1025
|
+
@m.scale! 2
|
1026
|
+
expect(@m).to eq(NMatrix.new([3, 3], [0, 2, 4,
|
1027
|
+
6, 8, 10,
|
1028
|
+
12, 14, 16], stype: stype, dtype: dtype))
|
1029
|
+
end
|
1030
|
+
end
|
1031
|
+
end
|
1032
|
+
end
|
1033
|
+
end
|
1034
|
+
end
|
807
1035
|
end
|