numo-linalg-alt 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +5 -0
  3. data/CODE_OF_CONDUCT.md +84 -0
  4. data/LICENSE.txt +27 -0
  5. data/README.md +106 -0
  6. data/ext/numo/linalg/blas/dot.c +72 -0
  7. data/ext/numo/linalg/blas/dot.h +13 -0
  8. data/ext/numo/linalg/blas/dot_sub.c +71 -0
  9. data/ext/numo/linalg/blas/dot_sub.h +13 -0
  10. data/ext/numo/linalg/blas/gemm.c +184 -0
  11. data/ext/numo/linalg/blas/gemm.h +16 -0
  12. data/ext/numo/linalg/blas/gemv.c +161 -0
  13. data/ext/numo/linalg/blas/gemv.h +16 -0
  14. data/ext/numo/linalg/blas/nrm2.c +67 -0
  15. data/ext/numo/linalg/blas/nrm2.h +13 -0
  16. data/ext/numo/linalg/converter.c +67 -0
  17. data/ext/numo/linalg/converter.h +23 -0
  18. data/ext/numo/linalg/extconf.rb +99 -0
  19. data/ext/numo/linalg/lapack/geev.c +152 -0
  20. data/ext/numo/linalg/lapack/geev.h +15 -0
  21. data/ext/numo/linalg/lapack/gelsd.c +92 -0
  22. data/ext/numo/linalg/lapack/gelsd.h +15 -0
  23. data/ext/numo/linalg/lapack/geqrf.c +72 -0
  24. data/ext/numo/linalg/lapack/geqrf.h +15 -0
  25. data/ext/numo/linalg/lapack/gesdd.c +108 -0
  26. data/ext/numo/linalg/lapack/gesdd.h +15 -0
  27. data/ext/numo/linalg/lapack/gesv.c +99 -0
  28. data/ext/numo/linalg/lapack/gesv.h +15 -0
  29. data/ext/numo/linalg/lapack/gesvd.c +152 -0
  30. data/ext/numo/linalg/lapack/gesvd.h +15 -0
  31. data/ext/numo/linalg/lapack/getrf.c +71 -0
  32. data/ext/numo/linalg/lapack/getrf.h +15 -0
  33. data/ext/numo/linalg/lapack/getri.c +82 -0
  34. data/ext/numo/linalg/lapack/getri.h +15 -0
  35. data/ext/numo/linalg/lapack/getrs.c +110 -0
  36. data/ext/numo/linalg/lapack/getrs.h +15 -0
  37. data/ext/numo/linalg/lapack/heev.c +71 -0
  38. data/ext/numo/linalg/lapack/heev.h +15 -0
  39. data/ext/numo/linalg/lapack/heevd.c +71 -0
  40. data/ext/numo/linalg/lapack/heevd.h +15 -0
  41. data/ext/numo/linalg/lapack/heevr.c +111 -0
  42. data/ext/numo/linalg/lapack/heevr.h +15 -0
  43. data/ext/numo/linalg/lapack/hegv.c +94 -0
  44. data/ext/numo/linalg/lapack/hegv.h +15 -0
  45. data/ext/numo/linalg/lapack/hegvd.c +94 -0
  46. data/ext/numo/linalg/lapack/hegvd.h +15 -0
  47. data/ext/numo/linalg/lapack/hegvx.c +133 -0
  48. data/ext/numo/linalg/lapack/hegvx.h +15 -0
  49. data/ext/numo/linalg/lapack/hetrf.c +68 -0
  50. data/ext/numo/linalg/lapack/hetrf.h +15 -0
  51. data/ext/numo/linalg/lapack/lange.c +66 -0
  52. data/ext/numo/linalg/lapack/lange.h +15 -0
  53. data/ext/numo/linalg/lapack/orgqr.c +79 -0
  54. data/ext/numo/linalg/lapack/orgqr.h +15 -0
  55. data/ext/numo/linalg/lapack/potrf.c +70 -0
  56. data/ext/numo/linalg/lapack/potrf.h +15 -0
  57. data/ext/numo/linalg/lapack/potri.c +70 -0
  58. data/ext/numo/linalg/lapack/potri.h +15 -0
  59. data/ext/numo/linalg/lapack/potrs.c +94 -0
  60. data/ext/numo/linalg/lapack/potrs.h +15 -0
  61. data/ext/numo/linalg/lapack/syev.c +71 -0
  62. data/ext/numo/linalg/lapack/syev.h +15 -0
  63. data/ext/numo/linalg/lapack/syevd.c +71 -0
  64. data/ext/numo/linalg/lapack/syevd.h +15 -0
  65. data/ext/numo/linalg/lapack/syevr.c +111 -0
  66. data/ext/numo/linalg/lapack/syevr.h +15 -0
  67. data/ext/numo/linalg/lapack/sygv.c +93 -0
  68. data/ext/numo/linalg/lapack/sygv.h +15 -0
  69. data/ext/numo/linalg/lapack/sygvd.c +93 -0
  70. data/ext/numo/linalg/lapack/sygvd.h +15 -0
  71. data/ext/numo/linalg/lapack/sygvx.c +133 -0
  72. data/ext/numo/linalg/lapack/sygvx.h +15 -0
  73. data/ext/numo/linalg/lapack/sytrf.c +72 -0
  74. data/ext/numo/linalg/lapack/sytrf.h +15 -0
  75. data/ext/numo/linalg/lapack/trtrs.c +99 -0
  76. data/ext/numo/linalg/lapack/trtrs.h +15 -0
  77. data/ext/numo/linalg/lapack/ungqr.c +79 -0
  78. data/ext/numo/linalg/lapack/ungqr.h +15 -0
  79. data/ext/numo/linalg/linalg.c +290 -0
  80. data/ext/numo/linalg/linalg.h +85 -0
  81. data/ext/numo/linalg/util.c +95 -0
  82. data/ext/numo/linalg/util.h +17 -0
  83. data/lib/numo/linalg/version.rb +10 -0
  84. data/lib/numo/linalg.rb +1309 -0
  85. data/vendor/tmp/.gitkeep +0 -0
  86. metadata +146 -0
