numo-linalg-alt 0.5.0 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +7 -0
  3. data/README.md +47 -5
  4. data/ext/numo/linalg/blas/blas_common.h +30 -0
  5. data/ext/numo/linalg/blas/blas_util.c +39 -0
  6. data/ext/numo/linalg/blas/blas_util.h +11 -0
  7. data/ext/numo/linalg/{converter.h → blas/converter.h} +0 -2
  8. data/ext/numo/linalg/blas/dot.c +1 -1
  9. data/ext/numo/linalg/blas/dot.h +1 -6
  10. data/ext/numo/linalg/blas/dot_sub.c +1 -1
  11. data/ext/numo/linalg/blas/dot_sub.h +1 -6
  12. data/ext/numo/linalg/blas/gemm.c +21 -21
  13. data/ext/numo/linalg/blas/gemm.h +3 -9
  14. data/ext/numo/linalg/blas/gemv.c +10 -10
  15. data/ext/numo/linalg/blas/gemv.h +3 -9
  16. data/ext/numo/linalg/blas/nrm2.c +1 -1
  17. data/ext/numo/linalg/blas/nrm2.h +1 -6
  18. data/ext/numo/linalg/extconf.rb +33 -6
  19. data/ext/numo/linalg/lapack/gebal.h +1 -1
  20. data/ext/numo/linalg/lapack/gees.c +4 -4
  21. data/ext/numo/linalg/lapack/gees.h +1 -1
  22. data/ext/numo/linalg/lapack/geev.c +8 -24
  23. data/ext/numo/linalg/lapack/geev.h +1 -1
  24. data/ext/numo/linalg/lapack/gehrd.h +1 -1
  25. data/ext/numo/linalg/lapack/gelsd.h +1 -1
  26. data/ext/numo/linalg/lapack/geqrf.h +1 -1
  27. data/ext/numo/linalg/lapack/gerqf.h +1 -1
  28. data/ext/numo/linalg/lapack/gesdd.h +1 -1
  29. data/ext/numo/linalg/lapack/gesv.h +1 -1
  30. data/ext/numo/linalg/lapack/gesvd.h +1 -1
  31. data/ext/numo/linalg/lapack/getrf.h +1 -1
  32. data/ext/numo/linalg/lapack/getri.h +1 -1
  33. data/ext/numo/linalg/lapack/getrs.h +1 -1
  34. data/ext/numo/linalg/lapack/gges.c +4 -4
  35. data/ext/numo/linalg/lapack/gges.h +1 -1
  36. data/ext/numo/linalg/lapack/heev.c +1 -1
  37. data/ext/numo/linalg/lapack/heev.h +1 -1
  38. data/ext/numo/linalg/lapack/heevd.c +1 -1
  39. data/ext/numo/linalg/lapack/heevd.h +1 -1
  40. data/ext/numo/linalg/lapack/heevr.c +1 -1
  41. data/ext/numo/linalg/lapack/heevr.h +1 -1
  42. data/ext/numo/linalg/lapack/hegv.c +1 -1
  43. data/ext/numo/linalg/lapack/hegv.h +1 -1
  44. data/ext/numo/linalg/lapack/hegvd.c +1 -1
  45. data/ext/numo/linalg/lapack/hegvd.h +1 -1
  46. data/ext/numo/linalg/lapack/hegvx.c +1 -1
  47. data/ext/numo/linalg/lapack/hegvx.h +1 -1
  48. data/ext/numo/linalg/lapack/hetrf.h +1 -1
  49. data/ext/numo/linalg/lapack/lange.h +1 -1
  50. data/ext/numo/linalg/lapack/lapack_util.c +57 -0
  51. data/ext/numo/linalg/lapack/lapack_util.h +27 -0
  52. data/ext/numo/linalg/lapack/orghr.h +1 -1
  53. data/ext/numo/linalg/lapack/orgqr.h +1 -1
  54. data/ext/numo/linalg/lapack/orgrq.h +1 -1
  55. data/ext/numo/linalg/lapack/potrf.h +1 -1
  56. data/ext/numo/linalg/lapack/potri.h +1 -1
  57. data/ext/numo/linalg/lapack/potrs.h +1 -1
  58. data/ext/numo/linalg/lapack/syev.c +1 -1
  59. data/ext/numo/linalg/lapack/syev.h +1 -1
  60. data/ext/numo/linalg/lapack/syevd.c +1 -1
  61. data/ext/numo/linalg/lapack/syevd.h +1 -1
  62. data/ext/numo/linalg/lapack/syevr.c +1 -1
  63. data/ext/numo/linalg/lapack/syevr.h +1 -1
  64. data/ext/numo/linalg/lapack/sygv.c +1 -1
  65. data/ext/numo/linalg/lapack/sygv.h +1 -1
  66. data/ext/numo/linalg/lapack/sygvd.c +1 -1
  67. data/ext/numo/linalg/lapack/sygvd.h +1 -1
  68. data/ext/numo/linalg/lapack/sygvx.c +1 -1
  69. data/ext/numo/linalg/lapack/sygvx.h +1 -1
  70. data/ext/numo/linalg/lapack/sytrf.h +1 -1
  71. data/ext/numo/linalg/lapack/trtrs.h +1 -1
  72. data/ext/numo/linalg/lapack/unghr.h +1 -1
  73. data/ext/numo/linalg/lapack/ungqr.h +1 -1
  74. data/ext/numo/linalg/lapack/ungrq.h +1 -1
  75. data/ext/numo/linalg/linalg.c +2 -0
  76. data/ext/numo/linalg/linalg.h +14 -6
  77. data/lib/numo/linalg/version.rb +1 -1
  78. data/lib/numo/linalg.rb +68 -20
  79. metadata +9 -6
  80. data/ext/numo/linalg/util.c +0 -103
  81. data/ext/numo/linalg/util.h +0 -18
  82. /data/ext/numo/linalg/{converter.c → blas/converter.c} +0 -0
