numo-linalg-alt 0.3.0 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/numo/linalg/blas/dot.c +61 -61
  4. data/ext/numo/linalg/blas/dot_sub.c +60 -60
  5. data/ext/numo/linalg/blas/gemm.c +161 -152
  6. data/ext/numo/linalg/blas/gemv.c +135 -131
  7. data/ext/numo/linalg/blas/nrm2.c +54 -54
  8. data/ext/numo/linalg/lapack/gebal.c +87 -0
  9. data/ext/numo/linalg/lapack/gebal.h +15 -0
  10. data/ext/numo/linalg/lapack/gees.c +243 -224
  11. data/ext/numo/linalg/lapack/geev.c +131 -114
  12. data/ext/numo/linalg/lapack/gelsd.c +85 -74
  13. data/ext/numo/linalg/lapack/geqrf.c +56 -55
  14. data/ext/numo/linalg/lapack/gerqf.c +70 -0
  15. data/ext/numo/linalg/lapack/gerqf.h +15 -0
  16. data/ext/numo/linalg/lapack/gesdd.c +100 -90
  17. data/ext/numo/linalg/lapack/gesv.c +84 -82
  18. data/ext/numo/linalg/lapack/gesvd.c +144 -133
  19. data/ext/numo/linalg/lapack/getrf.c +55 -54
  20. data/ext/numo/linalg/lapack/getri.c +68 -67
  21. data/ext/numo/linalg/lapack/getrs.c +96 -92
  22. data/ext/numo/linalg/lapack/gges.c +214 -0
  23. data/ext/numo/linalg/lapack/gges.h +15 -0
  24. data/ext/numo/linalg/lapack/heev.c +56 -54
  25. data/ext/numo/linalg/lapack/heevd.c +56 -54
  26. data/ext/numo/linalg/lapack/heevr.c +111 -100
  27. data/ext/numo/linalg/lapack/hegv.c +79 -76
  28. data/ext/numo/linalg/lapack/hegvd.c +79 -76
  29. data/ext/numo/linalg/lapack/hegvx.c +134 -122
  30. data/ext/numo/linalg/lapack/hetrf.c +56 -52
  31. data/ext/numo/linalg/lapack/lange.c +49 -48
  32. data/ext/numo/linalg/lapack/orgqr.c +65 -64
  33. data/ext/numo/linalg/lapack/orgrq.c +78 -0
  34. data/ext/numo/linalg/lapack/orgrq.h +15 -0
  35. data/ext/numo/linalg/lapack/potrf.c +53 -52
  36. data/ext/numo/linalg/lapack/potri.c +53 -52
  37. data/ext/numo/linalg/lapack/potrs.c +78 -76
  38. data/ext/numo/linalg/lapack/syev.c +56 -54
  39. data/ext/numo/linalg/lapack/syevd.c +56 -54
  40. data/ext/numo/linalg/lapack/syevr.c +109 -100
  41. data/ext/numo/linalg/lapack/sygv.c +79 -75
  42. data/ext/numo/linalg/lapack/sygvd.c +79 -75
  43. data/ext/numo/linalg/lapack/sygvx.c +134 -122
  44. data/ext/numo/linalg/lapack/sytrf.c +58 -54
  45. data/ext/numo/linalg/lapack/trtrs.c +83 -79
  46. data/ext/numo/linalg/lapack/ungqr.c +65 -64
  47. data/ext/numo/linalg/lapack/ungrq.c +78 -0
  48. data/ext/numo/linalg/lapack/ungrq.h +15 -0
  49. data/ext/numo/linalg/linalg.c +24 -13
  50. data/ext/numo/linalg/linalg.h +5 -0
  51. data/ext/numo/linalg/util.c +8 -0
  52. data/ext/numo/linalg/util.h +1 -0
  53. data/lib/numo/linalg/version.rb +1 -1
  54. data/lib/numo/linalg.rb +235 -3
  55. metadata +12 -2
