numo-linalg-alt 0.4.0 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +11 -0
  3. data/README.md +36 -0
  4. data/ext/numo/linalg/blas/dot.c +2 -2
  5. data/ext/numo/linalg/blas/dot_sub.c +2 -2
  6. data/ext/numo/linalg/blas/gemm.c +4 -4
  7. data/ext/numo/linalg/blas/gemv.c +4 -4
  8. data/ext/numo/linalg/blas/nrm2.c +4 -4
  9. data/ext/numo/linalg/extconf.rb +1 -0
  10. data/ext/numo/linalg/lapack/gebal.c +87 -0
  11. data/ext/numo/linalg/lapack/gebal.h +15 -0
  12. data/ext/numo/linalg/lapack/gees.c +4 -4
  13. data/ext/numo/linalg/lapack/geev.c +4 -4
  14. data/ext/numo/linalg/lapack/gehrd.c +77 -0
  15. data/ext/numo/linalg/lapack/gehrd.h +15 -0
  16. data/ext/numo/linalg/lapack/gelsd.c +4 -4
  17. data/ext/numo/linalg/lapack/geqrf.c +4 -4
  18. data/ext/numo/linalg/lapack/gerqf.c +4 -4
  19. data/ext/numo/linalg/lapack/gesdd.c +4 -4
  20. data/ext/numo/linalg/lapack/gesv.c +4 -4
  21. data/ext/numo/linalg/lapack/gesvd.c +4 -4
  22. data/ext/numo/linalg/lapack/getrf.c +4 -4
  23. data/ext/numo/linalg/lapack/getri.c +4 -4
  24. data/ext/numo/linalg/lapack/getrs.c +4 -4
  25. data/ext/numo/linalg/lapack/gges.c +4 -4
  26. data/ext/numo/linalg/lapack/heev.c +2 -2
  27. data/ext/numo/linalg/lapack/heevd.c +2 -2
  28. data/ext/numo/linalg/lapack/heevr.c +2 -2
  29. data/ext/numo/linalg/lapack/hegv.c +2 -2
  30. data/ext/numo/linalg/lapack/hegvd.c +2 -2
  31. data/ext/numo/linalg/lapack/hegvx.c +2 -2
  32. data/ext/numo/linalg/lapack/hetrf.c +2 -2
  33. data/ext/numo/linalg/lapack/lange.c +11 -12
  34. data/ext/numo/linalg/lapack/orghr.c +82 -0
  35. data/ext/numo/linalg/lapack/orghr.h +15 -0
  36. data/ext/numo/linalg/lapack/orgqr.c +2 -2
  37. data/ext/numo/linalg/lapack/orgrq.c +2 -2
  38. data/ext/numo/linalg/lapack/potrf.c +4 -4
  39. data/ext/numo/linalg/lapack/potri.c +4 -4
  40. data/ext/numo/linalg/lapack/potrs.c +4 -4
  41. data/ext/numo/linalg/lapack/syev.c +2 -2
  42. data/ext/numo/linalg/lapack/syevd.c +2 -2
  43. data/ext/numo/linalg/lapack/syevr.c +2 -2
  44. data/ext/numo/linalg/lapack/sygv.c +2 -2
  45. data/ext/numo/linalg/lapack/sygvd.c +2 -2
  46. data/ext/numo/linalg/lapack/sygvx.c +2 -2
  47. data/ext/numo/linalg/lapack/sytrf.c +4 -4
  48. data/ext/numo/linalg/lapack/trtrs.c +4 -4
  49. data/ext/numo/linalg/lapack/unghr.c +82 -0
  50. data/ext/numo/linalg/lapack/unghr.h +15 -0
  51. data/ext/numo/linalg/lapack/ungqr.c +2 -2
  52. data/ext/numo/linalg/lapack/ungrq.c +2 -2
  53. data/ext/numo/linalg/linalg.c +7 -3
  54. data/ext/numo/linalg/linalg.h +4 -0
  55. data/lib/numo/linalg/version.rb +1 -1
  56. data/lib/numo/linalg.rb +188 -0
  57. metadata +10 -2
data/lib/numo/linalg.rb CHANGED
@@ -740,6 +740,63 @@ module Numo
740
740
  [b, v, sdim]