@@ -35,15 +35,23 @@
35
35
 
36
36
  #include <ruby.h>
37
37
 
38
- #include <cblas.h>
39
- #include <lapacke.h>
40
- #include <openblas_config.h>
41
-
42
38
  #include <numo/narray.h>
43
39
  #include <numo/template.h>
44
40
 
45
- #include "converter.h"
46
- #include "util.h"
41
+ #ifndef _DEFINED_SCOMPLEX
42
+ #define _DEFINED_SCOMPLEX 1
43
+ #endif
44
+ #ifndef _DEFINED_DCOMPLEX
45
+ #define _DEFINED_DCOMPLEX 1
46
+ #endif
47
+
48
+ #include <cblas.h>
49
+ #include <lapacke.h>
50
+
51
+ #include "extconf.h"
52
+ #ifdef HAVE_OPENBLAS_CONFIG_H
53
+ #include <openblas_config.h>
54
+ #endif
47
55
 
48
56
  #include "blas/dot.h"
49
57
  #include "blas/dot_sub.h"
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::Linalg Alternative (numo-linalg-alt) is an alternative to Numo::Linalg.
6
6
  module Linalg
7
7
  # The version of numo-linalg-alt you install.
8
- VERSION = '0.5.0'
8
+ VERSION = '0.6.0'
9
9
  end
10
10
  end
data/lib/numo/linalg.rb CHANGED
@@ -340,7 +340,9 @@ module Numo
340
340
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
341
341
 
342
342
  fnc = :"#{bchr}potrs"
343
- x, _info = Numo::Linalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
343
+ x, info = Numo::Linalg::Lapack.send(fnc, a, b.dup, uplo: uplo)
344
+ raise "the #{-info}-th argument of potrs had illegal value" if info.negative?
345
+
344
346
  x
345
347
  end
346
348
 
@@ -364,17 +366,13 @@ module Numo
364
366
 
365
367
  getrf = :"#{bchr}getrf"
366
368
  lu, piv, info = Numo::Linalg::Lapack.send(getrf, a.dup)
