nmatrix-lapacke 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -195,19 +195,41 @@ class NMatrix
195
195
  NMatrix::LAPACK::lapacke_potrf(:row, which, self.shape[0], self, self.shape[1])
196
196
  end
197
197
 
198
- def solve b
198
+ def solve(b, opts = {})
199
199
  raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
200
200
  raise(ShapeError, "number of rows of b must equal number of cols of self") if
201
201
  self.shape[1] != b.shape[0]
202
- raise ArgumentError, "only works with dense matrices" if self.stype != :dense
203
- raise ArgumentError, "only works for non-integer, non-object dtypes" if
202
+ raise(ArgumentError, "only works with dense matrices") if self.stype != :dense
203
+ raise(ArgumentError, "only works for non-integer, non-object dtypes") if
204
204
  integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
205
205
 
206
- x = b.clone
207
- clone = self.clone
208
- n = self.shape[0]
209
- ipiv = NMatrix::LAPACK.lapacke_getrf(:row, n, n, clone, n)
210
- NMatrix::LAPACK.lapacke_getrs(:row, :no_transpose, n, b.shape[1], clone, n, ipiv, x, b.shape[1])
211
- x
206
+ opts = { form: :general }.merge(opts)
207
+ x = b.clone
208
+ n = self.shape[0]
209
+ nrhs = b.shape[1]
210
+
211
+ case opts[:form]
212
+ when :general
213
+ clone = self.clone
214
+ ipiv = NMatrix::LAPACK.lapacke_getrf(:row, n, n, clone, n)
215
+ NMatrix::LAPACK.lapacke_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, nrhs)
216
+ x
217
+ when :upper_tri, :upper_triangular
218
+ raise(ArgumentError, "upper triangular solver does not work with complex dtypes") if
219
+ complex_dtype? or b.complex_dtype?
220
+ NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
221
+ x
222
+ when :lower_tri, :lower_triangular
223
+ raise(ArgumentError, "lower triangular solver does not work with complex dtypes") if
224
+ complex_dtype? or b.complex_dtype?
225
+ NMatrix::BLAS::cblas_trsm(:row, :left, :lower, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
226
+ x
227
+ when :pos_def, :positive_definite
228
+ u, l = self.factorize_cholesky
229
+ z = l.solve(b, form: :lower_tri)
230
+ u.solve(z, form: :upper_tri)
231
+ else
232
+ raise(ArgumentError, "#{opts[:form]} is not a valid form option")
233
+ end
212
234
  end
213
235
  end
@@ -461,6 +461,12 @@ describe 'NMatrix' do
461
461
  expect(n.transpose).to eq n
462
462
  expect(n.transpose).not_to be n
463
463
  end
464
+
465
+ it "should check permute argument if supplied for #{stype} matrix" do
466
+ n = NMatrix.new([2,2], [1,2,3,4], stype: stype)
467
+ expect{n.transpose *4 }.to raise_error(ArgumentError)
468
+ expect{n.transpose [1,1,2] }.to raise_error(ArgumentError)
469
+ end
464
470
  end
465
471
  end
466
472
 
@@ -553,6 +553,7 @@ describe "math" do
553
553
  context "#solve" do
554
554
  NON_INTEGER_DTYPES.each do |dtype|
555
555
  next if dtype == :object # LU factorization doesnt work for :object yet
556
+
556
557
  it "solves linear equation for dtype #{dtype}" do
557
558
  a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
558
559
  b = NMatrix.new [2,1], [9,8], dtype: dtype
@@ -581,6 +582,82 @@ describe "math" do
581
582
  expect(a.solve(b)).to eq(NMatrix.new [3,2], [1,0, 0,0, 2,2], dtype: dtype)
582
583
  end
583
584
  end
585
+
586
+ FLOAT_DTYPES.each do |dtype|
587
+ context "when form: :lower_tri" do
588
+ let(:a) { NMatrix.new([3,3], [1, 0, 0, 2, 0.5, 0, 3, 3, 9], dtype: dtype) }
589
+
590
+ it "solves a lower triangular linear system A * x = b with vector b" do
591
+ b = NMatrix.new([3,1], [1,2,3], dtype: dtype)
592
+ x = a.solve(b, form: :lower_tri)
593
+ r = a.dot(x) - b
594
+ expect(r.abs.max).to be_within(1e-6).of(0.0)
595
+ end
596
+
597
+ it "solves a lower triangular linear system A * X = B with narrow B" do
598
+ b = NMatrix.new([3,2], [1,2,3,4,5,6], dtype: dtype)
599
+ x = a.solve(b, form: :lower_tri)
600
+ r = (a.dot(x) - b).abs.to_flat_a
601
+ expect(r.max).to be_within(1e-6).of(0.0)
602
+ end
603
+
604
+ it "solves a lower triangular linear system A * X = B with wide B" do
605
+ b = NMatrix.new([3,5], (1..15).to_a, dtype: dtype)
606
+ x = a.solve(b, form: :lower_tri)
607
+ r = (a.dot(x) - b).abs.to_flat_a
608
+ expect(r.max).to be_within(1e-6).of(0.0)
609
+ end
610
+ end
611
+
612
+ context "when form: :upper_tri" do
613
+ let(:a) { NMatrix.new([3,3], [3, 2, 1, 0, 2, 0.5, 0, 0, 9], dtype: dtype) }
614
+
615
+ it "solves an upper triangular linear system A * x = b with vector b" do
616
+ b = NMatrix.new([3,1], [1,2,3], dtype: dtype)
617
+ x = a.solve(b, form: :upper_tri)
618
+ r = a.dot(x) - b
619
+ expect(r.abs.max).to be_within(1e-6).of(0.0)
620
+ end
621
+
622
+ it "solves an upper triangular linear system A * X = B with narrow B" do
623
+ b = NMatrix.new([3,2], [1,2,3,4,5,6], dtype: dtype)
624
+ x = a.solve(b, form: :upper_tri)
625
+ r = (a.dot(x) - b).abs.to_flat_a
626
+ expect(r.max).to be_within(1e-6).of(0.0)
627
+ end
628
+
629
+ it "solves an upper triangular linear system A * X = B with a wide B" do
630
+ b = NMatrix.new([3,5], (1..15).to_a, dtype: dtype)
631
+ x = a.solve(b, form: :upper_tri)
632
+ r = (a.dot(x) - b).abs.to_flat_a
633
+ expect(r.max).to be_within(1e-6).of(0.0)
634
+ end
635
+ end
636
+
637
+ context "when form: :pos_def" do
638
+ let(:a) { NMatrix.new([3,3], [4, 1, 2, 1, 5, 3, 2, 3, 6], dtype: dtype) }
639
+
640
+ it "solves a linear system A * X = b with positive definite A and vector b" do
641
+ b = NMatrix.new([3,1], [6,4,8], dtype: dtype)
642
+ begin
643
+ x = a.solve(b, form: :pos_def)
644
+ expect(x).to be_within(1e-6).of(NMatrix.new([3,1], [1,0,1], dtype: dtype))
645
+ rescue NotImplementedError
646
+ "Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
647
+ end
648
+ end
649
+
650
+ it "solves a linear system A * X = B with positive definite A and matrix B" do
651
+ b = NMatrix.new([3,2], [8,3,14,13,14,19], dtype: dtype)
652
+ begin
653
+ x = a.solve(b, form: :pos_def)
654
+ expect(x).to be_within(1e-6).of(NMatrix.new([3,2], [1,-1,2,1,1,3], dtype: dtype))
655
+ rescue NotImplementedError
656
+ "Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
657
+ end
658
+ end
659
+ end
660
+ end
584
661
  end
585
662
 
586
663
  context "#hessenberg" do
@@ -138,3 +138,12 @@ end
138
138
  def integer_dtype? dtype
139
139
  [:byte,:int8,:int16,:int32,:int64].include?(dtype)
140
140
  end
141
+
142
+ # If a focus: true option is supplied to any test, running `rake spec focus=true`
143
+ # will run only the focused tests and nothing else.
144
+ if ENV["focus"] == "true"
145
+ RSpec.configure do |c|
146
+ c.filter_run :focus => true
147
+ end
148
+ end
149
+
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nmatrix-lapacke
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - John Woods
@@ -10,7 +10,7 @@ authors:
10
10
  autorequire:
11
11
  bindir: bin
12
12
  cert_chain: []
13
- date: 2015-08-25 00:00:00.000000000 Z
13
+ date: 2016-01-18 00:00:00.000000000 Z
14
14
  dependencies:
15
15
  - !ruby/object:Gem::Dependency
16
16
  name: nmatrix
@@ -18,14 +18,14 @@ dependencies:
18
18
  requirements:
19
19
  - - '='
20
20
  - !ruby/object:Gem::Version
21
- version: 0.2.0
21
+ version: 0.2.1
22
22
  type: :runtime
23
23
  prerelease: false
24
24
  version_requirements: !ruby/object:Gem::Requirement
25
25
  requirements:
26
26
  - - '='
27
27
  - !ruby/object:Gem::Version
28
- version: 0.2.0
28
+ version: 0.2.1
29
29
  description: For using linear algebra fuctions provided by LAPACK and BLAS
30
30
  email:
31
31
  - john.o.woods@gmail.com