741
741
  end
742
742
 
743
+ # Computes the Hessenberg decomposition of a square matrix.
744
+ # The Hessenberg decomposition is given by `A = Q * H * Q^H`,
745
+ # where `A` is the input matrix, `Q` is a unitary matrix,
746
+ # and `H` is an upper Hessenberg matrix.
747
+ #
748
+ # @example
749
+ # require 'numo/linalg'
750
+ #
751
+ # a = Numo::DFloat[[1, 2, 3], [4, 5, 6], [7, 8, 9]] * 0.5
752
+ # h, q = Numo::Linalg.hessenberg(a, calc_q: true)
753
+ #
754
+ # pp h
755
+ # # => Numo::DFloat#shape=[3,3]
756
+ # # [[0.5, -1.7985, -0.124035],
757
+ # # [-4.03113, 7.02308, 1.41538],
758
+ # # [0, 0.415385, -0.0230769]]
759
+ # pp q
760
+ # # => Numo::DFloat#shape=[3,3]
761
+ # # [[1, 0, 0],
762
+ # # [0, -0.496139, -0.868243],
763
+ # # [0, -0.868243, 0.496139]]
764
+ # pp (a - q.dot(h).dot(q.transpose)).abs.max
765
+ # # => 1.7763568394002505e-15
766
+ #
767
+ # @param a [Numo::NArray] The n-by-n square matrix.
768
+ # @param calc_q [Boolean] The flag indicating whether to calculate the unitary matrix `Q`.
769
+ # @return [Numo::NArray] if calc_q=false, the Hessenberg form `H`.
770
+ # @return [Array<Numo::NArray, Numo::NArray>] if calc_q=true,
771
+ # the Hessenberg form `H` and the unitary matrix `Q`.
772
+ def hessenberg(a, calc_q: false)
773
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
774
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
775
+
776
+ bchr = blas_char(a)
777
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
778
+
779
+ func = :"#{bchr}gebal"
780
+ b, ilo, ihi, _, info = Numo::Linalg::Lapack.send(func, a.dup)
781
+
782
+ raise "the #{-info}-th argument of #{func} had illegal value" if info.negative?
783
+
784
+ func = :"#{bchr}gehrd"
785
+ hq, tau, info = Numo::Linalg::Lapack.send(func, b, ilo: ilo, ihi: ihi)
786
+
787
+ raise "the #{-info}-th argument of #{func} had illegal value" if info.negative?
788
+
789
+ h = hq.triu(-1)
790
+ return h unless calc_q
791
+
792
+ func = %w[d s].include?(bchr) ? :"#{bchr}orghr" : :"#{bchr}unghr"
793
+ q, info = Numo::Linalg::Lapack.send(func, hq, tau, ilo: ilo, ihi: ihi)
794
+
795
+ raise "the #{-info}-th argument of #{func} had illegal value" if info.negative?
796
+
797
+ [h, q]
798
+ end
799
+
743
800
  # Solves linear equation `A * x = b` or `A * X = B` for `x` from square matrix `A`.
744
801
  #
745
802
  # @example
@@ -1189,6 +1246,102 @@ module Numo
1189
1246
  [r, scale]
1190
1247
  end
1191
1248
 