@@ -42,9 +42,11 @@ char blas_char(VALUE nary_arr) {
42
42
  if (RB_TYPE_P(arg, T_ARRAY)) {
43
43
  arg = rb_funcall(numo_cNArray, rb_intern("asarray"), 1, arg);
44
44
  }
45
- if (CLASS_OF(arg) == numo_cBit || CLASS_OF(arg) == numo_cInt64 || CLASS_OF(arg) == numo_cInt32 ||
46
- CLASS_OF(arg) == numo_cInt16 || CLASS_OF(arg) == numo_cInt8 || CLASS_OF(arg) == numo_cUInt64 ||
47
- CLASS_OF(arg) == numo_cUInt32 || CLASS_OF(arg) == numo_cUInt16 || CLASS_OF(arg) == numo_cUInt8) {
45
+ if (CLASS_OF(arg) == numo_cBit || CLASS_OF(arg) == numo_cInt64 ||
46
+ CLASS_OF(arg) == numo_cInt32 || CLASS_OF(arg) == numo_cInt16 ||
47
+ CLASS_OF(arg) == numo_cInt8 || CLASS_OF(arg) == numo_cUInt64 ||
48
+ CLASS_OF(arg) == numo_cUInt32 || CLASS_OF(arg) == numo_cUInt16 ||
49
+ CLASS_OF(arg) == numo_cUInt8) {
48
50
  if (type == 'n') {
49
51
  type = 'd';
50
52
  }
@@ -97,8 +99,7 @@ static VALUE linalg_blas_call(int argc, VALUE* argv, VALUE self) {
97
99
  }
98
100
 
99
101
  char fn_str[256];
100
- snprintf(fn_str, sizeof(fn_str), "%c%s",
101
- type, rb_id2name(rb_to_id(rb_to_symbol(fn_name))));
102
+ snprintf(fn_str, sizeof(fn_str), "%c%s", type, rb_id2name(rb_to_id(rb_to_symbol(fn_name))));
102
103
  ID fn_id = rb_intern(fn_str);
103
104
  size_t n = RARRAY_LEN(nary_arr);
104
105
  VALUE ret = Qnil;
@@ -146,7 +147,8 @@ static VALUE linalg_dot(VALUE self, VALUE a_, VALUE b_) {
146
147
  ret = rb_funcall(rb_mLinalgBlas, rb_intern("call"), 3, ID2SYM(fn_id), a, b);
147
148
  } else {
148
149
  VALUE kw_args = rb_hash_new();
149
- if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
150
+ if (!RTEST(nary_check_contiguous(b)) &&
151
+ RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
150
152
  b = rb_funcall(b, rb_intern("transpose"), 0);
151
153
  rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
152
154
  } else {
@@ -160,7 +162,8 @@ static VALUE linalg_dot(VALUE self, VALUE a_, VALUE b_) {
160
162
  } else {
161
163
  if (b_ndim == 1) {
162
164
  VALUE kw_args = rb_hash_new();
163
- if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
165
+ if (!RTEST(nary_check_contiguous(a)) &&
166
+ RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
164
167
  a = rb_funcall(a, rb_intern("transpose"), 0);
165
168
  rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
166
169
  } else {
@@ -172,13 +175,15 @@ static VALUE linalg_dot(VALUE self, VALUE a_, VALUE b_) {
172
175
  ret = rb_funcallv_kw(rb_mLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
173
176
  } else {
174
177
  VALUE kw_args = rb_hash_new();
175
- if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
178
+ if (!RTEST(nary_check_contiguous(a)) &&
179
+ RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
176
180
  a = rb_funcall(a, rb_intern("transpose"), 0);
177
181
  rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("T"));
178
182
  } else {
179
183
  rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("N"));
180
184
  }
181
- if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
185
+ if (!RTEST(nary_check_contiguous(b)) &&
186
+ RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
182
187
  b = rb_funcall(b, rb_intern("transpose"), 0);
183
188
  rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("T"));
184
189
  } else {
@@ -202,7 +207,8 @@ void Init_linalg(void) {
202
207
 
203
208
  /**
204
209
  * Document-module: Numo::Linalg
205
- * Numo::Linalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
210
+ * Numo::Linalg is a subset library from Numo::Linalg consisting only of methods used
211
+ * in Machine Learning algorithms.
206
212
  */
207
213
  rb_mLinalg = rb_define_module_under(rb_mNumo, "Linalg");
208
214
  /**
@@ -228,7 +234,7 @@ void Init_linalg(void) {
228
234
  * @param [Numo::NArray] a
229
235
  * @return [String]
230
236
  */
231
- rb_define_module_function(rb_mLinalg, "blas_char", RUBY_METHOD_FUNC(linalg_blas_char), -1);
237
+ rb_define_module_function(rb_mLinalg, "blas_char", linalg_blas_char, -1);
232
238
  /**
233
239
  * Calculates dot product of two vectors / matrices.
234
240
  *
@@ -237,7 +243,7 @@ void Init_linalg(void) {
237
243
  * @param [Numo::NArray] b
238
244
  * @return [Float|Complex|Numo::NArray]
239
245
  */
240
- rb_define_module_function(rb_mLinalg, "dot", RUBY_METHOD_FUNC(linalg_dot), 2);
246
+ rb_define_module_function(rb_mLinalg, "dot", linalg_dot, 2);
241
247
  /**
242
248
  * Calls BLAS function prefixed with BLAS char.
243
249
  *
@@ -247,7 +253,7 @@ void Init_linalg(void) {
247
253
  * @example
248
254
  * Numo::Linalg::Blas.call(:gemv, a, b)
249
255
  */
250
- rb_define_singleton_method(rb_mLinalgBlas, "call", RUBY_METHOD_FUNC(linalg_blas_call), -1);
256
+ rb_define_singleton_method(rb_mLinalgBlas, "call", linalg_blas_call, -1);
251
257
 
252
258
  define_linalg_blas_dot(rb_mLinalgBlas);
253
259
  define_linalg_blas_dot_sub(rb_mLinalgBlas);
@@ -255,9 +261,13 @@ void Init_linalg(void) {
255
261
  define_linalg_blas_gemv(rb_mLinalgBlas);
256
262
  define_linalg_blas_nrm2(rb_mLinalgBlas);
257
263
  define_linalg_lapack_geqrf(rb_mLinalgLapack);
264
+ define_linalg_lapack_gerqf(rb_mLinalgLapack);
258
265
  define_linalg_lapack_orgqr(rb_mLinalgLapack);
266
+ define_linalg_lapack_orgrq(rb_mLinalgLapack);
259
267
  define_linalg_lapack_ungqr(rb_mLinalgLapack);
268
+ define_linalg_lapack_ungrq(rb_mLinalgLapack);
260
269
  define_linalg_lapack_gees(rb_mLinalgLapack);
270
+ define_linalg_lapack_gges(rb_mLinalgLapack);
261
271
  define_linalg_lapack_geev(rb_mLinalgLapack);
262
272
  define_linalg_lapack_gesv(rb_mLinalgLapack);
263
273
  define_linalg_lapack_gesvd(rb_mLinalgLapack);
@@ -285,6 +295,7 @@ void Init_linalg(void) {
285
295
  define_linalg_lapack_gelsd(rb_mLinalgLapack);
286
296
  define_linalg_lapack_sytrf(rb_mLinalgLapack);
287
297
  define_linalg_lapack_hetrf(rb_mLinalgLapack);
298
+ define_linalg_lapack_gebal(rb_mLinalgLapack);
288
299
 
289
300
  rb_define_alias(rb_singleton_class(rb_mLinalgBlas), "znrm2", "dznrm2");
290
301
  rb_define_alias(rb_singleton_class(rb_mLinalgBlas), "cnrm2", "scnrm2");
@@ -51,16 +51,19 @@
51
51
  #include "blas/gemv.h"
52
52
  #include "blas/nrm2.h"
53
53
 
54
+ #include "lapack/gebal.h"
54
55
  #include "lapack/gees.h"
55
56
  #include "lapack/geev.h"
56
57
  #include "lapack/gelsd.h"
57
58
  #include "lapack/geqrf.h"
59
+ #include "lapack/gerqf.h"
58
60
  #include "lapack/gesdd.h"
59
61
  #include "lapack/gesv.h"
60
62
  #include "lapack/gesvd.h"
61
63
  #include "lapack/getrf.h"
62
64
  #include "lapack/getri.h"
63
65
  #include "lapack/getrs.h"
66
+ #include "lapack/gges.h"
64
67
  #include "lapack/heev.h"
65
68
  #include "lapack/heevd.h"
66
69
  #include "lapack/heevr.h"
@@ -70,6 +73,7 @@
70
73
  #include "lapack/hetrf.h"
71
74
  #include "lapack/lange.h"
72
75
  #include "lapack/orgqr.h"
76
+ #include "lapack/orgrq.h"
73
77
  #include "lapack/potrf.h"
74
78
  #include "lapack/potri.h"
75
79
  #include "lapack/potrs.h"
@@ -82,5 +86,6 @@
82
86
  #include "lapack/sytrf.h"
83
87
  #include "lapack/trtrs.h"
84
88
  #include "lapack/ungqr.h"
89
+ #include "lapack/ungrq.h"
85
90
 
86
91
  #endif /* NUMO_LINALG_ALT_LINALG_H */
@@ -20,6 +20,14 @@ char get_jobz(VALUE val) {
20
20
  return jobz;
21
21
  }
22
22
 
23
+ char get_jobvs(VALUE val) {
24
+ const char jobvs = NUM2CHR(val);
25
+ if (jobvs != 'N' && jobvs != 'V') {
26
+ rb_raise(rb_eArgError, "jobvs must be 'N' or 'V'");
27
+ }
28
+ return jobvs;
29
+ }
30
+
23
31
  char get_range(VALUE val) {
24
32
  const char range = NUM2CHR(val);
25
33
 
@@ -8,6 +8,7 @@
8
8
 
9
9
  lapack_int get_itype(VALUE val);
10
10
  char get_jobz(VALUE val);
11
+ char get_jobvs(VALUE val);
11
12
  char get_range(VALUE val);
12
13
  char get_uplo(VALUE val);
13
14
  int get_matrix_layout(VALUE val);
@@ -5,6 +5,6 @@ module Numo
5
5
  # Numo::Linalg Alternative (numo-linalg-alt) is an alternative to Numo::Linalg.
6
6
  module Linalg
7
7
  # The version of numo-linalg-alt you install.
8
- VERSION = '0.3.0'
8
+ VERSION = '0.4.1'
9
9
  end
10
10
  end
data/lib/numo/linalg.rb CHANGED
@@ -560,6 +560,127 @@ module Numo
560
560
  [q, r]
561
561
  end
562
562
 
563
+ # Computes the RQ decomposition of a matrix.
564
+ #
565
+ # @example
566
+ # require 'numo/linalg'
567
+ #
568
+ # a = Numo::DFloat.new(2, 3).rand
569
+ # r, q = Numo::Linalg.rq(a)
570
+ # pp r
571
+ # # =>
572
+ # # Numo::DFloat#shape=[2,3]
573
+ # # [[0, -0.381748, -0.79309],
574
+ # # [0, 0, -0.41502]]
575
+ # pp q
576
+ # # =>
577
+ # # Numo::DFloat#shape=[3,3]
578
+ # # [[0.227957, 0.874475, -0.428169],
579
+ # # [0.844617, -0.396377, -0.359872],
580
+ # # [-0.484416, -0.279603, -0.828953]]
581
+ # puts (a - r.dot(q)).abs.max
582
+ # # => 5.551115123125783e-17
583
+ #
584
+ # r, q = Numo::Linalg.rq(a, mode: 'economic')
585
+ # pp r
586
+ # # =>
587
+ # # Numo::DFloat#shape=[2,2]
588
+ # # [[-0.381748, -0.79309],
589
+ # # [0, -0.41502]]
590
+ # pp q
591
+ # # =>
592
+ # # Numo::DFloat#shape=[2,3]
593
+ # # [[0.844617, -0.396377, -0.359872],
594
+ # # [-0.484416, -0.279603, -0.828953]]
595
+ # puts (a - r.dot(q)).abs.max
596
+ # # => 5.551115123125783e-17
597
+ #
598
+ # @param a [Numo::NArray] The m-by-n matrix to be decomposed.
599
+ # @param mode [String] The mode of decomposition.
600
+ # - "full" -- returns both R [m, n] and Q [n, n],
601
+ # - "r" -- returns only R,
602
+ # - "economic" -- returns both R [m, k] and Q [k, n], where k = min(m, n).
603
+ # @return [Array<Numo::NArray>/Numo::NArray]
604
+ # if mode='full' or 'economic', returns [R, Q].
605
+ # if mode='r', returns R.
606
+ def rq(a, mode: 'full') # rubocop:disable Metrics/AbcSize
607
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
608
+ raise ArgumentError, "invalid mode: #{mode}" unless %w[full r economic].include?(mode)
609
+
610
+ bchr = blas_char(a)
611
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
612
+
613
+ fnc = :"#{bchr}gerqf"
614
+ rq, tau, info = Numo::Linalg::Lapack.send(fnc, a.dup)
615
+ raise "the #{-info}-th argument of #{fnc} had illegal value" if info.negative?
616
+
617
+ m, n = rq.shape
618
+ r = rq.triu(n - m).dup
619
+ r = r[true, (n - m)...n].dup if mode == 'economic' && n > m
620
+
621
+ return r if mode == 'r'
622
+
623
+ fnc = %w[d s].include?(bchr) ? :"#{bchr}orgrq" : :"#{bchr}ungrq"
624
+ tmp = if n < m
625
+ rq[(m - n)...m, 0...n].dup
626
+ elsif mode == 'economic'
627
+ rq.dup
628
+ else
629
+ rq.class.zeros(n, n).tap { |mat| mat[(n - m)...n, true] = rq }
630
+ end
631
+
632
+ q, info = Numo::Linalg::Lapack.send(fnc, tmp, tau)
633
+ raise "the #{-info}-th argument of #{fnc} had illegal value" if info.negative?
634
+
635
+ [r, q]
636
+ end
637
+
638
+ # Computes the QZ decomposition (generalized Schur decomposition) of a pair of square matrices.
639
+ #
640
+ # The QZ decomposition is given by `A = Q * AA * Z^H` and `B = Q * BB * Z^H`,
641
+ # where `A` and `B` are the input matrices, `Q` and `Z` are unitary matrices,
642
+ # and `AA` and `BB` are upper triangular matrices (or quasi-upper triangular matrices in real case).
643
+ #
644
+ # @example
645
+ # require 'numo/linalg'
646
+ #
647
+ # a = Numo::DFloat.new(5, 5).rand
648
+ # b = Numo::DFloat.new(5, 5).rand
649
+ #
650
+ # aa, bb, q, z = Numo::Linalg.qz(a, b)
651
+ #
652
+ # pp (a - q.dot(aa).dot(z.transpose)).abs.max
653
+ # # => 1.7763568394002505e-15
654
+ # pp (b - q.dot(bb).dot(z.transpose)).abs.max
655
+ # # => 1.1102230246251565e-15
656
+ #
657
+ # @param a [Numo::NArray] The n-by-n square matrix.
658
+ # @param b [Numo::NArray] The n-by-n square matrix.
659
+ # @return [Array<Numo::NArray, Numo::NArray, Numo::NArray, Numo::NArray>]
660
+ # The matrices `AA`, `BB`, `Q`, and `Z` such that `A = Q * AA * Z^H` and `B = Q * BB * Z^H`.
661
+ def qz(a, b) # rubocop:disable Metrics/AbcSize
662
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
663
+ raise Numo::NArray::ShapeError, 'input array b must be 2-dimensional' if b.ndim != 2
664
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
665
+ raise Numo::NArray::ShapeError, 'input array b must be square' if b.shape[0] != b.shape[1]
666
+ raise Numo::NArray::ShapeError, "incompatible dimensions: a.shape = #{a.shape} != b.shape = #{b.shape}" if a.shape != b.shape
667
+
668
+ bchr = blas_char(a, b)
669
+ raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'
670
+
671
+ fnc = :"#{bchr}gges"
672
+ if %w[d s].include?(bchr)
673
+ aa, bb, _ar, _ai, _beta, q, z, _sdim, info = Numo::Linalg::Lapack.send(fnc, a.dup, b.dup)
674
+ else
675
+ aa, bb, _alpha, _beta, q, z, _sdim, info = Numo::Linalg::Lapack.send(fnc, a.dup, b.dup)
676
+ end
677
+
678
+ raise "the #{-info}-th argument of #{fnc} had illegal value" if info.negative?
679
+ raise 'the QZ algorithm failed.' if info.positive?
680
+
681
+ [aa, bb, q, z]
682
+ end
683
+
563
684
  # Computes the Schur decomposition of a square matrix.
564
685
  # The Schur decomposition is given by `A = Z * T * Z^H`,
565
686
  # where `A` is the input matrix, `Z` is a unitary matrix,
@@ -604,11 +725,10 @@ module Numo
604
725
  raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
605
726
 
606
727
  fnc = :"#{bchr}gees"
607
- b = a.dup
608
728
  if %w[d s].include?(bchr)
609
- _wr, _wi, v, sdim, info = Numo::Linalg::Lapack.send(fnc, b, jobvs: 'V', sort: sort)
729
+ b, _wr, _wi, v, sdim, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvs: 'V', sort: sort)
610
730
  else
611
- _w, v, sdim, info = Numo::Linalg::Lapack.send(fnc, b, jobvs: 'V', sort: sort)
731
+ b, _w, v, sdim, info = Numo::Linalg::Lapack.send(fnc, a.dup, jobvs: 'V', sort: sort)
612
732
  end
613
733
 
614
734
  n = a.shape[0]
@@ -1069,6 +1189,102 @@ module Numo
1069
1189
  [r, scale]
1070
1190
  end
1071
1191
 
1192
+ # Computes a diagonal similarity transformation that balances a square matrix.
1193
+ #
1194
+ # @example
1195
+ # require 'numo/linalg'
1196
+ #
1197
+ # a = Numo::DFloat[[1, 0, 0], [1, 2, 0], [1, 2, 3]]
1198
+ # b, h = Numo::Linalg.matrix_balance(a)
1199
+ # pp b
1200
+ # # =>
1201
+ # # Numo::DFloat#shape=[3,3]
1202
+ # # [[3, 2, 1],
1203
+ # # [0, 2, 1],
1204
+ # # [0, 0, 1]]
1205
+ # pp h
1206
+ # # =>
1207
+ # # Numo::DFloat#shape=[3,3]
1208
+ # # [[0, 0, 1],
1209
+ # # [0, 1, 0],
1210
+ # # [1, 0, 0]]
1211
+ # pp (Numo::Linalg.inv(h).dot(a).dot(h) - b).abs.max
1212
+ # # => 0.0
1213
+ #
1214
+ # @param a [Numo::NArray] The n-by-n square matrix.
1215
+ # @param permute [Boolean] The flag indicating whether to permute the matrix.
1216
+ # @param scale [Boolean] The flag indicating whether to scale the matrix.
1217
+ # @param separate [Boolean] The flag indicating whether to return scaling factors and permutation indices
1218
+ # separately.
1219
+ # @return [Array<Numo::NArray, Numo::NArray>] if `separate` is `false`, the balanced matrix and the
1220
+ # similarity transformation matrix `H` ([b, h]). if `separate` is `true`, the balanced matrix, the
1221
+ # scaling factors, and the permutation indices ([b, scaler, perm]).
1222
+ def matrix_balance(a, permute: true, scale: true, separate: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
1223
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1224
+
1225
+ n = a.shape[0]
1226
+ raise ArgumentError, 'input array a must be square' if a.shape[1] != n
1227
+
1228
+ bchr = blas_char(a)
1229
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
1230
+
1231
+ job = if permute && scale
1232
+ 'B'
1233
+ elsif permute && !scale
1234
+ 'P'
1235
+ elsif !permute && scale
1236
+ 'S'
1237
+ else
1238
+ 'N'
1239
+ end
1240
+ fnc = :"#{bchr}gebal"
1241
+ b, lo, hi, prm_scl, info = Numo::Linalg::Lapack.send(fnc, a.dup, job: job)
1242
+
1243
+ raise "the #{info.abs}-th argument of #{fnc} had illegal value" if info.negative?
1244
+
1245
+ # convert from Fortran style index to Ruby style index.
1246
+ lo -= 1
1247
+ hi -= 1
1248
+ iprm_scl = Numo::Int32.cast(prm_scl) - 1
1249
+
1250
+ # extract scaling factors
1251
+ scaler = prm_scl.class.ones(n)
1252
+ scaler[lo...(hi + 1)] = prm_scl[lo...(hi + 1)]
1253
+
1254
+ # extract permutation indices
1255
+ perm = Numo::Int32.new(n).seq
1256
+ if hi < n - 1
1257
+ iprm_scl[(hi + 1)...n].to_a.reverse.each.with_index(1) do |s, i|
1258
+ j = n - i
1259
+ next if s == j
1260
+
1261
+ tmp_ls, tmp_lj = perm[[s, j]].to_a
1262
+ tmp_rj, tmp_rs = perm[[j, s]].to_a
1263
+ perm[[s, j]] = [tmp_rj, tmp_rs]
1264
+ perm[[j, s]] = [tmp_ls, tmp_lj]
1265
+ end
1266
+ end
1267
+ if lo > 0 # rubocop:disable Style/NumericPredicate
1268
+ iprm_scl[0...lo].to_a.each_with_index do |s, j|
1269
+ next if s == j
1270
+
1271
+ tmp_ls, tmp_lj = perm[[s, j]].to_a
1272
+ tmp_rj, tmp_rs = perm[[j, s]].to_a
1273
+ perm[[s, j]] = [tmp_rj, tmp_rs]
1274
+ perm[[j, s]] = [tmp_ls, tmp_lj]
1275
+ end
1276
+ end
1277
+
1278
+ return [b, scaler, perm] if separate
1279
+
1280
+ # construct inverse permutation matrix
1281
+ inv_perm = Numo::Int32.zeros(n)
1282
+ inv_perm[perm] = Numo::Int32.new(n).seq
1283
+ h = scaler.diag[inv_perm, true].dup
1284
+
1285
+ [b, h]
1286
+ end
1287
+
1072
1288
  # Computes the eigenvalues and right and/or left eigenvectors of a general square matrix.
1073
1289
  #
1074
1290
  # @example
@@ -1381,6 +1597,22 @@ module Numo
1381
1597
  end
1382
1598
  end
1383
1599
 
1600
+ # Computes the matrix tangent.
1601
+ #
1602
+ # @param a [Numo::NArray] The n-by-n square matrix.
1603
+ # @return [Numo::NArray] The matrix tangent of `a`.
1604
+ def tanm(a)
1605
+ raise Numo::NArray::ShapeError, 'input array a must be 2-dimensional' if a.ndim != 2
1606
+ raise Numo::NArray::ShapeError, 'input array a must be square' if a.shape[0] != a.shape[1]
1607
+
1608
+ bchr = blas_char(a)
1609
+ raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'
1610
+
1611
+ a_sin = sinm(a)
1612
+ a_cos = cosm(a)
1613
+ a_sin.dot(Numo::Linalg.inv(a_cos))
1614
+ end
1615
+
1384
1616
  # Computes the inverse of a matrix using its LU decomposition.
1385
1617
  #
1386
1618
  # @param lu [Numo::NArray] The LU decomposition of the n-by-n matrix `A`.
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: numo-linalg-alt
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.4.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
@@ -50,6 +50,8 @@ files:
50
50
  - ext/numo/linalg/converter.c
51
51
  - ext/numo/linalg/converter.h
52
52
  - ext/numo/linalg/extconf.rb
53
+ - ext/numo/linalg/lapack/gebal.c
54
+ - ext/numo/linalg/lapack/gebal.h
53
55
  - ext/numo/linalg/lapack/gees.c
54
56
  - ext/numo/linalg/lapack/gees.h
55
57
  - ext/numo/linalg/lapack/geev.c
@@ -58,6 +60,8 @@ files:
58
60
  - ext/numo/linalg/lapack/gelsd.h
59
61
  - ext/numo/linalg/lapack/geqrf.c
60
62
  - ext/numo/linalg/lapack/geqrf.h
63
+ - ext/numo/linalg/lapack/gerqf.c
64
+ - ext/numo/linalg/lapack/gerqf.h
61
65
  - ext/numo/linalg/lapack/gesdd.c
62
66
  - ext/numo/linalg/lapack/gesdd.h
63
67
  - ext/numo/linalg/lapack/gesv.c
@@ -70,6 +74,8 @@ files:
70
74
  - ext/numo/linalg/lapack/getri.h
71
75
  - ext/numo/linalg/lapack/getrs.c
72
76
  - ext/numo/linalg/lapack/getrs.h
77
+ - ext/numo/linalg/lapack/gges.c
78
+ - ext/numo/linalg/lapack/gges.h
73
79
  - ext/numo/linalg/lapack/heev.c
74
80
  - ext/numo/linalg/lapack/heev.h
75
81
  - ext/numo/linalg/lapack/heevd.c
@@ -88,6 +94,8 @@ files:
88
94
  - ext/numo/linalg/lapack/lange.h
89
95
  - ext/numo/linalg/lapack/orgqr.c
90
96
  - ext/numo/linalg/lapack/orgqr.h
97
+ - ext/numo/linalg/lapack/orgrq.c
98
+ - ext/numo/linalg/lapack/orgrq.h
91
99
  - ext/numo/linalg/lapack/potrf.c
92
100
  - ext/numo/linalg/lapack/potrf.h
93
101
  - ext/numo/linalg/lapack/potri.c
@@ -112,6 +120,8 @@ files:
112
120
  - ext/numo/linalg/lapack/trtrs.h
113
121
  - ext/numo/linalg/lapack/ungqr.c
114
122
  - ext/numo/linalg/lapack/ungqr.h
123
+ - ext/numo/linalg/lapack/ungrq.c
124
+ - ext/numo/linalg/lapack/ungrq.h
115
125
  - ext/numo/linalg/linalg.c
116
126
  - ext/numo/linalg/linalg.h
117
127
  - ext/numo/linalg/util.c
@@ -126,7 +136,7 @@ metadata:
126
136
  homepage_uri: https://github.com/yoshoku/numo-linalg-alt
127
137
  source_code_uri: https://github.com/yoshoku/numo-linalg-alt
128
138
  changelog_uri: https://github.com/yoshoku/numo-linalg-alt/blob/main/CHANGELOG.md
129
- documentation_uri: https://gemdocs.org/gems/numo-linalg-alt/0.3.0/
139
+ documentation_uri: https://gemdocs.org/gems/numo-linalg-alt/0.4.1/
130
140
  rubygems_mfa_required: 'true'
131
141
  rdoc_options: []
132
142
  require_paths: