numo-tiny_linalg 0.0.3 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,8 +10,87 @@ module Numo
10
10
  module TinyLinalg # rubocop:disable Metrics/ModuleLength
11
11
  module_function
12
12
 
13
+ # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
+ # by solving an ordinary or generalized eigenvalue problem.
15
+ #
16
+ # @example
17
+ # require 'numo/tiny_linalg'
18
+ #
19
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
20
+ #
21
+ # x = Numo::DFloat.new(5, 3).rand - 0.5
22
+ # c = x.dot(x.transpose)
23
+ # vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
24
+ #
25
+ # pp vals
26
+ # # =>
27
+ # # Numo::DFloat#shape=[3]
28
+ # # [0.118795, 0.434252, 0.903245]
29
+ #
30
+ # pp vecs
31
+ # # =>
32
+ # # Numo::DFloat#shape=[5,3]
33
+ # # [[0.154178, 0.60661, -0.382961],
34
+ # # [-0.349761, -0.141726, -0.513178],
35
+ # # [0.739633, -0.468202, 0.105933],
36
+ # # [0.0519655, -0.471436, -0.701507],
37
+ # # [-0.551488, -0.412883, 0.294371]]
38
+ #
39
+ # pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
40
+ # # => 3.3306690738754696e-16
41
+ #
42
+ # @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
43
+ # @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
44
+ # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
45
+ # @param vals_range [Range/Array]
46
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
47
+ # If nil, all eigenvalues and eigenvectors are computed.
48
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
49
+ # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
50
+ # @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
51
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
52
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
53
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
54
+
55
+ bchr = blas_char(a)
56
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
57
+
58
+ unless b.nil?
59
+ raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
60
+ raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
61
+ raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
62
+ end
63
+
64
+ jobz = vals_only ? 'N' : 'V'
65
+ b = a.class.eye(a.shape[0]) if b.nil?
66
+ sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
67
+
68
+ if vals_range.nil?
69
+ sy_he_gv << 'd' if turbo
70
+ vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
71
+ else
72
+ sy_he_gv << 'x'
73
+ il = vals_range.first + 1
74
+ iu = vals_range.last + 1
75
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
76
+ sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
77
+ )
78
+ end
79
+ vecs = nil if vals_only
80
+ [vals, vecs]
81
+ end
82
+
13
83
  # Computes the determinant of matrix.
14
84
  #
85
+ # @example
86
+ # require 'numo/tiny_linalg'
87
+ #
88
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
89
+ #
90
+ # a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
91
+ # pp (3.0 - Numo::Linalg.det(a)).abs
92
+ # # => 1.3322676295501878e-15
93
+ #
15
94
  # @param a [Numo::NArray] n-by-n square matrix.
16
95
  # @return [Float/Complex] The determinant of `a`.
17
96
  def det(a)
@@ -38,6 +117,21 @@ module Numo
38
117
 
39
118
  # Computes the inverse matrix of a square matrix.
40
119
  #
120
+ # @example
121
+ # require 'numo/tiny_linalg'
122
+ #
123
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
124
+ #
125
+ # a = Numo::DFloat.new(5, 5).rand
126
+ #
127
+ # inv_a = Numo::Linalg.inv(a)
128
+ #
129
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
130
+ # # => 7.019165976816745e-16
131
+ #
132
+ # pp inv_a.dot(a).sum
133
+ # # => 5.0
134
+ #
41
135
  # @param a [Numo::NArray] n-by-n square matrix.
42
136
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
43
137
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
@@ -64,6 +158,21 @@ module Numo
64
158
 
65
159
  # Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
66
160
  #
161
+ # @example
162
+ # require 'numo/tiny_linalg'
163
+ #
164
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
165
+ #
166
+ # a = Numo::DFloat.new(5, 3).rand
167
+ #
168
+ # inv_a = Numo::Linalg.pinv(a)
169
+ #
170
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
171
+ # # => 1.1102230246251565e-15
172
+ #
173
+ # pp inv_a.dot(a).sum
174
+ # # => 3.0
175
+ #
67
176
  # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
68
177
  # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
69
178
  # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
@@ -79,6 +188,34 @@ module Numo
79
188
 
80
189
  # Compute QR decomposition of a matrix.
81
190
  #
191
+ # @example
192
+ # require 'numo/tiny_linalg'
193
+ #
194
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
195
+ #
196
+ # x = Numo::DFloat.new(5, 3).rand
197
+ #
198
+ # q, r = Numo::Linalg.qr(x, mode: 'economic')
199
+ #
200
+ # pp q
201
+ # # =>
202
+ # # Numo::DFloat#shape=[5,3]
203
+ # # [[-0.0574417, 0.635216, 0.707116],
204
+ # # [-0.187002, -0.073192, 0.422088],
205
+ # # [-0.502239, 0.634088, -0.537489],
206
+ # # [-0.0473292, 0.134867, -0.0223491],
207
+ # # [-0.840979, -0.413385, 0.180096]]
208
+ #
209
+ # pp r
210
+ # # =>
211
+ # # Numo::DFloat#shape=[3,3]
212
+ # # [[-1.07508, -0.821334, -0.484586],
213
+ # # [0, 0.513035, 0.451868],
214
+ # # [0, 0, 0.678737]]
215
+ #
216
+ # pp (q.dot(r) - x).abs.max
217
+ # # => 3.885780586188048e-16
218
+ #
82
219
  # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
83
220
  # @param mode [String] The mode of decomposition.
84
221
  # - "reduce" -- returns both Q [m, m] and R [m, n],
@@ -122,26 +259,74 @@ module Numo
122
259
 
123
260
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
124
261
  #
