numo-tiny_linalg 0.0.3 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -10,8 +10,87 @@ module Numo
10
10
  module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
+ # by solving an ordinary or generalized eigenvalue problem.
15
+ #
16
+ # @example
17
+ # require 'numo/tiny_linalg'
18
+ #
19
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
20
+ #
21
+ # x = Numo::DFloat.new(5, 3).rand - 0.5
22
+ # c = x.dot(x.transpose)
23
+ # vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
24
+ #
25
+ # pp vals
26
+ # # =>
27
+ # # Numo::DFloat#shape=[3]
28
+ # # [0.118795, 0.434252, 0.903245]
29
+ #
30
+ # pp vecs
31
+ # # =>
32
+ # # Numo::DFloat#shape=[5,3]
33
+ # # [[0.154178, 0.60661, -0.382961],
34
+ # # [-0.349761, -0.141726, -0.513178],
35
+ # # [0.739633, -0.468202, 0.105933],
36
+ # # [0.0519655, -0.471436, -0.701507],
37
+ # # [-0.551488, -0.412883, 0.294371]]
38
+ #
39
+ # pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
40
+ # # => 3.3306690738754696e-16
41
+ #
42
+ # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
43
+ # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
44
+ # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
45
+ # @param vals_range [Range/Array]
46
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
47
+ # If nil, all eigenvalues and eigenvectors are computed.
48
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
49
+ # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
50
+ # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
51
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
52
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
53
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
54
+
55
+ bchr = blas_char(a)
56
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
57
+
58
+ unless b.nil?
59
+ raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
60
+ raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
61
+ raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
62
+ end
63
+
64
+ jobz = vals_only ? 'N' : 'V'
65
+ b = a.class.eye(a.shape[0]) if b.nil?
66
+ sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
67
+
68
+ if vals_range.nil?
69
+ sy_he_gv << 'd' if turbo
70
+ vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
71
+ else
72
+ sy_he_gv << 'x'
73
+ il = vals_range.first + 1
74
+ iu = vals_range.last + 1
75
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
76
+ sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
77
+ )
78
+ end
79
+ vecs = nil if vals_only
80
+ [vals, vecs]
81
+ end
82
+
13
83
  # Computes the determinant of matrix.
14
84
  #
85
+ # @example
86
+ # require 'numo/tiny_linalg'
87
+ #
88
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
89
+ #
90
+ # a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
91
+ # pp (3.0 - Numo::Linalg.det(a)).abs
92
+ # # => 1.3322676295501878e-15
93
+ #
15
94
  # @param a [Numo::NArray] n-by-n square matrix.
16
95
  # @return [Float/Complex] The determinant of `a`.
17
96
  def det(a)
@@ -38,6 +117,21 @@ module Numo
38
117
 
39
118
  # Computes the inverse matrix of a square matrix.
40
119
  #
120
+ # @example
121
+ # require 'numo/tiny_linalg'
122
+ #
123
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
124
+ #
125
+ # a = Numo::DFloat.new(5, 5).rand
126
+ #
127
+ # inv_a = Numo::Linalg.inv(a)
128
+ #
129
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
130
+ # # => 7.019165976816745e-16
131
+ #
132
+ # pp inv_a.dot(a).sum
133
+ # # => 5.0
134
+ #
41
135
  # @param a [Numo::NArray] n-by-n square matrix.
42
136
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
43
137
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
@@ -64,6 +158,21 @@ module Numo
64
158
 
65
159
  # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
66
160
  #
161
+ # @example
162
+ # require 'numo/tiny_linalg'
163
+ #
164
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
165
+ #
166
+ # a = Numo::DFloat.new(5, 3).rand
167
+ #
168
+ # inv_a = Numo::Linalg.pinv(a)
169
+ #
170
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
171
+ # # => 1.1102230246251565e-15
172
+ #
173
+ # pp inv_a.dot(a).sum
174
+ # # => 3.0
175
+ #
67
176
  # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
68
177
  # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
69
178
  # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
@@ -79,6 +188,34 @@ module Numo
79
188
 
80
189
  # Compute QR decomposition of a matrix.
81
190
  #
191
+ # @example
192
+ # require 'numo/tiny_linalg'
193
+ #
194
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
195
+ #
196
+ # x = Numo::DFloat.new(5, 3).rand
197
+ #
198
+ # q, r = Numo::Linalg.qr(x, mode: 'economic')
199
+ #
200
+ # pp q
201
+ # # =>
202
+ # # Numo::DFloat#shape=[5,3]
203
+ # # [[-0.0574417, 0.635216, 0.707116],
204
+ # # [-0.187002, -0.073192, 0.422088],
205
+ # # [-0.502239, 0.634088, -0.537489],
206
+ # # [-0.0473292, 0.134867, -0.0223491],
207
+ # # [-0.840979, -0.413385, 0.180096]]
208
+ #
209
+ # pp r
210
+ # # =>
211
+ # # Numo::DFloat#shape=[3,3]
212
+ # # [[-1.07508, -0.821334, -0.484586],
213
+ # # [0, 0.513035, 0.451868],
214
+ # # [0, 0, 0.678737]]
215
+ #
216
+ # pp (q.dot(r) - x).abs.max
217
+ # # => 3.885780586188048e-16
218
+ #
82
219
  # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
