numo-linalg-alt 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +5 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/LICENSE.txt +27 -0
- data/README.md +106 -0
- data/ext/numo/linalg/blas/dot.c +72 -0
- data/ext/numo/linalg/blas/dot.h +13 -0
- data/ext/numo/linalg/blas/dot_sub.c +71 -0
- data/ext/numo/linalg/blas/dot_sub.h +13 -0
- data/ext/numo/linalg/blas/gemm.c +184 -0
- data/ext/numo/linalg/blas/gemm.h +16 -0
- data/ext/numo/linalg/blas/gemv.c +161 -0
- data/ext/numo/linalg/blas/gemv.h +16 -0
- data/ext/numo/linalg/blas/nrm2.c +67 -0
- data/ext/numo/linalg/blas/nrm2.h +13 -0
- data/ext/numo/linalg/converter.c +67 -0
- data/ext/numo/linalg/converter.h +23 -0
- data/ext/numo/linalg/extconf.rb +99 -0
- data/ext/numo/linalg/lapack/geev.c +152 -0
- data/ext/numo/linalg/lapack/geev.h +15 -0
- data/ext/numo/linalg/lapack/gelsd.c +92 -0
- data/ext/numo/linalg/lapack/gelsd.h +15 -0
- data/ext/numo/linalg/lapack/geqrf.c +72 -0
- data/ext/numo/linalg/lapack/geqrf.h +15 -0
- data/ext/numo/linalg/lapack/gesdd.c +108 -0
- data/ext/numo/linalg/lapack/gesdd.h +15 -0
- data/ext/numo/linalg/lapack/gesv.c +99 -0
- data/ext/numo/linalg/lapack/gesv.h +15 -0
- data/ext/numo/linalg/lapack/gesvd.c +152 -0
- data/ext/numo/linalg/lapack/gesvd.h +15 -0
- data/ext/numo/linalg/lapack/getrf.c +71 -0
- data/ext/numo/linalg/lapack/getrf.h +15 -0
- data/ext/numo/linalg/lapack/getri.c +82 -0
- data/ext/numo/linalg/lapack/getri.h +15 -0
- data/ext/numo/linalg/lapack/getrs.c +110 -0
- data/ext/numo/linalg/lapack/getrs.h +15 -0
- data/ext/numo/linalg/lapack/heev.c +71 -0
- data/ext/numo/linalg/lapack/heev.h +15 -0
- data/ext/numo/linalg/lapack/heevd.c +71 -0
- data/ext/numo/linalg/lapack/heevd.h +15 -0
- data/ext/numo/linalg/lapack/heevr.c +111 -0
- data/ext/numo/linalg/lapack/heevr.h +15 -0
- data/ext/numo/linalg/lapack/hegv.c +94 -0
- data/ext/numo/linalg/lapack/hegv.h +15 -0
- data/ext/numo/linalg/lapack/hegvd.c +94 -0
- data/ext/numo/linalg/lapack/hegvd.h +15 -0
- data/ext/numo/linalg/lapack/hegvx.c +133 -0
- data/ext/numo/linalg/lapack/hegvx.h +15 -0
- data/ext/numo/linalg/lapack/hetrf.c +68 -0
- data/ext/numo/linalg/lapack/hetrf.h +15 -0
- data/ext/numo/linalg/lapack/lange.c +66 -0
- data/ext/numo/linalg/lapack/lange.h +15 -0
- data/ext/numo/linalg/lapack/orgqr.c +79 -0
- data/ext/numo/linalg/lapack/orgqr.h +15 -0
- data/ext/numo/linalg/lapack/potrf.c +70 -0
- data/ext/numo/linalg/lapack/potrf.h +15 -0
- data/ext/numo/linalg/lapack/potri.c +70 -0
- data/ext/numo/linalg/lapack/potri.h +15 -0
- data/ext/numo/linalg/lapack/potrs.c +94 -0
- data/ext/numo/linalg/lapack/potrs.h +15 -0
- data/ext/numo/linalg/lapack/syev.c +71 -0
- data/ext/numo/linalg/lapack/syev.h +15 -0
- data/ext/numo/linalg/lapack/syevd.c +71 -0
- data/ext/numo/linalg/lapack/syevd.h +15 -0
- data/ext/numo/linalg/lapack/syevr.c +111 -0
- data/ext/numo/linalg/lapack/syevr.h +15 -0
- data/ext/numo/linalg/lapack/sygv.c +93 -0
- data/ext/numo/linalg/lapack/sygv.h +15 -0
- data/ext/numo/linalg/lapack/sygvd.c +93 -0
- data/ext/numo/linalg/lapack/sygvd.h +15 -0
- data/ext/numo/linalg/lapack/sygvx.c +133 -0
- data/ext/numo/linalg/lapack/sygvx.h +15 -0
- data/ext/numo/linalg/lapack/sytrf.c +72 -0
- data/ext/numo/linalg/lapack/sytrf.h +15 -0
- data/ext/numo/linalg/lapack/trtrs.c +99 -0
- data/ext/numo/linalg/lapack/trtrs.h +15 -0
- data/ext/numo/linalg/lapack/ungqr.c +79 -0
- data/ext/numo/linalg/lapack/ungqr.h +15 -0
- data/ext/numo/linalg/linalg.c +290 -0
- data/ext/numo/linalg/linalg.h +85 -0
- data/ext/numo/linalg/util.c +95 -0
- data/ext/numo/linalg/util.h +17 -0
- data/lib/numo/linalg/version.rb +10 -0
- data/lib/numo/linalg.rb +1309 -0
- data/vendor/tmp/.gitkeep +0 -0
- metadata +146 -0
data/lib/numo/linalg.rb
ADDED
@@ -0,0 +1,1309 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'numo/narray'
|
4
|
+
require_relative 'linalg/version'
|
5
|
+
require_relative 'linalg/linalg'
|
6
|
+
|
7
|
+
# Ruby/Numo (NUmerical MOdules)
|
8
|
+
module Numo
|
9
|
+
# Numo::Linalg Alternative (numo-linalg-alt) is an alternative to Numo::Linalg.
|
10
|
+
module Linalg # rubocop:disable Metrics/ModuleLength
|
11
|
+
module_function
|
12
|
+
|
13
|
+
# Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
|
14
|
+
# by solving an ordinary or generalized eigenvalue problem.
|
15
|
+
#
|
16
|
+
# @example
|
17
|
+
# require 'numo/linalg'
|
18
|
+
#
|
19
|
+
# x = Numo::DFloat.new(5, 3).rand - 0.5
|
20
|
+
# c = x.dot(x.transpose)
|
21
|
+
# vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
|
22
|
+
#
|
23
|
+
# pp vals
|
24
|
+
# # =>
|
25
|
+
# # Numo::DFloat#shape=[3]
|
26
|
+
# # [0.118795, 0.434252, 0.903245]
|
27
|
+
#
|
28
|
+
# pp vecs
|
29
|
+
# # =>
|
30
|
+
# # Numo::DFloat#shape=[5,3]
|
31
|
+
# # [[0.154178, 0.60661, -0.382961],
|
32
|
+
# # [-0.349761, -0.141726, -0.513178],
|
33
|
+
# # [0.739633, -0.468202, 0.105933],
|
34
|
+
# # [0.0519655, -0.471436, -0.701507],
|
35
|
+
# # [-0.551488, -0.412883, 0.294371]]
|
36
|
+
#
|
37
|
+
# pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
|
38
|
+
# # => 3.3306690738754696e-16
|
39
|
+
#
|
40
|
+
# @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
|
41
|
+
# @param b [Numo::NArray] The n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
|
42
|
+
# @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
|
43
|
+
# @param vals_range [Range/Array]
|
44
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
45
|
+
# If nil, all eigenvalues and eigenvectors are computed.
|
46
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
47
|
+
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
48
|
+
# @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
|
49
|
+
def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/ParameterLists, Metrics/PerceivedComplexity, Lint/UnusedMethodArgument
|
50
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
51
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
52
|
+
|
53
|
+
b_given = !b.nil?
|
54
|
+
raise Numo::NArray::ShapeError, 'input array b must be 2-dimensional' if b_given && b.ndim != 2
|
55
|
+
raise Numo::NArray::ShapeError, 'input array b must be square' if b_given && b.shape[0] != b.shape[1]
|
56
|
+
raise ArgumentError, "invalid array type: #{b.class}" if b_given && blas_char(b) == 'n'
|
57
|
+
|
58
|
+
bchr = blas_char(a)
|
59
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
60
|
+
|
61
|
+
jobz = vals_only ? 'N' : 'V'
|
62
|
+
|
63
|
+
if b_given
|
64
|
+
fnc = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
|
65
|
+
if vals_range.nil?
|
66
|
+
fnc << 'd' if turbo
|
67
|
+
vecs, _b, vals, _info = Numo::Linalg::Lapack.send(fnc.to_sym, a.dup, b.dup, jobz: jobz)
|
68
|
+
else
|
69
|
+
fnc << 'x'
|
70
|
+
il = vals_range.first(1)[0] + 1
|
71
|
+
iu = vals_range.last(1)[0] + 1
|
72
|
+
_a, _b, _m, vals, vecs, _ifail, _info = Numo::Linalg::Lapack.send(
|
73
|
+
fnc.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
74
|
+
)
|
75
|
+
end
|
76
|
+
else
|
77
|
+
fnc = %w[d s].include?(bchr) ? "#{bchr}syev" : "#{bchr}heev"
|
78
|
+
if vals_range.nil?
|
79
|
+
fnc << 'd' if turbo
|
80
|
+
vecs, vals, _info = Numo::Linalg::Lapack.send(fnc.to_sym, a.dup, jobz: jobz)
|
81
|
+
else
|
82
|
+
fnc << 'r'
|
83
|
+
il = vals_range.first(1)[0] + 1
|
84
|
+
iu = vals_range.last(1)[0] + 1
|
85
|
+
_a, _m, vals, vecs, _isuppz, _info = Numo::Linalg::Lapack.send(
|
86
|
+
fnc.to_sym, a.dup, jobz: jobz, range: 'I', il: il, iu: iu
|
87
|
+
)
|
88
|
+
end
|
89
|
+
end
|
90
|
+
|
91
|
+
vecs = nil if vals_only
|
92
|
+
|
93
|
+
[vals, vecs]
|
94
|
+
end
|
95
|
+
|
96
|
+
# Computes the matrix or vector norm.
|
97
|
+
#
|
98
|
+
# | ord | matrix norm | vector norm |
|
99
|
+
# | ----- | ---------------------- | --------------------------- |
|
100
|
+
# | nil | Frobenius norm | 2-norm |
|
101
|
+
# | 'fro' | Frobenius norm | - |
|
102
|
+
# | 'nuc' | nuclear norm | - |
|
103
|
+
# | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
|
104
|
+
# | 0 | - | (x.ne 0).sum |
|
105
|
+
# | 1 | x.abs.sum(axis:-2).max | same as below |
|
106
|
+
# | 2 | 2-norm (max sing_vals) | same as below |
|
107
|
+
# | other | - | (x.abs**ord).sum**(1.0/ord) |
|
108
|
+
#
|
109
|
+
# @example
|
110
|
+
# require 'numo/linalg'
|
111
|
+
#
|
112
|
+
# # matrix norm
|
113
|
+
# x = Numo::DFloat[[1, 2, -3, 1], [-4, 1, 8, 2]]
|
114
|
+
# pp Numo::Linalg.norm(x)
|
115
|
+
# # => 10
|
116
|
+
#
|
117
|
+
# # vector norm
|
118
|
+
# x = Numo::DFloat[3, -4]
|
119
|
+
# pp Numo::Linalg.norm(x)
|
120
|
+
# # => 5
|
121
|
+
#
|
122
|
+
# @param a [Numo::NArray] The matrix or vector (>= 1-dimensinal NArray)
|
123
|
+
# @param ord [String/Numeric] The order of the norm.
|
124
|
+
# @param axis [Integer/Array] The applied axes.
|
125
|
+
# @param keepdims [Bool] The flag indicating whether to leave the normed axes in the result as dimensions with size one.
|
126
|
+
# @return [Numo::NArray/Numeric] The norm of the matrix or vectors.
|
127
|
+
def norm(a, ord = nil, axis: nil, keepdims: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
128
|
+
a = Numo::NArray.asarray(a) unless a.is_a?(Numo::NArray)
|
129
|
+
|
130
|
+
return 0.0 if a.empty?
|
131
|
+
|
132
|
+
# for compatibility with Numo::Linalg.norm
|
133
|
+
if ord.is_a?(String)
|
134
|
+
if ord == 'inf'
|
135
|
+
ord = Float::INFINITY
|
136
|
+
elsif ord == '-inf'
|
137
|
+
ord = -Float::INFINITY
|
138
|
+
end
|
139
|
+
end
|
140
|
+
|
141
|
+
if axis.nil?
|
142
|
+
norm = case a.ndim
|
143
|
+
when 1
|
144
|
+
Numo::Linalg::Blas.send(:"#{blas_char(a)}nrm2", a) if ord.nil? || ord == 2
|
145
|
+
when 2
|
146
|
+
if ord.nil? || ord == 'fro'
|
147
|
+
Numo::Linalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'F')
|
148
|
+
elsif ord.is_a?(Numeric)
|
149
|
+
if ord == 1
|
150
|
+
Numo::Linalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: '1')
|
151
|
+
elsif !ord.infinite?.nil? && ord.infinite?.positive?
|
152
|
+
Numo::Linalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'I')
|
153
|
+
end
|
154
|
+
end
|
155
|
+
else
|
156
|
+
if ord.nil?
|
157
|
+
b = a.flatten.dup
|
158
|
+
Numo::Linalg::Blas.send(:"#{blas_char(b)}nrm2", b)
|
159
|
+
end
|
160
|
+
end
|
161
|
+
unless norm.nil?
|
162
|
+
norm = Numo::NArray.asarray(norm).reshape(*([1] * a.ndim)) if keepdims
|
163
|
+
return norm
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
if axis.nil?
|
168
|
+
axis = Array.new(a.ndim) { |d| d }
|
169
|
+
else
|
170
|
+
case axis
|
171
|
+
when Integer
|
172
|
+
axis = [axis]
|
173
|
+
when Array, Numo::NArray
|
174
|
+
axis = axis.flatten.to_a
|
175
|
+
else
|
176
|
+
raise ArgumentError, "invalid axis: #{axis}"
|
177
|
+
end
|
178
|
+
end
|
179
|
+
|
180
|
+
raise ArgumentError, "the number of dimensions of axis is inappropriate for the norm: #{axis.size}" unless [1, 2].include?(axis.size)
|
181
|
+
raise ArgumentError, "axis is out of range: #{axis}" unless axis.all? { |ax| (-a.ndim...a.ndim).cover?(ax) }
|
182
|
+
|
183
|
+
if axis.size == 1
|
184
|
+
ord ||= 2
|
185
|
+
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(Numeric)
|
186
|
+
|
187
|
+
ord_inf = ord.infinite?
|
188
|
+
if ord_inf.nil?
|
189
|
+
case ord
|
190
|
+
when 0
|
191
|
+
a.class.cast(a.ne(0)).sum(axis: axis, keepdims: keepdims)
|
192
|
+
when 1
|
193
|
+
a.abs.sum(axis: axis, keepdims: keepdims)
|
194
|
+
else
|
195
|
+
(a.abs**ord).sum(axis: axis, keepdims: keepdims)**1.fdiv(ord)
|
196
|
+
end
|
197
|
+
elsif ord_inf.positive?
|
198
|
+
a.abs.max(axis: axis, keepdims: keepdims)
|
199
|
+
else
|
200
|
+
a.abs.min(axis: axis, keepdims: keepdims)
|
201
|
+
end
|
202
|
+
else
|
203
|
+
ord ||= 'fro'
|
204
|
+
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(String) || ord.is_a?(Numeric)
|
205
|
+
raise ArgumentError, "invalid axis: #{axis}" if axis.uniq.size == 1
|
206
|
+
|
207
|
+
r_axis, c_axis = axis.map { |ax| ax.negative? ? ax + a.ndim : ax }
|
208
|
+
|
209
|
+
norm = if ord.is_a?(String)
|
210
|
+
raise ArgumentError, "invalid ord: #{ord}" unless %w[fro nuc].include?(ord)
|
211
|
+
|
212
|
+
if ord == 'fro'
|
213
|
+
Numo::NMath.sqrt((a.abs**2).sum(axis: axis))
|
214
|
+
else
|
215
|
+
b = a.transpose(c_axis, r_axis).dup
|
216
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
217
|
+
s, = Numo::Linalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
218
|
+
s.sum(axis: -1)
|
219
|
+
end
|
220
|
+
else
|
221
|
+
ord_inf = ord.infinite?
|
222
|
+
if ord_inf.nil?
|
223
|
+
case ord
|
224
|
+
when -2
|
225
|
+
b = a.transpose(c_axis, r_axis).dup
|
226
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
227
|
+
s, = Numo::Linalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
228
|
+
s.min(axis: -1)
|
229
|
+
when -1
|
230
|
+
c_axis -= 1 if c_axis > r_axis
|
231
|
+
a.abs.sum(axis: r_axis).min(axis: c_axis)
|
232
|
+
when 1
|
233
|
+
c_axis -= 1 if c_axis > r_axis
|
234
|
+
a.abs.sum(axis: r_axis).max(axis: c_axis)
|
235
|
+
when 2
|
236
|
+
b = a.transpose(c_axis, r_axis).dup
|
237
|
+
gesvd = :"#{blas_char(b)}gesvd"
|
238
|
+
s, = Numo::Linalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
|
239
|
+
s.max(axis: -1)
|
240
|
+
else
|
241
|
+
raise ArgumentError, "invalid ord: #{ord}"
|
242
|
+
end
|
243
|
+
else
|
244
|
+
r_axis -= 1 if r_axis > c_axis
|
245
|
+
if ord_inf.positive?
|
246
|
+
a.abs.sum(axis: c_axis).max(axis: r_axis)
|
247
|
+
else
|
248
|
+
a.abs.sum(axis: c_axis).min(axis: r_axis)
|
249
|
+
end
|
250
|
+
end
|
251
|
+
end
|
252
|
+
if keepdims
|
253
|
+
norm = Numo::NArray.asarray(norm) unless norm.is_a?(Numo::NArray)
|
254
|
+
norm = norm.reshape(*([1] * a.ndim))
|
255
|
+
end
|
256
|
+
|
257
|
+
norm
|
258
|
+
end
|
259
|
+
end
|
260
|
+
|
261
|
+
# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
|
262
|
+
#
|
263
|
+
# @example
|
264
|
+
# require 'numo/linalg'
|
265
|
+
#
|
266
|
+
# s = Numo::DFloat.new(3, 3).rand - 0.5
|
267
|
+
# a = s.transpose.dot(s)
|
268
|
+
# u = Numo::Linalg.cholesky(a)
|
269
|
+
#
|
270
|
+
# pp u
|
271
|
+
# # =>
|
272
|
+
# # Numo::DFloat#shape=[3,3]
|
273
|
+
# # [[0.532006, 0.338183, -0.18036],
|
274
|
+
# # [0, 0.325153, 0.011721],
|
275
|
+
# # [0, 0, 0.436738]]
|
276
|
+
#
|
277
|
+
# pp (a - u.transpose.dot(u)).abs.max
|
278
|
+
# # => 1.3877787807814457e-17
|
279
|
+
#
|
280
|
+
# l = Numo::Linalg.cholesky(a, uplo: 'L')
|
281
|
+
#
|
282
|
+
# pp l
|
283
|
+
# # =>
|
284
|
+
# # Numo::DFloat#shape=[3,3]
|
285
|
+
# # [[0.532006, 0, 0],
|
286
|
+
# # [0.338183, 0.325153, 0],
|
287
|
+
# # [-0.18036, 0.011721, 0.436738]]
|
288
|
+
#
|
289
|
+
# pp (a - l.dot(l.transpose)).abs.max
|
290
|
+
# # => 1.3877787807814457e-17
|
291
|
+
#
|
292
|
+
# @param a [Numo::NArray] The n-by-n symmetric matrix.
|
293
|
+
# @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
|
294
|
+
# @return [Numo::NArray] The upper- or lower-triangular Cholesky factor of a.
|
295
|
+
def cholesky(a, uplo: 'U')
|
296
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
297
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
298
|
+
|
299
|
+
bchr = blas_char(a)
|
300
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
301
|
+
|
302
|
+
fnc = :"#{bchr}potrf"
|
303
|
+
c, _info = Numo::Linalg::Lapack.send(fnc, a.dup, uplo: uplo)
|
304
|
+
|
305
|
+
case uplo
|
306
|
+
when 'U'
|
307
|
+
c.triu
|
308
|
+
when 'L'
|
309
|
+
c.tril
|
310
|
+
else
|
311
|
+
raise ArgumentError, "invalid uplo: #{uplo}"
|
312
|
+
end
|
313
|
+
end
|
314
|
+
|
315
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` with the Cholesky factorization of `A`.
|
316
|
+
#
|
317
|
+
# @example
|
318
|
+
# require 'numo/linalg'
|
319
|
+
#
|
320
|
+
# s = Numo::DFloat.new(3, 3).rand - 0.5
|
321
|
+
# a = s.transpose.dot(s)
|
322
|
+
# u = Numo::Linalg.cholesky(a)
|
323
|
+
#
|
324
|
+
# b = Numo::DFloat.new(3).rand
|
325
|
+
# x = Numo::Linalg.cho_solve(u, b)
|
326
|
+
#
|
327
|
+
# puts (b - a.dot(x)).abs.max
|
328
|
+
# => 0.0
|
329
|
+
#
|
330
|
+
# @param a [Numo::NArray] The n-by-n cholesky factor.
|
331
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
332
|
+
# @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
|
333
|
+
# @return [Numo::NArray] The solution vector or matrix `X`.
|
334
|
+
def cho_solve(a, b, uplo: 'U')
|
335
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
336
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
337
|
+
raise Numo::NArray::ShapeError, "incompatible dimensions: a.shape[0] = #{a.shape[0]} != b.shape[0] = #{b.shape[0]}" if a.shape[0] != b.shape[0]
|
338
|
+
|
339
|
+
bchr = blas_char(a, b)
|
340
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
341
|
+
|
342
|
+
fnc = :"#{bchr}potrs"
|
343
|
+
x, _info = Numo::Linalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
|
344
|
+
x
|
345
|
+
end
|
346
|
+
|
347
|
+
# Computes the determinant of matrix.
|
348
|
+
#
|
349
|
+
# @example
|
350
|
+
# require 'numo/linalg'
|
351
|
+
#
|
352
|
+
# a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
|
353
|
+
# pp (3.0 - Numo::Linalg.det(a)).abs
|
354
|
+
# # => 1.3322676295501878e-15
|
355
|
+
#
|
356
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
357
|
+
# @return [Float/Complex] The determinant of `a`.
|
358
|
+
def det(a)
|
359
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
360
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
361
|
+
|
362
|
+
bchr = blas_char(a)
|
363
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
364
|
+
|
365
|
+
getrf = :"#{bchr}getrf"
|
366
|
+
lu, piv, info = Numo::Linalg::Lapack.send(getrf, a.dup)
|
367
|
+
|
368
|
+
if info.zero?
|
369
|
+
det_l = 1
|
370
|
+
det_u = lu.diagonal.prod
|
371
|
+
det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
|
372
|
+
det_l * det_u * det_p
|
373
|
+
elsif info.positive?
|
374
|
+
raise 'the factor U is singular, and the inverse matrix could not be computed.'
|
375
|
+
else
|
376
|
+
raise "the #{-info}-th argument of getrf had illegal value"
|
377
|
+
end
|
378
|
+
end
|
379
|
+
|
380
|
+
# Computes the inverse matrix of a square matrix.
|
381
|
+
#
|
382
|
+
# @example
|
383
|
+
# require 'numo/linalg'
|
384
|
+
#
|
385
|
+
# a = Numo::DFloat.new(5, 5).rand
|
386
|
+
#
|
387
|
+
# inv_a = Numo::Linalg.inv(a)
|
388
|
+
#
|
389
|
+
# pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
|
390
|
+
# # => 7.019165976816745e-16
|
391
|
+
#
|
392
|
+
# pp inv_a.dot(a).sum
|
393
|
+
# # => 5.0
|
394
|
+
#
|
395
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
396
|
+
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
397
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
398
|
+
# @return [Numo::NArray] The inverse matrix of `a`.
|
399
|
+
def inv(a, driver: 'getrf', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
400
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
401
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
402
|
+
|
403
|
+
bchr = blas_char(a)
|
404
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
405
|
+
|
406
|
+
getrf = :"#{bchr}getrf"
|
407
|
+
getri = :"#{bchr}getri"
|
408
|
+
|
409
|
+
lu, piv, info = Numo::Linalg::Lapack.send(getrf, a.dup)
|
410
|
+
if info.zero?
|
411
|
+
Numo::Linalg::Lapack.send(getri, lu, piv)[0]
|
412
|
+
elsif info.positive?
|
413
|
+
raise 'the factor U is singular, and the inverse matrix could not be computed.'
|
414
|
+
else
|
415
|
+
raise "the #{-info}-th argument of getrf had illegal value"
|
416
|
+
end
|
417
|
+
end
|
418
|
+
|
419
|
+
# Computes the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
|
420
|
+
#
|
421
|
+
# @example
|
422
|
+
# require 'numo/linalg'
|
423
|
+
#
|
424
|
+
# a = Numo::DFloat.new(5, 3).rand
|
425
|
+
#
|
426
|
+
# inv_a = Numo::Linalg.pinv(a)
|
427
|
+
#
|
428
|
+
# pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
|
429
|
+
# # => 1.1102230246251565e-15
|
430
|
+
#
|
431
|
+
# pp inv_a.dot(a).sum
|
432
|
+
# # => 3.0
|
433
|
+
#
|
434
|
+
# @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
|
435
|
+
# @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
|
436
|
+
# @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
|
437
|
+
# @return [Numo::NArray] The pseudo-inverse of `a`.
|
438
|
+
def pinv(a, driver: 'svd', rcond: nil)
|
439
|
+
s, u, vh = svd(a, driver: driver, job: 'S')
|
440
|
+
rcond = a.shape.max * s.class::EPSILON if rcond.nil?
|
441
|
+
rank = s.gt(rcond * s[0]).count
|
442
|
+
|
443
|
+
u = u[true, 0...rank] / s[0...rank]
|
444
|
+
u.dot(vh[0...rank, true]).conj.transpose
|
445
|
+
end
|
446
|
+
|
447
|
+
# Computes the QR decomposition of a matrix.
|
448
|
+
#
|
449
|
+
# @example
|
450
|
+
# require 'numo/linalg'
|
451
|
+
#
|
452
|
+
# x = Numo::DFloat.new(5, 3).rand
|
453
|
+
#
|
454
|
+
# q, r = Numo::Linalg.qr(x, mode: 'economic')
|
455
|
+
#
|
456
|
+
# pp q
|
457
|
+
# # =>
|
458
|
+
# # Numo::DFloat#shape=[5,3]
|
459
|
+
# # [[-0.0574417, 0.635216, 0.707116],
|
460
|
+
# # [-0.187002, -0.073192, 0.422088],
|
461
|
+
# # [-0.502239, 0.634088, -0.537489],
|
462
|
+
# # [-0.0473292, 0.134867, -0.0223491],
|
463
|
+
# # [-0.840979, -0.413385, 0.180096]]
|
464
|
+
#
|
465
|
+
# pp r
|
466
|
+
# # =>
|
467
|
+
# # Numo::DFloat#shape=[3,3]
|
468
|
+
# # [[-1.07508, -0.821334, -0.484586],
|
469
|
+
# # [0, 0.513035, 0.451868],
|
470
|
+
# # [0, 0, 0.678737]]
|
471
|
+
#
|
472
|
+
# pp (q.dot(r) - x).abs.max
|
473
|
+
# # => 3.885780586188048e-16
|
474
|
+
#
|
475
|
+
# @param a [Numo::NArray] The m-by-n matrix to be decomposed.
|
476
|
+
# @param mode [String] The mode of decomposition.
|
477
|
+
# - "reduce" -- returns both Q [m, m] and R [m, n],
|
478
|
+
# - "r" -- returns only R,
|
479
|
+
# - "economic" -- returns both Q [m, n] and R [n, n],
|
480
|
+
# - "raw" -- returns QR and TAU (LAPACK geqrf results).
|
481
|
+
# @return [Numo::NArray] if mode='r'.
|
482
|
+
# @return [Array<Numo::NArray>] if mode='reduce' or 'economic' or 'raw'.
|
483
|
+
def qr(a, mode: 'reduce')
|
484
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
485
|
+
raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
|
486
|
+
|
487
|
+
bchr = blas_char(a)
|
488
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
489
|
+
|
490
|
+
geqrf = :"#{bchr}geqrf"
|
491
|
+
qr, tau, = Numo::Linalg::Lapack.send(geqrf, a.dup)
|
492
|
+
|
493
|
+
return [qr, tau] if mode == 'raw'
|
494
|
+
|
495
|
+
m, n = qr.shape
|
496
|
+
r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
|
497
|
+
|
498
|
+
return r if mode == 'r'
|
499
|
+
|
500
|
+
org_ung_qr = %w[d s].include?(bchr) ? :"#{bchr}orgqr" : :"#{bchr}ungqr"
|
501
|
+
|
502
|
+
q = if m < n
|
503
|
+
Numo::Linalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
|
504
|
+
elsif mode == 'economic'
|
505
|
+
Numo::Linalg::Lapack.send(org_ung_qr, qr, tau)[0]
|
506
|
+
else
|
507
|
+
qqr = a.class.zeros(m, m)
|
508
|
+
qqr[0...m, 0...n] = qr
|
509
|
+
Numo::Linalg::Lapack.send(org_ung_qr, qqr, tau)[0]
|
510
|
+
end
|
511
|
+
|
512
|
+
[q, r]
|
513
|
+
end
|
514
|
+
|
515
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `A`.
|
516
|
+
#
|
517
|
+
# @example
|
518
|
+
# require 'numo/linalg'
|
519
|
+
#
|
520
|
+
# a = Numo::DFloat.new(3, 3).rand
|
521
|
+
# b = Numo::DFloat.eye(3)
|
522
|
+
#
|
523
|
+
# x = Numo::Linalg.solve(a, b)
|
524
|
+
#
|
525
|
+
# pp x
|
526
|
+
# # =>
|
527
|
+
# # Numo::DFloat#shape=[3,3]
|
528
|
+
# # [[-2.12332, 4.74868, 0.326773],
|
529
|
+
# # [1.38043, -3.79074, 1.25355],
|
530
|
+
# # [0.775187, 1.41032, -0.613774]]
|
531
|
+
#
|
532
|
+
# pp (b - a.dot(x)).abs.max
|
533
|
+
# # => 2.1081041547796492e-16
|
534
|
+
#
|
535
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
536
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
537
|
+
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
538
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
539
|
+
# @return [Numo::NArray] The solusion vector / matrix `X`.
|
540
|
+
def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
|
541
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
542
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
543
|
+
|
544
|
+
bchr = blas_char(a, b)
|
545
|
+
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
|
546
|
+
|
547
|
+
gesv = :"#{bchr}gesv"
|
548
|
+
Numo::Linalg::Lapack.send(gesv, a.dup, b.dup)[1]
|
549
|
+
end
|
550
|
+
|
551
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix.
|
552
|
+
#
|
553
|
+
# @example
|
554
|
+
# require 'numo/linalg'
|
555
|
+
#
|
556
|
+
# a = Numo::DFloat.new(3, 3).rand.triu
|
557
|
+
# b = Numo::DFloat.eye(3)
|
558
|
+
#
|
559
|
+
# x = Numo::Linalg.solve(a, b)
|
560
|
+
#
|
561
|
+
# pp x
|
562
|
+
# # =>
|
563
|
+
# # Numo::DFloat#shape=[3,3]
|
564
|
+
# # [[16.1932, -52.0604, 30.5283],
|
565
|
+
# # [0, 8.61765, -17.9585],
|
566
|
+
# # [0, 0, 6.05735]]
|
567
|
+
#
|
568
|
+
# pp (b - a.dot(x)).abs.max
|
569
|
+
# # => 4.071100642430302e-16
|
570
|
+
#
|
571
|
+
# @param a [Numo::NArray] The n-by-n triangular matrix.
|
572
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
573
|
+
# @param lower [Boolean] The flag indicating whether to use the lower-triangular part of `a`.
|
574
|
+
# @return [Numo::NArray] The solusion vector / matrix `X`.
|
575
|
+
def solve_triangular(a, b, lower: false)
|
576
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
577
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
578
|
+
|
579
|
+
bchr = blas_char(a, b)
|
580
|
+
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
|
581
|
+
|
582
|
+
trtrs = :"#{bchr}trtrs"
|
583
|
+
uplo = lower ? 'L' : 'U'
|
584
|
+
x, info = Numo::Linalg::Lapack.send(trtrs, a, b.dup, uplo: uplo)
|
585
|
+
raise "wrong value is given to the #{info}-th argument of #{trtrs} used internally" if info.negative?
|
586
|
+
|
587
|
+
x
|
588
|
+
end
|
589
|
+
|
590
|
+
# Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
|
591
|
+
#
|
592
|
+
# @example
|
593
|
+
# require 'numo/linalg'
|
594
|
+
#
|
595
|
+
# x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
|
596
|
+
# pp x
|
597
|
+
# # =>
|
598
|
+
# # Numo::DFloat#shape=[5,3]
|
599
|
+
# # [[0.104945, 0.0284236, 0.117406],
|
600
|
+
# # [0.862634, 0.210945, 0.922135],
|
601
|
+
# # [0.324507, 0.0752655, 0.339158],
|
602
|
+
# # [0.67085, 0.102594, 0.600882],
|
603
|
+
# # [0.404631, 0.116868, 0.46644]]
|
604
|
+
#
|
605
|
+
# s, u, vt = Numo::Linalg.svd(x, job: 'S')
|
606
|
+
#
|
607
|
+
# z = u.dot(s.diag).dot(vt)
|
608
|
+
# pp z
|
609
|
+
# # =>
|
610
|
+
# # Numo::DFloat#shape=[5,3]
|
611
|
+
# # [[0.104945, 0.0284236, 0.117406],
|
612
|
+
# # [0.862634, 0.210945, 0.922135],
|
613
|
+
# # [0.324507, 0.0752655, 0.339158],
|
614
|
+
# # [0.67085, 0.102594, 0.600882],
|
615
|
+
# # [0.404631, 0.116868, 0.46644]]
|
616
|
+
#
|
617
|
+
# pp (x - z).abs.max
|
618
|
+
# # => 4.440892098500626e-16
|
619
|
+
#
|
620
|
+
# @param a [Numo::NArray] Matrix to be decomposed.
|
621
|
+
# @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
|
622
|
+
# @param job [String] The job option ('A', 'S', or 'N').
|
623
|
+
# @return [Array<Numo::NArray>] The singular values and singular vectors ([s, u, vt]).
|
624
|
+
def svd(a, driver: 'svd', job: 'A')
|
625
|
+
raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
|
626
|
+
|
627
|
+
bchr = blas_char(a)
|
628
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
629
|
+
|
630
|
+
case driver.to_s
|
631
|
+
when 'sdd'
|
632
|
+
gesdd = :"#{bchr}gesdd"
|
633
|
+
s, u, vt, info = Numo::Linalg::Lapack.send(gesdd, a.dup, jobz: job)
|
634
|
+
when 'svd'
|
635
|
+
gesvd = :"#{bchr}gesvd"
|
636
|
+
s, u, vt, info = Numo::Linalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
|
637
|
+
else
|
638
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
639
|
+
end
|
640
|
+
|
641
|
+
raise "the #{info.abs}-th argument had illegal value" if info.negative?
|
642
|
+
raise 'input array has a NAN entry' if info == -4
|
643
|
+
raise 'svd did not converge' if info.positive?
|
644
|
+
|
645
|
+
[s, u, vt]
|
646
|
+
end
|
647
|
+
|
648
|
+
# Computes the matrix multiplication of two arrays.
|
649
|
+
#
|
650
|
+
# @example
|
651
|
+
# require 'numo/linalg'
|
652
|
+
#
|
653
|
+
# a = Numo::DFloat[[1, 0], [0, 1]]
|
654
|
+
# b = Numo::DFloat[[4, 1], [2, 2]]
|
655
|
+
# pp Numo::Linalg.matmul(a, b)
|
656
|
+
# # =>
|
657
|
+
# # Numo::DFloat#shape=[2,2]
|
658
|
+
# # [[4, 1],
|
659
|
+
# # [2, 2]]
|
660
|
+
#
|
661
|
+
# @param a [Numo::NArray] The first array.
|
662
|
+
# @param b [Numo::NArray] The second array.
|
663
|
+
# @return [Numo::NArray] The matrix product of `a` and `b`.
|
664
|
+
def matmul(a, b)
|
665
|
+
Numo::Linalg::Blas.call(:gemm, a, b)
|
666
|
+
end
|
667
|
+
|
668
|
+
# Computes the matrix `a` raised to the power of `n`.
|
669
|
+
#
|
670
|
+
# @example
|
671
|
+
# require 'numo/linalg'
|
672
|
+
#
|
673
|
+
# a = Numo::DFloat[[1, 2], [3, 4]]
|
674
|
+
# pp Numo::Linalg.matrix_power(a, 3)
|
675
|
+
# # =>
|
676
|
+
# # Numo::DFloat#shape=[2,2]
|
677
|
+
# # [[37, 54],
|
678
|
+
# # [81, 118]]
|
679
|
+
#
|
680
|
+
# @param a [Numo::NArray] The square matrix.
|
681
|
+
# @param n [Integer] The exponent.
|
682
|
+
# @return [Numo::NArray] The matrix `a` raised to the power of `n`.
|
683
|
+
def matrix_power(a, n)
|
684
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
685
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
686
|
+
raise ArgumentError, "exponent n must be an integer: #{n}" unless n.is_a?(Integer)
|
687
|
+
|
688
|
+
if n.zero?
|
689
|
+
a.class.eye(a.shape[0])
|
690
|
+
elsif n.positive?
|
691
|
+
r = a.dup
|
692
|
+
(n - 1).times { r = Numo::Linalg.matmul(r, a) }
|
693
|
+
r
|
694
|
+
else
|
695
|
+
inv_a = inv(a)
|
696
|
+
r = inv_a.dup
|
697
|
+
(-n - 1).times { r = Numo::Linalg.matmul(r, inv_a) }
|
698
|
+
r
|
699
|
+
end
|
700
|
+
end
|
701
|
+
|
702
|
+
# Computes the singular values of a matrix.
|
703
|
+
#
|
704
|
+
# @example
|
705
|
+
# require 'numo/linalg'
|
706
|
+
#
|
707
|
+
# a = Numo::DFloat[[1, 2, 3], [2, 4, 6], [-1, 1, -1]]
|
708
|
+
# pp Numo::Linalg.svdvals(a)
|
709
|
+
# # => Numo::DFloat#shape=[3]
|
710
|
+
# # [8.38434, 1.64402, 5.41675e-17]
|
711
|
+
#
|
712
|
+
# @param a [Numo::NArray] Matrix to be decomposed.
|
713
|
+
# @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
|
714
|
+
# @return [Numo::NArray] The singular values of `a`.
|
715
|
+
def svdvals(a, driver: 'sdd')
|
716
|
+
bchr = blas_char(a)
|
717
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
718
|
+
|
719
|
+
case driver.to_s
|
720
|
+
when 'sdd'
|
721
|
+
gesdd = :"#{bchr}gesdd"
|
722
|
+
s, _u, _vt, info = Numo::Linalg::Lapack.send(gesdd, a.dup, jobz: 'N')
|
723
|
+
when 'svd'
|
724
|
+
gesvd = :"#{bchr}gesvd"
|
725
|
+
s, _u, _vt, info = Numo::Linalg::Lapack.send(gesvd, a.dup, jobu: 'N', jobvt: 'N')
|
726
|
+
else
|
727
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
728
|
+
end
|
729
|
+
|
730
|
+
raise "the #{info.abs}-th argument had illegal value" if info.negative?
|
731
|
+
raise 'input array has a NAN entry' if info == -4
|
732
|
+
raise 'svd did not converge' if info.positive?
|
733
|
+
|
734
|
+
s
|
735
|
+
end
|
736
|
+
|
737
|
+
# Computes an orthonormal basis for the range of `A` using SVD.
|
738
|
+
#
|
739
|
+
# @example
|
740
|
+
# require 'numo/linalg'
|
741
|
+
#
|
742
|
+
# a = Numo::DFloat[[1, 2, 3], [2, 4, 6], [-1, 1, -1]]
|
743
|
+
# u = Numo::Linalg.orth(a)
|
744
|
+
# pp u
|
745
|
+
# # =>
|
746
|
+
# # Numo::DFloat#shape=[3,2]
|
747
|
+
# # [[-0.446229, -0.0296535],
|
748
|
+
# # [-0.892459, -0.059307],
|
749
|
+
# # [0.0663073, -0.997799]]
|
750
|
+
# pp u.transpose.dot(u)
|
751
|
+
# # =>
|
752
|
+
# # Numo::DFloat#shape=[2,2]
|
753
|
+
# # [[1, -1.97749e-16],
|
754
|
+
# # [-1.97749e-16, 1]]
|
755
|
+
#
|
756
|
+
# @param a [Numo::NArray] The m-by-n input matrix.
|
757
|
+
# @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
|
758
|
+
# @return [Numo::NArray] The orthonormal basis for the range of `a`.
|
759
|
+
def orth(a, rcond: nil)
|
760
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
761
|
+
|
762
|
+
s, u, = svd(a, driver: 'sdd', job: 'S')
|
763
|
+
tol = if rcond.nil? || rcond.negative?
|
764
|
+
a.shape.max * s.class::EPSILON
|
765
|
+
else
|
766
|
+
rcond
|
767
|
+
end
|
768
|
+
rank = s.gt(tol * s.max).count
|
769
|
+
u[true, 0...rank].dup
|
770
|
+
end
|
771
|
+
|
772
|
+
# Computes an orthonormal basis for the null space of `A` using SVD.
|
773
|
+
#
|
774
|
+
# @example
|
775
|
+
# require 'numo/linalg'
|
776
|
+
#
|
777
|
+
# a = Numo::DFloat.new(3, 5).rand - 0.5
|
778
|
+
# n = Numo::Linalg.null_space(a)
|
779
|
+
# pp n
|
780
|
+
# # =>
|
781
|
+
# # Numo::DFloat#shape=[5,2]
|
782
|
+
# # [[0.214096, -0.404277],
|
783
|
+
# # [-0.482225, -0.51557],
|
784
|
+
# # [-0.584394, -0.246804],
|
785
|
+
# # [0.596612, -0.351468],
|
786
|
+
# # [-0.155434, 0.621535]]
|
787
|
+
# pp n.transpose.dot(n)
|
788
|
+
# # =>
|
789
|
+
# # Numo::DFloat#shape=[2,2]
|
790
|
+
# # [[1, 1.31078e-16],
|
791
|
+
# # [1.31078e-16, 1]]
|
792
|
+
#
|
793
|
+
# @param a [Numo::NArray] The m-by-n input matrix.
|
794
|
+
# @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
|
795
|
+
# @return [Numo::NArray] The orthonormal basis for the null space of `a`.
|
796
|
+
def null_space(a, rcond: nil)
|
797
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
798
|
+
|
799
|
+
s, _u, vt = svd(a, driver: 'sdd', job: 'A')
|
800
|
+
tol = if rcond.nil? || rcond.negative?
|
801
|
+
a.shape.max * s.class::EPSILON
|
802
|
+
else
|
803
|
+
rcond
|
804
|
+
end
|
805
|
+
rank = s.gt(tol * s.max).count
|
806
|
+
vt[rank...vt.shape[0], true].conj.transpose.dup
|
807
|
+
end
|
808
|
+
|
809
|
+
# Computes the LU decomposition of a matrix using partial pivoting.
|
810
|
+
#
|
811
|
+
# @example
|
812
|
+
# require 'numo/linalg'
|
813
|
+
#
|
814
|
+
# a = Numo::DFloat.new(3, 4).rand - 0.5
|
815
|
+
# pm, l, u = Numo::Linalg.lu(a)
|
816
|
+
# error = (pm.dot(l).dot(u) - a).abs.max
|
817
|
+
# pp error
|
818
|
+
# # => 5.551115123125783e-17
|
819
|
+
#
|
820
|
+
# l, u = Numo::Linalg.lu(a, permute_l: true)
|
821
|
+
# error = (l.dot(u) - a).abs.max
|
822
|
+
# pp error
|
823
|
+
# # => 5.551115123125783e-17
|
824
|
+
#
|
825
|
+
# @param a [Numo::NArray] The m-by-n matrix to be decomposed.
|
826
|
+
# @param permute_l [Boolean] If true, returns `L` with the permutation applied.
|
827
|
+
# @return [Array<Numo::NArray>] if `permute_l` is `false`, the permutation matrix `P`, lower-triangular matrix `L`, and
|
828
|
+
# upper-triangular matrix `U` ([P, L, U]). if `permute_l` is `true`, the permuted lower-triangular matrix `L` and
|
829
|
+
# upper-triangular matrix `U` ([L, U]).
|
830
|
+
def lu(a, permute_l: false)
|
831
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
832
|
+
|
833
|
+
m, n = a.shape
|
834
|
+
k = [m, n].min
|
835
|
+
lu, piv = lu_fact(a)
|
836
|
+
l = lu.tril.tap { |nary| nary[nary.diag_indices] = 1 }[true, 0...k].dup
|
837
|
+
u = lu.triu[0...k, 0...n].dup
|
838
|
+
perm = a.class.eye(m).tap do |nary|
|
839
|
+
piv.to_a.each_with_index { |i, j| nary[true, [i - 1, j]] = nary[true, [j, i - 1]].dup }
|
840
|
+
end
|
841
|
+
|
842
|
+
permute_l ? [perm.dot(l), u] : [perm, l, u]
|
843
|
+
end
|
844
|
+
|
845
|
+
# Computes the LU decomposition of a matrix using partial pivoting.
|
846
|
+
#
|
847
|
+
# @param a [Numo::NArray] The m-by-n matrix to be decomposed.
|
848
|
+
# @return [Array<Numo::NArray>] The LU decomposition and pivot indices ([lu, piv]).
|
849
|
+
def lu_fact(a)
|
850
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
851
|
+
|
852
|
+
bchr = blas_char(a)
|
853
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
854
|
+
|
855
|
+
getrf = :"#{bchr}getrf"
|
856
|
+
lu, piv, info = Numo::Linalg::Lapack.send(getrf, a.dup)
|
857
|
+
|
858
|
+
raise "the #{info.abs}-th argument of getrf had illegal value" if info.negative?
|
859
|
+
raise "the U(#{info}, #{info}) is exactly zero. The factorization has been completed." if info.positive?
|
860
|
+
|
861
|
+
[lu, piv]
|
862
|
+
end
|
863
|
+
|
864
|
+
# Solves linear equation `A * x = b` or `A * X = B` for `x` using the LU decomposition of `A`.
|
865
|
+
#
|
866
|
+
# @example
|
867
|
+
# require 'numo/linalg'
|
868
|
+
#
|
869
|
+
# a = Numo::DFloat.new(3, 3).rand
|
870
|
+
# b = Numo::DFloat.eye(3)
|
871
|
+
# lu, ipiv = Numo::Linalg.lu_fact(a)
|
872
|
+
# x = Numo::Linalg.lu_solve(lu, ipiv, b)
|
873
|
+
#
|
874
|
+
# puts (b - a.dot(x)).abs.max
|
875
|
+
# => 2.220446049250313e-16
|
876
|
+
#
|
877
|
+
# @param lu [Numo::NArray] The LU decomposition of the n-by-n matrix `A`.
|
878
|
+
# @param ipiv [Numo::Int32/Int64] The pivot indices from `lu_fact`.
|
879
|
+
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
|
880
|
+
# @param trans [String] The type of system to be solved.
|
881
|
+
# - 'N': solve `A * x = b` (No transpose),
|
882
|
+
# - 'T': solve `A^T * x = b` (Transpose),
|
883
|
+
# - 'C': solve `A^H * x = b` (Conjugate transpose).
|
884
|
+
# @return [Numo::NArray] The solusion vector / matrix `X`.
|
885
|
+
def lu_solve(lu, ipiv, b, trans: 'N')
|
886
|
+
raise Numo::NArray::ShapeError, 'input array lu must be 2-dimensional' if lu.ndim != 2
|
887
|
+
raise Numo::NArray::ShapeError, 'input array lu must be square' if lu.shape[0] != lu.shape[1]
|
888
|
+
raise Numo::NArray::ShapeError, "incompatible dimensions: lu.shape[0] = #{lu.shape[0]} != b.shape[0] = #{b.shape[0]}" if lu.shape[0] != b.shape[0]
|
889
|
+
raise ArgumentError, 'trans must be "N", "T", or "C"' unless %w[N T C].include?(trans)
|
890
|
+
|
891
|
+
bchr = blas_char(lu)
|
892
|
+
raise ArgumentError, "invalid array type: #{lu.class}" if bchr == 'n'
|
893
|
+
|
894
|
+
getrs = :"#{bchr}getrs"
|
895
|
+
x, info = Numo::Linalg::Lapack.send(getrs, lu, ipiv, b.dup)
|
896
|
+
|
897
|
+
raise "the #{info.abs}-th argument of getrs had illegal value" if info.negative?
|
898
|
+
|
899
|
+
x
|
900
|
+
end
|
901
|
+
|
902
|
+
# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
|
903
|
+
#
|
904
|
+
# @param a [Numo::NArray] The n-by-n symmetric / Hermitian positive-definite matrix.
|
905
|
+
# @param uplo [String] The part of the matrix to be used ('U' or 'L').
|
906
|
+
# @return [Numo::NArray] The upper- / lower-triangular matrix `U` / `L`.
|
907
|
+
def cho_fact(a, uplo: 'U')
|
908
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
909
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
910
|
+
raise ArgumentError, 'uplo must be "U" or "L"' unless %w[U L].include?(uplo)
|
911
|
+
|
912
|
+
bchr = blas_char(a)
|
913
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
914
|
+
|
915
|
+
fnc = :"#{bchr}potrf"
|
916
|
+
c, info = Numo::Linalg::Lapack.send(fnc, a.dup, uplo: uplo)
|
917
|
+
|
918
|
+
raise "the #{info}-th leading minor of the array is not positive definite, and the factorization could not be completed." if info.positive?
|
919
|
+
raise "the #{-info}-th argument of #{fnc} had illegal value" if info.negative?
|
920
|
+
|
921
|
+
c
|
922
|
+
end
|
923
|
+
|
924
|
+
# Computes the eigenvalues and right and/or left eigenvectors of a general square matrix.
|
925
|
+
#
|
926
|
+
# @example
|
927
|
+
# require 'numo/linalg'
|
928
|
+
#
|
929
|
+
# a = Numo::DFloat.new(5, 5).rand - 0.5
|
930
|
+
# w, _vl, vr = Numo::Linalg.eig(a)
|
931
|
+
# error = (a.dot(vr) - vr.dot(w.diag)).abs.max
|
932
|
+
# pp error
|
933
|
+
# # => 4.718447854656915e-16
|
934
|
+
#
|
935
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
936
|
+
# @param left [Boolean] The flag indicating whether to compute the left eigenvectors.
|
937
|
+
# @param right [Boolean] The flag indicating whether to compute the right eigenvectors.
|
938
|
+
# @return [Array<Numo::NArray>] The eigenvalues, left eigenvectors, and right eigenvectors.
|
939
|
+
def eig(a, left: false, right: true) # rubocop:disable Metrics/AbcSize, Metrics/PerceivedComplexity
|
940
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
941
|
+
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
942
|
+
|
943
|
+
jobvl = left ? 'V' : 'N'
|
944
|
+
jobvr = right ? 'V' : 'N'
|
945
|
+
|
946
|
+
bchr = blas_char(a)
|
947
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
948
|
+
|
949
|
+
fnc = :"#{bchr}geev"
|
950
|
+
if %w[z c].include?(bchr)
|
951
|
+
w, vl, vr, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvl: jobvl, jobvr: jobvr)
|
952
|
+
else
|
953
|
+
wr, wi, vl, vr, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvl: jobvl, jobvr: jobvr)
|
954
|
+
end
|
955
|
+
|
956
|
+
raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
|
957
|
+
raise 'the QR algorithm failed to compute all the eigenvalues.' if info.positive?
|
958
|
+
|
959
|
+
if %w[d s].include?(bchr)
|
960
|
+
w = wr + (wi * 1.0i)
|
961
|
+
ids = wi.gt(0).where
|
962
|
+
unless ids.empty?
|
963
|
+
cast_class = bchr == 'd' ? Numo::DComplex : Numo::SComplex
|
964
|
+
if left
|
965
|
+
tmp = cast_class.cast(vl)
|
966
|
+
tmp[true, ids].imag = vl[true, ids + 1]
|
967
|
+
tmp[true, ids + 1] = tmp[true, ids].conj
|
968
|
+
vl = tmp
|
969
|
+
end
|
970
|
+
if right
|
971
|
+
tmp = cast_class.cast(vr)
|
972
|
+
tmp[true, ids].imag = vr[true, ids + 1]
|
973
|
+
tmp[true, ids + 1] = tmp[true, ids].conj
|
974
|
+
vr = tmp
|
975
|
+
end
|
976
|
+
end
|
977
|
+
end
|
978
|
+
|
979
|
+
[w, left ? vl : nil, right ? vr : nil]
|
980
|
+
end
|
981
|
+
|
982
|
+
# Computes the eigenvalues of a general square matrix.
|
983
|
+
#
|
984
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
985
|
+
# @return [Numo::NArray] The eigenvalues.
|
986
|
+
def eigvals(a)
|
987
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
988
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
989
|
+
|
990
|
+
bchr = blas_char(a)
|
991
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
992
|
+
|
993
|
+
fnc = :"#{bchr}geev"
|
994
|
+
if %w[z c].include?(bchr)
|
995
|
+
w, _vl, _vr, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvl: 'N', jobvr: 'N')
|
996
|
+
else
|
997
|
+
wr, wi, _vl, _vr, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvl: 'N', jobvr: 'N')
|
998
|
+
w = wr + (wi * 1.0i)
|
999
|
+
end
|
1000
|
+
|
1001
|
+
raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
|
1002
|
+
raise 'the QR algorithm failed to compute all the eigenvalues.' if info.positive?
|
1003
|
+
|
1004
|
+
w
|
1005
|
+
end
|
1006
|
+
|
1007
|
+
# Computes the eigenvalues of a symmetric / Hermitian matrix by solving an ordinary / generalized eigenvalue problem.
|
1008
|
+
#
|
1009
|
+
# @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
|
1010
|
+
# @param b [Numo::NArray] The n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
|
1011
|
+
# @param vals_range [Range/Array]
|
1012
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
1013
|
+
# If nil, all eigenvalues and eigenvectors are computed.
|
1014
|
+
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
|
1015
|
+
# @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
|
1016
|
+
# @return [Numo::NArray] The eigenvalues.
|
1017
|
+
def eigvalsh(a, b = nil, vals_range: nil, uplo: 'U', turbo: false)
|
1018
|
+
eigh(a, b, vals_only: true, vals_range: vals_range, uplo: uplo, turbo: turbo)[0]
|
1019
|
+
end
|
1020
|
+
|
1021
|
+
# Computes the Bunch-Kaufman decomposition of a symmetric / Hermitian matrix.
|
1022
|
+
# The factorization has the form `A = U * D * U^T` or `A = L * D * L^T`,
|
1023
|
+
# where `U` (or `L`) is a product of permutation and unit upper
|
1024
|
+
# (lower) triangular matrices, and `D` is a block diagonal matrix.
|
1025
|
+
#
|
1026
|
+
# @example
|
1027
|
+
# require 'numo/linalg'
|
1028
|
+
#
|
1029
|
+
# a = Numo::DFloat.new(5, 5).rand
|
1030
|
+
# a = 0.5 * (a + a.transpose)
|
1031
|
+
# u, d, _perm = Numo::Linalg.ldl(a)
|
1032
|
+
# error = (a - u.dot(d).dot(u.transpose)).abs.max
|
1033
|
+
# pp error
|
1034
|
+
# # => 5.551115123125783e-17
|
1035
|
+
#
|
1036
|
+
# @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
|
1037
|
+
# @param uplo [String] The part of the matrix to be used ('U' or 'L').
|
1038
|
+
# @param hermitian [Boolean] The flag indicating whether `a` is Hermitian.
|
1039
|
+
# @return [Array<Numo::NArray>] The permutated upper (lower) triangular matrix, the block diagonal matrix, and the permutation indices.
|
1040
|
+
def ldl(a, uplo: 'U', hermitian: true)
|
1041
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
1042
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
1043
|
+
|
1044
|
+
bchr = blas_char(a)
|
1045
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
1046
|
+
|
1047
|
+
complex = bchr =~ /c|z/
|
1048
|
+
fnc = complex && hermitian ? :"#{bchr}hetrf" : :"#{bchr}sytrf"
|
1049
|
+
lud = a.dup
|
1050
|
+
ipiv, info = Numo::Linalg::Lapack.send(fnc, lud, uplo: uplo)
|
1051
|
+
|
1052
|
+
raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
|
1053
|
+
raise 'the factorization has been completed' if info.positive?
|
1054
|
+
|
1055
|
+
_lud_permutation(lud, ipiv, uplo: uplo, hermitian: hermitian)
|
1056
|
+
end
|
1057
|
+
|
1058
|
+
# Compute the condition number of a matrix.
|
1059
|
+
#
|
1060
|
+
# @param a [Numo::NArray] The input matrix.
|
1061
|
+
# @param ord [String/Symbol/Integer] The order of the norm.
|
1062
|
+
# nil or 2: 2-norm using singular values, 'fro': Frobenius norm, 'info': infinity norm, and 1: 1-norm.
|
1063
|
+
# @return [Numo::NArray] The condition number of the matrix.
|
1064
|
+
def cond(a, ord = nil)
|
1065
|
+
if ord.nil? || ord == 2 || ord == -2
|
1066
|
+
svals = svdvals(a)
|
1067
|
+
if ord == -2
|
1068
|
+
svals[false, -1] / svals[false, 0]
|
1069
|
+
else
|
1070
|
+
svals[false, 0] / svals[false, -1]
|
1071
|
+
end
|
1072
|
+
else
|
1073
|
+
inv_a = inv(a)
|
1074
|
+
norm(a, ord, axis: [-2, -1]) * norm(inv_a, ord, axis: [-2, -1])
|
1075
|
+
end
|
1076
|
+
end
|
1077
|
+
|
1078
|
+
# Computes the sign and natural logarithm of the determinant of a matrix.
|
1079
|
+
#
|
1080
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
1081
|
+
# @return [Array<Float/Complex>] The sign and natural logarithm of the determinant of `a` ([sign, logdet]).
|
1082
|
+
def slogdet(a)
|
1083
|
+
lu, ipiv = lu_fact(a)
|
1084
|
+
dg = lu.diagonal
|
1085
|
+
return 0, (-Float::INFINITY) if dg.eq(0).any?
|
1086
|
+
|
1087
|
+
idx = ipiv.class.new(ipiv.shape[-1]).seq(1)
|
1088
|
+
n_nonzero = ipiv.ne(idx).count(axis: -1)
|
1089
|
+
sign = ((-1.0)**(n_nonzero % 2)) * (dg / dg.abs).prod
|
1090
|
+
|
1091
|
+
logdet = Numo::NMath.log(dg.abs).sum(axis: -1)
|
1092
|
+
|
1093
|
+
[sign, logdet]
|
1094
|
+
end
|
1095
|
+
|
1096
|
+
# Computes the rank of a matrix using SVD.
|
1097
|
+
#
|
1098
|
+
# @param a [Numo::NArray] The input matrix.
|
1099
|
+
# @param tol [Float] The threshold value for small singular values of `a`.
|
1100
|
+
# @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
|
1101
|
+
# @return [Integer] The rank of the matrix.
|
1102
|
+
def matrix_rank(a, tol: nil, driver: 'svd')
|
1103
|
+
return a.ne(0).any? ? 1 : 0 if a.ndim < 2
|
1104
|
+
|
1105
|
+
s = svdvals(a, driver: driver)
|
1106
|
+
tol ||= s.max(axis: -1, keepdims: true) * (a.shape[-2..].max * s.class::EPSILON)
|
1107
|
+
s.gt(tol).count(axis: -1)
|
1108
|
+
end
|
1109
|
+
|
1110
|
+
# Computes the least-squares solution to a linear matrix equation.
|
1111
|
+
#
|
1112
|
+
# @param a [Numo::NArray] The m-by-n input matrix.
|
1113
|
+
# @param b [Numo::NArray] The m-dimensional right-hand side vector or the m-by-nrhs right-hand side matrix.
|
1114
|
+
# @param driver [String] The LAPACK driver to be used (This argument is ignored, 'lsd' is always used).
|
1115
|
+
# @param rcond [Float] The threshold value for small singular values of `a`.
|
1116
|
+
# @return [Array<Numo::NArray, Float/Complex, Integer, Numo::NArray>] The least-squares solution matrix / vector `x`,
|
1117
|
+
# the sum of squared residuals, the effective rank of `a`, and the singular values of `a`.
|
1118
|
+
def lstsq(a, b, driver: 'lsd', rcond: nil) # rubocop:disable Lint/UnusedMethodArgument, Metrics/AbcSize
|
1119
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
1120
|
+
raise Numo::NArray::ShapeError, "incompatible dimensions: a.shape[0] = #{a.shape[0]} != b.shape[0] = #{b.shape[0]}" if a.shape[0] != b.shape[0]
|
1121
|
+
|
1122
|
+
bchr = blas_char(a)
|
1123
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
1124
|
+
|
1125
|
+
m, n = a.shape
|
1126
|
+
if m < n
|
1127
|
+
if b.ndim == 1
|
1128
|
+
x = Numo::DFloat.zeros(n)
|
1129
|
+
x[0...b.size] = b
|
1130
|
+
else
|
1131
|
+
x = Numo::DFloat.zeros(n, b.shape[1])
|
1132
|
+
x[0...b.shape[0], 0...b.shape[1]] = b
|
1133
|
+
end
|
1134
|
+
else
|
1135
|
+
x = b.dup
|
1136
|
+
end
|
1137
|
+
|
1138
|
+
fnc = :"#{bchr}gelsd"
|
1139
|
+
s, rank, info = Numo::Linalg::Lapack.send(fnc, a.dup, x, rcond: rcond)
|
1140
|
+
|
1141
|
+
raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
|
1142
|
+
raise 'the algorithm for computing the SVD failed to converge' if info.positive?
|
1143
|
+
|
1144
|
+
resids = x.class[]
|
1145
|
+
if m > n
|
1146
|
+
if rank == n
|
1147
|
+
resids = if b.ndim == 1
|
1148
|
+
(x[n..].abs**2).sum(axis: 0)
|
1149
|
+
else
|
1150
|
+
(x[n..-1, true].abs**2).sum(axis: 0)
|
1151
|
+
end
|
1152
|
+
end
|
1153
|
+
x = if b.ndim == 1
|
1154
|
+
x[false, 0...n]
|
1155
|
+
else
|
1156
|
+
x[false, 0...n, true]
|
1157
|
+
end
|
1158
|
+
end
|
1159
|
+
|
1160
|
+
[x, resids, rank, s]
|
1161
|
+
end
|
1162
|
+
|
1163
|
+
# Computes the matrix exponential using a scaling and squaring algorithm with a Pade approximation.
|
1164
|
+
#
|
1165
|
+
# @param a [Numo::NArray] The n-by-n square matrix.
|
1166
|
+
# @param ord [Integer] The order of the Padé approximation.
|
1167
|
+
# @return [Numo::NArray] The matrix exponential of `a`.
|
1168
|
+
#
|
1169
|
+
# Reference:
|
1170
|
+
# - C. Moler and C. Van Loan, "Nineteen Dubious Ways to Compute the Exponential of a Matrix, Twenty-Five Years Later," SIAM Review, vol. 45, no. 1, pp. 3-49, 2003.
|
1171
|
+
def expm(a, ord = 8) # rubocop:disable Metrics/AbcSize
|
1172
|
+
raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
|
1173
|
+
raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
|
1174
|
+
|
1175
|
+
norm = a.abs.max
|
1176
|
+
n_sqr = norm.positive? ? [0, Math.log2(norm).to_i + 1].max : 0
|
1177
|
+
a /= 2**n_sqr
|
1178
|
+
|
1179
|
+
x = a.dup
|
1180
|
+
c = 0.5
|
1181
|
+
sgn = 1
|
1182
|
+
nume = a.class.eye(a.shape[0]) + (c * a)
|
1183
|
+
deno = a.class.eye(a.shape[0]) - (c * a)
|
1184
|
+
(2..ord).each do |k|
|
1185
|
+
c *= (ord - k + 1).fdiv(k * ((2 * ord) - k + 1))
|
1186
|
+
x = a.dot(x)
|
1187
|
+
c_x = c * x
|
1188
|
+
nume += c_x
|
1189
|
+
deno += sgn * c_x
|
1190
|
+
sgn = -sgn
|
1191
|
+
end
|
1192
|
+
|
1193
|
+
a_expm = Numo::Linalg.solve(deno, nume)
|
1194
|
+
n_sqr.times { a_expm = a_expm.dot(a_expm) }
|
1195
|
+
a_expm
|
1196
|
+
end
|
1197
|
+
|
1198
|
+
# Computes the inverse of a matrix using its LU decomposition.
|
1199
|
+
#
|
1200
|
+
# @param lu [Numo::NArray] The LU decomposition of the n-by-n matrix `A`.
|
1201
|
+
# @param ipiv [Numo::Int32] The pivot indices from `lu_fact`.
|
1202
|
+
# @return [Numo::NArray] The inverse of the matrix `A`.
|
1203
|
+
def lu_inv(lu, ipiv)
|
1204
|
+
bchr = blas_char(lu)
|
1205
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
1206
|
+
|
1207
|
+
fnc = :"#{bchr}getri"
|
1208
|
+
inv, info = Numo::Linalg::Lapack.send(fnc, lu.dup, ipiv)
|
1209
|
+
|
1210
|
+
raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
|
1211
|
+
raise 'the matrix is singular and its inverse could not be computed' if info.positive?
|
1212
|
+
|
1213
|
+
inv
|
1214
|
+
end
|
1215
|
+
|
1216
|
+
# Computes the inverse of a symmetric / Hermitian positive-definite matrix using its Cholesky decomposition.
|
1217
|
+
#
|
1218
|
+
# @example
|
1219
|
+
# require 'numo/linalg'
|
1220
|
+
#
|
1221
|
+
# a = Numo::DFloat.new(3, 5).rand - 0.5
|
1222
|
+
# a = a.dot(a.transpose)
|
1223
|
+
# c = Numo::Linalg.cho_fact(a)
|
1224
|
+
# tri_inv_a = Numo::Linalg.cho_inv(c)
|
1225
|
+
# tri_inv_a = tri_inv_a.triu
|
1226
|
+
# inv_a = tri_inv_a + tri_inv_a.transpose - tri_inv_a.diagonal.diag
|
1227
|
+
# error = (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
|
1228
|
+
# pp error
|
1229
|
+
# # => 1.923726113137665e-15
|
1230
|
+
#
|
1231
|
+
# @param a [Numo::NArray] The Cholesky decomposition of the n-by-n symmetric / Hermitian positive-definite matrix.
|
1232
|
+
# @param uplo [String] The part of the matrix to be used ('U' or 'L').
|
1233
|
+
# @return [Numo::NArray] The upper- / lower-triangular matrix of the inverse of `a`.
|
1234
|
+
def cho_inv(a, uplo: 'U')
|
1235
|
+
bchr = blas_char(a)
|
1236
|
+
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
|
1237
|
+
|
1238
|
+
fnc = :"#{bchr}potri"
|
1239
|
+
inv, info = Numo::Linalg::Lapack.send(fnc, a.dup, uplo: uplo)
|
1240
|
+
|
1241
|
+
raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
|
1242
|
+
raise "the (#{info}, #info)-th element of the factor U or L is zero, and the inverse could not be computed." if info.positive?
|
1243
|
+
|
1244
|
+
inv
|
1245
|
+
end
|
1246
|
+
|
1247
|
+
# @!visibility private
|
1248
|
+
def _lud_permutation(lud, ipiv, uplo: 'U', hermitian: true) # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
|
1249
|
+
n = lud.shape[0]
|
1250
|
+
d = lud.class.zeros(n, n)
|
1251
|
+
perm = Numo::Int32.new(n).seq
|
1252
|
+
if uplo == 'U'
|
1253
|
+
u = lud.triu.tap { |m| m[m.diag_indices] = 1 }
|
1254
|
+
# If IPIV(k) > 0, then rows and columns k and IPIV(k) were interchanged
|
1255
|
+
# and D(k,k) is a 1-by-1 diagonal block.
|
1256
|
+
# IF UPLO = 'U' and If IPIV(k) = IPIV(k-1) < 0, then
|
1257
|
+
# rows and columns k-1 and -IPIV(k) were interchanged
|
1258
|
+
# and D(k-1:k,k-1:k) is a 2-by-2 diagonal block.
|
1259
|
+
changed_2x2 = false
|
1260
|
+
n.times do |k|
|
1261
|
+
d[k, k] = lud[k, k]
|
1262
|
+
if ipiv[k].positive?
|
1263
|
+
i = ipiv[k] - 1
|
1264
|
+
u[[i, k], 0..k] = u[[k, i], 0..k].dup
|
1265
|
+
perm[[i, k]] = perm[[k, i]].dup
|
1266
|
+
elsif k.positive? && ipiv[k].negative? && ipiv[k] == ipiv[k - 1] && !changed_2x2
|
1267
|
+
i = -ipiv[k] - 1
|
1268
|
+
d[k - 1, k] = lud[k - 1, k]
|
1269
|
+
d[k, k - 1] = hermitian ? d[k - 1, k].conj : d[k - 1, k]
|
1270
|
+
u[k - 1, k] = 0
|
1271
|
+
u[[i, k - 1], 0..k] = u[[k - 1, i], 0..k].dup
|
1272
|
+
perm[[i, k - 1]] = perm[[k - 1, i]].dup
|
1273
|
+
changed_2x2 = true
|
1274
|
+
next
|
1275
|
+
end
|
1276
|
+
changed_2x2 = false if changed_2x2
|
1277
|
+
end
|
1278
|
+
[u, d, perm.sort_index]
|
1279
|
+
else
|
1280
|
+
l = lud.tril.tap { |m| m[m.diag_indices] = 1 }
|
1281
|
+
# If UPLO = 'L' and IPIV(k) = IPIV(k+1) < 0, then
|
1282
|
+
# rows and columns k+1 and -IPIV(k) were interchanged
|
1283
|
+
# and D(k:k+1,k:k+1) is a 2-by-2 diagonal block.
|
1284
|
+
changed_2x2 = false
|
1285
|
+
(n - 1).downto(0) do |k|
|
1286
|
+
d[k, k] = lud[k, k]
|
1287
|
+
if ipiv[k].positive?
|
1288
|
+
i = ipiv[k] - 1
|
1289
|
+
l[[i, k], k...n] = l[[k, i], k...n].dup
|
1290
|
+
perm[[i, k]] = perm[[k, i]].dup
|
1291
|
+
elsif k < n - 1 && ipiv[k].negative? && ipiv[k] == ipiv[k + 1] && !changed_2x2
|
1292
|
+
i = -ipiv[k] - 1
|
1293
|
+
d[k + 1, k] = lud[k + 1, k]
|
1294
|
+
d[k, k + 1] = hermitian ? d[k + 1, k].conj : d[k + 1, k]
|
1295
|
+
l[k + 1, k] = 0
|
1296
|
+
l[[i, k + 1], k...n] = l[[k + 1, i], k...n].dup
|
1297
|
+
perm[[i, k + 1]] = perm[[k + 1, i]].dup
|
1298
|
+
changed_2x2 = true
|
1299
|
+
next
|
1300
|
+
end
|
1301
|
+
changed_2x2 = false if changed_2x2
|
1302
|
+
end
|
1303
|
+
[l, d, perm.sort_index]
|
1304
|
+
end
|
1305
|
+
end
|
1306
|
+
|
1307
|
+
private_class_method :_lud_permutation
|
1308
|
+
end
|
1309
|
+
end
|