numo-tiny_linalg 0.0.3 → 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +3 -3
- data/ext/numo/tiny_linalg/blas/gemm.hpp +3 -49
- data/ext/numo/tiny_linalg/blas/gemv.hpp +2 -48
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/gesdd.hpp +11 -11
- data/ext/numo/tiny_linalg/lapack/gesv.hpp +10 -30
- data/ext/numo/tiny_linalg/lapack/gesvd.hpp +12 -12
- data/ext/numo/tiny_linalg/lapack/getrf.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/getri.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +121 -0
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +121 -0
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +137 -0
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +112 -0
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +112 -0
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +136 -0
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +5 -25
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +74 -21
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +30 -6
- data/ext/numo/tiny_linalg/util.hpp +100 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +203 -35
- metadata +9 -2
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -10,8 +10,87 @@ module Numo
|
|
10
10
|
module TinyLinalg # rubocop:disable Metrics/ModuleLength
|
11
11
|
module_function
|
12
12
|
|
13
|
+
# Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
|
14
|
+
# by solving an ordinary or generalized eigenvalue problem.
|
15
|
+
#
|
16
|
+
# @example
|
17
|
+
# require 'numo/tiny_linalg'
|
18
|
+
#
|
19
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
20
|
+
#
|
21
|
+
# x = Numo::DFloat.new(5, 3).rand - 0.5
|
22
|
+
# c = x.dot(x.transpose)
|
23
|
+
# vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
|
24
|
+
#
|
25
|
+
# pp vals
|
26
|
+
# # =>
|
27
|
+
# # Numo::DFloat#shape=[3]
|
28
|
+
# # [0.118795, 0.434252, 0.903245]
|
29
|
+
#
|
30
|
+
# pp vecs
|
31
|
+
# # =>
|
32
|
+
# # Numo::DFloat#shape=[5,3]
|
33
|
+
# # [[0.154178, 0.60661, -0.382961],
|
34
|
+
# # [-0.349761, -0.141726, -0.513178],
|
35
|
+
# # [0.739633, -0.468202, 0.105933],
|
36
|
+
# # [0.0519655, -0.471436, -0.701507],
|
37
|
+
# # [-0.551488, -0.412883, 0.294371]]
|
38
|
+
#
|
39
|
+
# pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
|
40
|
+
# # => 3.3306690738754696e-16
|
41
|
+
#
|
42
|
+
# @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
|
43
|
+
# @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
|
44
|
+
# @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
|
45
|
+
# @param vals_range [Range/Array]
|
46
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
47
|
+
# If nil, all eigenvalues and eigenvectors are computed.
|
48
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
49
|
+
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
50
|
+
# @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
|
51
|
+
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
|
52
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
53
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
54
|
+
|
55
|
+
bchr = blas_char(a)
|
56
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
57
|
+
|
58
|
+
unless b.nil?
|
59
|
+
raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
|
60
|
+
raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
|
61
|
+
raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
|
62
|
+
end
|
63
|
+
|
64
|
+
jobz = vals_only ? 'N' : 'V'
|
65
|
+
b = a.class.eye(a.shape[0]) if b.nil?
|
66
|
+
sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
67
|
+
|
68
|
+
if vals_range.nil?
|
69
|
+
sy_he_gv << 'd' if turbo
|
70
|
+
vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
|
71
|
+
else
|
72
|
+
sy_he_gv << 'x'
|
73
|
+
il = vals_range.first + 1
|
74
|
+
iu = vals_range.last + 1
|
75
|
+
_a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
|
76
|
+
sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
77
|
+
)
|
78
|
+
end
|
79
|
+
vecs = nil if vals_only
|
80
|
+
[vals, vecs]
|
81
|
+
end
|
82
|
+
|
13
83
|
# Computes the determinant of matrix.
|
14
84
|
#
|
85
|
+
# @example
|
86
|
+
# require 'numo/tiny_linalg'
|
87
|
+
#
|
88
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
89
|
+
#
|
90
|
+
# a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
|
91
|
+
# pp (3.0 - Numo::Linalg.det(a)).abs
|
92
|
+
# # => 1.3322676295501878e-15
|
93
|
+
#
|
15
94
|
# @param a [Numo::NArray] n-by-n square matrix.
|
16
95
|
# @return [Float/Complex] The determinant of `a`.
|
17
96
|
def det(a)
|
@@ -38,6 +117,21 @@ module Numo
|
|
38
117
|
|
39
118
|
# Computes the inverse matrix of a square matrix.
|
40
119
|
#
|
120
|
+
# @example
|
121
|
+
# require 'numo/tiny_linalg'
|
122
|
+
#
|
123
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
124
|
+
#
|
125
|
+
# a = Numo::DFloat.new(5, 5).rand
|
126
|
+
#
|
127
|
+
# inv_a = Numo::Linalg.inv(a)
|
128
|
+
#
|
129
|
+
# pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
|
130
|
+
# # => 7.019165976816745e-16
|
131
|
+
#
|
132
|
+
# pp inv_a.dot(a).sum
|
133
|
+
# # => 5.0
|
134
|
+
#
|
41
135
|
# @param a [Numo::NArray] n-by-n square matrix.
|
42
136
|
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
43
137
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
@@ -64,6 +158,21 @@ module Numo
|
|
64
158
|
|
65
159
|
# Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
|
66
160
|
#
|
161
|
+
# @example
|
162
|
+
# require 'numo/tiny_linalg'
|
163
|
+
#
|
164
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
165
|
+
#
|
166
|
+
# a = Numo::DFloat.new(5, 3).rand
|
167
|
+
#
|
168
|
+
# inv_a = Numo::Linalg.pinv(a)
|
169
|
+
#
|
170
|
+
# pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
|
171
|
+
# # => 1.1102230246251565e-15
|
172
|
+
#
|
173
|
+
# pp inv_a.dot(a).sum
|
174
|
+
# # => 3.0
|
175
|
+
#
|
67
176
|
# @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
|
68
177
|
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
69
178
|
# @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
|
@@ -79,6 +188,34 @@ module Numo
|
|
79
188
|
|
80
189
|
# Compute QR decomposition of a matrix.
|
81
190
|
#
|
191
|
+
# @example
|
192
|
+
# require 'numo/tiny_linalg'
|
193
|
+
#
|
194
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
195
|
+
#
|
196
|
+
# x = Numo::DFloat.new(5, 3).rand
|
197
|
+
#
|
198
|
+
# q, r = Numo::Linalg.qr(x, mode: 'economic')
|
199
|
+
#
|
200
|
+
# pp q
|
201
|
+
# # =>
|
202
|
+
# # Numo::DFloat#shape=[5,3]
|
203
|
+
# # [[-0.0574417, 0.635216, 0.707116],
|
204
|
+
# # [-0.187002, -0.073192, 0.422088],
|
205
|
+
# # [-0.502239, 0.634088, -0.537489],
|
206
|
+
# # [-0.0473292, 0.134867, -0.0223491],
|
207
|
+
# # [-0.840979, -0.413385, 0.180096]]
|
208
|
+
#
|
209
|
+
# pp r
|
210
|
+
# # =>
|
211
|
+
# # Numo::DFloat#shape=[3,3]
|
212
|
+
# # [[-1.07508, -0.821334, -0.484586],
|
213
|
+
# # [0, 0.513035, 0.451868],
|
214
|
+
# # [0, 0, 0.678737]]
|
215
|
+
#
|
216
|
+
# pp (q.dot(r) - x).abs.max
|
217
|
+
# # => 3.885780586188048e-16
|
218
|
+
#
|
82
219
|
# @param a [Numo::NArray] The m-by-n matrix to be decomposed.
|
83
220
|
# @param mode [String] The mode of decomposition.
|
84
221
|
# - "reduce" -- returns both Q [m, m] and R [m, n],
|
@@ -122,26 +259,74 @@ module Numo
|
|
122
259
|
|
123
260
|
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
|
124
261
|
#
|
125
|
-
# @
|
262
|
+
# @example
|
263
|
+
# require 'numo/tiny_linalg'
|
264
|
+
#
|
265
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
266
|
+
#
|
267
|
+
# a = Numo::DFloat.new(3, 3).rand
|
268
|
+
# b = Numo::DFloat.eye(3)
|
269
|
+
#
|
270
|
+
# x = Numo::Linalg.solve(a, b)
|
271
|
+
#
|
272
|
+
# pp x
|
273
|
+
# # =>
|
274
|
+
# # Numo::DFloat#shape=[3,3]
|
275
|
+
# # [[-2.12332, 4.74868, 0.326773],
|
276
|
+
# # [1.38043, -3.79074, 1.25355],
|
277
|
+
# # [0.775187, 1.41032, -0.613774]]
|
278
|
+
#
|
279
|
+
# pp (b - a.dot(x)).abs.max
|
280
|
+
# # => 2.1081041547796492e-16
|
281
|
+
#
|
282
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
126
283
|
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
|
127
284
|
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
128
285
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
129
286
|
# @return [Numo::NArray] The solusion vector / matrix `x`.
|
130
287
|
def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
Lapack.cgesv(a.dup, b.dup)[1]
|
140
|
-
end
|
288
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
289
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
290
|
+
|
291
|
+
bchr = blas_char(a, b)
|
292
|
+
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
|
293
|
+
|
294
|
+
gesv = "#{bchr}gesv".to_sym
|
295
|
+
Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
|
141
296
|
end
|
142
297
|
|
143
298
|
# Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
|
144
299
|
#
|
300
|
+
# @example
|
301
|
+
# require 'numo/tiny_linalg'
|
302
|
+
#
|
303
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
304
|
+
#
|
305
|
+
# x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
|
306
|
+
# pp x
|
307
|
+
# # =>
|
308
|
+
# # Numo::DFloat#shape=[5,3]
|
309
|
+
# # [[0.104945, 0.0284236, 0.117406],
|
310
|
+
# # [0.862634, 0.210945, 0.922135],
|
311
|
+
# # [0.324507, 0.0752655, 0.339158],
|
312
|
+
# # [0.67085, 0.102594, 0.600882],
|
313
|
+
# # [0.404631, 0.116868, 0.46644]]
|
314
|
+
#
|
315
|
+
# s, u, vt = Numo::Linalg.svd(x, job: 'S')
|
316
|
+
#
|
317
|
+
# z = u.dot(s.diag).dot(vt)
|
318
|
+
# pp z
|
319
|
+
# # =>
|
320
|
+
# # Numo::DFloat#shape=[5,3]
|
321
|
+
# # [[0.104945, 0.0284236, 0.117406],
|
322
|
+
# # [0.862634, 0.210945, 0.922135],
|
323
|
+
# # [0.324507, 0.0752655, 0.339158],
|
324
|
+
# # [0.67085, 0.102594, 0.600882],
|
325
|
+
# # [0.404631, 0.116868, 0.46644]]
|
326
|
+
#
|
327
|
+
# pp (x - z).abs.max
|
328
|
+
# # => 4.440892098500626e-16
|
329
|
+
#
|
145
330
|
# @param a [Numo::NArray] Matrix to be decomposed.
|
146
331
|
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
147
332
|
# @param job [String] Job option ('A', 'S', or 'N').
|
@@ -149,33 +334,16 @@ module Numo
|
|
149
334
|
def svd(a, driver: 'svd', job: 'A')
|
150
335
|
raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
|
151
336
|
|
337
|
+
bchr = blas_char(a)
|
338
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
339
|
+
|
152
340
|
case driver.to_s
|
153
341
|
when 'sdd'
|
154
|
-
|
155
|
-
|
156
|
-
Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
|
157
|
-
when Numo::SFloat
|
158
|
-
Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
|
159
|
-
when Numo::DComplex
|
160
|
-
Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
|
161
|
-
when Numo::SComplex
|
162
|
-
Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
|
163
|
-
else
|
164
|
-
raise ArgumentError, "invalid array type: #{a.class}"
|
165
|
-
end
|
342
|
+
gesdd = "#{bchr}gesdd".to_sym
|
343
|
+
s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
|
166
344
|
when 'svd'
|
167
|
-
|
168
|
-
|
169
|
-
Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
|
170
|
-
when Numo::SFloat
|
171
|
-
Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
|
172
|
-
when Numo::DComplex
|
173
|
-
Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
|
174
|
-
when Numo::SComplex
|
175
|
-
Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
|
176
|
-
else
|
177
|
-
raise ArgumentError, "invalid array type: #{a.class}"
|
178
|
-
end
|
345
|
+
gesvd = "#{bchr}gesvd".to_sym
|
346
|
+
s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
|
179
347
|
else
|
180
348
|
raise ArgumentError, "invalid driver: #{driver}"
|
181
349
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0
|
4
|
+
version: 0.1.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-08-
|
11
|
+
date: 2023-08-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -52,10 +52,17 @@ files:
|
|
52
52
|
- ext/numo/tiny_linalg/lapack/gesvd.hpp
|
53
53
|
- ext/numo/tiny_linalg/lapack/getrf.hpp
|
54
54
|
- ext/numo/tiny_linalg/lapack/getri.hpp
|
55
|
+
- ext/numo/tiny_linalg/lapack/hegv.hpp
|
56
|
+
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
57
|
+
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
55
58
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
59
|
+
- ext/numo/tiny_linalg/lapack/sygv.hpp
|
60
|
+
- ext/numo/tiny_linalg/lapack/sygvd.hpp
|
61
|
+
- ext/numo/tiny_linalg/lapack/sygvx.hpp
|
56
62
|
- ext/numo/tiny_linalg/lapack/ungqr.hpp
|
57
63
|
- ext/numo/tiny_linalg/tiny_linalg.cpp
|
58
64
|
- ext/numo/tiny_linalg/tiny_linalg.hpp
|
65
|
+
- ext/numo/tiny_linalg/util.hpp
|
59
66
|
- lib/numo/tiny_linalg.rb
|
60
67
|
- lib/numo/tiny_linalg/version.rb
|
61
68
|
- vendor/tmp/.gitkeep
|