numo-linalg 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +7 -3
  3. data/ext/numo/linalg/blas/extconf.rb +1 -2
  4. data/ext/numo/linalg/blas/numo_blas.h +6 -0
  5. data/ext/numo/linalg/blas/tmpl/mv.c +3 -2
  6. data/ext/numo/linalg/lapack/gen/spec.rb +5 -0
  7. data/ext/numo/linalg/lapack/lapack.c +29 -0
  8. data/ext/numo/linalg/lapack/numo_lapack.h +3 -0
  9. data/ext/numo/linalg/lapack/tmpl/gqr.c +1 -1
  10. data/ext/numo/linalg/lapack/tmpl/sygvx.c +130 -0
  11. data/ext/numo/linalg/mkmf_linalg.rb +2 -19
  12. data/lib/numo/linalg/function.rb +168 -77
  13. data/lib/numo/linalg/loader.rb +6 -14
  14. data/lib/numo/linalg/version.rb +1 -1
  15. data/numo-linalg.gemspec +2 -1
  16. data/spec/linalg/autoloader_spec.rb +27 -0
  17. data/spec/linalg/function/cho_fact_spec.rb +31 -0
  18. data/spec/linalg/function/cho_inv_spec.rb +39 -0
  19. data/spec/linalg/function/cho_solve_spec.rb +66 -0
  20. data/spec/linalg/function/cholesky_spec.rb +43 -0
  21. data/spec/linalg/function/cond_spec.rb +57 -0
  22. data/spec/linalg/function/det_spec.rb +21 -0
  23. data/spec/linalg/function/dot_spec.rb +84 -0
  24. data/spec/linalg/function/eig_spec.rb +53 -0
  25. data/spec/linalg/function/eigh_spec.rb +81 -0
  26. data/spec/linalg/function/eigvals_spec.rb +27 -0
  27. data/spec/linalg/function/eigvalsh_spec.rb +60 -0
  28. data/spec/linalg/function/inv_spec.rb +57 -0
  29. data/spec/linalg/function/ldl_spec.rb +51 -0
  30. data/spec/linalg/function/lstsq_spec.rb +80 -0
  31. data/spec/linalg/function/lu_fact_spec.rb +34 -0
  32. data/spec/linalg/function/lu_inv_spec.rb +21 -0
  33. data/spec/linalg/function/lu_solve_spec.rb +40 -0
  34. data/spec/linalg/function/lu_spec.rb +46 -0
  35. data/spec/linalg/function/matmul_spec.rb +41 -0
  36. data/spec/linalg/function/matrix_power_spec.rb +31 -0
  37. data/spec/linalg/function/matrix_rank_spec.rb +33 -0
  38. data/spec/linalg/function/norm_spec.rb +81 -0
  39. data/spec/linalg/function/pinv_spec.rb +48 -0
  40. data/spec/linalg/function/qr_spec.rb +82 -0
  41. data/spec/linalg/function/slogdet_spec.rb +21 -0
  42. data/spec/linalg/function/solve_spec.rb +98 -0
  43. data/spec/linalg/function/svd_spec.rb +88 -0
  44. data/spec/linalg/function/svdvals_spec.rb +40 -0
  45. data/spec/spec_helper.rb +55 -0
  46. metadata +79 -6
  47. data/spec/lapack_spec.rb +0 -13
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 611dd9ee204102a1688b5d199b95b5376d32239b04dfcf69a605da54599a9310
4
- data.tar.gz: d999af4188a239e9a78b7f071d850f2b11bf822c1c0237c592c6ebe4979b278d
3
+ metadata.gz: ef1ecaaf8a71faee9a7cef26bc35fae50230224856ab5dd6186cee93e5cca69d
4
+ data.tar.gz: f2e5e5944ad4832ef4ad4db05f6e94142fefb780e19282d94ece2a055cf447a0
5
5
  SHA512:
6
- metadata.gz: 7da7684a72653a7718de5ff51a99ed14ca0dc3770cae735139487cf56ef2610f292cf3e970abd681320d8a1847664dcb7d6e9ac2c936867abb2cae1154bdb053
7
- data.tar.gz: 70bb3e22d41fee16e953b53467b709cdd96bd80adceb526123df941335782054823bb84a9d703ca45e73e88c0ad34b7cbf90699a53cc7e7dcb1277251e7fa0e6
6
+ metadata.gz: 3fcf1506aa63ccbba57208e358102ae0dc537cd80b530e4b359b07d01a7cb1da11261dfd6d7028224175ae0f38014ec0373b99fc339caa2b952ad48c0fb26fd7
7
+ data.tar.gz: c7a9eb8b65719571120a6a122721e1eaadc240164e182a5095ff45018453a90f4cc48f99ef219305827beace8839bdebb0d27d35dfb5f64ec64f54d5406a8910
data/README.md CHANGED
@@ -18,7 +18,7 @@ This desgin allows you to change backend libraries without re-compiling.
18
18
  * Matrix and vector products