83
220
  # @param mode [String] The mode of decomposition.
84
221
  # - "reduce" -- returns both Q [m, m] and R [m, n],
@@ -122,26 +259,74 @@ module Numo
122
259
 
123
260
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
124
261
  #
125
- # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
262
+ # @example
263
+ # require 'numo/tiny_linalg'
264
+ #
265
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
266
+ #
267
+ # a = Numo::DFloat.new(3, 3).rand
268
+ # b = Numo::DFloat.eye(3)
269
+ #
270
+ # x = Numo::Linalg.solve(a, b)
271
+ #
272
+ # pp x
273
+ # # =>
274
+ # # Numo::DFloat#shape=[3,3]
275
+ # # [[-2.12332, 4.74868, 0.326773],
276
+ # # [1.38043, -3.79074, 1.25355],
277
+ # # [0.775187, 1.41032, -0.613774]]
278
+ #
279
+ # pp (b - a.dot(x)).abs.max
280
+ # # => 2.1081041547796492e-16
281
+ #
282
+ # @param a [Numo::NArray] The n-by-n square matrix.
126
283
  # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
127
284
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
128
285
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
129
286
  # @return [Numo::NArray] The solusion vector / matrix `x`.
130
287
  def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
131
- case blas_char(a, b)
132
- when 'd'
133
- Lapack.dgesv(a.dup, b.dup)[1]
134
- when 's'
135
- Lapack.sgesv(a.dup, b.dup)[1]
136
- when 'z'
137
- Lapack.zgesv(a.dup, b.dup)[1]
138
- when 'c'
139
- Lapack.cgesv(a.dup, b.dup)[1]
140
- end
288
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
289
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
290
+
291
+ bchr = blas_char(a, b)
292
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
293
+
294
+ gesv = "#{bchr}gesv".to_sym
295
+ Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
141
296
  end
142
297
 
143
298
  # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
144
299
  #
300
+ # @example
301
+ # require 'numo/tiny_linalg'
302
+ #
303
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
304
+ #
305
+ # x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
306
+ # pp x
307
+ # # =>
308
+ # # Numo::DFloat#shape=[5,3]
309
+ # # [[0.104945, 0.0284236, 0.117406],
310
+ # # [0.862634, 0.210945, 0.922135],
311
+ # # [0.324507, 0.0752655, 0.339158],
312
+ # # [0.67085, 0.102594, 0.600882],
313
+ # # [0.404631, 0.116868, 0.46644]]
314
+ #
315
+ # s, u, vt = Numo::Linalg.svd(x, job: 'S')
316
+ #
317
+ # z = u.dot(s.diag).dot(vt)
318
+ # pp z
319
+ # # =>
320
+ # # Numo::DFloat#shape=[5,3]
321
+ # # [[0.104945, 0.0284236, 0.117406],
322
+ # # [0.862634, 0.210945, 0.922135],
323
+ # # [0.324507, 0.0752655, 0.339158],
324
+ # # [0.67085, 0.102594, 0.600882],
325
+ # # [0.404631, 0.116868, 0.46644]]
326
+ #
327
+ # pp (x - z).abs.max
328
+ # # => 4.440892098500626e-16
329
+ #
145
330
  # @param a [Numo::NArray] Matrix to be decomposed.
146
331
  # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
147
332
  # @param job [String] Job option ('A', 'S', or 'N').
@@ -149,33 +334,16 @@ module Numo
149
334
  def svd(a, driver: 'svd', job: 'A')
150
335
  raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
151
336
 
337
+ bchr = blas_char(a)
338
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
339
+
152
340
  case driver.to_s
153
341
  when 'sdd'
154
- s, u, vt, info = case a
155
- when Numo::DFloat
156
- Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
157
- when Numo::SFloat
158
- Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
159
- when Numo::DComplex
160
- Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
161
- when Numo::SComplex
162
- Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
163
- else
164
- raise ArgumentError, "invalid array type: #{a.class}"
165
- end
342
+ gesdd = "#{bchr}gesdd".to_sym
343
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
166
344
  when 'svd'
167
- s, u, vt, info = case a
168
- when Numo::DFloat
169
- Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
170
- when Numo::SFloat
171
- Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
172
- when Numo::DComplex
173
- Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
174
- when Numo::SComplex
175
- Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
176
- else
177
- raise ArgumentError, "invalid array type: #{a.class}"
178
- end
345
+ gesvd = "#{bchr}gesvd".to_sym
346
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
179
347
  else
180
348
  raise ArgumentError, "invalid driver: #{driver}"
181
349
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.3
4
+ version: 0.1.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-02 00:00:00.000000000 Z
11
+ date: 2023-08-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -52,10 +52,17 @@ files:
52
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
54
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/hegv.hpp
56
+ - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
+ - ext/numo/tiny_linalg/lapack/hegvx.hpp
55
58
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
59
+ - ext/numo/tiny_linalg/lapack/sygv.hpp
60
+ - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
+ - ext/numo/tiny_linalg/lapack/sygvx.hpp
56
62
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
57
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
58
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp
65
+ - ext/numo/tiny_linalg/util.hpp
59
66
  - lib/numo/tiny_linalg.rb
60
67
  - lib/numo/tiny_linalg/version.rb
61
68
  - vendor/tmp/.gitkeep