numo-tiny_linalg 0.0.4 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,6 +13,32 @@ module Numo
13
13
  # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
14
  # by solving an ordinary or generalized eigenvalue problem.
15
15
  #
16
+ # @example
17
+ # require 'numo/tiny_linalg'
18
+ #
19
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
20
+ #
21
+ # x = Numo::DFloat.new(5, 3).rand - 0.5
22
+ # c = x.dot(x.transpose)
23
+ # vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
24
+ #
25
+ # pp vals
26
+ # # =>
27
+ # # Numo::DFloat#shape=[3]
28
+ # # [0.118795, 0.434252, 0.903245]
29
+ #
30
+ # pp vecs
31
+ # # =>
32
+ # # Numo::DFloat#shape=[5,3]
33
+ # # [[0.154178, 0.60661, -0.382961],
34
+ # # [-0.349761, -0.141726, -0.513178],
35
+ # # [0.739633, -0.468202, 0.105933],
36
+ # # [0.0519655, -0.471436, -0.701507],
37
+ # # [-0.551488, -0.412883, 0.294371]]
38
+ #
39
+ # pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
40
+ # # => 3.3306690738754696e-16
41
+ #
16
42
  # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
17
43
  # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
18
44
  # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
@@ -56,6 +82,15 @@ module Numo
56
82
 
57
83
  # Computes the determinant of matrix.
58
84
  #
85
+ # @example
86
+ # require 'numo/tiny_linalg'
87
+ #
88
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
89
+ #
90
+ # a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
91
+ # pp (3.0 - Numo::Linalg.det(a)).abs
92
+ # # => 1.3322676295501878e-15
93
+ #
59
94
  # @param a [Numo::NArray] n-by-n square matrix.
60
95
  # @return [Float/Complex] The determinant of `a`.
61
96
  def det(a)
@@ -82,6 +117,21 @@ module Numo
82
117
 
83
118
  # Computes the inverse matrix of a square matrix.
84
119
  #
120
+ # @example
121
+ # require 'numo/tiny_linalg'
122
+ #
123
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
124
+ #
125
+ # a = Numo::DFloat.new(5, 5).rand
126
+ #
127
+ # inv_a = Numo::Linalg.inv(a)
128
+ #
129
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
130
+ # # => 7.019165976816745e-16
131
+ #
132
+ # pp inv_a.dot(a).sum
133
+ # # => 5.0
134
+ #
85
135
  # @param a [Numo::NArray] n-by-n square matrix.
86
136
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
87
137
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
@@ -108,6 +158,21 @@ module Numo
108
158
 
109
159
  # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
110
160
  #
161
+ # @example
162
+ # require 'numo/tiny_linalg'
163
+ #
164
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
165
+ #
166
+ # a = Numo::DFloat.new(5, 3).rand
167
+ #
168
+ # inv_a = Numo::Linalg.pinv(a)
169
+ #
170
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
171
+ # # => 1.1102230246251565e-15
172
+ #
173
+ # pp inv_a.dot(a).sum
174
+ # # => 3.0
175
+ #
111
176
  # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
112
177
  # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
113
178
  # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
@@ -123,6 +188,34 @@ module Numo
123
188
 
124
189
  # Compute QR decomposition of a matrix.
125
190
  #
191
+ # @example
192
+ # require 'numo/tiny_linalg'
193
+ #
194
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
195
+ #
196
+ # x = Numo::DFloat.new(5, 3).rand
197
+ #
198
+ # q, r = Numo::Linalg.qr(x, mode: 'economic')
199
+ #
200
+ # pp q
201
+ # # =>
202
+ # # Numo::DFloat#shape=[5,3]
203
+ # # [[-0.0574417, 0.635216, 0.707116],
204
+ # # [-0.187002, -0.073192, 0.422088],
205
+ # # [-0.502239, 0.634088, -0.537489],
206
+ # # [-0.0473292, 0.134867, -0.0223491],
207
+ # # [-0.840979, -0.413385, 0.180096]]
208
+ #
209
+ # pp r
210
+ # # =>
211
+ # # Numo::DFloat#shape=[3,3]
212
+ # # [[-1.07508, -0.821334, -0.484586],
213
+ # # [0, 0.513035, 0.451868],
214
+ # # [0, 0, 0.678737]]
215
+ #
216
+ # pp (q.dot(r) - x).abs.max
217
+ # # => 3.885780586188048e-16
218
+ #
126
219
  # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
127
220
  # @param mode [String] The mode of decomposition.
128
221
  # - "reduce" -- returns both Q [m, m] and R [m, n],
@@ -166,26 +259,74 @@ module Numo
166
259
 
167
260
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
168
261
  #
