numo-tiny_linalg 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,32 @@ module Numo
13
13
  # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
14
  # by solving an ordinary or generalized eigenvalue problem.
15
15
  #
16
+ # @example
17
+ # require 'numo/tiny_linalg'
18
+ #
19
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
20
+ #
21
+ # x = Numo::DFloat.new(5, 3).rand - 0.5
22
+ # c = x.dot(x.transpose)
23
+ # vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
24
+ #
25
+ # pp vals
26
+ # # =>
27
+ # # Numo::DFloat#shape=[3]
28
+ # # [0.118795, 0.434252, 0.903245]
29
+ #
30
+ # pp vecs
31
+ # # =>
32
+ # # Numo::DFloat#shape=[5,3]
33
+ # # [[0.154178, 0.60661, -0.382961],
34
+ # # [-0.349761, -0.141726, -0.513178],
35
+ # # [0.739633, -0.468202, 0.105933],
36
+ # # [0.0519655, -0.471436, -0.701507],
37
+ # # [-0.551488, -0.412883, 0.294371]]
38
+ #
39
+ # pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
40
+ # # => 3.3306690738754696e-16
41
+ #
16
42
  # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
17
43
  # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
18
44
  # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
@@ -56,6 +82,15 @@ module Numo
56
82
 
57
83
  # Computes the determinant of matrix.
58
84
  #
85
+ # @example
86
+ # require 'numo/tiny_linalg'
87
+ #
88
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
89
+ #
90
+ # a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
91
+ # pp (3.0 - Numo::Linalg.det(a)).abs
92
+ # # => 1.3322676295501878e-15
93
+ #
59
94
  # @param a [Numo::NArray] n-by-n square matrix.
60
95
  # @return [Float/Complex] The determinant of `a`.
61
96
  def det(a)
@@ -82,6 +117,21 @@ module Numo
82
117
 
83
118
  # Computes the inverse matrix of a square matrix.
84
119
  #
120
+ # @example
121
+ # require 'numo/tiny_linalg'
122
+ #
123
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
124
+ #
125
+ # a = Numo::DFloat.new(5, 5).rand
126
+ #
127
+ # inv_a = Numo::Linalg.inv(a)
128
+ #
129
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
130
+ # # => 7.019165976816745e-16
131
+ #
132
+ # pp inv_a.dot(a).sum
133
+ # # => 5.0
134
+ #
85
135
  # @param a [Numo::NArray] n-by-n square matrix.
86
136
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
87
137
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
@@ -108,6 +158,21 @@ module Numo
108
158
 
109
159
  # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
110
160
  #
161
+ # @example
162
+ # require 'numo/tiny_linalg'
163
+ #
164
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
165
+ #
166
+ # a = Numo::DFloat.new(5, 3).rand
167
+ #
168
+ # inv_a = Numo::Linalg.pinv(a)
169
+ #
170
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
171
+ # # => 1.1102230246251565e-15
172
+ #
173
+ # pp inv_a.dot(a).sum
174
+ # # => 3.0
175
+ #
111
176
  # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
112
177
  # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
113
178
  # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
@@ -123,6 +188,34 @@ module Numo
123
188
 
124
189
  # Compute QR decomposition of a matrix.
125
190
  #
191
+ # @example
192
+ # require 'numo/tiny_linalg'
193
+ #
194
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
195
+ #
196
+ # x = Numo::DFloat.new(5, 3).rand
197
+ #
198
+ # q, r = Numo::Linalg.qr(x, mode: 'economic')
199
+ #
200
+ # pp q
201
+ # # =>
202
+ # # Numo::DFloat#shape=[5,3]
203
+ # # [[-0.0574417, 0.635216, 0.707116],
204
+ # # [-0.187002, -0.073192, 0.422088],
205
+ # # [-0.502239, 0.634088, -0.537489],
206
+ # # [-0.0473292, 0.134867, -0.0223491],
207
+ # # [-0.840979, -0.413385, 0.180096]]
208
+ #
209
+ # pp r
210
+ # # =>
211
+ # # Numo::DFloat#shape=[3,3]
212
+ # # [[-1.07508, -0.821334, -0.484586],
213
+ # # [0, 0.513035, 0.451868],
214
+ # # [0, 0, 0.678737]]
215
+ #
216
+ # pp (q.dot(r) - x).abs.max
217
+ # # => 3.885780586188048e-16
218
+ #
126
219
  # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
127
220
  # @param mode [String] The mode of decomposition.
128
221
  # - "reduce" -- returns both Q [m, m] and R [m, n],
@@ -166,26 +259,74 @@ module Numo
166
259
 
167
260
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
168
261
  #