@@ -0,0 +1,1309 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+ require_relative 'linalg/version'
5
+ require_relative 'linalg/linalg'
6
+
7
+ # Ruby/Numo (NUmerical MOdules)
8
+ module Numo
9
+ # Numo::Linalg Alternative (numo-linalg-alt) is an alternative to Numo::Linalg.
10
+ module Linalg # rubocop:disable Metrics/ModuleLength
11
+ module_function
12
+
13
+ # Computes the eigenvalues and eigenvectors of a symmetric / Hermitian matrix
14
+ # by solving an ordinary or generalized eigenvalue problem.
15
+ #
16
+ # @example
17
+ # require 'numo/linalg'
18
+ #
19
+ # x = Numo::DFloat.new(5, 3).rand - 0.5
20
+ # c = x.dot(x.transpose)
21
+ # vals, vecs = Numo::Linalg.eigh(c, vals_range: [2, 4])
22
+ #
23
+ # pp vals
24
+ # # =>
25
+ # # Numo::DFloat#shape=[3]
26
+ # # [0.118795, 0.434252, 0.903245]
27
+ #
28
+ # pp vecs
29
+ # # =>
30
+ # # Numo::DFloat#shape=[5,3]
31
+ # # [[0.154178, 0.60661, -0.382961],
32
+ # # [-0.349761, -0.141726, -0.513178],
33
+ # # [0.739633, -0.468202, 0.105933],
34
+ # # [0.0519655, -0.471436, -0.701507],
35
+ # # [-0.551488, -0.412883, 0.294371]]
36
+ #
37
+ # pp (x - vecs.dot(vals.diag).dot(vecs.transpose)).abs.max
38
+ # # => 3.3306690738754696e-16
39
+ #
40
+ # @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
41
+ # @param b [Numo::NArray] The n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
42
+ # @param vals_only [Boolean] The flag indicating whether to return only eigenvalues.
43
+ # @param vals_range [Range/Array]
44
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
45
+ # If nil, all eigenvalues and eigenvectors are computed.
46
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
47
+ # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
48
+ # @return [Array<Numo::NArray>] The eigenvalues and eigenvectors.
49
+ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/ParameterLists, Metrics/PerceivedComplexity, Lint/UnusedMethodArgument
50
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
51
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
52
+
53
+ b_given = !b.nil?
54
+ raise Numo::NArray::ShapeError, 'input array b must be 2-dimensional' if b_given && b.ndim != 2
55
+ raise Numo::NArray::ShapeError, 'input array b must be square' if b_given && b.shape[0] != b.shape[1]
56
+ raise ArgumentError, "invalid array type: #{b.class}" if b_given && blas_char(b) == 'n'
57
+
58
+ bchr = blas_char(a)
59
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
60
+
61
+ jobz = vals_only ? 'N' : 'V'
62
+
63
+ if b_given
64
+ fnc = %w[d s].include?(bchr) ? "#{bchr}sygv" : "#{bchr}hegv"
65
+ if vals_range.nil?
66
+ fnc << 'd' if turbo
67
+ vecs, _b, vals, _info = Numo::Linalg::Lapack.send(fnc.to_sym, a.dup, b.dup, jobz: jobz)
68
+ else
69
+ fnc << 'x'
70
+ il = vals_range.first(1)[0] + 1
71
+ iu = vals_range.last(1)[0] + 1
72
+ _a, _b, _m, vals, vecs, _ifail, _info = Numo::Linalg::Lapack.send(
73
+ fnc.to_sym, a.dup, b.dup, jobz: jobz, range: 'I', il: il, iu: iu
74
+ )
75
+ end
76
+ else
77
+ fnc = %w[d s].include?(bchr) ? "#{bchr}syev" : "#{bchr}heev"
78
+ if vals_range.nil?
79
+ fnc << 'd' if turbo
80
+ vecs, vals, _info = Numo::Linalg::Lapack.send(fnc.to_sym, a.dup, jobz: jobz)
81
+ else
82
+ fnc << 'r'
83
+ il = vals_range.first(1)[0] + 1
84
+ iu = vals_range.last(1)[0] + 1
85
+ _a, _m, vals, vecs, _isuppz, _info = Numo::Linalg::Lapack.send(
86
+ fnc.to_sym, a.dup, jobz: jobz, range: 'I', il: il, iu: iu
87
+ )
88
+ end
89
+ end
90
+
91
+ vecs = nil if vals_only
92
+
93
+ [vals, vecs]
94
+ end
95
+
96
+ # Computes the matrix or vector norm.
97
+ #
98
+ # | ord | matrix norm | vector norm |
99
+ # | ----- | ---------------------- | --------------------------- |
100
+ # | nil | Frobenius norm | 2-norm |
101
+ # | 'fro' | Frobenius norm | - |
102
+ # | 'nuc' | nuclear norm | - |
103
+ # | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
104
+ # | 0 | - | (x.ne 0).sum |
105
+ # | 1 | x.abs.sum(axis:-2).max | same as below |
106
+ # | 2 | 2-norm (max sing_vals) | same as below |
107
+ # | other | - | (x.abs**ord).sum**(1.0/ord) |
108
+ #
109
+ # @example
110
+ # require 'numo/linalg'
111
+ #
112
+ # # matrix norm
113
+ # x = Numo::DFloat[[1, 2, -3, 1], [-4, 1, 8, 2]]
114
+ # pp Numo::Linalg.norm(x)
115
+ # # => 10
116
+ #
117
+ # # vector norm
118
+ # x = Numo::DFloat[3, -4]
119
+ # pp Numo::Linalg.norm(x)
120
+ # # => 5
121
+ #
122
+ # @param a [Numo::NArray] The matrix or vector (>= 1-dimensinal NArray)
123
+ # @param ord [String/Numeric] The order of the norm.
124
+ # @param axis [Integer/Array] The applied axes.
125
+ # @param keepdims [Bool] The flag indicating whether to leave the normed axes in the result as dimensions with size one.
126
+ # @return [Numo::NArray/Numeric] The norm of the matrix or vectors.
127
+ def norm(a, ord = nil, axis: nil, keepdims: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
128
+ a = Numo::NArray.asarray(a) unless a.is_a?(Numo::NArray)
129
+
130
+ return 0.0 if a.empty?
131
+
132
+ # for compatibility with Numo::Linalg.norm
133
+ if ord.is_a?(String)
134
+ if ord == 'inf'
135
+ ord = Float::INFINITY
136
+ elsif ord == '-inf'
137
+ ord = -Float::INFINITY
138
+ end
139
+ end
140
+
141
+ if axis.nil?
142
+ norm = case a.ndim
143
+ when 1
144
+ Numo::Linalg::Blas.send(:"#{blas_char(a)}nrm2", a) if ord.nil? || ord == 2
145
+ when 2
146
+ if ord.nil? || ord == 'fro'
147
+ Numo::Linalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'F')
148
+ elsif ord.is_a?(Numeric)
149
+ if ord == 1
150
+ Numo::Linalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: '1')
151
+ elsif !ord.infinite?.nil? && ord.infinite?.positive?
152
+ Numo::Linalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'I')
153
+ end
154
+ end
155
+ else
156
+ if ord.nil?
157
+ b = a.flatten.dup
158
+ Numo::Linalg::Blas.send(:"#{blas_char(b)}nrm2", b)
159
+ end
160
+ end
161
+ unless norm.nil?
162
+ norm = Numo::NArray.asarray(norm).reshape(*([1] * a.ndim)) if keepdims
163
+ return norm
164
+ end
165
+ end
166
+
167
+ if axis.nil?
168
+ axis = Array.new(a.ndim) { |d| d }
169
+ else
170
+ case axis
171
+ when Integer
172
+ axis = [axis]
173
+ when Array, Numo::NArray
174
+ axis = axis.flatten.to_a
175
+ else
176
+ raise ArgumentError, "invalid axis: #{axis}"
177
+ end
178
+ end
179
+
180
+ raise ArgumentError, "the number of dimensions of axis is inappropriate for the norm: #{axis.size}" unless [1, 2].include?(axis.size)
181
+ raise ArgumentError, "axis is out of range: #{axis}" unless axis.all? { |ax| (-a.ndim...a.ndim).cover?(ax) }
182
+
183
+ if axis.size == 1
184
+ ord ||= 2
185
+ raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(Numeric)
186
+
187
+ ord_inf = ord.infinite?
188
+ if ord_inf.nil?
189
+ case ord
190
+ when 0
191
+ a.class.cast(a.ne(0)).sum(axis: axis, keepdims: keepdims)
192
+ when 1
193
+ a.abs.sum(axis: axis, keepdims: keepdims)
194
+ else
195
+ (a.abs**ord).sum(axis: axis, keepdims: keepdims)**1.fdiv(ord)
196
+ end
197
+ elsif ord_inf.positive?
198
+ a.abs.max(axis: axis, keepdims: keepdims)
199
+ else
200
+ a.abs.min(axis: axis, keepdims: keepdims)
201
+ end
202
+ else
203
+ ord ||= 'fro'
204
+ raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(String) || ord.is_a?(Numeric)
205
+ raise ArgumentError, "invalid axis: #{axis}" if axis.uniq.size == 1
206
+
207
+ r_axis, c_axis = axis.map { |ax| ax.negative? ? ax + a.ndim : ax }
208
+
209
+ norm = if ord.is_a?(String)
210
+ raise ArgumentError, "invalid ord: #{ord}" unless %w[fro nuc].include?(ord)
211
+
212
+ if ord == 'fro'
213
+ Numo::NMath.sqrt((a.abs**2).sum(axis: axis))
214
+ else
215
+ b = a.transpose(c_axis, r_axis).dup
216
+ gesvd = :"#{blas_char(b)}gesvd"
217
+ s, = Numo::Linalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
218
+ s.sum(axis: -1)
219
+ end
220
+ else
221
+ ord_inf = ord.infinite?
222
+ if ord_inf.nil?
223
+ case ord
224
+ when -2
225
+ b = a.transpose(c_axis, r_axis).dup
226
+ gesvd = :"#{blas_char(b)}gesvd"
227
+ s, = Numo::Linalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
228
+ s.min(axis: -1)
229
+ when -1
230
+ c_axis -= 1 if c_axis > r_axis
231
+ a.abs.sum(axis: r_axis).min(axis: c_axis)
232
+ when 1
233
+ c_axis -= 1 if c_axis > r_axis
234
+ a.abs.sum(axis: r_axis).max(axis: c_axis)
235
+ when 2
236
+ b = a.transpose(c_axis, r_axis).dup
237
+ gesvd = :"#{blas_char(b)}gesvd"
238
+ s, = Numo::Linalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
239
+ s.max(axis: -1)
240
+ else
241
+ raise ArgumentError, "invalid ord: #{ord}"
242
+ end
243
+ else
244
+ r_axis -= 1 if r_axis > c_axis
245
+ if ord_inf.positive?
246
+ a.abs.sum(axis: c_axis).max(axis: r_axis)
247
+ else
248
+ a.abs.sum(axis: c_axis).min(axis: r_axis)
249
+ end
250
+ end
251
+ end
252
+ if keepdims
253
+ norm = Numo::NArray.asarray(norm) unless norm.is_a?(Numo::NArray)
254
+ norm = norm.reshape(*([1] * a.ndim))
255
+ end
256
+
257
+ norm
258
+ end
259
+ end
260
+
261
+ # Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
262
+ #
263
+ # @example
264
+ # require 'numo/linalg'
265
+ #
266
+ # s = Numo::DFloat.new(3, 3).rand - 0.5
267
+ # a = s.transpose.dot(s)
268
+ # u = Numo::Linalg.cholesky(a)
269
+ #
270
+ # pp u
271
+ # # =>
272
+ # # Numo::DFloat#shape=[3,3]
273
+ # # [[0.532006, 0.338183, -0.18036],
274
+ # # [0, 0.325153, 0.011721],
275
+ # # [0, 0, 0.436738]]
276
+ #
277
+ # pp (a - u.transpose.dot(u)).abs.max
278
+ # # => 1.3877787807814457e-17
279
+ #
280
+ # l = Numo::Linalg.cholesky(a, uplo: 'L')
281
+ #
282
+ # pp l
283
+ # # =>
284
+ # # Numo::DFloat#shape=[3,3]
285
+ # # [[0.532006, 0, 0],
286
+ # # [0.338183, 0.325153, 0],
287
+ # # [-0.18036, 0.011721, 0.436738]]
288
+ #
289
+ # pp (a - l.dot(l.transpose)).abs.max
290
+ # # => 1.3877787807814457e-17
291
+ #
292
+ # @param a [Numo::NArray] The n-by-n symmetric matrix.
293
+ # @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
294
+ # @return [Numo::NArray] The upper- or lower-triangular Cholesky factor of a.
295
+ def cholesky(a, uplo: 'U')
296
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
297
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
298
+
299
+ bchr = blas_char(a)
300
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
301
+
302
+ fnc = :"#{bchr}potrf"
303
+ c, _info = Numo::Linalg::Lapack.send(fnc, a.dup, uplo: uplo)
304
+
305
+ case uplo
306
+ when 'U'
307
+ c.triu
308
+ when 'L'
309
+ c.tril
310
+ else
311
+ raise ArgumentError, "invalid uplo: #{uplo}"
312
+ end
313
+ end
314
+
315
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` with the Cholesky factorization of `A`.
316
+ #
317
+ # @example
318
+ # require 'numo/linalg'
319
+ #
320
+ # s = Numo::DFloat.new(3, 3).rand - 0.5
321
+ # a = s.transpose.dot(s)
322
+ # u = Numo::Linalg.cholesky(a)
323
+ #
324
+ # b = Numo::DFloat.new(3).rand
325
+ # x = Numo::Linalg.cho_solve(u, b)
326
+ #
327
+ # puts (b - a.dot(x)).abs.max
328
+ # => 0.0
329
+ #
330
+ # @param a [Numo::NArray] The n-by-n cholesky factor.
331
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
332
+ # @param uplo [String] Whether to compute the upper- or lower-triangular Cholesky factor ('U' or 'L').
333
+ # @return [Numo::NArray] The solution vector or matrix `X`.
334
+ def cho_solve(a, b, uplo: 'U')
335
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
336
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
337
+ raise Numo::NArray::ShapeError, "incompatible dimensions: a.shape[0] = #{a.shape[0]} != b.shape[0] = #{b.shape[0]}" if a.shape[0] != b.shape[0]
338
+
339
+ bchr = blas_char(a, b)
340
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
341
+
342
+ fnc = :"#{bchr}potrs"
343
+ x, _info = Numo::Linalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
344
+ x
345
+ end
346
+
347
+ # Computes the determinant of matrix.
348
+ #
349
+ # @example
350
+ # require 'numo/linalg'
351
+ #
352
+ # a = Numo::DFloat[[0, 2, 3], [4, 5, 6], [7, 8, 9]]
353
+ # pp (3.0 - Numo::Linalg.det(a)).abs
354
+ # # => 1.3322676295501878e-15
355
+ #
356
+ # @param a [Numo::NArray] The n-by-n square matrix.
357
+ # @return [Float/Complex] The determinant of `a`.
358
+ def det(a)
359
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
360
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
361
+
362
+ bchr = blas_char(a)
363
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
364
+
365
+ getrf = :"#{bchr}getrf"
366
+ lu, piv, info = Numo::Linalg::Lapack.send(getrf, a.dup)
367
+
368
+ if info.zero?
369
+ det_l = 1
370
+ det_u = lu.diagonal.prod
371
+ det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
372
+ det_l * det_u * det_p
373
+ elsif info.positive?
374
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
375
+ else
376
+ raise "the #{-info}-th argument of getrf had illegal value"
377
+ end
378
+ end
379
+
380
+ # Computes the inverse matrix of a square matrix.
381
+ #
382
+ # @example
383
+ # require 'numo/linalg'
384
+ #
385
+ # a = Numo::DFloat.new(5, 5).rand
386
+ #
387
+ # inv_a = Numo::Linalg.inv(a)
388
+ #
389
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(5)).abs.max
390
+ # # => 7.019165976816745e-16
391
+ #
392
+ # pp inv_a.dot(a).sum
393
+ # # => 5.0
394
+ #
395
+ # @param a [Numo::NArray] The n-by-n square matrix.
396
+ # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
397
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
398
+ # @return [Numo::NArray] The inverse matrix of `a`.
399
+ def inv(a, driver: 'getrf', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
400
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
401
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
402
+
403
+ bchr = blas_char(a)
404
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
405
+
406
+ getrf = :"#{bchr}getrf"
407
+ getri = :"#{bchr}getri"
408
+
409
+ lu, piv, info = Numo::Linalg::Lapack.send(getrf, a.dup)
410
+ if info.zero?
411
+ Numo::Linalg::Lapack.send(getri, lu, piv)[0]
412
+ elsif info.positive?
413
+ raise 'the factor U is singular, and the inverse matrix could not be computed.'
414
+ else
415
+ raise "the #{-info}-th argument of getrf had illegal value"
416
+ end
417
+ end
418
+
419
+ # Computes the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
420
+ #
421
+ # @example
422
+ # require 'numo/linalg'
423
+ #
424
+ # a = Numo::DFloat.new(5, 3).rand
425
+ #
426
+ # inv_a = Numo::Linalg.pinv(a)
427
+ #
428
+ # pp (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
429
+ # # => 1.1102230246251565e-15
430
+ #
431
+ # pp inv_a.dot(a).sum
432
+ # # => 3.0
433
+ #
434
+ # @param a [Numo::NArray] The m-by-n matrix to be pseudo-inverted.
435
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
436
+ # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
437
+ # @return [Numo::NArray] The pseudo-inverse of `a`.
438
+ def pinv(a, driver: 'svd', rcond: nil)
439
+ s, u, vh = svd(a, driver: driver, job: 'S')
440
+ rcond = a.shape.max * s.class::EPSILON if rcond.nil?
441
+ rank = s.gt(rcond * s[0]).count
442
+
443
+ u = u[true, 0...rank] / s[0...rank]
444
+ u.dot(vh[0...rank, true]).conj.transpose
445
+ end
446
+
447
+ # Computes the QR decomposition of a matrix.
448
+ #
449
+ # @example
450
+ # require 'numo/linalg'
451
+ #
452
+ # x = Numo::DFloat.new(5, 3).rand
453
+ #
454
+ # q, r = Numo::Linalg.qr(x, mode: 'economic')
455
+ #
456
+ # pp q
457
+ # # =>
458
+ # # Numo::DFloat#shape=[5,3]
459
+ # # [[-0.0574417, 0.635216, 0.707116],
460
+ # # [-0.187002, -0.073192, 0.422088],
461
+ # # [-0.502239, 0.634088, -0.537489],
462
+ # # [-0.0473292, 0.134867, -0.0223491],
463
+ # # [-0.840979, -0.413385, 0.180096]]
464
+ #
465
+ # pp r
466
+ # # =>
467
+ # # Numo::DFloat#shape=[3,3]
468
+ # # [[-1.07508, -0.821334, -0.484586],
469
+ # # [0, 0.513035, 0.451868],
470
+ # # [0, 0, 0.678737]]
471
+ #
472
+ # pp (q.dot(r) - x).abs.max
473
+ # # => 3.885780586188048e-16
474
+ #
475
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
476
+ # @param mode [String] The mode of decomposition.
477
+ # - "reduce" -- returns both Q [m, m] and R [m, n],
478
+ # - "r" -- returns only R,
479
+ # - "economic" -- returns both Q [m, n] and R [n, n],
480
+ # - "raw" -- returns QR and TAU (LAPACK geqrf results).
481
+ # @return [Numo::NArray] if mode='r'.
482
+ # @return [Array<Numo::NArray>] if mode='reduce' or 'economic' or 'raw'.
483
+ def qr(a, mode: 'reduce')
484
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
485
+ raise ArgumentError, "invalid mode: #{mode}" unless %w[reduce r economic raw].include?(mode)
486
+
487
+ bchr = blas_char(a)
488
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
489
+
490
+ geqrf = :"#{bchr}geqrf"
491
+ qr, tau, = Numo::Linalg::Lapack.send(geqrf, a.dup)
492
+
493
+ return [qr, tau] if mode == 'raw'
494
+
495
+ m, n = qr.shape
496
+ r = m > n && %w[economic raw].include?(mode) ? qr[0...n, true].triu : qr.triu
497
+
498
+ return r if mode == 'r'
499
+
500
+ org_ung_qr = %w[d s].include?(bchr) ? :"#{bchr}orgqr" : :"#{bchr}ungqr"
501
+
502
+ q = if m < n
503
+ Numo::Linalg::Lapack.send(org_ung_qr, qr[true, 0...m], tau)[0]
504
+ elsif mode == 'economic'
505
+ Numo::Linalg::Lapack.send(org_ung_qr, qr, tau)[0]
506
+ else
507
+ qqr = a.class.zeros(m, m)
508
+ qqr[0...m, 0...n] = qr
509
+ Numo::Linalg::Lapack.send(org_ung_qr, qqr, tau)[0]
510
+ end
511
+
512
+ [q, r]
513
+ end
514
+
515
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `A`.
516
+ #
517
+ # @example
518
+ # require 'numo/linalg'
519
+ #
520
+ # a = Numo::DFloat.new(3, 3).rand
521
+ # b = Numo::DFloat.eye(3)
522
+ #
523
+ # x = Numo::Linalg.solve(a, b)
524
+ #
525
+ # pp x
526
+ # # =>
527
+ # # Numo::DFloat#shape=[3,3]
528
+ # # [[-2.12332, 4.74868, 0.326773],
529
+ # # [1.38043, -3.79074, 1.25355],
530
+ # # [0.775187, 1.41032, -0.613774]]
531
+ #
532
+ # pp (b - a.dot(x)).abs.max
533
+ # # => 2.1081041547796492e-16
534
+ #
535
+ # @param a [Numo::NArray] The n-by-n square matrix.
536
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
537
+ # @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
538
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
539
+ # @return [Numo::NArray] The solusion vector / matrix `X`.
540
+ def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
541
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
542
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
543
+
544
+ bchr = blas_char(a, b)
545
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
546
+
547
+ gesv = :"#{bchr}gesv"
548
+ Numo::Linalg::Lapack.send(gesv, a.dup, b.dup)[1]
549
+ end
550
+
551
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix.
552
+ #
553
+ # @example
554
+ # require 'numo/linalg'
555
+ #
556
+ # a = Numo::DFloat.new(3, 3).rand.triu
557
+ # b = Numo::DFloat.eye(3)
558
+ #
559
+ # x = Numo::Linalg.solve(a, b)
560
+ #
561
+ # pp x
562
+ # # =>
563
+ # # Numo::DFloat#shape=[3,3]
564
+ # # [[16.1932, -52.0604, 30.5283],
565
+ # # [0, 8.61765, -17.9585],
566
+ # # [0, 0, 6.05735]]
567
+ #
568
+ # pp (b - a.dot(x)).abs.max
569
+ # # => 4.071100642430302e-16
570
+ #
571
+ # @param a [Numo::NArray] The n-by-n triangular matrix.
572
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
573
+ # @param lower [Boolean] The flag indicating whether to use the lower-triangular part of `a`.
574
+ # @return [Numo::NArray] The solusion vector / matrix `X`.
575
+ def solve_triangular(a, b, lower: false)
576
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
577
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
578
+
579
+ bchr = blas_char(a, b)
580
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
581
+
582
+ trtrs = :"#{bchr}trtrs"
583
+ uplo = lower ? 'L' : 'U'
584
+ x, info = Numo::Linalg::Lapack.send(trtrs, a, b.dup, uplo: uplo)
585
+ raise "wrong value is given to the #{info}-th argument of #{trtrs} used internally" if info.negative?
586
+
587
+ x
588
+ end
589
+
590
+ # Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
591
+ #
592
+ # @example
593
+ # require 'numo/linalg'
594
+ #
595
+ # x = Numo::DFloat.new(5, 2).rand.dot(Numo::DFloat.new(2, 3).rand)
596
+ # pp x
597
+ # # =>
598
+ # # Numo::DFloat#shape=[5,3]
599
+ # # [[0.104945, 0.0284236, 0.117406],
600
+ # # [0.862634, 0.210945, 0.922135],
601
+ # # [0.324507, 0.0752655, 0.339158],
602
+ # # [0.67085, 0.102594, 0.600882],
603
+ # # [0.404631, 0.116868, 0.46644]]
604
+ #
605
+ # s, u, vt = Numo::Linalg.svd(x, job: 'S')
606
+ #
607
+ # z = u.dot(s.diag).dot(vt)
608
+ # pp z
609
+ # # =>
610
+ # # Numo::DFloat#shape=[5,3]
611
+ # # [[0.104945, 0.0284236, 0.117406],
612
+ # # [0.862634, 0.210945, 0.922135],
613
+ # # [0.324507, 0.0752655, 0.339158],
614
+ # # [0.67085, 0.102594, 0.600882],
615
+ # # [0.404631, 0.116868, 0.46644]]
616
+ #
617
+ # pp (x - z).abs.max
618
+ # # => 4.440892098500626e-16
619
+ #
620
+ # @param a [Numo::NArray] Matrix to be decomposed.
621
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
622
+ # @param job [String] The job option ('A', 'S', or 'N').
623
+ # @return [Array<Numo::NArray>] The singular values and singular vectors ([s, u, vt]).
624
+ def svd(a, driver: 'svd', job: 'A')
625
+ raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)
626
+
627
+ bchr = blas_char(a)
628
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
629
+
630
+ case driver.to_s
631
+ when 'sdd'
632
+ gesdd = :"#{bchr}gesdd"
633
+ s, u, vt, info = Numo::Linalg::Lapack.send(gesdd, a.dup, jobz: job)
634
+ when 'svd'
635
+ gesvd = :"#{bchr}gesvd"
636
+ s, u, vt, info = Numo::Linalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
637
+ else
638
+ raise ArgumentError, "invalid driver: #{driver}"
639
+ end
640
+
641
+ raise "the #{info.abs}-th argument had illegal value" if info.negative?
642
+ raise 'input array has a NAN entry' if info == -4
643
+ raise 'svd did not converge' if info.positive?
644
+
645
+ [s, u, vt]
646
+ end
647
+
648
+ # Computes the matrix multiplication of two arrays.
649
+ #
650
+ # @example
651
+ # require 'numo/linalg'
652
+ #
653
+ # a = Numo::DFloat[[1, 0], [0, 1]]
654
+ # b = Numo::DFloat[[4, 1], [2, 2]]
655
+ # pp Numo::Linalg.matmul(a, b)
656
+ # # =>
657
+ # # Numo::DFloat#shape=[2,2]
658
+ # # [[4, 1],
659
+ # # [2, 2]]
660
+ #
661
+ # @param a [Numo::NArray] The first array.
662
+ # @param b [Numo::NArray] The second array.
663
+ # @return [Numo::NArray] The matrix product of `a` and `b`.
664
+ def matmul(a, b)
665
+ Numo::Linalg::Blas.call(:gemm, a, b)
666
+ end
667
+
668
+ # Computes the matrix `a` raised to the power of `n`.
669
+ #
670
+ # @example
671
+ # require 'numo/linalg'
672
+ #
673
+ # a = Numo::DFloat[[1, 2], [3, 4]]
674
+ # pp Numo::Linalg.matrix_power(a, 3)
675
+ # # =>
676
+ # # Numo::DFloat#shape=[2,2]
677
+ # # [[37, 54],
678
+ # # [81, 118]]
679
+ #
680
+ # @param a [Numo::NArray] The square matrix.
681
+ # @param n [Integer] The exponent.
682
+ # @return [Numo::NArray] The matrix `a` raised to the power of `n`.
683
+ def matrix_power(a, n)
684
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
685
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
686
+ raise ArgumentError, "exponent n must be an integer: #{n}" unless n.is_a?(Integer)
687
+
688
+ if n.zero?
689
+ a.class.eye(a.shape[0])
690
+ elsif n.positive?
691
+ r = a.dup
692
+ (n - 1).times { r = Numo::Linalg.matmul(r, a) }
693
+ r
694
+ else
695
+ inv_a = inv(a)
696
+ r = inv_a.dup
697
+ (-n - 1).times { r = Numo::Linalg.matmul(r, inv_a) }
698
+ r
699
+ end
700
+ end
701
+
702
+ # Computes the singular values of a matrix.
703
+ #
704
+ # @example
705
+ # require 'numo/linalg'
706
+ #
707
+ # a = Numo::DFloat[[1, 2, 3], [2, 4, 6], [-1, 1, -1]]
708
+ # pp Numo::Linalg.svdvals(a)
709
+ # # => Numo::DFloat#shape=[3]
710
+ # # [8.38434, 1.64402, 5.41675e-17]
711
+ #
712
+ # @param a [Numo::NArray] Matrix to be decomposed.
713
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
714
+ # @return [Numo::NArray] The singular values of `a`.
715
+ def svdvals(a, driver: 'sdd')
716
+ bchr = blas_char(a)
717
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
718
+
719
+ case driver.to_s
720
+ when 'sdd'
721
+ gesdd = :"#{bchr}gesdd"
722
+ s, _u, _vt, info = Numo::Linalg::Lapack.send(gesdd, a.dup, jobz: 'N')
723
+ when 'svd'
724
+ gesvd = :"#{bchr}gesvd"
725
+ s, _u, _vt, info = Numo::Linalg::Lapack.send(gesvd, a.dup, jobu: 'N', jobvt: 'N')
726
+ else
727
+ raise ArgumentError, "invalid driver: #{driver}"
728
+ end
729
+
730
+ raise "the #{info.abs}-th argument had illegal value" if info.negative?
731
+ raise 'input array has a NAN entry' if info == -4
732
+ raise 'svd did not converge' if info.positive?
733
+
734
+ s
735
+ end
736
+
737
+ # Computes an orthonormal basis for the range of `A` using SVD.
738
+ #
739
+ # @example
740
+ # require 'numo/linalg'
741
+ #
742
+ # a = Numo::DFloat[[1, 2, 3], [2, 4, 6], [-1, 1, -1]]
743
+ # u = Numo::Linalg.orth(a)
744
+ # pp u
745
+ # # =>
746
+ # # Numo::DFloat#shape=[3,2]
747
+ # # [[-0.446229, -0.0296535],
748
+ # # [-0.892459, -0.059307],
749
+ # # [0.0663073, -0.997799]]
750
+ # pp u.transpose.dot(u)
751
+ # # =>
752
+ # # Numo::DFloat#shape=[2,2]
753
+ # # [[1, -1.97749e-16],
754
+ # # [-1.97749e-16, 1]]
755
+ #
756
+ # @param a [Numo::NArray] The m-by-n input matrix.
757
+ # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
758
+ # @return [Numo::NArray] The orthonormal basis for the range of `a`.
759
+ def orth(a, rcond: nil)
760
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
761
+
762
+ s, u, = svd(a, driver: 'sdd', job: 'S')
763
+ tol = if rcond.nil? || rcond.negative?
764
+ a.shape.max * s.class::EPSILON
765
+ else
766
+ rcond
767
+ end
768
+ rank = s.gt(tol * s.max).count
769
+ u[true, 0...rank].dup
770
+ end
771
+
772
+ # Computes an orthonormal basis for the null space of `A` using SVD.
773
+ #
774
+ # @example
775
+ # require 'numo/linalg'
776
+ #
777
+ # a = Numo::DFloat.new(3, 5).rand - 0.5
778
+ # n = Numo::Linalg.null_space(a)
779
+ # pp n
780
+ # # =>
781
+ # # Numo::DFloat#shape=[5,2]
782
+ # # [[0.214096, -0.404277],
783
+ # # [-0.482225, -0.51557],
784
+ # # [-0.584394, -0.246804],
785
+ # # [0.596612, -0.351468],
786
+ # # [-0.155434, 0.621535]]
787
+ # pp n.transpose.dot(n)
788
+ # # =>
789
+ # # Numo::DFloat#shape=[2,2]
790
+ # # [[1, 1.31078e-16],
791
+ # # [1.31078e-16, 1]]
792
+ #
793
+ # @param a [Numo::NArray] The m-by-n input matrix.
794
+ # @param rcond [Float] The threshold value for small singular values of `a`, default value is `a.shape.max * EPS`.
795
+ # @return [Numo::NArray] The orthonormal basis for the null space of `a`.
796
+ def null_space(a, rcond: nil)
797
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
798
+
799
+ s, _u, vt = svd(a, driver: 'sdd', job: 'A')
800
+ tol = if rcond.nil? || rcond.negative?
801
+ a.shape.max * s.class::EPSILON
802
+ else
803
+ rcond
804
+ end
805
+ rank = s.gt(tol * s.max).count
806
+ vt[rank...vt.shape[0], true].conj.transpose.dup
807
+ end
808
+
809
+ # Computes the LU decomposition of a matrix using partial pivoting.
810
+ #
811
+ # @example
812
+ # require 'numo/linalg'
813
+ #
814
+ # a = Numo::DFloat.new(3, 4).rand - 0.5
815
+ # pm, l, u = Numo::Linalg.lu(a)
816
+ # error = (pm.dot(l).dot(u) - a).abs.max
817
+ # pp error
818
+ # # => 5.551115123125783e-17
819
+ #
820
+ # l, u = Numo::Linalg.lu(a, permute_l: true)
821
+ # error = (l.dot(u) - a).abs.max
822
+ # pp error
823
+ # # => 5.551115123125783e-17
824
+ #
825
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
826
+ # @param permute_l [Boolean] If true, returns `L` with the permutation applied.
827
+ # @return [Array<Numo::NArray>] if `permute_l` is `false`, the permutation matrix `P`, lower-triangular matrix `L`, and
828
+ # upper-triangular matrix `U` ([P, L, U]). if `permute_l` is `true`, the permuted lower-triangular matrix `L` and
829
+ # upper-triangular matrix `U` ([L, U]).
830
+ def lu(a, permute_l: false)
831
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
832
+
833
+ m, n = a.shape
834
+ k = [m, n].min
835
+ lu, piv = lu_fact(a)
836
+ l = lu.tril.tap { |nary| nary[nary.diag_indices] = 1 }[true, 0...k].dup
837
+ u = lu.triu[0...k, 0...n].dup
838
+ perm = a.class.eye(m).tap do |nary|
839
+ piv.to_a.each_with_index { |i, j| nary[true, [i - 1, j]] = nary[true, [j, i - 1]].dup }
840
+ end
841
+
842
+ permute_l ? [perm.dot(l), u] : [perm, l, u]
843
+ end
844
+
845
+ # Computes the LU decomposition of a matrix using partial pivoting.
846
+ #
847
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
848
+ # @return [Array<Numo::NArray>] The LU decomposition and pivot indices ([lu, piv]).
849
+ def lu_fact(a)
850
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
851
+
852
+ bchr = blas_char(a)
853
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
854
+
855
+ getrf = :"#{bchr}getrf"
856
+ lu, piv, info = Numo::Linalg::Lapack.send(getrf, a.dup)
857
+
858
+ raise "the #{info.abs}-th argument of getrf had illegal value" if info.negative?
859
+ raise "the U(#{info}, #{info}) is exactly zero. The factorization has been completed." if info.positive?
860
+
861
+ [lu, piv]
862
+ end
863
+
864
+ # Solves linear equation `A * x = b` or `A * X = B` for `x` using the LU decomposition of `A`.
865
+ #
866
+ # @example
867
+ # require 'numo/linalg'
868
+ #
869
+ # a = Numo::DFloat.new(3, 3).rand
870
+ # b = Numo::DFloat.eye(3)
871
+ # lu, ipiv = Numo::Linalg.lu_fact(a)
872
+ # x = Numo::Linalg.lu_solve(lu, ipiv, b)
873
+ #
874
+ # puts (b - a.dot(x)).abs.max
875
+ # => 2.220446049250313e-16
876
+ #
877
+ # @param lu [Numo::NArray] The LU decomposition of the n-by-n matrix `A`.
878
+ # @param ipiv [Numo::Int32/Int64] The pivot indices from `lu_fact`.
879
+ # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
880
+ # @param trans [String] The type of system to be solved.
881
+ # - 'N': solve `A * x = b` (No transpose),
882
+ # - 'T': solve `A^T * x = b` (Transpose),
883
+ # - 'C': solve `A^H * x = b` (Conjugate transpose).
884
+ # @return [Numo::NArray] The solusion vector / matrix `X`.
885
+ def lu_solve(lu, ipiv, b, trans: 'N')
886
+ raise Numo::NArray::ShapeError, 'input array lu must be 2-dimensional' if lu.ndim != 2
887
+ raise Numo::NArray::ShapeError, 'input array lu must be square' if lu.shape[0] != lu.shape[1]
888
+ raise Numo::NArray::ShapeError, "incompatible dimensions: lu.shape[0] = #{lu.shape[0]} != b.shape[0] = #{b.shape[0]}" if lu.shape[0] != b.shape[0]
889
+ raise ArgumentError, 'trans must be "N", "T", or "C"' unless %w[N T C].include?(trans)
890
+
891
+ bchr = blas_char(lu)
892
+ raise ArgumentError, "invalid array type: #{lu.class}" if bchr == 'n'
893
+
894
+ getrs = :"#{bchr}getrs"
895
+ x, info = Numo::Linalg::Lapack.send(getrs, lu, ipiv, b.dup)
896
+
897
+ raise "the #{info.abs}-th argument of getrs had illegal value" if info.negative?
898
+
899
+ x
900
+ end
901
+
902
+ # Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
903
+ #
904
+ # @param a [Numo::NArray] The n-by-n symmetric / Hermitian positive-definite matrix.
905
+ # @param uplo [String] The part of the matrix to be used ('U' or 'L').
906
+ # @return [Numo::NArray] The upper- / lower-triangular matrix `U` / `L`.
907
+ def cho_fact(a, uplo: 'U')
908
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
909
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
910
+ raise ArgumentError, 'uplo must be "U" or "L"' unless %w[U L].include?(uplo)
911
+
912
+ bchr = blas_char(a)
913
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
914
+
915
+ fnc = :"#{bchr}potrf"
916
+ c, info = Numo::Linalg::Lapack.send(fnc, a.dup, uplo: uplo)
917
+
918
+ raise "the #{info}-th leading minor of the array is not positive definite, and the factorization could not be completed." if info.positive?
919
+ raise "the #{-info}-th argument of #{fnc} had illegal value" if info.negative?
920
+
921
+ c
922
+ end
923
+
924
+ # Computes the eigenvalues and right and/or left eigenvectors of a general square matrix.
925
+ #
926
+ # @example
927
+ # require 'numo/linalg'
928
+ #
929
+ # a = Numo::DFloat.new(5, 5).rand - 0.5
930
+ # w, _vl, vr = Numo::Linalg.eig(a)
931
+ # error = (a.dot(vr) - vr.dot(w.diag)).abs.max
932
+ # pp error
933
+ # # => 4.718447854656915e-16
934
+ #
935
+ # @param a [Numo::NArray] The n-by-n square matrix.
936
+ # @param left [Boolean] The flag indicating whether to compute the left eigenvectors.
937
+ # @param right [Boolean] The flag indicating whether to compute the right eigenvectors.
938
+ # @return [Array<Numo::NArray>] The eigenvalues, left eigenvectors, and right eigenvectors.
939
+ def eig(a, left: false, right: true) # rubocop:disable Metrics/AbcSize, Metrics/PerceivedComplexity
940
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
941
+ raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]
942
+
943
+ jobvl = left ? 'V' : 'N'
944
+ jobvr = right ? 'V' : 'N'
945
+
946
+ bchr = blas_char(a)
947
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
948
+
949
+ fnc = :"#{bchr}geev"
950
+ if %w[z c].include?(bchr)
951
+ w, vl, vr, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvl: jobvl, jobvr: jobvr)
952
+ else
953
+ wr, wi, vl, vr, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvl: jobvl, jobvr: jobvr)
954
+ end
955
+
956
+ raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
957
+ raise 'the QR algorithm failed to compute all the eigenvalues.' if info.positive?
958
+
959
+ if %w[d s].include?(bchr)
960
+ w = wr + (wi * 1.0i)
961
+ ids = wi.gt(0).where
962
+ unless ids.empty?
963
+ cast_class = bchr == 'd' ? Numo::DComplex : Numo::SComplex
964
+ if left
965
+ tmp = cast_class.cast(vl)
966
+ tmp[true, ids].imag = vl[true, ids + 1]
967
+ tmp[true, ids + 1] = tmp[true, ids].conj
968
+ vl = tmp
969
+ end
970
+ if right
971
+ tmp = cast_class.cast(vr)
972
+ tmp[true, ids].imag = vr[true, ids + 1]
973
+ tmp[true, ids + 1] = tmp[true, ids].conj
974
+ vr = tmp
975
+ end
976
+ end
977
+ end
978
+
979
+ [w, left ? vl : nil, right ? vr : nil]
980
+ end
981
+
982
+ # Computes the eigenvalues of a general square matrix.
983
+ #
984
+ # @param a [Numo::NArray] The n-by-n square matrix.
985
+ # @return [Numo::NArray] The eigenvalues.
986
+ def eigvals(a)
987
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
988
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
989
+
990
+ bchr = blas_char(a)
991
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
992
+
993
+ fnc = :"#{bchr}geev"
994
+ if %w[z c].include?(bchr)
995
+ w, _vl, _vr, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvl: 'N', jobvr: 'N')
996
+ else
997
+ wr, wi, _vl, _vr, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvl: 'N', jobvr: 'N')
998
+ w = wr + (wi * 1.0i)
999
+ end
1000
+
1001
+ raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
1002
+ raise 'the QR algorithm failed to compute all the eigenvalues.' if info.positive?
1003
+
1004
+ w
1005
+ end
1006
+
1007
+ # Computes the eigenvalues of a symmetric / Hermitian matrix by solving an ordinary / generalized eigenvalue problem.
1008
+ #
1009
+ # @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
1010
+ # @param b [Numo::NArray] The n-by-n symmetric / Hermitian matrix. If nil, identity matrix is assumed.
1011
+ # @param vals_range [Range/Array]
1012
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
1013
+ # If nil, all eigenvalues and eigenvectors are computed.
1014
+ # @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
1015
+ # @param turbo [Bool] The flag indicating whether to use a divide and conquer algorithm. If vals_range is given, this flag is ignored.
1016
+ # @return [Numo::NArray] The eigenvalues.
1017
+ def eigvalsh(a, b = nil, vals_range: nil, uplo: 'U', turbo: false)
1018
+ eigh(a, b, vals_only: true, vals_range: vals_range, uplo: uplo, turbo: turbo)[0]
1019
+ end
1020
+
1021
+ # Computes the Bunch-Kaufman decomposition of a symmetric / Hermitian matrix.
1022
+ # The factorization has the form `A = U * D * U^T` or `A = L * D * L^T`,
1023
+ # where `U` (or `L`) is a product of permutation and unit upper
1024
+ # (lower) triangular matrices, and `D` is a block diagonal matrix.
1025
+ #
1026
+ # @example
1027
+ # require 'numo/linalg'
1028
+ #
1029
+ # a = Numo::DFloat.new(5, 5).rand
1030
+ # a = 0.5 * (a + a.transpose)
1031
+ # u, d, _perm = Numo::Linalg.ldl(a)
1032
+ # error = (a - u.dot(d).dot(u.transpose)).abs.max
1033
+ # pp error
1034
+ # # => 5.551115123125783e-17
1035
+ #
1036
+ # @param a [Numo::NArray] The n-by-n symmetric / Hermitian matrix.
1037
+ # @param uplo [String] The part of the matrix to be used ('U' or 'L').
1038
+ # @param hermitian [Boolean] The flag indicating whether `a` is Hermitian.
1039
+ # @return [Array<Numo::NArray>] The permutated upper (lower) triangular matrix, the block diagonal matrix, and the permutation indices.
1040
+ def ldl(a, uplo: 'U', hermitian: true)
1041
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1042
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1043
+
1044
+ bchr = blas_char(a)
1045
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
1046
+
1047
+ complex = bchr =~ /c|z/
1048
+ fnc = complex && hermitian ? :"#{bchr}hetrf" : :"#{bchr}sytrf"
1049
+ lud = a.dup
1050
+ ipiv, info = Numo::Linalg::Lapack.send(fnc, lud, uplo: uplo)
1051
+
1052
+ raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
1053
+ raise 'the factorization has been completed' if info.positive?
1054
+
1055
+ _lud_permutation(lud, ipiv, uplo: uplo, hermitian: hermitian)
1056
+ end
1057
+
1058
+ # Compute the condition number of a matrix.
1059
+ #
1060
+ # @param a [Numo::NArray] The input matrix.
1061
+ # @param ord [String/Symbol/Integer] The order of the norm.
1062
+ # nil or 2: 2-norm using singular values, 'fro': Frobenius norm, 'info': infinity norm, and 1: 1-norm.
1063
+ # @return [Numo::NArray] The condition number of the matrix.
1064
+ def cond(a, ord = nil)
1065
+ if ord.nil? || ord == 2 || ord == -2
1066
+ svals = svdvals(a)
1067
+ if ord == -2
1068
+ svals[false, -1] / svals[false, 0]
1069
+ else
1070
+ svals[false, 0] / svals[false, -1]
1071
+ end
1072
+ else
1073
+ inv_a = inv(a)
1074
+ norm(a, ord, axis: [-2, -1]) * norm(inv_a, ord, axis: [-2, -1])
1075
+ end
1076
+ end
1077
+
1078
+ # Computes the sign and natural logarithm of the determinant of a matrix.
1079
+ #
1080
+ # @param a [Numo::NArray] The n-by-n square matrix.
1081
+ # @return [Array<Float/Complex>] The sign and natural logarithm of the determinant of `a` ([sign, logdet]).
1082
+ def slogdet(a)
1083
+ lu, ipiv = lu_fact(a)
1084
+ dg = lu.diagonal
1085
+ return 0, (-Float::INFINITY) if dg.eq(0).any?
1086
+
1087
+ idx = ipiv.class.new(ipiv.shape[-1]).seq(1)
1088
+ n_nonzero = ipiv.ne(idx).count(axis: -1)
1089
+ sign = ((-1.0)**(n_nonzero % 2)) * (dg / dg.abs).prod
1090
+
1091
+ logdet = Numo::NMath.log(dg.abs).sum(axis: -1)
1092
+
1093
+ [sign, logdet]
1094
+ end
1095
+
1096
+ # Computes the rank of a matrix using SVD.
1097
+ #
1098
+ # @param a [Numo::NArray] The input matrix.
1099
+ # @param tol [Float] The threshold value for small singular values of `a`.
1100
+ # @param driver [String] The LAPACK driver to be used ('svd' or 'sdd').
1101
+ # @return [Integer] The rank of the matrix.
1102
+ def matrix_rank(a, tol: nil, driver: 'svd')
1103
+ return a.ne(0).any? ? 1 : 0 if a.ndim < 2
1104
+
1105
+ s = svdvals(a, driver: driver)
1106
+ tol ||= s.max(axis: -1, keepdims: true) * (a.shape[-2..].max * s.class::EPSILON)
1107
+ s.gt(tol).count(axis: -1)
1108
+ end
1109
+
1110
+ # Computes the least-squares solution to a linear matrix equation.
1111
+ #
1112
+ # @param a [Numo::NArray] The m-by-n input matrix.
1113
+ # @param b [Numo::NArray] The m-dimensional right-hand side vector or the m-by-nrhs right-hand side matrix.
1114
+ # @param driver [String] The LAPACK driver to be used (This argument is ignored, 'lsd' is always used).
1115
+ # @param rcond [Float] The threshold value for small singular values of `a`.
1116
+ # @return [Array<Numo::NArray, Float/Complex, Integer, Numo::NArray>] The least-squares solution matrix / vector `x`,
1117
+ # the sum of squared residuals, the effective rank of `a`, and the singular values of `a`.
1118
+ def lstsq(a, b, driver: 'lsd', rcond: nil) # rubocop:disable Lint/UnusedMethodArgument, Metrics/AbcSize
1119
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1120
+ raise Numo::NArray::ShapeError, "incompatible dimensions: a.shape[0] = #{a.shape[0]} != b.shape[0] = #{b.shape[0]}" if a.shape[0] != b.shape[0]
1121
+
1122
+ bchr = blas_char(a)
1123
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
1124
+
1125
+ m, n = a.shape
1126
+ if m < n
1127
+ if b.ndim == 1
1128
+ x = Numo::DFloat.zeros(n)
1129
+ x[0...b.size] = b
1130
+ else
1131
+ x = Numo::DFloat.zeros(n, b.shape[1])
1132
+ x[0...b.shape[0], 0...b.shape[1]] = b
1133
+ end
1134
+ else
1135
+ x = b.dup
1136
+ end
1137
+
1138
+ fnc = :"#{bchr}gelsd"
1139
+ s, rank, info = Numo::Linalg::Lapack.send(fnc, a.dup, x, rcond: rcond)
1140
+
1141
+ raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
1142
+ raise 'the algorithm for computing the SVD failed to converge' if info.positive?
1143
+
1144
+ resids = x.class[]
1145
+ if m > n
1146
+ if rank == n
1147
+ resids = if b.ndim == 1
1148
+ (x[n..].abs**2).sum(axis: 0)
1149
+ else
1150
+ (x[n..-1, true].abs**2).sum(axis: 0)
1151
+ end
1152
+ end
1153
+ x = if b.ndim == 1
1154
+ x[false, 0...n]
1155
+ else
1156
+ x[false, 0...n, true]
1157
+ end
1158
+ end
1159
+
1160
+ [x, resids, rank, s]
1161
+ end
1162
+
1163
+ # Computes the matrix exponential using a scaling and squaring algorithm with a Pade approximation.
1164
+ #
1165
+ # @param a [Numo::NArray] The n-by-n square matrix.
1166
+ # @param ord [Integer] The order of the Padé approximation.
1167
+ # @return [Numo::NArray] The matrix exponential of `a`.
1168
+ #
1169
+ # Reference:
1170
+ # - C. Moler and C. Van Loan, "Nineteen Dubious Ways to Compute the Exponential of a Matrix, Twenty-Five Years Later," SIAM Review, vol. 45, no. 1, pp. 3-49, 2003.
1171
+ def expm(a, ord = 8) # rubocop:disable Metrics/AbcSize
1172
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1173
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1174
+
1175
+ norm = a.abs.max
1176
+ n_sqr = norm.positive? ? [0, Math.log2(norm).to_i + 1].max : 0
1177
+ a /= 2**n_sqr
1178
+
1179
+ x = a.dup
1180
+ c = 0.5
1181
+ sgn = 1
1182
+ nume = a.class.eye(a.shape[0]) + (c * a)
1183
+ deno = a.class.eye(a.shape[0]) - (c * a)
1184
+ (2..ord).each do |k|
1185
+ c *= (ord - k + 1).fdiv(k * ((2 * ord) - k + 1))
1186
+ x = a.dot(x)
1187
+ c_x = c * x
1188
+ nume += c_x
1189
+ deno += sgn * c_x
1190
+ sgn = -sgn
1191
+ end
1192
+
1193
+ a_expm = Numo::Linalg.solve(deno, nume)
1194
+ n_sqr.times { a_expm = a_expm.dot(a_expm) }
1195
+ a_expm
1196
+ end
1197
+
1198
+ # Computes the inverse of a matrix using its LU decomposition.
1199
+ #
1200
+ # @param lu [Numo::NArray] The LU decomposition of the n-by-n matrix `A`.
1201
+ # @param ipiv [Numo::Int32] The pivot indices from `lu_fact`.
1202
+ # @return [Numo::NArray] The inverse of the matrix `A`.
1203
+ def lu_inv(lu, ipiv)
1204
+ bchr = blas_char(lu)
1205
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
1206
+
1207
+ fnc = :"#{bchr}getri"
1208
+ inv, info = Numo::Linalg::Lapack.send(fnc, lu.dup, ipiv)
1209
+
1210
+ raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
1211
+ raise 'the matrix is singular and its inverse could not be computed' if info.positive?
1212
+
1213
+ inv
1214
+ end
1215
+
1216
+ # Computes the inverse of a symmetric / Hermitian positive-definite matrix using its Cholesky decomposition.
1217
+ #
1218
+ # @example
1219
+ # require 'numo/linalg'
1220
+ #
1221
+ # a = Numo::DFloat.new(3, 5).rand - 0.5
1222
+ # a = a.dot(a.transpose)
1223
+ # c = Numo::Linalg.cho_fact(a)
1224
+ # tri_inv_a = Numo::Linalg.cho_inv(c)
1225
+ # tri_inv_a = tri_inv_a.triu
1226
+ # inv_a = tri_inv_a + tri_inv_a.transpose - tri_inv_a.diagonal.diag
1227
+ # error = (inv_a.dot(a) - Numo::DFloat.eye(3)).abs.max
1228
+ # pp error
1229
+ # # => 1.923726113137665e-15
1230
+ #
1231
+ # @param a [Numo::NArray] The Cholesky decomposition of the n-by-n symmetric / Hermitian positive-definite matrix.
1232
+ # @param uplo [String] The part of the matrix to be used ('U' or 'L').
1233
+ # @return [Numo::NArray] The upper- / lower-triangular matrix of the inverse of `a`.
1234
+ def cho_inv(a, uplo: 'U')
1235
+ bchr = blas_char(a)
1236
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
1237
+
1238
+ fnc = :"#{bchr}potri"
1239
+ inv, info = Numo::Linalg::Lapack.send(fnc, a.dup, uplo: uplo)
1240
+
1241
+ raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
1242
+ raise "the (#{info}, #info)-th element of the factor U or L is zero, and the inverse could not be computed." if info.positive?
1243
+
1244
+ inv
1245
+ end
1246
+
1247
+ # @!visibility private
1248
+ def _lud_permutation(lud, ipiv, uplo: 'U', hermitian: true) # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
1249
+ n = lud.shape[0]
1250
+ d = lud.class.zeros(n, n)
1251
+ perm = Numo::Int32.new(n).seq
1252
+ if uplo == 'U'
1253
+ u = lud.triu.tap { |m| m[m.diag_indices] = 1 }
1254
+ # If IPIV(k) > 0, then rows and columns k and IPIV(k) were interchanged
1255
+ # and D(k,k) is a 1-by-1 diagonal block.
1256
+ # IF UPLO = 'U' and If IPIV(k) = IPIV(k-1) < 0, then
1257
+ # rows and columns k-1 and -IPIV(k) were interchanged
1258
+ # and D(k-1:k,k-1:k) is a 2-by-2 diagonal block.
1259
+ changed_2x2 = false
1260
+ n.times do |k|
1261
+ d[k, k] = lud[k, k]
1262
+ if ipiv[k].positive?
1263
+ i = ipiv[k] - 1
1264
+ u[[i, k], 0..k] = u[[k, i], 0..k].dup
1265
+ perm[[i, k]] = perm[[k, i]].dup
1266
+ elsif k.positive? && ipiv[k].negative? && ipiv[k] == ipiv[k - 1] && !changed_2x2
1267
+ i = -ipiv[k] - 1
1268
+ d[k - 1, k] = lud[k - 1, k]
1269
+ d[k, k - 1] = hermitian ? d[k - 1, k].conj : d[k - 1, k]
1270
+ u[k - 1, k] = 0
1271
+ u[[i, k - 1], 0..k] = u[[k - 1, i], 0..k].dup
1272
+ perm[[i, k - 1]] = perm[[k - 1, i]].dup
1273
+ changed_2x2 = true
1274
+ next
1275
+ end
1276
+ changed_2x2 = false if changed_2x2
1277
+ end
1278
+ [u, d, perm.sort_index]
1279
+ else
1280
+ l = lud.tril.tap { |m| m[m.diag_indices] = 1 }
1281
+ # If UPLO = 'L' and IPIV(k) = IPIV(k+1) < 0, then
1282
+ # rows and columns k+1 and -IPIV(k) were interchanged
1283
+ # and D(k:k+1,k:k+1) is a 2-by-2 diagonal block.
1284
+ changed_2x2 = false
1285
+ (n - 1).downto(0) do |k|
1286
+ d[k, k] = lud[k, k]
1287
+ if ipiv[k].positive?
1288
+ i = ipiv[k] - 1
1289
+ l[[i, k], k...n] = l[[k, i], k...n].dup
1290
+ perm[[i, k]] = perm[[k, i]].dup
1291
+ elsif k < n - 1 && ipiv[k].negative? && ipiv[k] == ipiv[k + 1] && !changed_2x2
1292
+ i = -ipiv[k] - 1
1293
+ d[k + 1, k] = lud[k + 1, k]
1294
+ d[k, k + 1] = hermitian ? d[k + 1, k].conj : d[k + 1, k]
1295
+ l[k + 1, k] = 0
1296
+ l[[i, k + 1], k...n] = l[[k + 1, i], k...n].dup
1297
+ perm[[i, k + 1]] = perm[[k + 1, i]].dup
1298
+ changed_2x2 = true
1299
+ next
1300
+ end
1301
+ changed_2x2 = false if changed_2x2
1302
+ end
1303
+ [l, d, perm.sort_index]
1304
+ end
1305
+ end
1306
+
1307
+ private_class_method :_lud_permutation
1308
+ end
1309
+ end