369
+ raise "the #{-info}-th argument of getrf had illegal value" if info.negative?
370
+ raise 'the factor U is singular, and the inverse matrix could not be computed.' if info.positive?
367
371
 
368
- if info.zero?
369
- det_l = 1
370
- det_u = lu.diagonal.prod
371
- det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
372
- det_l * det_u * det_p
373
- elsif info.positive?
374
- raise 'the factor U is singular, and the inverse matrix could not be computed.'
375
- else
376
- raise "the #{-info}-th argument of getrf had illegal value"
377
- end
372
+ det_l = 1
373
+ det_u = lu.diagonal.prod
374
+ det_p = piv.map_with_index { |v, i| v == i + 1 ? 1 : -1 }.prod
375
+ det_l * det_u * det_p
378
376
  end
379
377
 
380
378
  # Computes the inverse matrix of a square matrix.
@@ -407,13 +405,14 @@ module Numo
407
405
  getri = :"#{bchr}getri"
408
406
 
409
407
  lu, piv, info = Numo::Linalg::Lapack.send(getrf, a.dup)
410
- if info.zero?
411
- Numo::Linalg::Lapack.send(getri, lu, piv)[0]
412
- elsif info.positive?
413
- raise 'the factor U is singular, and the inverse matrix could not be computed.'
414
- else
415
- raise "the #{-info}-th argument of getrf had illegal value"
416
- end
408
+ raise "the #{-info}-th argument of getrf had illegal value" if info.negative?
409
+ raise 'the factor U is singular, and the inverse matrix could not be computed.' if info.positive?
410
+
411
+ a_inv, info = Numo::Linalg::Lapack.send(getri, lu, piv)
412
+ raise "the #{-info}-th argument of getrf had illegal value" if info.negative?
413
+ raise 'the factor U is singular, and the inverse matrix could not be computed.' if info.positive?
414
+
415
+ a_inv
417
416
  end
418
417
 
419
418
  # Computes the (Moore-Penrose) pseudo-inverse of a matrix using singular value decomposition.
@@ -441,7 +440,7 @@ module Numo
441
440
  rank = s.gt(rcond * s[0]).count
442
441
 
443
442
  u = u[true, 0...rank] / s[0...rank]
444
- u.dot(vh[0...rank, true]).conj.transpose
443
+ u.dot(vh[0...rank, true]).conj.transpose.dup
445
444
  end
446
445
 
447
446
  # Computes the polar decomposition of a matrix.
@@ -830,7 +829,11 @@ module Numo
830
829
  raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
831
830
 
832
831
  gesv = :"#{bchr}gesv"
833
- Numo::Linalg::Lapack.send(gesv, a.dup, b.dup)[1]
832
+ _lu, x, _ipiv, info = Numo::Linalg::Lapack.send(gesv, a.dup, b.dup)
833
+ raise "the #{-info}-th argument of getrf had illegal value" if info.negative?
834
+ raise 'the factor U is singular, and the solution could not be computed.' if info.positive?
835
+
836
+ x
834
837
  end
835
838
 
836
839
  # Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix.
@@ -1616,6 +1619,22 @@ module Numo
1616
1619
  a_expm
1617
1620
  end
1618
1621
 
1622
+ # Computes the matrix logarithm using its eigenvalue decomposition.
1623
+ #
1624
+ # @param a [Numo::NArray] The n-by-n square matrix.
1625
+ # @return [Numo::NArray] The matrix logarithm of `a`.
1626
+ def logm(a)
1627
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1628
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1629
+
1630
+ ev, vl, = eig(a, left: true, right: false)
1631
+ v = vl.transpose.conj
1632
+ inv_v = Numo::Linalg.inv(v)
1633
+ log_ev = Numo::NMath.log(ev)
1634
+
1635
+ inv_v.dot(log_ev.diag).dot(v)
1636
+ end
1637
+
1619
1638
  # Computes the matrix sine using the matrix exponential.
1620
1639
  #
1621
1640
  # @param a [Numo::NArray] The n-by-n square matrix.
@@ -1705,6 +1724,35 @@ module Numo
1705
1724
  a_sinh.dot(Numo::Linalg.inv(a_cosh))
1706
1725
  end
1707
1726
 
1727
+ # Computes the square root of a matrix using its eigenvalue decomposition.
1728
+ #
1729
+ # @param a [Numo::NArray] The n-by-n square matrix.
1730
+ # @return [Numo::NArray] The matrix square root of `a`.
1731
+ def sqrtm(a)
1732
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1733
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1734
+
1735
+ ev, vl, = eig(a, left: true, right: false)
1736
+ v = vl.transpose.conj
1737
+ inv_v = Numo::Linalg.inv(v)
1738
+ sqrt_ev = Numo::NMath.sqrt(ev)
1739
+
1740
+ inv_v.dot(sqrt_ev.diag).dot(v)
1741
+ end
1742
+
1743
+ # Computes the matrix sign function using its inverse and square root matrices.
1744
+ #
1745
+ # @param a [Numo::NArray] The n-by-n square matrix.
1746
+ # @return [Numo::NArray] The matrix sign function of `a`.
1747
+ def signm(a)
1748
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1749
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1750
+
1751
+ a_sqrt = sqrtm(a.dot(a))
1752
+ a_inv = Numo::Linalg.inv(a)
1753
+ a_inv.dot(a_sqrt)
1754
+ end
1755
+
1708
1756
  # Computes the inverse of a matrix using its LU decomposition.
1709
1757
  #
1710
1758
  # @param lu [Numo::NArray] The LU decomposition of the n-by-n matrix `A`.
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-linalg-alt
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.0
4
+ version: 0.6.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
@@ -37,6 +37,11 @@ files:
37
37
  - CODE_OF_CONDUCT.md
38
38
  - LICENSE.txt
39
39
  - README.md
40
+ - ext/numo/linalg/blas/blas_common.h
41
+ - ext/numo/linalg/blas/blas_util.c
42
+ - ext/numo/linalg/blas/blas_util.h
43
+ - ext/numo/linalg/blas/converter.c
44
+ - ext/numo/linalg/blas/converter.h
40
45
  - ext/numo/linalg/blas/dot.c
41
46
  - ext/numo/linalg/blas/dot.h
42
47
  - ext/numo/linalg/blas/dot_sub.c
@@ -47,8 +52,6 @@ files:
47
52
  - ext/numo/linalg/blas/gemv.h
48
53
  - ext/numo/linalg/blas/nrm2.c
49
54
  - ext/numo/linalg/blas/nrm2.h
50
- - ext/numo/linalg/converter.c
51
- - ext/numo/linalg/converter.h
52
55
  - ext/numo/linalg/extconf.rb
53
56
  - ext/numo/linalg/lapack/gebal.c
54
57
  - ext/numo/linalg/lapack/gebal.h
@@ -94,6 +97,8 @@ files:
94
97
  - ext/numo/linalg/lapack/hetrf.h
95
98
  - ext/numo/linalg/lapack/lange.c
96
99
  - ext/numo/linalg/lapack/lange.h
100
+ - ext/numo/linalg/lapack/lapack_util.c
101
+ - ext/numo/linalg/lapack/lapack_util.h
97
102
  - ext/numo/linalg/lapack/orghr.c
98
103
  - ext/numo/linalg/lapack/orghr.h
99
104
  - ext/numo/linalg/lapack/orgqr.c
@@ -130,8 +135,6 @@ files:
130
135
  - ext/numo/linalg/lapack/ungrq.h
131
136
  - ext/numo/linalg/linalg.c
132
137
  - ext/numo/linalg/linalg.h
133
- - ext/numo/linalg/util.c
134
- - ext/numo/linalg/util.h
135
138
  - lib/numo/linalg.rb
136
139
  - lib/numo/linalg/version.rb
137
140
  - vendor/tmp/.gitkeep
@@ -142,7 +145,7 @@ metadata:
142
145
  homepage_uri: https://github.com/yoshoku/numo-linalg-alt