169
- # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
262
+ # @example
263
+ # require 'numo/tiny_linalg'
264
+ #
265
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
266
+ #
267
+ # a = Numo::DFloat.new(3, 3).rand
268
+ # b = Numo::DFloat.eye(3)
269
+ #
270
+ # x = Numo::Linalg.solve(a, b)
271
+ #
272
+ # pp x
273
+ # # =>
274
+ # # Numo::DFloat#shape=[3,3]
275
+ # # [[-2.12332, 4.74868, 0.326773],
276
+ # # [1.38043, -3.79074, 1.25355],
277
+ # # [0.775187, 1.41032, -0.613774]]
278
+ #
279
+ # pp (b - a.dot(x)).abs.max
280
+ # # => 2.1081041547796492e-16
281
+ #
282
+ # @param a [Numo::NArray] The n-by-n square matrix.
170
283
  # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
171
284
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
172
285
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
173
286
  # @return [Numo::NArray] The solusion vector / matrix `x`.
174
287
  def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
175
- case blas_char(a, b)
176
- when 'd'
177
- Lapack.dgesv(a.dup, b.dup)[1]
178
- when 's'
179
- Lapack.sgesv(a.dup, b.dup)[1]
180
- when 'z'
181
- Lapack.zgesv(a.dup, b.dup)[1]
182
- when 'c'
183
- Lapack.cgesv(a.dup, b.dup)[1]
184
- end
288
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
289
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
290
+
291
+ bchr = blas_char(a, b)
292
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
293
+
294
+ gesv = "#{bchr}gesv".to_sym
295
+ Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
185
296
  end
186
297
 
187
298
  # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
188
299
  #
300
+ # @example
301
+ # require 'numo/tiny_linalg'
302
+ #
303
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
304
+ #
305
+ # x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
306
+ # pp x
307
+ # # =>
308
+ # # Numo::DFloat#shape=[5,3]
309
+ # # [[0.104945, 0.0284236, 0.117406],
310
+ # # [0.862634, 0.210945, 0.922135],
311
+ # # [0.324507, 0.0752655, 0.339158],
312
+ # # [0.67085, 0.102594, 0.600882],
313
+ # # [0.404631, 0.116868, 0.46644]]
314
+ #
315
+ # s, u, vt = Numo::Linalg.svd(x, job: 'S')
316
+ #
317
+ # z = u.dot(s.diag).dot(vt)
318
+ # pp z
319
+ # # =>
320
+ # # Numo::DFloat#shape=[5,3]
321
+ # # [[0.104945, 0.0284236, 0.117406],
322
+ # # [0.862634, 0.210945, 0.922135],
323
+ # # [0.324507, 0.0752655, 0.339158],
324
+ # # [0.67085, 0.102594, 0.600882],
325
+ # # [0.404631, 0.116868, 0.46644]]
326
+ #
327
+ # pp (x - z).abs.max
328
+ # # => 4.440892098500626e-16
329
+ #
189
330
  # @param a [Numo::NArray] Matrix to be decomposed.
190
331
  # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
191
332
  # @param job [String] Job option ('A', 'S', or 'N').
@@ -193,33 +334,16 @@ module Numo
193
334
  def svd(a, driver: 'svd', job: 'A')
194
335
  raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
195
336
 
337
+ bchr = blas_char(a)
338
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
339
+
196
340
  case driver.to_s
197
341
  when 'sdd'
198
- s, u, vt, info = case a
199
- when Numo::DFloat
200
- Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
201
- when Numo::SFloat
202
- Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
203
- when Numo::DComplex
204
- Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
205
- when Numo::SComplex
206
- Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
207
- else
208
- raise ArgumentError, "invalid array type: #{a.class}"
209
- end
342
+ gesdd = "#{bchr}gesdd".to_sym
343
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
210
344
  when 'svd'
211
- s, u, vt, info = case a
212
- when Numo::DFloat
213
- Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
214
- when Numo::SFloat
215
- Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
216
- when Numo::DComplex
217
- Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
218
- when Numo::SComplex
219
- Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
220
- else
221
- raise ArgumentError, "invalid array type: #{a.class}"
222
- end
345
+ gesvd = "#{bchr}gesvd".to_sym
346
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
223
347
  else
224
348
  raise ArgumentError, "invalid driver: #{driver}"
225
349
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.4
4
+ version: 0.1.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
@@ -62,6 +62,7 @@ files:
62
62
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
63
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
64
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp
65
+ - ext/numo/tiny_linalg/util.hpp
65
66
  - lib/numo/tiny_linalg.rb
66
67
  - lib/numo/tiny_linalg/version.rb
67
68
  - vendor/tmp/.gitkeep