125
- # @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
262
+ # @example
263
+ # require 'numo/tiny_linalg'
264
+ #
265
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
266
+ #
267
+ # a = Numo::DFloat.new(3, 3).rand
268
+ # b = Numo::DFloat.eye(3)
269
+ #
270
+ # x = Numo::Linalg.solve(a, b)
271
+ #
272
+ # pp x
273
+ # # =>
274
+ # # Numo::DFloat#shape=[3,3]
275
+ # # [[-2.12332, 4.74868, 0.326773],
276
+ # # [1.38043, -3.79074, 1.25355],
277
+ # # [0.775187, 1.41032, -0.613774]]
278
+ #
279
+ # pp (b - a.dot(x)).abs.max
280
+ # # => 2.1081041547796492e-16
281
+ #
282
+ # @param a [Numo::NArray] The n-by-n square matrix.
126
283
  # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
127
284
  # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
128
285
  # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
129
286
  # @return [Numo::NArray] The solusion vector / matrix `x`.
130
287
  def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
131
- case blas_char(a, b)
132
- when 'd'
133
- Lapack.dgesv(a.dup, b.dup)[1]
134
- when 's'
135
- Lapack.sgesv(a.dup, b.dup)[1]
136
- when 'z'
137
- Lapack.zgesv(a.dup, b.dup)[1]
138
- when 'c'
139
- Lapack.cgesv(a.dup, b.dup)[1]
140
- end
288
+ raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
289
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
290
+
291
+ bchr = blas_char(a, b)
292
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
293
+
294
+ gesv = "#{bchr}gesv".to_sym
295
+ Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
141
296
  end
142
297
 
143
298
  # Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
144
299
  #
300
+ # @example
301
+ # require 'numo/tiny_linalg'
302
+ #
303
+ # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
304
+ #
305
+ # x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
306
+ # pp x
307
+ # # =>
308
+ # # Numo::DFloat#shape=[5,3]
309
+ # # [[0.104945, 0.0284236, 0.117406],
310
+ # # [0.862634, 0.210945, 0.922135],
311
+ # # [0.324507, 0.0752655, 0.339158],
312
+ # # [0.67085, 0.102594, 0.600882],
313
+ # # [0.404631, 0.116868, 0.46644]]
314
+ #
315
+ # s, u, vt = Numo::Linalg.svd(x, job: 'S')
316
+ #
317
+ # z = u.dot(s.diag).dot(vt)
318
+ # pp z
319
+ # # =>
320
+ # # Numo::DFloat#shape=[5,3]
321
+ # # [[0.104945, 0.0284236, 0.117406],
322
+ # # [0.862634, 0.210945, 0.922135],
323
+ # # [0.324507, 0.0752655, 0.339158],
324
+ # # [0.67085, 0.102594, 0.600882],
325
+ # # [0.404631, 0.116868, 0.46644]]
326
+ #
327
+ # pp (x - z).abs.max
328
+ # # => 4.440892098500626e-16
329
+ #
145
330
  # @param a [Numo::NArray] Matrix to be decomposed.
146
331
  # @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
147
332
  # @param job [String] Job option ('A', 'S', or 'N').
@@ -149,33 +334,16 @@ module Numo
149
334
  def svd(a, driver: 'svd', job: 'A')
150
335
  raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
151
336
 
337
+ bchr = blas_char(a)
338
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
339
+
152
340
  case driver.to_s
153
341
  when 'sdd'
154
- s, u, vt, info = case a
155
- when Numo::DFloat
156
- Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
157
- when Numo::SFloat
158
- Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
159
- when Numo::DComplex
160
- Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
161
- when Numo::SComplex
162
- Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
163
- else
164
- raise ArgumentError, "invalid array type: #{a.class}"
165
- end
342
+ gesdd = "#{bchr}gesdd".to_sym
343
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
166
344
  when 'svd'
167
- s, u, vt, info = case a
168
- when Numo::DFloat
169
- Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
170
- when Numo::SFloat
171
- Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
172
- when Numo::DComplex
173
- Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
174
- when Numo::SComplex
175
- Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
176
- else
177
- raise ArgumentError, "invalid array type: #{a.class}"
178
- end
345
+ gesvd = "#{bchr}gesvd".to_sym
346
+ s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
179
347
  else
180
348
  raise ArgumentError, "invalid driver: #{driver}"
181
349
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-tiny_linalg
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.3
4
+ version: 0.1.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-02 00:00:00.000000000 Z
11
+ date: 2023-08-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -52,10 +52,17 @@ files:
52
52
  - ext/numo/tiny_linalg/lapack/gesvd.hpp
53
53
  - ext/numo/tiny_linalg/lapack/getrf.hpp
54
54
  - ext/numo/tiny_linalg/lapack/getri.hpp
55
+ - ext/numo/tiny_linalg/lapack/hegv.hpp
56
+ - ext/numo/tiny_linalg/lapack/hegvd.hpp
57
+ - ext/numo/tiny_linalg/lapack/hegvx.hpp
55
58
  - ext/numo/tiny_linalg/lapack/orgqr.hpp
59
+ - ext/numo/tiny_linalg/lapack/sygv.hpp
60
+ - ext/numo/tiny_linalg/lapack/sygvd.hpp
61
+ - ext/numo/tiny_linalg/lapack/sygvx.hpp
56
62
  - ext/numo/tiny_linalg/lapack/ungqr.hpp
57
63
  - ext/numo/tiny_linalg/tiny_linalg.cpp
58
64
  - ext/numo/tiny_linalg/tiny_linalg.hpp
65
+ - ext/numo/tiny_linalg/util.hpp
59
66
  - lib/numo/tiny_linalg.rb
60
67
  - lib/numo/tiny_linalg/version.rb
61
68
  - vendor/tmp/.gitkeep