143
146
  source_code_uri: https://github.com/yoshoku/numo-linalg-alt
144
147
  changelog_uri: https://github.com/yoshoku/numo-linalg-alt/blob/main/CHANGELOG.md
145
- documentation_uri: https://gemdocs.org/gems/numo-linalg-alt/0.5.0/
148
+ documentation_uri: https://gemdocs.org/gems/numo-linalg-alt/0.6.0/
146
149
  rubygems_mfa_required: 'true'
147
150
  rdoc_options: []
148
151
  require_paths:
@@ -1,103 +0,0 @@
1
- #include "util.h"
2
-
3
- lapack_int get_itype(VALUE val) {
4
- const lapack_int itype = NUM2INT(val);
5
-
6
- if (itype != 1 && itype != 2 && itype != 3) {
7
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
8
- }
9
-
10
- return itype;
11
- }
12
-
13
- char get_jobz(VALUE val) {
14
- const char jobz = NUM2CHR(val);
15
-
16
- if (jobz != 'N' && jobz != 'V') {
17
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
18
- }
19
-
20
- return jobz;
21
- }
22
-
23
- char get_jobvs(VALUE val) {
24
- const char jobvs = NUM2CHR(val);
25
- if (jobvs != 'N' && jobvs != 'V') {
26
- rb_raise(rb_eArgError, "jobvs must be 'N' or 'V'");
27
- }
28
- return jobvs;
29
- }
30
-
31
- char get_range(VALUE val) {
32
- const char range = NUM2CHR(val);
33
-
34
- if (range != 'A' && range != 'V' && range != 'I') {
35
- rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
36
- }
37
-
38
- return range;
39
- }
40
-
41
- char get_uplo(VALUE val) {
42
- const char uplo = NUM2CHR(val);
43
-
44
- if (uplo != 'U' && uplo != 'L') {
45
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
46
- }
47
-
48
- return uplo;
49
- }
50
-
51
- int get_matrix_layout(VALUE val) {
52
- const char option = NUM2CHR(val);
53
-
54
- switch (option) {
55
- case 'r':
56
- case 'R':
57
- break;
58
- case 'c':
59
- case 'C':
60
- rb_warn("Numo::Linalg does not support column major.");
61
- break;
62
- }
63
-
64
- return LAPACK_ROW_MAJOR;
65
- }
66
-
67
- enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
68
- const char option = NUM2CHR(val);
69
- enum CBLAS_TRANSPOSE res = CblasNoTrans;
70
-
71
- switch (option) {
72
- case 'n':
73
- case 'N':
74
- res = CblasNoTrans;
75
- break;
76
- case 't':
77
- case 'T':
78
- res = CblasTrans;
79
- break;
80
- case 'c':
81
- case 'C':
82
- res = CblasConjTrans;
83
- break;
84
- }
85
-
86
- return res;
87
- }
88
-
89
- enum CBLAS_ORDER get_cblas_order(VALUE val) {
90
- const char option = NUM2CHR(val);
91
-
92
- switch (option) {
93
- case 'r':
94
- case 'R':
95
- break;
96
- case 'c':
97
- case 'C':
98
- rb_warn("Numo::Linalg does not support column major.");
99
- break;
100
- }
101
-
102
- return CblasRowMajor;
103
- }
@@ -1,18 +0,0 @@
1
- #ifndef NUMO_LINALG_ALT_UTIL_H
2
- #define NUMO_LINALG_ALT_UTIL_H 1
3
-
4
- #include <ruby.h>
5
-
6
- #include <cblas.h>
7
- #include <lapacke.h>
8
-
9
- lapack_int get_itype(VALUE val);
10
- char get_jobz(VALUE val);
11
- char get_jobvs(VALUE val);
12
- char get_range(VALUE val);
13
- char get_uplo(VALUE val);
14
- int get_matrix_layout(VALUE val);
15
- enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val);
16
- enum CBLAS_ORDER get_cblas_order(VALUE val);
17
-
18
- #endif // NUMO_LINALG_ALT_UTIL_H
File without changes