numo-tiny_linalg 0.0.4 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,15 +13,41 @@ module Numo
13
13
  # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
14
  # by solving an ordinary or generalized eigenvalue problem.
15
15
  #
16
- # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
17
- # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
16
+ # @example
17
+ # require 'numo/tiny_linalg'
18
+ #
19
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
20
+ #
21
+ # x = Numo::DFloat.new(5, 3).rand - 0.5
22
+ # c = x.dot(x.transpose)
23
+ # vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
24
+ #
25
+ # pp vals
26
+ # # =>
27
+ # # Numo::DFloat#shape=[3]
28
+ # # [0.118795, 0.434252, 0.903245]
29
+ #
30
+ # pp vecs
31
+ # # =>
32
+ # # Numo::DFloat#shape=[5,3]
33
+ # # [[0.154178, 0.60661, -0.382961],
34
+ # # [-0.349761, -0.141726, -0.513178],
35
+ # # [0.739633, -0.468202, 0.105933],
36
+ # # [0.0519655, -0.471436, -0.701507],
37
+ # # [-0.551488, -0.412883, 0.294371]]
38
+ #
39
+ # pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
40
+ # # => 3.3306690738754696e-16
41
+ #
42
+ # @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
43
+ # @param b [Numo::NArray] The n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
18
44
  # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
19
45
  # @param vals_range [Range/Array]
20
46
  # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
21
47
  # If nil, all eigenvalues and eigenvectors are computed.
22
48
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
23
49
  # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
24
- # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
50
+ # @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
25
51
  def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
26
52
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
27
53
  raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
@@ -44,8 +70,8 @@ module Numo
44
70
  vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
45
71
  else
46
72
  sy_he_gv << 'x'
47
- il = vals_range.first + 1
48
- iu = vals_range.last + 1
73
+ il = vals_range.first(1)[0] + 1
74
+ iu = vals_range.last(1)[0] + 1
49
75
  _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
50
76
  sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
51
77
  )
@@ -56,7 +82,16 @@ module Numo
56
82
 
57
83
  # Computes the determinant of matrix.
58
84
  #
59
- # @param a [Numo::NArray] n-by-n square matrix.
85
+ # @example
86
+ # require 'numo/tiny_linalg'
87
+ #
88
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
89
+ #
90
+ # a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
91
+ # pp (3.0 - Numo::Linalg.det(a)).abs
92
+ # # => 1.3322676295501878e-15
93
+ #
94
+ # @param a [Numo::NArray] The n-by-n square matrix.
60
95
  # @return [Float/Complex] The determinant of `a`.
61
96
  def det(a)
62
97
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
@@ -82,7 +117,22 @@ module Numo
82
117
 
83
118
  # Computes the inverse matrix of a square matrix.
84
119
  #
85
- # @param a [Numo::NArray] n-by-n square matrix.
120
+ # @example
121
+ # require 'numo/tiny_linalg'
122
+ #
123
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
124
+ #
125
+ # a = Numo::DFloat.new(5, 5).rand
126
+ #
127
+ # inv_a = Numo::Linalg.inv(a)
128
+ #
129
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
130
+ # # => 7.019165976816745e-16
131
+ #
132
+ # pp inv_a.dot(a).sum
133
+ # # => 5.0
134
+ #
135
+ # @param a [Numo::NArray] The n-by-n square matrix.
86
136
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
87
137
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
88
138
  # @return [Numo::NArray] The inverse matrix of `a`.
@@ -106,10 +156,25 @@ module Numo
106
156
  end
107
157
  end
108
158
 
109
- # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
159
+ # Computes the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
160
+ #
161
+ # @example
162
+ # require 'numo/tiny_linalg'
163
+ #
164
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
165
+ #
166
+ # a = Numo::DFloat.new(5, 3).rand
167
+ #
168
+ # inv_a = Numo::Linalg.pinv(a)
169
+ #
170
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
171
+ # # => 1.1102230246251565e-15
172
+ #
173
+ # pp inv_a.dot(a).sum
174
+ # # => 3.0
110
175
  #