19
19
  * dot, matmul
20
20
  * Decomposition
21
- * lu\_fact, lu\_inv, lu\_solve, cho\_fact, cho\_inv, cho\_solve
21
+ * lu, lu\_fact, lu\_inv, lu\_solve, ldl, cholesky, cho\_fact, cho\_inv, cho\_solve,
22
22
  qr, svd, svdvals
23
23
  * Matrix eigenvalues
24
24
  * eig, eigh, eigvals, eigvalsh
@@ -78,8 +78,12 @@ require "numo/linalg"
78
78
 
79
79
  ## Authors
80
80
 
81
- * Masahiro TANAKA
82
- * Makoto KISHIMOTO
81
+ * Masahiro Tanaka
82
+ * Makoto Kishimoto
83
+ * Atsushi Tatsuma
84
+
85
+ ## Acknowledgement
86
+
83
87
  * This work is partly supported by 2016 Ruby Association Grant.
84
88
 
85
89
  ## ToDo
@@ -1,5 +1,4 @@
1
1
  require 'mkmf'
2
- require 'numo/narray'
3
2
  require_relative '../mkmf_linalg'
4
3
 
5
4
  srcs = %w(
@@ -21,7 +20,7 @@ if !have_header("numo/narray.h")
21
20
  exit(1)
22
21
  end
23
22
 
24
- if RUBY_PLATFORM =~ /cygwin|mingw/
23
+ if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
25
24
  find_libnarray_a
26
25
  unless have_library("narray","nary_new")
27
26
  puts "libnarray.a not found"
@@ -31,6 +31,12 @@ extern void numo_cblas_check_func(void **func, const char *name);
31
31
  #define SWAP_IFROW(order,a,b,tmp) \
32
32
  { if ((order)==CblasRowMajor) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
33
33
 
34
+ #define SWAP_IFNOTRANS(trans,a,b,tmp) \
35
+ { if ((trans)==CblasNoTrans) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
36
+
37
+ #define SWAP_IFTRANS(trans,a,b,tmp) \
38
+ { if ((trans)!=CblasNoTrans) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
39
+
34
40
  #define SWAP_IFCOLTR(order,trans,a,b,tmp) \
35
41
  { if (((order)==CblasRowMajor && (trans)!=CblasNoTrans) || \
36
42
  ((order)!=CblasRowMajor && (trans)==CblasNoTrans)) \
@@ -79,7 +79,7 @@ static void
79
79
  opt("beta"),
80
80
  !is_ge && opt("side"),
81
81
  !is_ge && opt("uplo"),
82
- is_ge || is_tr && opt("trans"),
82
+ (is_ge || is_tr) && opt("trans"),
83
83
  opt("order")
84
84
  ].select{|x| x}.join("\n ")
85
85
  %>
@@ -143,9 +143,10 @@ static VALUE
143
143
  CHECK_DIM_GE(na2,1);
144
144
  nx = COL_SIZE(na2);
145
145
  #if GE
146
- SWAP_IFCOLTR(g.order,g.trans, ma,na, tmp);
146
+ SWAP_IFCOL(g.order, ma, na, tmp);
147
147
  g.m = ma;
148
148
  g.n = na;
149
+ SWAP_IFTRANS(g.trans, ma, na, tmp);
149
150
  #else
150
151
  CHECK_SQUARE("a",na1);
151
152
  #endif
@@ -6,6 +6,9 @@ def_id "jobz"
6
6
  def_id "jobvl"
7
7
  def_id "jobvr"
8
8
  def_id "trans"
9
+ def_id "range"
10
+ def_id "il"
11
+ def_id "iu"
9
12
  def_id "rcond"
10
13
  def_id "itype"
11
14
  def_id "norm"
@@ -58,11 +61,13 @@ when /c|z/
58
61
  decl "?heevd", "syev"
59
62
  decl "?hegv", "sygv"
60
63
  decl "?hegvd", "sygv"
64
+ decl "?hegvx", "sygvx"
61
65
  else
62
66
  decl "?syev"
63
67
  decl "?syevd", "syev"
64
68
  decl "?sygv"
65
69
  decl "?sygvd", "sygv"
70
+ decl "?sygvx"
66
71
  end
67
72
 
68
73
  # factorize
@@ -124,6 +124,35 @@ numo_lapacke_option_job(VALUE job, char true_char, char false_char)
124
124
  return 0;
125
125
  }
126
126
 
127
+ char
128
+ numo_lapacke_option_range(VALUE job, char true_char, char false_char)
129
+ {
130
+ char *ptr, c;
131
+
132
+ switch(TYPE(job)) {
133
+ case T_NIL:
134
+ case T_UNDEF:
135
+ case T_TRUE:
136
+ return true_char;
137
+ case T_FALSE:
138
+ return false_char;
139
+ case T_SYMBOL:
140
+ job = rb_sym2str(job);
141
+ case T_STRING:
142
+ ptr = RSTRING_PTR(job);
143
+ if (RSTRING_LEN(job) > 0) {
144
+ c = ptr[0];
145
+ if (c >= 'a' && c <= 'z') {
146
+ c -= 'a'-'A';
147
+ }
148
+ return c;
149
+ }
150
+ break;
151
+ }
152
+ rb_raise(rb_eArgError,"invalid value for JOB option");
153
+ return 0;
154
+ }
155
+
127
156
  char
128
157
  numo_lapacke_option_trans(VALUE trans)
129
158
  {
@@ -15,6 +15,9 @@ extern int numo_lapacke_option_order(VALUE order);
15
15
  #define option_job numo_lapacke_option_job
16
16
  extern char numo_lapacke_option_job(VALUE job, char true_char, char false_char);
17
17
 
18
+ #define option_range numo_lapacke_option_range
19
+ extern char numo_lapacke_option_range(VALUE range, char true_char, char false_char);
20
+
18
21
  #define option_trans numo_lapacke_option_trans
19
22
  extern char numo_lapacke_option_trans(VALUE trans);
20
23
 
@@ -30,7 +30,7 @@ static void
30
30
  SWAP_IFCOL(g->order,m,n);
31
31
  lda = NDL_STEP(lp,0) / sizeof(dtype);
32
32
 
33
- printf("order=%d m=%d n=%d k=%d lda=%d \n",g->order,m,n,k,lda);
33
+ //printf("order=%d m=%d n=%d k=%d lda=%d \n",g->order,m,n,k,lda);
34
34
 
35
35
  *info = (*func_p)(g->order, m, n, k, a, lda, tau);
36
36
  CHECK_ERROR(*info);
@@ -0,0 +1,130 @@
1
+ #define args_t <%=func_name%>_args_t
2
+ #define func_p <%=func_name%>_p
3
+
4
+ typedef struct {
5
+ int order;
6
+ int itype;
7
+ char jobz;
8
+ char uplo;
9
+ char range;
10
+ int il;
11
+ int iu;
12
+ } args_t;
13
+
14
+ static <%=func_name%>_t func_p = 0;
15
+
16
+ static void
17
+ <%=c_iter%>(na_loop_t * const lp)
18
+ {
19
+ dtype *a, *b, *z;
20
+ rtype *w;
21
+ int *ifail;
22
+ int *info;
23
+ int m, n, lda, ldb, ldz;
24
+ rtype vl = 0, vu = 0;
25
+ rtype abstol = 0;
26
+
27
+ args_t *g;
28
+
29
+ a = (dtype*)NDL_PTR(lp, 0);
30
+ b = (dtype*)NDL_PTR(lp, 1);
31
+ w = (rtype*)NDL_PTR(lp, 2);
32
+ z = (dtype*)NDL_PTR(lp, 3);
33
+ ifail = (int*)NDL_PTR(lp, 4);
34
+ info = (int*)NDL_PTR(lp, 5);
35
+ g = (args_t*)(lp->opt_ptr);
36
+
37
+ n = NDL_SHAPE(lp, 0)[1];
38
+ lda = NDL_STEP(lp, 0) / sizeof(dtype);
39
+ ldb = NDL_STEP(lp, 1) / sizeof(dtype);
40
+ ldz = NDL_SHAPE(lp, 3)[1];
41
+
42
+ *info = (*func_p)( g->order, g->itype, g->jobz, g->range, g->uplo, n, a, lda, b, ldb,
43
+ vl, vu, g->il, g->iu, abstol, &m, w, z, ldz, ifail );
44
+
45
+ CHECK_ERROR(*info);
46
+ }
47
+
48
+ /*
49
+ <%
50
+ params = [
51
+ mat("a",:inplace),
52
+ mat("b",:inplace),
53
+ "@param [Integer] itype Specifies the problem type to be solved. If 1: A*x = (lambda)*B*x, If 2: A*B*x = (lambda)*x, If 3: B*A*x = (lambda)*x.",
54
+ jobe("jobz"),
55
+ opt("uplo"),
56
+ opt("order"),
57
+ "@param [String or Symbol] range If 'A': Compute all eigenvalues, if 'I': Compute eigenvalues with indices il to iu (default='A')",
58
+ "@param [Integer] il Specifies the index of the smallest eigenvalue in ascending order to be returned. If range = 'A', il is not referenced.",
59
+ "@param [Integer] iu Specifies the index of the largest eigenvalue in ascending order to be returned. Constraint: 1<=il<=iu<=N. If range = 'A', iu is not referenced.",
60
+ ].select{|x| x}.join("\n ")
61
+ return_name="a, b, w, z, ifail, info"
62
+ %>
63
+ @overload <%=name%>(a, b, [itype:1, jobz:'V', uplo:'U', order:'R', range:'I', il: 1, il: 2])
64
+ <%=params%>
65
+ @return [[<%=return_name%>]]
66
+ Array<<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,Integer>
67
+ <%=outparam(return_name)%>
68
+
69
+ <%=description%>
70
+ */
71
+
72
+ static VALUE
73
+ <%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
74
+ {
75
+ VALUE a, b, ans;
76
+ int n, nb, m;
77
+ narray_t *na1, *na2;
78
+ size_t w_shape[1];
79
+ size_t z_shape[2];
80
+ size_t ifail_shape[1];
81
+
82
+ ndfunc_arg_in_t ain[2] = {{OVERWRITE, 2}, {OVERWRITE, 2}};
83
+ ndfunc_arg_out_t aout[4] = {{cRT, 1, w_shape}, {cT, 2, z_shape}, {cI, 1, ifail_shape}, {cInt, 0}};
84
+ ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP | NDF_EXTRACT, 2, 4, ain, aout};
85
+
86
+ args_t g;
87
+ VALUE opts[7] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
88
+ VALUE kw_hash = Qnil;
89
+ ID kw_table[7] = {id_order, id_jobz, id_uplo, id_itype, id_range, id_il, id_iu};
90
+
91
+ CHECK_FUNC(func_p,"<%=func_name%>");
92
+
93
+ rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
94
+ rb_get_kwargs(kw_hash, kw_table, 0, 7, opts);
95
+ g.order = option_order(opts[0]);
96
+ g.jobz = option_job(opts[1], 'V', 'N');
97
+ g.uplo = option_uplo(opts[2]);
98
+ g.itype = NUM2INT(option_value(opts[3], INT2FIX(1)));
99
+ g.range = option_range(opts[4], 'A', 'I');
100
+ g.il = NUM2INT(option_value(opts[5], INT2FIX(1)));
101
+ g.iu = NUM2INT(option_value(opts[6], INT2FIX(1)));
102
+
103
+ COPY_OR_CAST_TO(a, cT);
104
+ GetNArray(a, na1);
105
+ CHECK_DIM_GE(na1, 2);
106
+
107
+ COPY_OR_CAST_TO(b, cT);
108
+ GetNArray(b, na2);
109
+ CHECK_DIM_GE(na2, 2);
110
+ CHECK_SQUARE("matrix a", na1);
111
+ n = COL_SIZE(na1);
112
+ CHECK_SQUARE("matrix b", na2);
113
+ nb = COL_SIZE(na2);
114
+ if (n != nb) {
115
+ rb_raise(nary_eShapeError, "matrix a and b must have same size");
116
+ }
117
+
118
+ m = g.range == 'I' ? g.iu - g.il + 1 : n;
119
+ w_shape[0] = m;
120
+ z_shape[0] = n;
121
+ z_shape[1] = m;
122
+ ifail_shape[0] = m;
123
+
124
+ ans = na_ndloop3(&ndf, &g, 2, a, b);
125
+
126
+ return rb_ary_new3(6, a, b, RARRAY_AREF(ans, 0), RARRAY_AREF(ans, 1), RARRAY_AREF(ans, 2), RARRAY_AREF(ans, 3));
127
+ }
128
+
129
+ #undef args_t
130
+ #undef func_p
@@ -17,24 +17,6 @@ def create_site_conf
17
17
  FileUtils.mkdir_p "lib"
18
18
 
19
19
  ext = detect_library_extension
20
- need_version = false
21
- if ext == 'so'
22
- begin
23
- Fiddle.dlopen "libm.so"
24
- rescue
25
- (5..7).each do |i|
26
- begin
27
- Fiddle.dlopen "libm.so.#{i}"
28
- need_version = true
29
- break
30
- rescue
31
- end
32
- end
33
- if !need_version
34
- raise "failed to check whether dynamically linked shared object needs version suffix"
35
- end
36
- end
37
- end
38
20
 
39
21
  open("lib/site_conf.rb","w"){|f| f.write "
40
22
  module Numo
@@ -49,7 +31,6 @@ module Numo
49
31
 
50
32
  module Loader
51
33
  EXT = '#{ext}'
52
- NEED_VERSION_SUFFIX = #{need_version}
53
34
  end
54
35
 
55
36
  end
@@ -68,6 +49,8 @@ def detect_library_extension
68
49
  end
69
50
  end
70
51
 
52
+ require 'numo/narray'
53
+
71
54
  def find_narray_h
72
55
  $LOAD_PATH.each do |x|
73
56
  if File.exist? File.join(x,'numo/numo/narray.h')
@@ -65,7 +65,7 @@ module Numo; module Linalg
65
65
  NArray.array_type(a)
66
66
  end
67
67
  if k && k < NArray
68
- t = k::UPCAST[t]
68
+ t = k::UPCAST[t] || t::UPCAST[k]
69
69
  end
70
70
  end
71
71
  BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
@@ -86,7 +86,8 @@ module Numo; module Linalg
86
86
  when 1
87
87
  case b.ndim
88
88
  when 1
89
- Blas.call(:dot, a, b)
89
+ func = blas_char(a, b) =~ /c|z/ ? :dotu : :dot
90
+ Blas.call(func, a, b)
90
91
  else
91
92
  Blas.call(:gemv, b, a, trans:'t')
92
93
  end
@@ -194,10 +195,10 @@ module Numo; module Linalg
194
195
  #
195
196
  # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
196
197
  # @param mode [String]
197
- # - "reduce" -- returns both Q and R,
198
- # - "r" -- returns only R,
199
- # - "economy" -- returns both Q and R but computed in economy-size,
200
- # - "raw" -- returns QR and TAU used in LAPACK.
198
+ # - "reduce" -- returns both Q and R,
199
+ # - "r" -- returns only R,
200
+ # - "economic" -- returns both Q and R but computed in economy-size,
201
+ # - "raw" -- returns QR and TAU used in LAPACK.
201
202
  # @return [r] if mode:"r"
202
203
  # @return [[q,r]] if mode:"reduce" or "economic"
203
204
  # @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
@@ -295,6 +296,38 @@ module Numo; module Linalg
295
296
  end
296
297
  end
297
298
 
299
+ # Computes an LU factorization of a M-by-N matrix A
300
+ # using partial pivoting with row interchanges.
301
+ #
302
+ # The factorization has the form
303
+ #
304
+ # A = P * L * U
305
+ #
306
+ # where P is a permutation matrix, L is lower triangular with unit
307
+ # diagonal elements (lower trapezoidal if m > n), and U is upper
308
+ # triangular (upper trapezoidal if m < n).
309
+ #
310
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
311
+ # @param permute_l [Bool] (optional) If true, perform the matrix product of P and L.
312
+ # @return [[p,l,u]] if permute_l == false
313
+ # @return [[pl,u]] if permute_l == true
314
+ #
315
+ # - **p** [Numo::NArray] -- The permutation matrix P.
316
+ # - **l** [Numo::NArray] -- The factor L.
317
+ # - **u** [Numo::NArray] -- The factor U.
318
+
319
+ def lu(a, permute_l: false)
320
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
321
+ m, n = a.shape
322
+ k = [m, n].min
323
+ lu, ip = lu_fact(a)
324
+ l = lu.tril.tap { |mat| mat[mat.diag_indices(0)] = 1.0 }[true, 0...k]
325
+ u = lu.triu[0...k, 0...n]
326
+ p = Numo::DFloat.eye(m).tap do |mat|
327
+ ip.to_a.each_with_index { |i, j| mat[true, [i - 1, j]] = mat[true, [j, i - 1]].dup }
328
+ end
329
+ permute_l ? [p.dot(l), u] : [p, l, u]
330
+ end
298
331
 
299
332
  # Computes an LU factorization of a M-by-N matrix A
300
333
  # using partial pivoting with row interchanges.
@@ -308,26 +341,14 @@ module Numo; module Linalg
308
341
  # triangular (upper trapezoidal if m < n).
309
342
  #
310
343
  # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
311
- # @param driver [String or Symbol] choose LAPACK diriver from
312
- # 'gen','sym','her'. (optional, default='gen')
313
- # @param uplo [String or Symbol] optional, default='U'. Access upper
314
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
315
344
  # @return [[lu, ipiv]]
316
345
  # - **lu** [Numo::NArray] -- The factors L and U from the factorization
317
346
  # `A = P*L*U`; the unit diagonal elements of L are not stored.
318
347
  # - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
319
348
  # row i of the matrix was interchanged with row IPIV(i).
320
349
 
321
- def lu_fact(a, driver:"gen", uplo:"U")
322
- case driver.to_s
323
- when /^gen?(trf)?$/i
324
- Lapack.call(:getrf, a)[0..1]
325
- when /^(sym?|her?)(trf)?$/i
326
- func = driver[0..2].downcase+"trf"
327
- Lapack.call(func, a, uplo:uplo)[0..1]
328
- else
329
- raise ArgumentError, "invalid driver: #{driver}"
330
- end
350
+ def lu_fact(a)
351
+ Lapack.call(:getrf, a)[0..1]
331
352
  end
332
353
 
333
354
  # Computes the inverse of a matrix using the LU factorization
@@ -345,22 +366,10 @@ module Numo; module Linalg
345
366
  # @param ipiv [Numo::NArray] The pivot indices from
346
367
  # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
347
368
  # interchanged with row IPIV(i).
348
- # @param driver [String or Symbol] choose LAPACK diriver from
349
- # 'gen','sym','her'. (optional, default='gen')
350
- # @param uplo [String or Symbol] optional, default='U'. Access upper
351
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
352
369
  # @return [Numo::NArray] the inverse of the original matrix A.
353
370
 
354
- def lu_inv(lu, ipiv, driver:"gen", uplo:"U")
355
- case driver.to_s
356
- when /^gen?(tri)?$/i
357
- Lapack.call(:getri, lu, ipiv)[0]
358
- when /^(sym?|her?)(tri)?$/i
359
- func = driver[0..2].downcase+"tri"
360
- Lapack.call(func, lu, ipiv, uplo:uplo)[0]
361
- else
362
- raise ArgumentError, "invalid driver: #{driver}"
363
- end
371
+ def lu_inv(lu, ipiv)
372
+ Lapack.call(:getri, lu, ipiv)[0]
364
373
  end
365
374
 
366
375
  # Solves a system of linear equations
@@ -377,31 +386,100 @@ module Numo; module Linalg
377
386
  # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
378
387
  # interchanged with row IPIV(i).
379
388
  # @param b [Numo::NArray] the right hand side matrix B.
380
- # @param driver [String or Symbol] choose LAPACK diriver from
381
- # 'gen','sym','her'. (optional, default='gen')
382
- # @param uplo [String or Symbol] optional, default='U'. Access upper
383
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
384
389
  # @param trans [String or Symbol]
385
- # Specifies the form of the system of equations
386
- # (omitted if not driver:"gen"):
387
- #
390
+ # Specifies the form of the system of equations:
388
391
  # - If 'N': `A * X = B` (No transpose).
389
392
  # - If 'T': `A*\*T* X = B` (Transpose).
390
393
  # - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
391
394
  # @return [Numo::NArray] the solution matrix X.
392
395
 
393
- def lu_solve(lu, ipiv, b, driver:"gen", uplo:"U", trans:"N")
394
- case driver.to_s
395
- when /^gen?(trs)?$/i
396
- Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
397
- when /^(sym?|her?)(trs)?$/i
398
- func = driver[0..2].downcase+"trs"
399
- Lapack.call(func, lu, ipiv, b, uplo:uplo)[0]
400
- else
401
- raise ArgumentError, "invalid driver: #{driver}"
396
+ def lu_solve(lu, ipiv, b, trans:"N")
397
+ Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
398
+ end
399
+
400
+ # Computes the LDLt or Bunch-Kaufman factorization of a symmetric/Hermitian matrix A.
401
+ # The factorization has the form
402
+ #
403
+ # A = U*D*U**T or A = L*D*L**T
404
+ #
405
+ # where U (or L) is a product of permutation and unit upper (lower) triangular matrices
406
+ # and D is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
407
+ #
408
+ # @param a [Numo::NArray] m-by-m matrix A (>= 2-dimensinal NArray)
409
+ # @param uplo [String or Symbol] optional, default='U'. Access upper or ('U') lower ('L') triangle.
410
+ # @param hermitian [Bool] optional, default=true. If true, hermitian matrix is assumed.
411
+ # (omitted when real-value matrix is given)
412
+ #
413
+ # @return [[lu,d,perm]]
414
+ #
415
+ # - **lu** [Numo::NArray] -- The permutated upper (lower) triangular matrix U (L).
416
+ # - **d** [Numo::NArray] -- The block diagonal matrix D.
417
+ # - **perm** [Numo::NArray] -- The row-permutation index for changing lu into triangular form.
418
+
419
+ def ldl(a, uplo: 'U', hermitian: true)
420
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
421
+ raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
422
+
423
+ is_complex = blas_char(a) =~ /c|z/
424
+ func = is_complex && hermitian ? 'hetrf' : 'sytrf'
425
+ lud, ipiv, = Lapack.call(func.to_sym, a, uplo: uplo)
426
+
427
+ lu = (uplo == 'U' ? lud.triu : lud.tril).tap { |mat| mat[mat.diag_indices(0)] = 1.0 }
428
+ d = lud[lud.diag_indices(0)].diag
429
+
430
+ m = a.shape[0]
431
+ n = m - 1
432
+ changed_2x2 = false
433
+ perm = Numo::Int32.new(m).seq
434
+ m.times do |t|
435
+ i = uplo == 'U' ? t : n - t
436
+ j = uplo == 'U' ? i - 1 : i + 1;
437
+ r = uplo == 'U' ? 0..i : i..n;
438
+ if ipiv[i] > 0
439
+ k = ipiv[i] - 1
440
+ lu[[k, i], r] = lu[[i, k], r].dup
441
+ perm[[k, i]] = perm[[i, k]].dup
442
+ elsif j.between?(0, n) && ipiv[i] == ipiv[j] && !changed_2x2
443
+ k = ipiv[i].abs - 1
444
+ d[j, i] = lud[j, i]
445
+ d[i, j] = is_complex && hermitian ? lud[j, i].conj : lud[j, i]
446
+ lu[j, i] = 0.0
447
+ lu[[k, j], r] = lu[[j, k], r].dup
448
+ perm[[k, j]] = perm[[j, k]].dup
449
+ changed_2x2 = true
450
+ next
451
+ end
452
+ changed_2x2 = false
402
453
  end
454
+
455
+ [lu, d, perm.sort_index]
403
456
  end
404
457
 
458
+ # Computes the Cholesky factorization of a symmetric/Hermitian
459
+ # positive definite matrix A. The factorization has the form
460
+ #
461
+ # A = U**H * U, if UPLO = 'U', or
462
+ # A = L * L**H, if UPLO = 'L',
463
+ #
464
+ # where U is an upper triangular matrix and L is a lower triangular matrix.
465
+ # @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
466
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
467
+ # or ('U') lower ('L') triangle.
468
+ # @return [Numo::NArray] The factor U or L.
469
+
470
+ def cholesky(a, uplo: 'U')
471
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
472
+ raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
473
+ factor = Lapack.call(:potrf, a, uplo: uplo)[0]
474
+ if uplo == 'U'
475
+ factor.triu
476
+ else
477
+ # TODO: Use the tril method if the verision of Numo::NArray
478
+ # in the runtime dependency list becomes 0.9.1.3 or higher.
479
+ m, = a.shape
480
+ factor * Numo::DFloat.ones(m, m).triu.transpose
481
+ end
482
+ end
405
483
 
406
484
  # Computes the Cholesky factorization of a symmetric/Hermitian
407
485
  # positive definite matrix A. The factorization has the form
@@ -409,16 +487,16 @@ module Numo; module Linalg
409
487
  # A = U**H * U, if UPLO = 'U', or
410
488
  # A = L * L**H, if UPLO = 'L',
411
489
  #
412
- # where U is an upper triangular matrix and L is lower triangular
490
+ # where U is an upper triangular matrix and L is a lower triangular matrix.
413
491
  # @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
414
492
  # @param uplo [String or Symbol] optional, default='U'. Access upper
415
493
  # or ('U') lower ('L') triangle.
416
- # @return [Numo::NArray] the factor U or L.
494
+ # @return [Numo::NArray] The matrix which has the Cholesky factor in upper or lower triangular part.
495
+ # Remain part consists of random values.
417
496
 
418
497
  def cho_fact(a, uplo:'U')
419
498
  Lapack.call(:potrf, a, uplo:uplo)[0]
420
499
  end
421
- #alias cholesky cho_fact
422
500
 
423
501
  # Computes the inverse of a symmetric/Hermitian
424
502
  # positive definite matrix A using the Cholesky factorization
@@ -479,27 +557,41 @@ module Numo; module Linalg
479
557
  [w,vl,vr] #.compact
480
558
  end
481
559
 
482
- # Computes the eigenvalues and, optionally, the left and/or right
483
- # eigenvectors for a square symmetric/hermitian matrix A.
560
+ # Obtains the eigenvalues and, optionally, the eigenvectors
561
+ # by solving an ordinary or generalized eigenvalue problem
562
+ # for a square symmetric / Hermitian matrix.
484
563
  #
485
- # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
486
- # @param values_only [Bool] (optional) If false, eigenvectors are computed.
564
+ # @param a [Numo::NArray] square symmetric matrix (>= 2-dimensinal NArray)
565
+ # @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
566
+ # If nil, identity matrix is assumed.
567
+ # @param vals_only [Bool] (optional) If false, eigenvectors are computed.
568
+ # @param vals_range [Range] (optional)
569
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
570
+ # If nil or 0...N (N is the size of the matrix a), all eigenvalues and eigenvectors are returned.
487
571
  # @param uplo [String or Symbol] (optional, default='U')
488
572
  # Access upper ('U') or lower ('L') triangle.
573
+ # @param turbo [Bool] (optional) If true, divide and conquer algorithm is used.
489
574
  # @return [[w,v]]
490
575
  # - **w** [Numo::NArray] -- The eigenvalues.
491
576
  # - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
492
577
 
493
- def eigh(a, vals_only:false, uplo:false, turbo:false)
578
+ def eigh(a, b=nil, vals_only:false, vals_range:nil, uplo:'U', turbo:false)
494
579
  jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
495
- case blas_char(a)
496
- when /c|z/
497
- func = turbo ? :hegv : :heev
580
+ b = a.class.eye(a.shape[0]) if b.nil?
581
+ func = blas_char(a, b) =~ /c|z/ ? 'hegv' : 'sygv'
582
+ if vals_range.nil?
583
+ func << 'd' if turbo
584
+ v, u_, w, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz)
585
+ v = nil if vals_only
586
+ [w, v]
498
587
  else
499
- func = turbo ? :sygv : :syev
588
+ func << 'x'
589
+ il = vals_range.first(1)[0]
590
+ iu = vals_range.last(1)[0]
591
+ a_, b_, w, v, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz, range: 'I', il: il + 1, iu: iu + 1)
592
+ v = nil if vals_only
593
+ [w, v]
500
594
  end
501
- w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
502
- [w,v] #.compact
503
595
  end
504
596
 
505
597
  # Computes the eigenvalues only for a square nonsymmetric matrix A.
@@ -519,23 +611,22 @@ module Numo; module Linalg
519
611
  w
520
612
  end
521
613
 
522
- # Computes the eigenvalues for a square symmetric/hermitian matrix A.
614
+ # Obtains the eigenvalues by solving an ordinary or generalized eigenvalue problem
615
+ # for a square symmetric / Hermitian matrix.
523
616
  #
524
617
  # @param a [Numo::NArray] square symmetric/hermitian matrix
525
618
  # (>= 2-dimensinal NArray)
619
+ # @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
620
+ # If nil, identity matrix is assumed.
621
+ # @param vals_range [Range] (optional)
622
+ # The range of indices of the eigenvalues (in ascending order) to be returned.
623
+ # If nil or 0...N (N is the size of the matrix a), all eigenvalues are returned.
526
624
  # @param uplo [String or Symbol] (optional, default='U')
527
625
  # Access upper ('U') or lower ('L') triangle.
528
626
  # @return [Numo::NArray] eigenvalues
529
627
 
530
- def eigvalsh(a, uplo:false, turbo:false)
531
- jobz = 'N' # jobz: Compute eigenvalues and eigenvectors.
532
- case blas_char(a)
533
- when /c|z/
534
- func = turbo ? :hegv : :heev
535
- else
536
- func = turbo ? :sygv : :syev
537
- end
538
- Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
628
+ def eigvalsh(a, b=nil, vals_range:nil, uplo:'U', turbo:false)
629
+ eigh(a, b, vals_only: true, vals_range: vals_range, uplo: uplo, turbo: turbo).first
539
630
  end
540
631
 
541
632
 
@@ -801,7 +892,7 @@ module Numo; module Linalg
801
892
  # returns lu, x, ipiv, info
802
893
  Lapack.call(:gesv, a, b)[1]
803
894
  when /^(sym?|her?|pos?)(sv)?$/i
804
- func = driver[0..2].downcase+"sv"
895
+ func = driver[0..1].downcase+"sv"
805
896
  Lapack.call(func, a, b, uplo:uplo)[1]
806
897
  else
807
898
  raise ArgumentError, "invalid driver: #{driver}"
@@ -834,8 +925,8 @@ module Numo; module Linalg
834
925
  solve(a, b, driver:d, uplo:uplo)
835
926
  when /(ge|sy|he)tr[fi]$/
836
927
  d = $1
837
- lu, piv = lu_fact(a, driver:d, uplo:uplo)
838
- lu_inv(lu, piv, driver:d, uplo:uplo)
928
+ lu, piv = lu_fact(a)
929
+ lu_inv(lu, piv)
839
930
  when /potr[fi]$/
840
931
  lu = cho_fact(a, uplo:uplo)
841
932
  cho_inv(lu, uplo:uplo)