numo-tiny_linalg 0.0.3 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +3 -3
- data/ext/numo/tiny_linalg/blas/gemm.hpp +3 -49
- data/ext/numo/tiny_linalg/blas/gemv.hpp +2 -48
- data/ext/numo/tiny_linalg/lapack/geqrf.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/gesdd.hpp +11 -11
- data/ext/numo/tiny_linalg/lapack/gesv.hpp +10 -30
- data/ext/numo/tiny_linalg/lapack/gesvd.hpp +12 -12
- data/ext/numo/tiny_linalg/lapack/getrf.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/getri.hpp +9 -29
- data/ext/numo/tiny_linalg/lapack/hegv.hpp +121 -0
- data/ext/numo/tiny_linalg/lapack/hegvd.hpp +121 -0
- data/ext/numo/tiny_linalg/lapack/hegvx.hpp +137 -0
- data/ext/numo/tiny_linalg/lapack/orgqr.hpp +5 -25
- data/ext/numo/tiny_linalg/lapack/sygv.hpp +112 -0
- data/ext/numo/tiny_linalg/lapack/sygvd.hpp +112 -0
- data/ext/numo/tiny_linalg/lapack/sygvx.hpp +136 -0
- data/ext/numo/tiny_linalg/lapack/ungqr.hpp +5 -25
- data/ext/numo/tiny_linalg/tiny_linalg.cpp +74 -21
- data/ext/numo/tiny_linalg/tiny_linalg.hpp +30 -6
- data/ext/numo/tiny_linalg/util.hpp +100 -0
- data/lib/numo/tiny_linalg/version.rb +1 -1
- data/lib/numo/tiny_linalg.rb +203 -35
- metadata +9 -2
data/lib/numo/tiny_linalg.rb
CHANGED
@@ -10,8 +10,87 @@ module Numo
|
|
10
10
|
module TinyLinalg # rubocop:disable Metrics/ModuleLength
|
11
11
|
module_function
|
12
12
|
|
13
|
+
# Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
|
14
|
+
# by solving an ordinary or generalized eigenvalue problem.
|
15
|
+
#
|
16
|
+
# @example
|
17
|
+
# require 'numo/tiny_linalg'
|
18
|
+
#
|
19
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
20
|
+
#
|
21
|
+
# x = Numo::DFloat.new(5, 3).rand - 0.5
|
22
|
+
# c = x.dot(x.transpose)
|
23
|
+
# vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
|
24
|
+
#
|
25
|
+
# pp vals
|
26
|
+
# # =>
|
27
|
+
# # Numo::DFloat#shape=[3]
|
28
|
+
# # [0.118795, 0.434252, 0.903245]
|
29
|
+
#
|
30
|
+
# pp vecs
|
31
|
+
# # =>
|
32
|
+
# # Numo::DFloat#shape=[5,3]
|
33
|
+
# # [[0.154178, 0.60661, -0.382961],
|
34
|
+
# # [-0.349761, -0.141726, -0.513178],
|
35
|
+
# # [0.739633, -0.468202, 0.105933],
|
36
|
+
# # [0.0519655, -0.471436, -0.701507],
|
37
|
+
# # [-0.551488, -0.412883, 0.294371]]
|
38
|
+
#
|
39
|
+
# pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
|
40
|
+
# # => 3.3306690738754696e-16
|
41
|
+
#
|
42
|
+
# @param a [Numo::NArray] n-by-n symmetric / Hermitian matrix.
|
43
|
+
# @param b [Numo::NArray] n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
|
44
|
+
# @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
|
45
|
+
# @param vals_range [Range/Array]
|
46
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
47
|
+
# If nil, all eigenvalues and eigenvectors are computed.
|
48
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
49
|
+
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
50
|
+
# @return [Array<Numo::NArray, Numo::NArray>] The eigenvalues and eigenvectors.
|
51
|
+
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/ParameterLists, Lint/UnusedMethodArgument
|
52
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
53
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
54
|
+
|
55
|
+
bchr = blas_char(a)
|
56
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
57
|
+
|
58
|
+
unless b.nil?
|
59
|
+
raise ArgumentError, 'input array b must be 2-dimensional' if b.ndim != 2
|
60
|
+
raise ArgumentError, 'input array b must be square' if b.shape[0] != b.shape[1]
|
61
|
+
raise ArgumentError, "invalid array type: #{b.class}" if blas_char(b) == 'n'
|
62
|
+
end
|
63
|
+
|
64
|
+
jobz = vals_only ? 'N' : 'V'
|
65
|
+
b = a.class.eye(a.shape[0]) if b.nil?
|
66
|
+
sy_he_gv = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
67
|
+
|
68
|
+
if vals_range.nil?
|
69
|
+
sy_he_gv << 'd' if turbo
|
70
|
+
vecs, _b, vals, _info = Numo::TinyLinalg::Lapack.send(sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz)
|
71
|
+
else
|
72
|
+
sy_he_gv << 'x'
|
73
|
+
il = vals_range.first + 1
|
74
|
+
iu = vals_range.last + 1
|
75
|
+
_a, _b, _m, vals, vecs, _ifail, _info = Numo::TinyLinalg::Lapack.send(
|
76
|
+
sy_he_gv.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
77
|
+
)
|
78
|
+
end
|
79
|
+
vecs = nil if vals_only
|
80
|
+
[vals, vecs]
|
81
|
+
end
|
82
|
+
|
13
83
|
# Computes the determinant of matrix.
|
14
84
|
#
|
85
|
+
# @example
|
86
|
+
# require 'numo/tiny_linalg'
|
87
|
+
#
|
88
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
89
|
+
#
|
90
|
+
# a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
|
91
|
+
# pp (3.0 - Numo::Linalg.det(a)).abs
|
92
|
+
# # => 1.3322676295501878e-15
|
93
|
+
#
|
15
94
|
# @param a [Numo::NArray] n-by-n square matrix.
|
16
95
|
# @return [Float/Complex] The determinant of `a`.
|
17
96
|
def det(a)
|
@@ -38,6 +117,21 @@ module Numo
|
|
38
117
|
|
39
118
|
# Computes the inverse matrix of a square matrix.
|
40
119
|
#
|
120
|
+
# @example
|
121
|
+
# require 'numo/tiny_linalg'
|
122
|
+
#
|
123
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
124
|
+
#
|
125
|
+
# a = Numo::DFloat.new(5, 5).rand
|
126
|
+
#
|
127
|
+
# inv_a = Numo::Linalg.inv(a)
|
128
|
+
#
|
129
|
+
# pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
|
130
|
+
# # => 7.019165976816745e-16
|
131
|
+
#
|
132
|
+
# pp inv_a.dot(a).sum
|
133
|
+
# # => 5.0
|
134
|
+
#
|
41
135
|
# @param a [Numo::NArray] n-by-n square matrix.
|
42
136
|
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
43
137
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
@@ -64,6 +158,21 @@ module Numo
|
|
64
158
|
|
65
159
|
# Compute the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
|
66
160
|
#
|
161
|
+
# @example
|
162
|
+
# require 'numo/tiny_linalg'
|
163
|
+
#
|
164
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
165
|
+
#
|
166
|
+
# a = Numo::DFloat.new(5, 3).rand
|
167
|
+
#
|
168
|
+
# inv_a = Numo::Linalg.pinv(a)
|
169
|
+
#
|
170
|
+
# pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
|
171
|
+
# # => 1.1102230246251565e-15
|
172
|
+
#
|
173
|
+
# pp inv_a.dot(a).sum
|
174
|
+
# # => 3.0
|
175
|
+
#
|
67
176
|
# @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
|
68
177
|
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
69
178
|
# @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
|
@@ -79,6 +188,34 @@ module Numo
|
|
79
188
|
|
80
189
|
# Compute QR decomposition of a matrix.
|
81
190
|
#
|
191
|
+
# @example
|
192
|
+
# require 'numo/tiny_linalg'
|
193
|
+
#
|
194
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
195
|
+
#
|
196
|
+
# x = Numo::DFloat.new(5, 3).rand
|
197
|
+
#
|
198
|
+
# q, r = Numo::Linalg.qr(x, mode: 'economic')
|
199
|
+
#
|
200
|
+
# pp q
|
201
|
+
# # =>
|
202
|
+
# # Numo::DFloat#shape=[5,3]
|
203
|
+
# # [[-0.0574417, 0.635216, 0.707116],
|
204
|
+
# # [-0.187002, -0.073192, 0.422088],
|
205
|
+
# # [-0.502239, 0.634088, -0.537489],
|
206
|
+
# # [-0.0473292, 0.134867, -0.0223491],
|
207
|
+
# # [-0.840979, -0.413385, 0.180096]]
|
208
|
+
#
|
209
|
+
# pp r
|
210
|
+
# # =>
|
211
|
+
# # Numo::DFloat#shape=[3,3]
|
212
|
+
# # [[-1.07508, -0.821334, -0.484586],
|
213
|
+
# # [0, 0.513035, 0.451868],
|
214
|
+
# # [0, 0, 0.678737]]
|
215
|
+
#
|
216
|
+
# pp (q.dot(r) - x).abs.max
|
217
|
+
# # => 3.885780586188048e-16
|
218
|
+
#
|
82
219
|
# @param a [Numo::NArray] The m-by-n matrix to be decomposed.
|
83
220
|
# @param mode [String] The mode of decomposition.
|
84
221
|
# - "reduce" -- returns both Q [m, m] and R [m, n],
|
@@ -122,26 +259,74 @@ module Numo
|
|
122
259
|
|
123
260
|
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `a`.
|
124
261
|
#
|
125
|
-
# @
|
262
|
+
# @example
|
263
|
+
# require 'numo/tiny_linalg'
|
264
|
+
#
|
265
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
266
|
+
#
|
267
|
+
# a = Numo::DFloat.new(3, 3).rand
|
268
|
+
# b = Numo::DFloat.eye(3)
|
269
|
+
#
|
270
|
+
# x = Numo::Linalg.solve(a, b)
|
271
|
+
#
|
272
|
+
# pp x
|
273
|
+
# # =>
|
274
|
+
# # Numo::DFloat#shape=[3,3]
|
275
|
+
# # [[-2.12332, 4.74868, 0.326773],
|
276
|
+
# # [1.38043, -3.79074, 1.25355],
|
277
|
+
# # [0.775187, 1.41032, -0.613774]]
|
278
|
+
#
|
279
|
+
# pp (b - a.dot(x)).abs.max
|
280
|
+
# # => 2.1081041547796492e-16
|
281
|
+
#
|
282
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
126
283
|
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
|
127
284
|
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
128
285
|
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
129
286
|
# @return [Numo::NArray] The solusion vector / matrix `x`.
|
130
287
|
def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
Lapack.cgesv(a.dup, b.dup)[1]
|
140
|
-
end
|
288
|
+
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
|
289
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
290
|
+
|
291
|
+
bchr = blas_char(a, b)
|
292
|
+
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
|
293
|
+
|
294
|
+
gesv = "#{bchr}gesv".to_sym
|
295
|
+
Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
|
141
296
|
end
|
142
297
|
|
143
298
|
# Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
|
144
299
|
#
|
300
|
+
# @example
|
301
|
+
# require 'numo/tiny_linalg'
|
302
|
+
#
|
303
|
+
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
|
304
|
+
#
|
305
|
+
# x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
|
306
|
+
# pp x
|
307
|
+
# # =>
|
308
|
+
# # Numo::DFloat#shape=[5,3]
|
309
|
+
# # [[0.104945, 0.0284236, 0.117406],
|
310
|
+
# # [0.862634, 0.210945, 0.922135],
|
311
|
+
# # [0.324507, 0.0752655, 0.339158],
|
312
|
+
# # [0.67085, 0.102594, 0.600882],
|
313
|
+
# # [0.404631, 0.116868, 0.46644]]
|
314
|
+
#
|
315
|
+
# s, u, vt = Numo::Linalg.svd(x, job: 'S')
|
316
|
+
#
|
317
|
+
# z = u.dot(s.diag).dot(vt)
|
318
|
+
# pp z
|
319
|
+
# # =>
|
320
|
+
# # Numo::DFloat#shape=[5,3]
|
321
|
+
# # [[0.104945, 0.0284236, 0.117406],
|
322
|
+
# # [0.862634, 0.210945, 0.922135],
|
323
|
+
# # [0.324507, 0.0752655, 0.339158],
|
324
|
+
# # [0.67085, 0.102594, 0.600882],
|
325
|
+
# # [0.404631, 0.116868, 0.46644]]
|
326
|
+
#
|
327
|
+
# pp (x - z).abs.max
|
328
|
+
# # => 4.440892098500626e-16
|
329
|
+
#
|
145
330
|
# @param a [Numo::NArray] Matrix to be decomposed.
|
146
331
|
# @param driver [String] LAPACK driver to be used ('svd' or 'sdd').
|
147
332
|
# @param job [String] Job option ('A', 'S', or 'N').
|
@@ -149,33 +334,16 @@ module Numo
|
|
149
334
|
def svd(a, driver: 'svd', job: 'A')
|
150
335
|
raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
|
151
336
|
|
337
|
+
bchr = blas_char(a)
|
338
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
339
|
+
|
152
340
|
case driver.to_s
|
153
341
|
when 'sdd'
|
154
|
-
|
155
|
-
|
156
|
-
Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
|
157
|
-
when Numo::SFloat
|
158
|
-
Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
|
159
|
-
when Numo::DComplex
|
160
|
-
Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
|
161
|
-
when Numo::SComplex
|
162
|
-
Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
|
163
|
-
else
|
164
|
-
raise ArgumentError, "invalid array type: #{a.class}"
|
165
|
-
end
|
342
|
+
gesdd = "#{bchr}gesdd".to_sym
|
343
|
+
s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
|
166
344
|
when 'svd'
|
167
|
-
|
168
|
-
|
169
|
-
Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
|
170
|
-
when Numo::SFloat
|
171
|
-
Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
|
172
|
-
when Numo::DComplex
|
173
|
-
Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
|
174
|
-
when Numo::SComplex
|
175
|
-
Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
|
176
|
-
else
|
177
|
-
raise ArgumentError, "invalid array type: #{a.class}"
|
178
|
-
end
|
345
|
+
gesvd = "#{bchr}gesvd".to_sym
|
346
|
+
s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
|
179
347
|
else
|
180
348
|
raise ArgumentError, "invalid driver: #{driver}"
|
181
349
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-tiny_linalg
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0
|
4
|
+
version: 0.1.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-08-
|
11
|
+
date: 2023-08-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -52,10 +52,17 @@ files:
|
|
52
52
|
- ext/numo/tiny_linalg/lapack/gesvd.hpp
|
53
53
|
- ext/numo/tiny_linalg/lapack/getrf.hpp
|
54
54
|
- ext/numo/tiny_linalg/lapack/getri.hpp
|
55
|
+
- ext/numo/tiny_linalg/lapack/hegv.hpp
|
56
|
+
- ext/numo/tiny_linalg/lapack/hegvd.hpp
|
57
|
+
- ext/numo/tiny_linalg/lapack/hegvx.hpp
|
55
58
|
- ext/numo/tiny_linalg/lapack/orgqr.hpp
|
59
|
+
- ext/numo/tiny_linalg/lapack/sygv.hpp
|
60
|
+
- ext/numo/tiny_linalg/lapack/sygvd.hpp
|
61
|
+
- ext/numo/tiny_linalg/lapack/sygvx.hpp
|
56
62
|
- ext/numo/tiny_linalg/lapack/ungqr.hpp
|
57
63
|
- ext/numo/tiny_linalg/tiny_linalg.cpp
|
58
64
|
- ext/numo/tiny_linalg/tiny_linalg.hpp
|
65
|
+
- ext/numo/tiny_linalg/util.hpp
|
59
66
|
- lib/numo/tiny_linalg.rb
|
60
67
|
- lib/numo/tiny_linalg/version.rb
|
61
68
|
- vendor/tmp/.gitkeep
|