nmatrix-lapacke 0.2.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/ext/nmatrix/data/complex.h +183 -159
- data/ext/nmatrix/data/data.h +306 -292
- data/ext/nmatrix/data/ruby_object.h +193 -193
- data/ext/nmatrix/math/math.h +3 -2
- data/ext/nmatrix/math/trsm.h +152 -152
- data/ext/nmatrix/nmatrix.h +30 -0
- data/ext/nmatrix/ruby_constants.h +35 -35
- data/ext/nmatrix/storage/common.h +4 -3
- data/ext/nmatrix/storage/dense/dense.h +8 -7
- data/ext/nmatrix/storage/list/list.h +7 -6
- data/ext/nmatrix/storage/storage.h +12 -11
- data/ext/nmatrix/storage/yale/class.h +2 -2
- data/ext/nmatrix/storage/yale/iterators/base.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/iterator.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row_stored.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +1 -0
- data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +2 -1
- data/ext/nmatrix/storage/yale/yale.h +7 -6
- data/ext/nmatrix/types.h +3 -2
- data/ext/nmatrix/util/sl_list.h +19 -18
- data/ext/nmatrix_lapacke/extconf.rb +15 -9
- data/ext/nmatrix_lapacke/math_lapacke.cpp +6 -6
- data/lib/nmatrix/lapacke.rb +31 -9
- data/spec/00_nmatrix_spec.rb +6 -0
- data/spec/math_spec.rb +77 -0
- data/spec/spec_helper.rb +9 -0
- metadata +4 -4
data/lib/nmatrix/lapacke.rb
CHANGED
@@ -195,19 +195,41 @@ class NMatrix
|
|
195
195
|
NMatrix::LAPACK::lapacke_potrf(:row, which, self.shape[0], self, self.shape[1])
|
196
196
|
end
|
197
197
|
|
198
|
-
def solve
|
198
|
+
def solve(b, opts = {})
|
199
199
|
raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
|
200
200
|
raise(ShapeError, "number of rows of b must equal number of cols of self") if
|
201
201
|
self.shape[1] != b.shape[0]
|
202
|
-
raise
|
203
|
-
raise
|
202
|
+
raise(ArgumentError, "only works with dense matrices") if self.stype != :dense
|
203
|
+
raise(ArgumentError, "only works for non-integer, non-object dtypes") if
|
204
204
|
integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
|
205
205
|
|
206
|
-
|
207
|
-
|
208
|
-
n
|
209
|
-
|
210
|
-
|
211
|
-
|
206
|
+
opts = { form: :general }.merge(opts)
|
207
|
+
x = b.clone
|
208
|
+
n = self.shape[0]
|
209
|
+
nrhs = b.shape[1]
|
210
|
+
|
211
|
+
case opts[:form]
|
212
|
+
when :general
|
213
|
+
clone = self.clone
|
214
|
+
ipiv = NMatrix::LAPACK.lapacke_getrf(:row, n, n, clone, n)
|
215
|
+
NMatrix::LAPACK.lapacke_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, nrhs)
|
216
|
+
x
|
217
|
+
when :upper_tri, :upper_triangular
|
218
|
+
raise(ArgumentError, "upper triangular solver does not work with complex dtypes") if
|
219
|
+
complex_dtype? or b.complex_dtype?
|
220
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
|
221
|
+
x
|
222
|
+
when :lower_tri, :lower_triangular
|
223
|
+
raise(ArgumentError, "lower triangular solver does not work with complex dtypes") if
|
224
|
+
complex_dtype? or b.complex_dtype?
|
225
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :lower, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
|
226
|
+
x
|
227
|
+
when :pos_def, :positive_definite
|
228
|
+
u, l = self.factorize_cholesky
|
229
|
+
z = l.solve(b, form: :lower_tri)
|
230
|
+
u.solve(z, form: :upper_tri)
|
231
|
+
else
|
232
|
+
raise(ArgumentError, "#{opts[:form]} is not a valid form option")
|
233
|
+
end
|
212
234
|
end
|
213
235
|
end
|
data/spec/00_nmatrix_spec.rb
CHANGED
@@ -461,6 +461,12 @@ describe 'NMatrix' do
|
|
461
461
|
expect(n.transpose).to eq n
|
462
462
|
expect(n.transpose).not_to be n
|
463
463
|
end
|
464
|
+
|
465
|
+
it "should check permute argument if supplied for #{stype} matrix" do
|
466
|
+
n = NMatrix.new([2,2], [1,2,3,4], stype: stype)
|
467
|
+
expect{n.transpose *4 }.to raise_error(ArgumentError)
|
468
|
+
expect{n.transpose [1,1,2] }.to raise_error(ArgumentError)
|
469
|
+
end
|
464
470
|
end
|
465
471
|
end
|
466
472
|
|
data/spec/math_spec.rb
CHANGED
@@ -553,6 +553,7 @@ describe "math" do
|
|
553
553
|
context "#solve" do
|
554
554
|
NON_INTEGER_DTYPES.each do |dtype|
|
555
555
|
next if dtype == :object # LU factorization doesnt work for :object yet
|
556
|
+
|
556
557
|
it "solves linear equation for dtype #{dtype}" do
|
557
558
|
a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
|
558
559
|
b = NMatrix.new [2,1], [9,8], dtype: dtype
|
@@ -581,6 +582,82 @@ describe "math" do
|
|
581
582
|
expect(a.solve(b)).to eq(NMatrix.new [3,2], [1,0, 0,0, 2,2], dtype: dtype)
|
582
583
|
end
|
583
584
|
end
|
585
|
+
|
586
|
+
FLOAT_DTYPES.each do |dtype|
|
587
|
+
context "when form: :lower_tri" do
|
588
|
+
let(:a) { NMatrix.new([3,3], [1, 0, 0, 2, 0.5, 0, 3, 3, 9], dtype: dtype) }
|
589
|
+
|
590
|
+
it "solves a lower triangular linear system A * x = b with vector b" do
|
591
|
+
b = NMatrix.new([3,1], [1,2,3], dtype: dtype)
|
592
|
+
x = a.solve(b, form: :lower_tri)
|
593
|
+
r = a.dot(x) - b
|
594
|
+
expect(r.abs.max).to be_within(1e-6).of(0.0)
|
595
|
+
end
|
596
|
+
|
597
|
+
it "solves a lower triangular linear system A * X = B with narrow B" do
|
598
|
+
b = NMatrix.new([3,2], [1,2,3,4,5,6], dtype: dtype)
|
599
|
+
x = a.solve(b, form: :lower_tri)
|
600
|
+
r = (a.dot(x) - b).abs.to_flat_a
|
601
|
+
expect(r.max).to be_within(1e-6).of(0.0)
|
602
|
+
end
|
603
|
+
|
604
|
+
it "solves a lower triangular linear system A * X = B with wide B" do
|
605
|
+
b = NMatrix.new([3,5], (1..15).to_a, dtype: dtype)
|
606
|
+
x = a.solve(b, form: :lower_tri)
|
607
|
+
r = (a.dot(x) - b).abs.to_flat_a
|
608
|
+
expect(r.max).to be_within(1e-6).of(0.0)
|
609
|
+
end
|
610
|
+
end
|
611
|
+
|
612
|
+
context "when form: :upper_tri" do
|
613
|
+
let(:a) { NMatrix.new([3,3], [3, 2, 1, 0, 2, 0.5, 0, 0, 9], dtype: dtype) }
|
614
|
+
|
615
|
+
it "solves an upper triangular linear system A * x = b with vector b" do
|
616
|
+
b = NMatrix.new([3,1], [1,2,3], dtype: dtype)
|
617
|
+
x = a.solve(b, form: :upper_tri)
|
618
|
+
r = a.dot(x) - b
|
619
|
+
expect(r.abs.max).to be_within(1e-6).of(0.0)
|
620
|
+
end
|
621
|
+
|
622
|
+
it "solves an upper triangular linear system A * X = B with narrow B" do
|
623
|
+
b = NMatrix.new([3,2], [1,2,3,4,5,6], dtype: dtype)
|
624
|
+
x = a.solve(b, form: :upper_tri)
|
625
|
+
r = (a.dot(x) - b).abs.to_flat_a
|
626
|
+
expect(r.max).to be_within(1e-6).of(0.0)
|
627
|
+
end
|
628
|
+
|
629
|
+
it "solves an upper triangular linear system A * X = B with a wide B" do
|
630
|
+
b = NMatrix.new([3,5], (1..15).to_a, dtype: dtype)
|
631
|
+
x = a.solve(b, form: :upper_tri)
|
632
|
+
r = (a.dot(x) - b).abs.to_flat_a
|
633
|
+
expect(r.max).to be_within(1e-6).of(0.0)
|
634
|
+
end
|
635
|
+
end
|
636
|
+
|
637
|
+
context "when form: :pos_def" do
|
638
|
+
let(:a) { NMatrix.new([3,3], [4, 1, 2, 1, 5, 3, 2, 3, 6], dtype: dtype) }
|
639
|
+
|
640
|
+
it "solves a linear system A * X = b with positive definite A and vector b" do
|
641
|
+
b = NMatrix.new([3,1], [6,4,8], dtype: dtype)
|
642
|
+
begin
|
643
|
+
x = a.solve(b, form: :pos_def)
|
644
|
+
expect(x).to be_within(1e-6).of(NMatrix.new([3,1], [1,0,1], dtype: dtype))
|
645
|
+
rescue NotImplementedError
|
646
|
+
"Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
|
647
|
+
end
|
648
|
+
end
|
649
|
+
|
650
|
+
it "solves a linear system A * X = B with positive definite A and matrix B" do
|
651
|
+
b = NMatrix.new([3,2], [8,3,14,13,14,19], dtype: dtype)
|
652
|
+
begin
|
653
|
+
x = a.solve(b, form: :pos_def)
|
654
|
+
expect(x).to be_within(1e-6).of(NMatrix.new([3,2], [1,-1,2,1,1,3], dtype: dtype))
|
655
|
+
rescue NotImplementedError
|
656
|
+
"Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
|
657
|
+
end
|
658
|
+
end
|
659
|
+
end
|
660
|
+
end
|
584
661
|
end
|
585
662
|
|
586
663
|
context "#hessenberg" do
|
data/spec/spec_helper.rb
CHANGED
@@ -138,3 +138,12 @@ end
|
|
138
138
|
def integer_dtype? dtype
|
139
139
|
[:byte,:int8,:int16,:int32,:int64].include?(dtype)
|
140
140
|
end
|
141
|
+
|
142
|
+
# If a focus: true option is supplied to any test, running `rake spec focus=true`
|
143
|
+
# will run only the focused tests and nothing else.
|
144
|
+
if ENV["focus"] == "true"
|
145
|
+
RSpec.configure do |c|
|
146
|
+
c.filter_run :focus => true
|
147
|
+
end
|
148
|
+
end
|
149
|
+
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: nmatrix-lapacke
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- John Woods
|
@@ -10,7 +10,7 @@ authors:
|
|
10
10
|
autorequire:
|
11
11
|
bindir: bin
|
12
12
|
cert_chain: []
|
13
|
-
date:
|
13
|
+
date: 2016-01-18 00:00:00.000000000 Z
|
14
14
|
dependencies:
|
15
15
|
- !ruby/object:Gem::Dependency
|
16
16
|
name: nmatrix
|
@@ -18,14 +18,14 @@ dependencies:
|
|
18
18
|
requirements:
|
19
19
|
- - '='
|
20
20
|
- !ruby/object:Gem::Version
|
21
|
-
version: 0.2.
|
21
|
+
version: 0.2.1
|
22
22
|
type: :runtime
|
23
23
|
prerelease: false
|
24
24
|
version_requirements: !ruby/object:Gem::Requirement
|
25
25
|
requirements:
|
26
26
|
- - '='
|
27
27
|
- !ruby/object:Gem::Version
|
28
|
-
version: 0.2.
|
28
|
+
version: 0.2.1
|
29
29
|
description: For using linear algebra fuctions provided by LAPACK and BLAS
|
30
30
|
email:
|
31
31
|
- john.o.woods@gmail.com
|