nmatrix-lapacke 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/nmatrix/data/complex.h +183 -159
- data/ext/nmatrix/data/data.h +306 -292
- data/ext/nmatrix/data/ruby_object.h +193 -193
- data/ext/nmatrix/math/math.h +3 -2
- data/ext/nmatrix/math/trsm.h +152 -152
- data/ext/nmatrix/nmatrix.h +30 -0
- data/ext/nmatrix/ruby_constants.h +35 -35
- data/ext/nmatrix/storage/common.h +4 -3
- data/ext/nmatrix/storage/dense/dense.h +8 -7
- data/ext/nmatrix/storage/list/list.h +7 -6
- data/ext/nmatrix/storage/storage.h +12 -11
- data/ext/nmatrix/storage/yale/class.h +2 -2
- data/ext/nmatrix/storage/yale/iterators/base.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/iterator.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row_stored.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +1 -0
- data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +2 -1
- data/ext/nmatrix/storage/yale/yale.h +7 -6
- data/ext/nmatrix/types.h +3 -2
- data/ext/nmatrix/util/sl_list.h +19 -18
- data/ext/nmatrix_lapacke/extconf.rb +15 -9
- data/ext/nmatrix_lapacke/math_lapacke.cpp +6 -6
- data/lib/nmatrix/lapacke.rb +31 -9
- data/spec/00_nmatrix_spec.rb +6 -0
- data/spec/math_spec.rb +77 -0
- data/spec/spec_helper.rb +9 -0
- metadata +4 -4
data/lib/nmatrix/lapacke.rb
CHANGED
@@ -195,19 +195,41 @@ class NMatrix
|
|
195
195
|
NMatrix::LAPACK::lapacke_potrf(:row, which, self.shape[0], self, self.shape[1])
|
196
196
|
end
|
197
197
|
|
198
|
-
def solve
|
198
|
+
def solve(b, opts = {})
|
199
199
|
raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
|
200
200
|
raise(ShapeError, "number of rows of b must equal number of cols of self") if
|
201
201
|
self.shape[1] != b.shape[0]
|
202
|
-
raise
|
203
|
-
raise
|
202
|
+
raise(ArgumentError, "only works with dense matrices") if self.stype != :dense
|
203
|
+
raise(ArgumentError, "only works for non-integer, non-object dtypes") if
|
204
204
|
integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
|
205
205
|
|
206
|
-
|
207
|
-
|
208
|
-
n
|
209
|
-
|
210
|
-
|
211
|
-
|
206
|
+
opts = { form: :general }.merge(opts)
|
207
|
+
x = b.clone
|
208
|
+
n = self.shape[0]
|
209
|
+
nrhs = b.shape[1]
|
210
|
+
|
211
|
+
case opts[:form]
|
212
|
+
when :general
|
213
|
+
clone = self.clone
|
214
|
+
ipiv = NMatrix::LAPACK.lapacke_getrf(:row, n, n, clone, n)
|
215
|
+
NMatrix::LAPACK.lapacke_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, nrhs)
|
216
|
+
x
|
217
|
+
when :upper_tri, :upper_triangular
|
218
|
+
raise(ArgumentError, "upper triangular solver does not work with complex dtypes") if
|
219
|
+
complex_dtype? or b.complex_dtype?
|
220
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
|
221
|
+
x
|
222
|
+
when :lower_tri, :lower_triangular
|
223
|
+
raise(ArgumentError, "lower triangular solver does not work with complex dtypes") if
|
224
|
+
complex_dtype? or b.complex_dtype?
|
225
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :lower, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
|
226
|
+
x
|
227
|
+
when :pos_def, :positive_definite
|
228
|
+
u, l = self.factorize_cholesky
|
229
|
+
z = l.solve(b, form: :lower_tri)
|
230
|
+
u.solve(z, form: :upper_tri)
|
231
|
+
else
|
232
|
+
raise(ArgumentError, "#{opts[:form]} is not a valid form option")
|
233
|
+
end
|
212
234
|
end
|
213
235
|
end
|
data/spec/00_nmatrix_spec.rb
CHANGED
@@ -461,6 +461,12 @@ describe 'NMatrix' do
|
|
461
461
|
expect(n.transpose).to eq n
|
462
462
|
expect(n.transpose).not_to be n
|
463
463
|
end
|
464
|
+
|
465
|
+
it "should check permute argument if supplied for #{stype} matrix" do
|
466
|
+
n = NMatrix.new([2,2], [1,2,3,4], stype: stype)
|
467
|
+
expect{n.transpose *4 }.to raise_error(ArgumentError)
|
468
|
+
expect{n.transpose [1,1,2] }.to raise_error(ArgumentError)
|
469
|
+
end
|
464
470
|
end
|
465
471
|
end
|
466
472
|
|
data/spec/math_spec.rb
CHANGED
@@ -553,6 +553,7 @@ describe "math" do
|
|
553
553
|
context "#solve" do
|
554
554
|
NON_INTEGER_DTYPES.each do |dtype|
|
555
555
|
next if dtype == :object # LU factorization doesnt work for :object yet
|
556
|
+
|
556
557
|
it "solves linear equation for dtype #{dtype}" do
|
557
558
|
a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
|
558
559
|
b = NMatrix.new [2,1], [9,8], dtype: dtype
|
@@ -581,6 +582,82 @@ describe "math" do
|
|
581
582
|
expect(a.solve(b)).to eq(NMatrix.new [3,2], [1,0, 0,0, 2,2], dtype: dtype)
|
582
583
|
end
|
583
584
|
end
|
585
|
+
|
586
|
+
FLOAT_DTYPES.each do |dtype|
|
587
|
+
context "when form: :lower_tri" do
|
588
|
+
let(:a) { NMatrix.new([3,3], [1, 0, 0, 2, 0.5, 0, 3, 3, 9], dtype: dtype) }
|
589
|
+
|
590
|
+
it "solves a lower triangular linear system A * x = b with vector b" do
|
591
|
+
b = NMatrix.new([3,1], [1,2,3], dtype: dtype)
|
592
|
+
x = a.solve(b, form: :lower_tri)
|
593
|
+
r = a.dot(x) - b
|
594
|
+
expect(r.abs.max).to be_within(1e-6).of(0.0)
|
595
|
+
end
|
596
|
+
|
597
|
+
it "solves a lower triangular linear system A * X = B with narrow B" do
|
598
|
+
b = NMatrix.new([3,2], [1,2,3,4,5,6], dtype: dtype)
|
599
|
+
x = a.solve(b, form: :lower_tri)
|
600
|
+
r = (a.dot(x) - b).abs.to_flat_a
|
601
|
+
expect(r.max).to be_within(1e-6).of(0.0)
|
602
|
+
end
|
603
|
+
|
604
|
+
it "solves a lower triangular linear system A * X = B with wide B" do
|
605
|
+
b = NMatrix.new([3,5], (1..15).to_a, dtype: dtype)
|
606
|
+
x = a.solve(b, form: :lower_tri)
|
607
|
+
r = (a.dot(x) - b).abs.to_flat_a
|
608
|
+
expect(r.max).to be_within(1e-6).of(0.0)
|
609
|
+
end
|
610
|
+
end
|
611
|
+
|
612
|
+
context "when form: :upper_tri" do
|
613
|
+
let(:a) { NMatrix.new([3,3], [3, 2, 1, 0, 2, 0.5, 0, 0, 9], dtype: dtype) }
|
614
|
+
|
615
|
+
it "solves an upper triangular linear system A * x = b with vector b" do
|
616
|
+
b = NMatrix.new([3,1], [1,2,3], dtype: dtype)
|
617
|
+
x = a.solve(b, form: :upper_tri)
|
618
|
+
r = a.dot(x) - b
|
619
|
+
expect(r.abs.max).to be_within(1e-6).of(0.0)
|
620
|
+
end
|
621
|
+
|
622
|
+
it "solves an upper triangular linear system A * X = B with narrow B" do
|
623
|
+
b = NMatrix.new([3,2], [1,2,3,4,5,6], dtype: dtype)
|
624
|
+
x = a.solve(b, form: :upper_tri)
|
625
|
+
r = (a.dot(x) - b).abs.to_flat_a
|
626
|
+
expect(r.max).to be_within(1e-6).of(0.0)
|
627
|
+
end
|
628
|
+
|
629
|
+
it "solves an upper triangular linear system A * X = B with a wide B" do
|
630
|
+
b = NMatrix.new([3,5], (1..15).to_a, dtype: dtype)
|
631
|
+
x = a.solve(b, form: :upper_tri)
|
632
|
+
r = (a.dot(x) - b).abs.to_flat_a
|
633
|
+
expect(r.max).to be_within(1e-6).of(0.0)
|
634
|
+
end
|
635
|
+
end
|
636
|
+
|
637
|
+
context "when form: :pos_def" do
|
638
|
+
let(:a) { NMatrix.new([3,3], [4, 1, 2, 1, 5, 3, 2, 3, 6], dtype: dtype) }
|
639
|
+
|
640
|
+
it "solves a linear system A * X = b with positive definite A and vector b" do
|
641
|
+
b = NMatrix.new([3,1], [6,4,8], dtype: dtype)
|
642
|
+
begin
|
643
|
+
x = a.solve(b, form: :pos_def)
|
644
|
+
expect(x).to be_within(1e-6).of(NMatrix.new([3,1], [1,0,1], dtype: dtype))
|
645
|
+
rescue NotImplementedError
|
646
|
+
"Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
|
647
|
+
end
|
648
|
+
end
|
649
|
+
|
650
|
+
it "solves a linear system A * X = B with positive definite A and matrix B" do
|
651
|
+
b = NMatrix.new([3,2], [8,3,14,13,14,19], dtype: dtype)
|
652
|
+
begin
|
653
|
+
x = a.solve(b, form: :pos_def)
|
654
|
+
expect(x).to be_within(1e-6).of(NMatrix.new([3,2], [1,-1,2,1,1,3], dtype: dtype))
|
655
|
+
rescue NotImplementedError
|
656
|
+
"Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
|
657
|
+
end
|
658
|
+
end
|
659
|
+
end
|
660
|
+
end
|
584
661
|
end
|
585
662
|
|
586
663
|
context "#hessenberg" do
|
data/spec/spec_helper.rb
CHANGED
@@ -138,3 +138,12 @@ end
|
|
138
138
|
def integer_dtype? dtype
|
139
139
|
[:byte,:int8,:int16,:int32,:int64].include?(dtype)
|
140
140
|
end
|
141
|
+
|
142
|
+
# If a focus: true option is supplied to any test, running `rake spec focus=true`
|
143
|
+
# will run only the focused tests and nothing else.
|
144
|
+
if ENV["focus"] == "true"
|
145
|
+
RSpec.configure do |c|
|
146
|
+
c.filter_run :focus => true
|
147
|
+
end
|
148
|
+
end
|
149
|
+
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: nmatrix-lapacke
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- John Woods
|
@@ -10,7 +10,7 @@ authors:
|
|
10
10
|
autorequire:
|
11
11
|
bindir: bin
|
12
12
|
cert_chain: []
|
13
|
-
date:
|
13
|
+
date: 2016-01-18 00:00:00.000000000 Z
|
14
14
|
dependencies:
|
15
15
|
- !ruby/object:Gem::Dependency
|
16
16
|
name: nmatrix
|
@@ -18,14 +18,14 @@ dependencies:
|
|
18
18
|
requirements:
|
19
19
|
- - '='
|
20
20
|
- !ruby/object:Gem::Version
|
21
|
-
version: 0.2.
|
21
|
+
version: 0.2.1
|
22
22
|
type: :runtime
|
23
23
|
prerelease: false
|
24
24
|
version_requirements: !ruby/object:Gem::Requirement
|
25
25
|
requirements:
|
26
26
|
- - '='
|
27
27
|
- !ruby/object:Gem::Version
|
28
|
-
version: 0.2.
|
28
|
+
version: 0.2.1
|
29
29
|
description: For using linear algebra fuctions provided by LAPACK and BLAS
|
30
30
|
email:
|
31
31
|
- john.o.woods@gmail.com
|