111
176
  # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
112
- # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
177
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
113
178
  # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
114
179
  # @return [Numo::NArray] The pseudo-inverse of `a`.
115
180
  def pinv(a, driver: 'svd', rcond: nil)
@@ -121,7 +186,35 @@ module Numo
121
186
  u.dot(vh[0...rank, true]).conj.transpose
122
187
  end
123
188
 
124
- # Compute QR decomposition of a matrix.
189
+ # Computes the QR decomposition of a matrix.
190
+ #
191
+ # @example
192
+ # require 'numo/tiny_linalg'
193
+ #
194
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
195
+ #
196
+ # x = Numo::DFloat.new(5, 3).rand
197
+ #
198
+ # q, r = Numo::Linalg.qr(x, mode: 'economic')
199
+ #
200
+ # pp q
201
+ # # =>
202
+ # # Numo::DFloat#shape=[5,3]
203
+ # # [[-0.0574417, 0.635216, 0.707116],
204
+ # # [-0.187002, -0.073192, 0.422088],
205
+ # # [-0.502239, 0.634088, -0.537489],
206
+ # # [-0.0473292, 0.134867, -0.0223491],
207
+ # # [-0.840979, -0.413385, 0.180096]]
208
+ #
209
+ # pp r
210
+ # # =>
211
+ # # Numo::DFloat#shape=[3,3]
212
+ # # [[-1.07508, -0.821334, -0.484586],
213
+ # # [0, 0.513035, 0.451868],
214
+ # # [0, 0, 0.678737]]
215
+ #
216
+ # pp (q.dot(r) - x).abs.max
217
+ # # => 3.885780586188048e-16
125
218
  #
126
219
  # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
127
220
  # @param mode [String] The mode of decomposition.
@@ -129,9 +222,8 @@ module Numo
129
222
  # - "r" -- returns only R,
130
223
  # - "economic" -- returns both Q [m, n] and R [n, n],
131
224
  # - "raw" -- returns QR and TAU (LAPACK geqrf results).
132
- # @return [Numo::NArray] if mode='r'
133
- # @return [Array<Numo::NArray,Numo::NArray>] if mode='reduce' or mode='economic'
134
- # @return [Array<Numo::NArray,Numo::NArray>] if mode='raw' (LAPACK geqrf result)
225
+ # @return [Numo::NArray] if mode='r'.
226
+ # @return [Array<Numo::NArray>] if mode='reduce' or 'economic' or 'raw'.
135
227
  def qr(a, mode: 'reduce')
136
228
  raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
137
229
  raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
@@ -166,60 +258,91 @@ module Numo
166
258
 
167
259
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
168
260
  #
169
- # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
261
+ # @example
262
+ # require 'numo/tiny_linalg'
263
+ #
264
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
265
+ #
266
+ # a = Numo::DFloat.new(3, 3).rand
267
+ # b = Numo::DFloat.eye(3)
268
+ #
269
+ # x = Numo::Linalg.solve(a, b)
270
+ #
271
+ # pp x
272
+ # # =>
273
+ # # Numo::DFloat#shape=[3,3]
274
+ # # [[-2.12332, 4.74868, 0.326773],
275
+ # # [1.38043, -3.79074, 1.25355],
276
+ # # [0.775187, 1.41032, -0.613774]]
277
+ #
278
+ # pp (b - a.dot(x)).abs.max
279
+ # # => 2.1081041547796492e-16
280
+ #
281
+ # @param a [Numo::NArray] The n-by-n square matrix.
170
282
  # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
171
283
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
172
284
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
173
285
  # @return [Numo::NArray] The solusion vector / matrix `x`.
174
286
  def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
