nmatrix-lapacke 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -195,19 +195,41 @@ class NMatrix
195
195
  NMatrix::LAPACK::lapacke_potrf(:row, which, self.shape[0], self, self.shape[1])
196
196
  end
197
197
 
198
- def solve b
198
+ def solve(b, opts = {})
199
199
  raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
200
200
  raise(ShapeError, "number of rows of b must equal number of cols of self") if
201
201
  self.shape[1] != b.shape[0]
202
- raise ArgumentError, "only works with dense matrices" if self.stype != :dense
203
- raise ArgumentError, "only works for non-integer, non-object dtypes" if
202
+ raise(ArgumentError, "only works with dense matrices") if self.stype != :dense
203
+ raise(ArgumentError, "only works for non-integer, non-object dtypes") if
204
204
  integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
205
205
 
206
- x = b.clone
207
- clone = self.clone
208
- n = self.shape[0]
209
- ipiv = NMatrix::LAPACK.lapacke_getrf(:row, n, n, clone, n)
210
- NMatrix::LAPACK.lapacke_getrs(:row, :no_transpose, n, b.shape[1], clone, n, ipiv, x, b.shape[1])
211
- x
206
+ opts = { form: :general }.merge(opts)
207
+ x = b.clone
208
+ n = self.shape[0]
209
+ nrhs = b.shape[1]
210
+
211
+ case opts[:form]
212
+ when :general
213
+ clone = self.clone
214
+ ipiv = NMatrix::LAPACK.lapacke_getrf(:row, n, n, clone, n)
215
+ NMatrix::LAPACK.lapacke_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, nrhs)
216
+ x
217
+ when :upper_tri, :upper_triangular
218
+ raise(ArgumentError, "upper triangular solver does not work with complex dtypes") if
219
+ complex_dtype? or b.complex_dtype?
220
+ NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
221
+ x
222
+ when :lower_tri, :lower_triangular
223
+ raise(ArgumentError, "lower triangular solver does not work with complex dtypes") if
224
+ complex_dtype? or b.complex_dtype?
225
+ NMatrix::BLAS::cblas_trsm(:row, :left, :lower, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
226
+ x
227
+ when :pos_def, :positive_definite
228
+ u, l = self.factorize_cholesky
229
+ z = l.solve(b, form: :lower_tri)
230
+ u.solve(z, form: :upper_tri)
231
+ else
232
+ raise(ArgumentError, "#{opts[:form]} is not a valid form option")
233
+ end
212
234
  end
213
235
  end
@@ -461,6 +461,12 @@ describe 'NMatrix' do
461
461
  expect(n.transpose).to eq n
462
462
  expect(n.transpose).not_to be n
463
463
  end
464
+
465
+ it "should check permute argument if supplied for #{stype} matrix" do
466
+ n = NMatrix.new([2,2], [1,2,3,4], stype: stype)
467
+ expect{n.transpose *4 }.to raise_error(ArgumentError)
468
+ expect{n.transpose [1,1,2] }.to raise_error(ArgumentError)
469
+ end
464
470
  end
465
471
  end
466
472
 
@@ -553,6 +553,7 @@ describe "math" do
553
553
  context "#solve" do
554
554
  NON_INTEGER_DTYPES.each do |dtype|
555
555
  next if dtype == :object # LU factorization doesnt work for :object yet
556
+
556
557
  it "solves linear equation for dtype #{dtype}" do
557
558
  a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
558
559
  b = NMatrix.new [2,1], [9,8], dtype: dtype
@@ -581,6 +582,82 @@ describe "math" do
581
582
  expect(a.solve(b)).to eq(NMatrix.new [3,2], [1,0, 0,0, 2,2], dtype: dtype)
582
583
  end
583
584
  end
585
+
586
+ FLOAT_DTYPES.each do |dtype|
587
+ context "when form: :lower_tri" do
588
+ let(:a) { NMatrix.new([3,3], [1, 0, 0, 2, 0.5, 0, 3, 3, 9], dtype: dtype) }
589
+
590
+ it "solves a lower triangular linear system A * x = b with vector b" do
591
+ b = NMatrix.new([3,1], [1,2,3], dtype: dtype)
592
+ x = a.solve(b, form: :lower_tri)
593
+ r = a.dot(x) - b
594
+ expect(r.abs.max).to be_within(1e-6).of(0.0)
595
+ end
596
+
597
+ it "solves a lower triangular linear system A * X = B with narrow B" do
598
+ b = NMatrix.new([3,2], [1,2,3,4,5,6], dtype: dtype)
599
+ x = a.solve(b, form: :lower_tri)
600
+ r = (a.dot(x) - b).abs.to_flat_a
601
+ expect(r.max).to be_within(1e-6).of(0.0)
602
+ end
603
+
604
+ it "solves a lower triangular linear system A * X = B with wide B" do
605
+ b = NMatrix.new([3,5], (1..15).to_a, dtype: dtype)
606
+ x = a.solve(b, form: :lower_tri)
607
+ r = (a.dot(x) - b).abs.to_flat_a
608
+ expect(r.max).to be_within(1e-6).of(0.0)
609
+ end
610
+ end
611
+
612
+ context "when form: :upper_tri" do
613
+ let(:a) { NMatrix.new([3,3], [3, 2, 1, 0, 2, 0.5, 0, 0, 9], dtype: dtype) }
614
+
615
+ it "solves an upper triangular linear system A * x = b with vector b" do
616
+ b = NMatrix.new([3,1], [1,2,3], dtype: dtype)
617
+ x = a.solve(b, form: :upper_tri)
618
+ r = a.dot(x) - b
619
+ expect(r.abs.max).to be_within(1e-6).of(0.0)
620
+ end
621
+
622
+ it "solves an upper triangular linear system A * X = B with narrow B" do
623
+ b = NMatrix.new([3,2], [1,2,3,4,5,6], dtype: dtype)
624
+ x = a.solve(b, form: :upper_tri)
625
+ r = (a.dot(x) - b).abs.to_flat_a
626
+ expect(r.max).to be_within(1e-6).of(0.0)
627
+ end
628
+
629
+ it "solves an upper triangular linear system A * X = B with a wide B" do
630
+ b = NMatrix.new([3,5], (1..15).to_a, dtype: dtype)
631
+ x = a.solve(b, form: :upper_tri)
632
+ r = (a.dot(x) - b).abs.to_flat_a
633
+ expect(r.max).to be_within(1e-6).of(0.0)
634
+ end
635
+ end
636
+
637
+ context "when form: :pos_def" do
638
+ let(:a) { NMatrix.new([3,3], [4, 1, 2, 1, 5, 3, 2, 3, 6], dtype: dtype) }
639
+
640
+ it "solves a linear system A * X = b with positive definite A and vector b" do
641
+ b = NMatrix.new([3,1], [6,4,8], dtype: dtype)
642
+ begin
643
+ x = a.solve(b, form: :pos_def)
644
+ expect(x).to be_within(1e-6).of(NMatrix.new([3,1], [1,0,1], dtype: dtype))
645
+ rescue NotImplementedError
646
+ "Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
647
+ end
648
+ end
649
+
650
+ it "solves a linear system A * X = B with positive definite A and matrix B" do
651
+ b = NMatrix.new([3,2], [8,3,14,13,14,19], dtype: dtype)
652
+ begin
653
+ x = a.solve(b, form: :pos_def)
654
+ expect(x).to be_within(1e-6).of(NMatrix.new([3,2], [1,-1,2,1,1,3], dtype: dtype))
655
+ rescue NotImplementedError
656
+ "Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
657
+ end
658
+ end
659
+ end
660
+ end
584
661
  end
585
662
 
586
663
  context "#hessenberg" do
@@ -138,3 +138,12 @@ end
138
138
  def integer_dtype? dtype
139
139
  [:byte,:int8,:int16,:int32,:int64].include?(dtype)
140
140
  end
141
+
142
+ # If a focus: true option is supplied to any test, running `rake spec focus=true`
143
+ # will run only the focused tests and nothing else.
144
+ if ENV["focus"] == "true"
145
+ RSpec.configure do |c|
146
+ c.filter_run :focus => true
147
+ end
148
+ end
149
+
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nmatrix-lapacke
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - John Woods
@@ -10,7 +10,7 @@ authors:
10
10
  autorequire:
11
11
  bindir: bin
12
12
  cert_chain: []
13
- date: 2015-08-25 00:00:00.000000000 Z
13
+ date: 2016-01-18 00:00:00.000000000 Z
14
14
  dependencies:
15
15
  - !ruby/object:Gem::Dependency
16
16
  name: nmatrix
@@ -18,14 +18,14 @@ dependencies:
18
18
  requirements:
19
19
  - - '='
20
20
  - !ruby/object:Gem::Version
21
- version: 0.2.0
21
+ version: 0.2.1
22
22
  type: :runtime
23
23
  prerelease: false
24
24
  version_requirements: !ruby/object:Gem::Requirement
25
25
  requirements:
26
26
  - - '='
27
27
  - !ruby/object:Gem::Version
28
- version: 0.2.0
28
+ version: 0.2.1
29
29
  description: For using linear algebra fuctions provided by LAPACK and BLAS
30
30
  email:
31
31
  - john.o.woods@gmail.com