169
- # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
262
+ # @example
263
+ # require 'numo/tiny_linalg'
264
+ #
265
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
266
+ #
267
+ # a = Numo::DFloat.new(3, 3).rand
268
+ # b = Numo::DFloat.eye(3)
269
+ #
270
+ # x = Numo::Linalg.solve(a, b)
271
+ #
272
+ # pp x
273
+ # # =>
274
+ # # Numo::DFloat#shape=[3,3]
275
+ # # [[-2.12332, 4.74868, 0.326773],
276
+ # # [1.38043, -3.79074, 1.25355],
277
+ # # [0.775187, 1.41032, -0.613774]]
278
+ #
279
+ # pp (b - a.dot(x)).abs.max
280
+ # # => 2.1081041547796492e-16
281
+ #
282
+ # @param a [Numo::NArray] The n-by-n square matrix.
170
283
  # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
171
284
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
172
285
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
173
286
  # @return [Numo::NArray] The solusion vector / matrix `x`.
174
287
  def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
175
- case blas_char(a, b)
176
- when 'd'
177
- Lapack.dgesv(a.dup, b.dup)[1]
178
- when 's'
179
- Lapack.sgesv(a.dup, b.dup)[1]
180
- when 'z'
181
- Lapack.zgesv(a.dup, b.dup)[1]
182
- when 'c'
183
- Lapack.cgesv(a.dup, b.dup)[1]
184
- end
288
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
289
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
290
+
291
+ bchr = blas_char(a, b)
292
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
293
+
294
+ gesv = "#{bchr}gesv".to_sym
295
+ Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
185
296
  end
186
297
 
187
298
  # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
188
299
  #
300
+ # @example
301
+ # require 'numo/tiny_linalg'
302
+ #
303
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
304
+ #
305
+ # x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
306
+ # pp x
307
+ # # =>
308
+ # # Numo::DFloat#shape=[5,3]
309
+ # # [[0.104945, 0.0284236, 0.117406],
310
+ # # [0.862634, 0.210945, 0.922135],
311
+ # # [0.324507, 0.0752655, 0.339158],
312
+ # # [0.67085, 0.102594, 0.600882],
313
+ # # [0.404631, 0.116868, 0.46644]]
314
+ #
315
+ # s, u, vt = Numo::Linalg.svd(x, job: 'S')
316
+ #
317
+ # z = u.dot(s.diag).dot(vt)
318
+ # pp z
319
+ # # =>
320
+ # # Numo::DFloat#shape=[5,3]
321
+ # # [[0.104945, 0.0284236, 0.117406],
322
+ # # [0.862634, 0.210945, 0.922135],
323
+ # # [0.324507, 0.0752655, 0.339158],
324
+ # # [0.67085, 0.102594, 0.600882],
325
+ # # [0.404631, 0.116868, 0.46644]]
326
+ #
327
+ # pp (x - z).abs.max
328
+ # # => 4.440892098500626e-16
329
+ #
189
330
  # @param a [Numo::NArray] Matrix to be decomposed.
190
331
  # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
191
332
  # @param job [String] Job option ('A', 'S', or 'N').
@@ -193,33 +334,16 @@ module Numo
193
334
  def svd(a, driver: 'svd', job: 'A')
194
335
  raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
195
336
 
337
+ bchr = blas_char(a)
338
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
339
+
196
340
  case driver.to_s
197
341
  when 'sdd'
198
- s, u, vt, info = case a
199
- when Numo::DFloat
200
- Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
201
- when Numo::SFloat
202
- Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
203
- when Numo::DComplex
204
- Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
205
- when Numo::SComplex
206
- Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
207
- else
208
- raise ArgumentError, "invalid array type: #{a.class}"
209
- end
342
+ gesdd = "#{bchr}gesdd".to_sym
343
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
210
344
  when 'svd'
211
- s, u, vt, info = case a
212
- when Numo::DFloat
213
- Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
214
- when Numo::SFloat
215
- Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
216
- when Numo::DComplex
217
- Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
218
- when Numo::SComplex
219
- Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
220
- else
221
- raise ArgumentError, "invalid array type: #{a.class}"
222
- end
345
+ gesvd = "#{bchr}gesvd".to_sym
346
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
223
347
  else
224
348
  raise ArgumentError, "invalid driver: #{driver}"
225
349
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.4
4
+ version: 0.1.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
@@ -62,6 +62,7 @@ files:
62
62
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
63
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
64
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp
65
+ - ext/numo/tiny_linalg/util.hpp
65
66
  - lib/numo/tiny_linalg.rb
66
67
  - lib/numo/tiny_linalg/version.rb
67
68
  - vendor/tmp/.gitkeep