175
- case blas_char(a, b)
176
- when 'd'
177
- Lapack.dgesv(a.dup, b.dup)[1]
178
- when 's'
179
- Lapack.sgesv(a.dup, b.dup)[1]
180
- when 'z'
181
- Lapack.zgesv(a.dup, b.dup)[1]
182
- when 'c'
183
- Lapack.cgesv(a.dup, b.dup)[1]
184
- end
287
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
288
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
289
+
290
+ bchr = blas_char(a, b)
291
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
292
+
293
+ gesv = "#{bchr}gesv".to_sym
294
+ Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
185
295
  end
186
296
 
187
- # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
297
+ # Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
298
+ #
299
+ # @example
300
+ # require 'numo/tiny_linalg'
301
+ #
302
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
303
+ #
304
+ # x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
305
+ # pp x
306
+ # # =>
307
+ # # Numo::DFloat#shape=[5,3]
308
+ # # [[0.104945, 0.0284236, 0.117406],
309
+ # # [0.862634, 0.210945, 0.922135],
310
+ # # [0.324507, 0.0752655, 0.339158],
311
+ # # [0.67085, 0.102594, 0.600882],
312
+ # # [0.404631, 0.116868, 0.46644]]
313
+ #
314
+ # s, u, vt = Numo::Linalg.svd(x, job: 'S')
315
+ #
316
+ # z = u.dot(s.diag).dot(vt)
317
+ # pp z
318
+ # # =>
319
+ # # Numo::DFloat#shape=[5,3]
320
+ # # [[0.104945, 0.0284236, 0.117406],
321
+ # # [0.862634, 0.210945, 0.922135],
322
+ # # [0.324507, 0.0752655, 0.339158],
323
+ # # [0.67085, 0.102594, 0.600882],
324
+ # # [0.404631, 0.116868, 0.46644]]
325
+ #
326
+ # pp (x - z).abs.max
327
+ # # => 4.440892098500626e-16
188
328
  #
189
329
  # @param a [Numo::NArray] Matrix to be decomposed.
190
- # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
191
- # @param job [String] Job option ('A', 'S', or 'N').
192
- # @return [Array<Numo::NArray>] Singular values and singular vectors ([s, u, vt]).
330
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
331
+ # @param job [String] The job option ('A', 'S', or 'N').
332
+ # @return [Array<Numo::NArray>] The singular values and singular vectors ([s, u, vt]).
193
333
  def svd(a, driver: 'svd', job: 'A')
194
334
  raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
195
335
 
336
+ bchr = blas_char(a)
337
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
338
+
196
339
  case driver.to_s
197
340
  when 'sdd'
198
- s, u, vt, info = case a
199
- when Numo::DFloat
200
- Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
201
- when Numo::SFloat
202
- Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
203
- when Numo::DComplex
204
- Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
205
- when Numo::SComplex
206
- Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
207
- else
208
- raise ArgumentError, "invalid array type: #{a.class}"
209
- end
341
+ gesdd = "#{bchr}gesdd".to_sym
342
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
210
343
  when 'svd'
211
- s, u, vt, info = case a
212
- when Numo::DFloat
213
- Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
214
- when Numo::SFloat
215
- Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
216
- when Numo::DComplex
217
- Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
218
- when Numo::SComplex
219
- Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
220
- else
221
- raise ArgumentError, "invalid array type: #{a.class}"
222
- end
344
+ gesvd = "#{bchr}gesvd".to_sym
345
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
223
346
  else
224
347
  raise ArgumentError, "invalid driver: #{driver}"
225
348
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.4
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-06 00:00:00.000000000 Z
11
+ date: 2023-08-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -62,6 +62,7 @@ files:
62
62
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
63
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
64
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp
65
+ - ext/numo/tiny_linalg/util.hpp
65
66
  - lib/numo/tiny_linalg.rb
66
67
  - lib/numo/tiny_linalg/version.rb
67
68
  - vendor/tmp/.gitkeep