1249
+ # Computes a diagonal similarity transformation that balances a square matrix.
1250
+ #
1251
+ # @example
1252
+ # require 'numo/linalg'
1253
+ #
1254
+ # a = Numo::DFloat[[1, 0, 0], [1, 2, 0], [1, 2, 3]]
1255
+ # b, h = Numo::Linalg.matrix_balance(a)
1256
+ # pp b
1257
+ # # =>
1258
+ # # Numo::DFloat#shape=[3,3]
1259
+ # # [[3, 2, 1],
1260
+ # # [0, 2, 1],
1261
+ # # [0, 0, 1]]
1262
+ # pp h
1263
+ # # =>
1264
+ # # Numo::DFloat#shape=[3,3]
1265
+ # # [[0, 0, 1],
1266
+ # # [0, 1, 0],
1267
+ # # [1, 0, 0]]
1268
+ # pp (Numo::Linalg.inv(h).dot(a).dot(h) - b).abs.max
1269
+ # # => 0.0
1270
+ #
1271
+ # @param a [Numo::NArray] The n-by-n square matrix.
1272
+ # @param permute [Boolean] The flag indicating whether to permute the matrix.
1273
+ # @param scale [Boolean] The flag indicating whether to scale the matrix.
1274
+ # @param separate [Boolean] The flag indicating whether to return scaling factors and permutation indices
1275
+ # separately.
1276
+ # @return [Array<Numo::NArray, Numo::NArray>] if `separate` is `false`, the balanced matrix and the
1277
+ # similarity transformation matrix `H` ([b, h]). if `separate` is `true`, the balanced matrix, the
1278
+ # scaling factors, and the permutation indices ([b, scaler, perm]).
1279
+ def matrix_balance(a, permute: true, scale: true, separate: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
1280
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1281
+
1282
+ n = a.shape[0]
1283
+ raise ArgumentError, 'input array a must be square' if a.shape[1] != n
1284
+
1285
+ bchr = blas_char(a)
1286
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
1287
+
1288
+ job = if permute && scale
1289
+ 'B'
1290
+ elsif permute && !scale
1291
+ 'P'
1292
+ elsif !permute && scale
1293
+ 'S'
1294
+ else
1295
+ 'N'
1296
+ end
1297
+ fnc = :"#{bchr}gebal"
1298
+ b, lo, hi, prm_scl, info = Numo::Linalg::Lapack.send(fnc, a.dup, job: job)
1299
+
1300
+ raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
1301
+
1302
+ # convert from Fortran style index to Ruby style index.
1303
+ lo -= 1
1304
+ hi -= 1
1305
+ iprm_scl = Numo::Int32.cast(prm_scl) - 1
1306
+
1307
+ # extract scaling factors
1308
+ scaler = prm_scl.class.ones(n)
1309
+ scaler[lo...(hi + 1)] = prm_scl[lo...(hi + 1)]
1310
+
1311
+ # extract permutation indices
1312
+ perm = Numo::Int32.new(n).seq
1313
+ if hi < n - 1
1314
+ iprm_scl[(hi + 1)...n].to_a.reverse.each.with_index(1) do |s, i|
1315
+ j = n - i
1316
+ next if s == j
1317
+
1318
+ tmp_ls, tmp_lj = perm[[s, j]].to_a
1319
+ tmp_rj, tmp_rs = perm[[j, s]].to_a
1320
+ perm[[s, j]] = [tmp_rj, tmp_rs]
1321
+ perm[[j, s]] = [tmp_ls, tmp_lj]
1322
+ end
1323
+ end
1324
+ if lo > 0 # rubocop:disable Style/NumericPredicate
1325
+ iprm_scl[0...lo].to_a.each_with_index do |s, j|
1326
+ next if s == j
1327
+
1328
+ tmp_ls, tmp_lj = perm[[s, j]].to_a
1329
+ tmp_rj, tmp_rs = perm[[j, s]].to_a
1330
+ perm[[s, j]] = [tmp_rj, tmp_rs]
1331
+ perm[[j, s]] = [tmp_ls, tmp_lj]
1332
+ end
1333
+ end
1334
+
1335
+ return [b, scaler, perm] if separate
1336
+
1337
+ # construct inverse permutation matrix
1338
+ inv_perm = Numo::Int32.zeros(n)
1339
+ inv_perm[perm] = Numo::Int32.new(n).seq
1340
+ h = scaler.diag[inv_perm, true].dup
1341
+
1342
+ [b, h]
1343
+ end
1344
+
1192
1345
  # Computes the eigenvalues and right and/or left eigenvectors of a general square matrix.
1193
1346
  #
1194
1347
  # @example
@@ -1517,6 +1670,41 @@ module Numo
1517
1670
  a_sin.dot(Numo::Linalg.inv(a_cos))
1518
1671
  end
1519
1672
 
1673
+ # Computes the matrix hyperbolic sine using the matrix exponential.
1674
+ #
1675
+ # @param a [Numo::NArray] The n-by-n square matrix.
1676
+ # @return [Numo::NArray] The matrix hyperbolic sine of `a`.
1677
+ def sinhm(a)
1678
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1679
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1680
+
1681
+ 0.5 * (expm(a) - expm(-a))
1682
+ end
1683
+
1684
+ # Computes the matrix hyperbolic cosine using the matrix exponential.
1685
+ #
1686
+ # @param a [Numo::NArray] The n-by-n square matrix.
1687
+ # @return [Numo::NArray] The matrix hyperbolic cosine of `a`.
1688
+ def coshm(a)
1689
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1690
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1691
+
1692
+ 0.5 * (expm(a) + expm(-a))
1693
+ end
1694
+
1695
+ # Computes the matrix hyperbolic tangent.
1696
+ #
1697
+ # @param a [Numo::NArray] The n-by-n square matrix.
1698
+ # @return [Numo::NArray] The matrix hyperbolic tangent of `a`.
1699
+ def tanhm(a)
1700
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1701
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1702
+
1703
+ a_sinh = sinhm(a)
1704
+ a_cosh = coshm(a)
1705
+ a_sinh.dot(Numo::Linalg.inv(a_cosh))
1706
+ end
1707
+
1520
1708
  # Computes the inverse of a matrix using its LU decomposition.
1521
1709
  #
1522
1710
  # @param lu [Numo::NArray] The LU decomposition of the n-by-n matrix `A`.
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-linalg-alt
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.0
4
+ version: 0.5.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
@@ -50,10 +50,14 @@ files:
50
50
  - ext/numo/linalg/converter.c
51
51
  - ext/numo/linalg/converter.h
52
52
  - ext/numo/linalg/extconf.rb
53
+ - ext/numo/linalg/lapack/gebal.c
54
+ - ext/numo/linalg/lapack/gebal.h
53
55
  - ext/numo/linalg/lapack/gees.c
54
56
  - ext/numo/linalg/lapack/gees.h
55
57
  - ext/numo/linalg/lapack/geev.c
56
58
  - ext/numo/linalg/lapack/geev.h
59
+ - ext/numo/linalg/lapack/gehrd.c
60
+ - ext/numo/linalg/lapack/gehrd.h
57
61
  - ext/numo/linalg/lapack/gelsd.c
58
62
  - ext/numo/linalg/lapack/gelsd.h
59
63
  - ext/numo/linalg/lapack/geqrf.c
@@ -90,6 +94,8 @@ files:
90
94
  - ext/numo/linalg/lapack/hetrf.h
91
95
  - ext/numo/linalg/lapack/lange.c
92
96
  - ext/numo/linalg/lapack/lange.h
97
+ - ext/numo/linalg/lapack/orghr.c
98
+ - ext/numo/linalg/lapack/orghr.h
93
99
  - ext/numo/linalg/lapack/orgqr.c
94
100
  - ext/numo/linalg/lapack/orgqr.h
95
101
  - ext/numo/linalg/lapack/orgrq.c
@@ -116,6 +122,8 @@ files:
116
122
  - ext/numo/linalg/lapack/sytrf.h
117
123
  - ext/numo/linalg/lapack/trtrs.c
118
124
  - ext/numo/linalg/lapack/trtrs.h
125
+ - ext/numo/linalg/lapack/unghr.c
126
+ - ext/numo/linalg/lapack/unghr.h
119
127
  - ext/numo/linalg/lapack/ungqr.c
120
128
  - ext/numo/linalg/lapack/ungqr.h
121
129
  - ext/numo/linalg/lapack/ungrq.c
@@ -134,7 +142,7 @@ metadata:
134
142
  homepage_uri: https://github.com/yoshoku/numo-linalg-alt
135
143
  source_code_uri: https://github.com/yoshoku/numo-linalg-alt
136
144
  changelog_uri: https://github.com/yoshoku/numo-linalg-alt/blob/main/CHANGELOG.md
137
- documentation_uri: https://gemdocs.org/gems/numo-linalg-alt/0.4.0/
145
+ documentation_uri: https://gemdocs.org/gems/numo-linalg-alt/0.5.0/
138
146
  rubygems_mfa_required: 'true'
139
147
  rdoc_options: []
140
148
  require_paths: