numo-linalg 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +7 -3
  3. data/ext/numo/linalg/blas/extconf.rb +1 -2
  4. data/ext/numo/linalg/blas/numo_blas.h +6 -0
  5. data/ext/numo/linalg/blas/tmpl/mv.c +3 -2
  6. data/ext/numo/linalg/lapack/gen/spec.rb +5 -0
  7. data/ext/numo/linalg/lapack/lapack.c +29 -0
  8. data/ext/numo/linalg/lapack/numo_lapack.h +3 -0
  9. data/ext/numo/linalg/lapack/tmpl/gqr.c +1 -1
  10. data/ext/numo/linalg/lapack/tmpl/sygvx.c +130 -0
  11. data/ext/numo/linalg/mkmf_linalg.rb +2 -19
  12. data/lib/numo/linalg/function.rb +168 -77
  13. data/lib/numo/linalg/loader.rb +6 -14
  14. data/lib/numo/linalg/version.rb +1 -1
  15. data/numo-linalg.gemspec +2 -1
  16. data/spec/linalg/autoloader_spec.rb +27 -0
  17. data/spec/linalg/function/cho_fact_spec.rb +31 -0
  18. data/spec/linalg/function/cho_inv_spec.rb +39 -0
  19. data/spec/linalg/function/cho_solve_spec.rb +66 -0
  20. data/spec/linalg/function/cholesky_spec.rb +43 -0
  21. data/spec/linalg/function/cond_spec.rb +57 -0
  22. data/spec/linalg/function/det_spec.rb +21 -0
  23. data/spec/linalg/function/dot_spec.rb +84 -0
  24. data/spec/linalg/function/eig_spec.rb +53 -0
  25. data/spec/linalg/function/eigh_spec.rb +81 -0
  26. data/spec/linalg/function/eigvals_spec.rb +27 -0
  27. data/spec/linalg/function/eigvalsh_spec.rb +60 -0
  28. data/spec/linalg/function/inv_spec.rb +57 -0
  29. data/spec/linalg/function/ldl_spec.rb +51 -0
  30. data/spec/linalg/function/lstsq_spec.rb +80 -0
  31. data/spec/linalg/function/lu_fact_spec.rb +34 -0
  32. data/spec/linalg/function/lu_inv_spec.rb +21 -0
  33. data/spec/linalg/function/lu_solve_spec.rb +40 -0
  34. data/spec/linalg/function/lu_spec.rb +46 -0
  35. data/spec/linalg/function/matmul_spec.rb +41 -0
  36. data/spec/linalg/function/matrix_power_spec.rb +31 -0
  37. data/spec/linalg/function/matrix_rank_spec.rb +33 -0
  38. data/spec/linalg/function/norm_spec.rb +81 -0
  39. data/spec/linalg/function/pinv_spec.rb +48 -0
  40. data/spec/linalg/function/qr_spec.rb +82 -0
  41. data/spec/linalg/function/slogdet_spec.rb +21 -0
  42. data/spec/linalg/function/solve_spec.rb +98 -0
  43. data/spec/linalg/function/svd_spec.rb +88 -0
  44. data/spec/linalg/function/svdvals_spec.rb +40 -0
  45. data/spec/spec_helper.rb +55 -0
  46. metadata +79 -6
  47. data/spec/lapack_spec.rb +0 -13
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 611dd9ee204102a1688b5d199b95b5376d32239b04dfcf69a605da54599a9310
4
- data.tar.gz: d999af4188a239e9a78b7f071d850f2b11bf822c1c0237c592c6ebe4979b278d
3
+ metadata.gz: ef1ecaaf8a71faee9a7cef26bc35fae50230224856ab5dd6186cee93e5cca69d
4
+ data.tar.gz: f2e5e5944ad4832ef4ad4db05f6e94142fefb780e19282d94ece2a055cf447a0
5
5
  SHA512:
6
- metadata.gz: 7da7684a72653a7718de5ff51a99ed14ca0dc3770cae735139487cf56ef2610f292cf3e970abd681320d8a1847664dcb7d6e9ac2c936867abb2cae1154bdb053
7
- data.tar.gz: 70bb3e22d41fee16e953b53467b709cdd96bd80adceb526123df941335782054823bb84a9d703ca45e73e88c0ad34b7cbf90699a53cc7e7dcb1277251e7fa0e6
6
+ metadata.gz: 3fcf1506aa63ccbba57208e358102ae0dc537cd80b530e4b359b07d01a7cb1da11261dfd6d7028224175ae0f38014ec0373b99fc339caa2b952ad48c0fb26fd7
7
+ data.tar.gz: c7a9eb8b65719571120a6a122721e1eaadc240164e182a5095ff45018453a90f4cc48f99ef219305827beace8839bdebb0d27d35dfb5f64ec64f54d5406a8910
data/README.md CHANGED
@@ -18,7 +18,7 @@ This desgin allows you to change backend libraries without re-compiling.
18
18
  * Matrix and vector products
19
19
  * dot, matmul
20
20
  * Decomposition
21
- * lu\_fact, lu\_inv, lu\_solve, cho\_fact, cho\_inv, cho\_solve
21
+ * lu, lu\_fact, lu\_inv, lu\_solve, ldl, cholesky, cho\_fact, cho\_inv, cho\_solve,
22
22
  qr, svd, svdvals
23
23
  * Matrix eigenvalues
24
24
  * eig, eigh, eigvals, eigvalsh
@@ -78,8 +78,12 @@ require "numo/linalg"
78
78
 
79
79
  ## Authors
80
80
 
81
- * Masahiro TANAKA
82
- * Makoto KISHIMOTO
81
+ * Masahiro Tanaka
82
+ * Makoto Kishimoto
83
+ * Atsushi Tatsuma
84
+
85
+ ## Acknowledgement
86
+
83
87
  * This work is partly supported by 2016 Ruby Association Grant.
84
88
 
85
89
  ## ToDo
@@ -1,5 +1,4 @@
1
1
  require 'mkmf'
2
- require 'numo/narray'
3
2
  require_relative '../mkmf_linalg'
4
3
 
5
4
  srcs = %w(
@@ -21,7 +20,7 @@ if !have_header("numo/narray.h")
21
20
  exit(1)
22
21
  end
23
22
 
24
- if RUBY_PLATFORM =~ /cygwin|mingw/
23
+ if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
25
24
  find_libnarray_a
26
25
  unless have_library("narray","nary_new")
27
26
  puts "libnarray.a not found"
@@ -31,6 +31,12 @@ extern void numo_cblas_check_func(void **func, const char *name);
31
31
  #define SWAP_IFROW(order,a,b,tmp) \
32
32
  { if ((order)==CblasRowMajor) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
33
33
 
34
+ #define SWAP_IFNOTRANS(trans,a,b,tmp) \
35
+ { if ((trans)==CblasNoTrans) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
36
+
37
+ #define SWAP_IFTRANS(trans,a,b,tmp) \
38
+ { if ((trans)!=CblasNoTrans) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
39
+
34
40
  #define SWAP_IFCOLTR(order,trans,a,b,tmp) \
35
41
  { if (((order)==CblasRowMajor && (trans)!=CblasNoTrans) || \
36
42
  ((order)!=CblasRowMajor && (trans)==CblasNoTrans)) \
@@ -79,7 +79,7 @@ static void
79
79
  opt("beta"),
80
80
  !is_ge && opt("side"),
81
81
  !is_ge && opt("uplo"),
82
- is_ge || is_tr && opt("trans"),
82
+ (is_ge || is_tr) && opt("trans"),
83
83
  opt("order")
84
84
  ].select{|x| x}.join("\n ")
85
85
  %>
@@ -143,9 +143,10 @@ static VALUE
143
143
  CHECK_DIM_GE(na2,1);
144
144
  nx = COL_SIZE(na2);
145
145
  #if GE
146
- SWAP_IFCOLTR(g.order,g.trans, ma,na, tmp);
146
+ SWAP_IFCOL(g.order, ma, na, tmp);
147
147
  g.m = ma;
148
148
  g.n = na;
149
+ SWAP_IFTRANS(g.trans, ma, na, tmp);
149
150
  #else
150
151
  CHECK_SQUARE("a",na1);
151
152
  #endif
@@ -6,6 +6,9 @@ def_id "jobz"
6
6
  def_id "jobvl"
7
7
  def_id "jobvr"
8
8
  def_id "trans"
9
+ def_id "range"
10
+ def_id "il"
11
+ def_id "iu"
9
12
  def_id "rcond"
10
13
  def_id "itype"
11
14
  def_id "norm"
@@ -58,11 +61,13 @@ when /c|z/
58
61
  decl "?heevd", "syev"
59
62
  decl "?hegv", "sygv"
60
63
  decl "?hegvd", "sygv"
64
+ decl "?hegvx", "sygvx"
61
65
  else
62
66
  decl "?syev"
63
67
  decl "?syevd", "syev"
64
68
  decl "?sygv"
65
69
  decl "?sygvd", "sygv"
70
+ decl "?sygvx"
66
71
  end
67
72
 
68
73
  # factorize
@@ -124,6 +124,35 @@ numo_lapacke_option_job(VALUE job, char true_char, char false_char)
124
124
  return 0;
125
125
  }
126
126
 
127
+ char
128
+ numo_lapacke_option_range(VALUE job, char true_char, char false_char)
129
+ {
130
+ char *ptr, c;
131
+
132
+ switch(TYPE(job)) {
133
+ case T_NIL:
134
+ case T_UNDEF:
135
+ case T_TRUE:
136
+ return true_char;
137
+ case T_FALSE:
138
+ return false_char;
139
+ case T_SYMBOL:
140
+ job = rb_sym2str(job);
141
+ case T_STRING:
142
+ ptr = RSTRING_PTR(job);
143
+ if (RSTRING_LEN(job) > 0) {
144
+ c = ptr[0];
145
+ if (c >= 'a' && c <= 'z') {
146
+ c -= 'a'-'A';
147
+ }
148
+ return c;
149
+ }
150
+ break;
151
+ }
152
+ rb_raise(rb_eArgError,"invalid value for JOB option");
153
+ return 0;
154
+ }
155
+
127
156
  char
128
157
  numo_lapacke_option_trans(VALUE trans)
129
158
  {
@@ -15,6 +15,9 @@ extern int numo_lapacke_option_order(VALUE order);
15
15
  #define option_job numo_lapacke_option_job
16
16
  extern char numo_lapacke_option_job(VALUE job, char true_char, char false_char);
17
17
 
18
+ #define option_range numo_lapacke_option_range
19
+ extern char numo_lapacke_option_range(VALUE range, char true_char, char false_char);
20
+
18
21
  #define option_trans numo_lapacke_option_trans
19
22
  extern char numo_lapacke_option_trans(VALUE trans);
20
23
 
@@ -30,7 +30,7 @@ static void
30
30
  SWAP_IFCOL(g->order,m,n);
31
31
  lda = NDL_STEP(lp,0) / sizeof(dtype);
32
32
 
33
- printf("order=%d m=%d n=%d k=%d lda=%d \n",g->order,m,n,k,lda);
33
+ //printf("order=%d m=%d n=%d k=%d lda=%d \n",g->order,m,n,k,lda);
34
34
 
35
35
  *info = (*func_p)(g->order, m, n, k, a, lda, tau);
36
36
  CHECK_ERROR(*info);
@@ -0,0 +1,130 @@
1
+ #define args_t <%=func_name%>_args_t
2
+ #define func_p <%=func_name%>_p
3
+
4
+ typedef struct {
5
+ int order;
6
+ int itype;
7
+ char jobz;
8
+ char uplo;
9
+ char range;
10
+ int il;
11
+ int iu;
12
+ } args_t;
13
+
14
+ static <%=func_name%>_t func_p = 0;
15
+
16
+ static void
17
+ <%=c_iter%>(na_loop_t * const lp)
18
+ {
19
+ dtype *a, *b, *z;
20
+ rtype *w;
21
+ int *ifail;
22
+ int *info;
23
+ int m, n, lda, ldb, ldz;
24
+ rtype vl = 0, vu = 0;
25
+ rtype abstol = 0;
26
+
27
+ args_t *g;
28
+
29
+ a = (dtype*)NDL_PTR(lp, 0);
30
+ b = (dtype*)NDL_PTR(lp, 1);
31
+ w = (rtype*)NDL_PTR(lp, 2);
32
+ z = (dtype*)NDL_PTR(lp, 3);
33
+ ifail = (int*)NDL_PTR(lp, 4);
34
+ info = (int*)NDL_PTR(lp, 5);
35
+ g = (args_t*)(lp->opt_ptr);
36
+
37
+ n = NDL_SHAPE(lp, 0)[1];
38
+ lda = NDL_STEP(lp, 0) / sizeof(dtype);
39
+ ldb = NDL_STEP(lp, 1) / sizeof(dtype);
40
+ ldz = NDL_SHAPE(lp, 3)[1];
41
+
42
+ *info = (*func_p)( g->order, g->itype, g->jobz, g->range, g->uplo, n, a, lda, b, ldb,
43
+ vl, vu, g->il, g->iu, abstol, &m, w, z, ldz, ifail );
44
+
45
+ CHECK_ERROR(*info);
46
+ }
47
+
48
+ /*
49
+ <%
50
+ params = [
51
+ mat("a",:inplace),
52
+ mat("b",:inplace),
53
+ "@param [Integer] itype Specifies the problem type to be solved. If 1: A*x = (lambda)*B*x, If 2: A*B*x = (lambda)*x, If 3: B*A*x = (lambda)*x.",
54
+ jobe("jobz"),
55
+ opt("uplo"),
56
+ opt("order"),
57
+ "@param [String or Symbol] range If 'A': Compute all eigenvalues, if 'I': Compute eigenvalues with indices il to iu (default='A')",
58
+ "@param [Integer] il Specifies the index of the smallest eigenvalue in ascending order to be returned. If range = 'A', il is not referenced.",
59
+ "@param [Integer] iu Specifies the index of the largest eigenvalue in ascending order to be returned. Constraint: 1<=il<=iu<=N. If range = 'A', iu is not referenced.",
60
+ ].select{|x| x}.join("\n ")
61
+ return_name="a, b, w, z, ifail, info"
62
+ %>
63
+ @overload <%=name%>(a, b, [itype:1, jobz:'V', uplo:'U', order:'R', range:'I', il: 1, il: 2])
64
+ <%=params%>
65
+ @return [[<%=return_name%>]]
66
+ Array<<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,Integer>
67
+ <%=outparam(return_name)%>
68
+
69
+ <%=description%>
70
+ */
71
+
72
+ static VALUE
73
+ <%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
74
+ {
75
+ VALUE a, b, ans;
76
+ int n, nb, m;
77
+ narray_t *na1, *na2;
78
+ size_t w_shape[1];
79
+ size_t z_shape[2];
80
+ size_t ifail_shape[1];
81
+
82
+ ndfunc_arg_in_t ain[2] = {{OVERWRITE, 2}, {OVERWRITE, 2}};
83
+ ndfunc_arg_out_t aout[4] = {{cRT, 1, w_shape}, {cT, 2, z_shape}, {cI, 1, ifail_shape}, {cInt, 0}};
84
+ ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP | NDF_EXTRACT, 2, 4, ain, aout};
85
+
86
+ args_t g;
87
+ VALUE opts[7] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
88
+ VALUE kw_hash = Qnil;
89
+ ID kw_table[7] = {id_order, id_jobz, id_uplo, id_itype, id_range, id_il, id_iu};
90
+
91
+ CHECK_FUNC(func_p,"<%=func_name%>");
92
+
93
+ rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
94
+ rb_get_kwargs(kw_hash, kw_table, 0, 7, opts);
95
+ g.order = option_order(opts[0]);
96
+ g.jobz = option_job(opts[1], 'V', 'N');
97
+ g.uplo = option_uplo(opts[2]);
98
+ g.itype = NUM2INT(option_value(opts[3], INT2FIX(1)));
99
+ g.range = option_range(opts[4], 'A', 'I');
100
+ g.il = NUM2INT(option_value(opts[5], INT2FIX(1)));
101
+ g.iu = NUM2INT(option_value(opts[6], INT2FIX(1)));
102
+
103
+ COPY_OR_CAST_TO(a, cT);
104
+ GetNArray(a, na1);
105
+ CHECK_DIM_GE(na1, 2);
106
+
107
+ COPY_OR_CAST_TO(b, cT);
108
+ GetNArray(b, na2);
109
+ CHECK_DIM_GE(na2, 2);
110
+ CHECK_SQUARE("matrix a", na1);
111
+ n = COL_SIZE(na1);
112
+ CHECK_SQUARE("matrix b", na2);
113
+ nb = COL_SIZE(na2);
114
+ if (n != nb) {
115
+ rb_raise(nary_eShapeError, "matrix a and b must have same size");
116
+ }
117
+
118
+ m = g.range == 'I' ? g.iu - g.il + 1 : n;
119
+ w_shape[0] = m;
120
+ z_shape[0] = n;
121
+ z_shape[1] = m;
122
+ ifail_shape[0] = m;
123
+
124
+ ans = na_ndloop3(&ndf, &g, 2, a, b);
125
+
126
+ return rb_ary_new3(6, a, b, RARRAY_AREF(ans, 0), RARRAY_AREF(ans, 1), RARRAY_AREF(ans, 2), RARRAY_AREF(ans, 3));
127
+ }
128
+
129
+ #undef args_t
130
+ #undef func_p
@@ -17,24 +17,6 @@ def create_site_conf
17
17
  FileUtils.mkdir_p "lib"
18
18
 
19
19
  ext = detect_library_extension
20
- need_version = false
21
- if ext == 'so'
22
- begin
23
- Fiddle.dlopen "libm.so"
24
- rescue
25
- (5..7).each do |i|
26
- begin
27
- Fiddle.dlopen "libm.so.#{i}"
28
- need_version = true
29
- break
30
- rescue
31
- end
32
- end
33
- if !need_version
34
- raise "failed to check whether dynamically linked shared object needs version suffix"
35
- end
36
- end
37
- end
38
20
 
39
21
  open("lib/site_conf.rb","w"){|f| f.write "
40
22
  module Numo
@@ -49,7 +31,6 @@ module Numo
49
31
 
50
32
  module Loader
51
33
  EXT = '#{ext}'
52
- NEED_VERSION_SUFFIX = #{need_version}
53
34
  end
54
35
 
55
36
  end
@@ -68,6 +49,8 @@ def detect_library_extension
68
49
  end
69
50
  end
70
51
 
52
+ require 'numo/narray'
53
+
71
54
  def find_narray_h
72
55
  $LOAD_PATH.each do |x|
73
56
  if File.exist? File.join(x,'numo/numo/narray.h')
@@ -65,7 +65,7 @@ module Numo; module Linalg
65
65
  NArray.array_type(a)
66
66
  end
67
67
  if k && k < NArray
68
- t = k::UPCAST[t]
68
+ t = k::UPCAST[t] || t::UPCAST[k]
69
69
  end
70
70
  end
71
71
  BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
@@ -86,7 +86,8 @@ module Numo; module Linalg
86
86
  when 1
87
87
  case b.ndim
88
88
  when 1
89
- Blas.call(:dot, a, b)
89
+ func = blas_char(a, b) =~ /c|z/ ? :dotu : :dot
90
+ Blas.call(func, a, b)
90
91
  else
91
92
  Blas.call(:gemv, b, a, trans:'t')
92
93
  end
@@ -194,10 +195,10 @@ module Numo; module Linalg
194
195
  #
195
196
  # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
196
197
  # @param mode [String]
197
- # - "reduce" -- returns both Q and R,
198
- # - "r" -- returns only R,
199
- # - "economy" -- returns both Q and R but computed in economy-size,
200
- # - "raw" -- returns QR and TAU used in LAPACK.
198
+ # - "reduce" -- returns both Q and R,
199
+ # - "r" -- returns only R,
200
+ # - "economic" -- returns both Q and R but computed in economy-size,
201
+ # - "raw" -- returns QR and TAU used in LAPACK.
201
202
  # @return [r] if mode:"r"
202
203
  # @return [[q,r]] if mode:"reduce" or "economic"
203
204
  # @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
@@ -295,6 +296,38 @@ module Numo; module Linalg
295
296
  end
296
297
  end
297
298
 
299
+ # Computes an LU factorization of a M-by-N matrix A
300
+ # using partial pivoting with row interchanges.
301
+ #
302
+ # The factorization has the form
303
+ #
304
+ # A = P * L * U
305
+ #
306
+ # where P is a permutation matrix, L is lower triangular with unit
307
+ # diagonal elements (lower trapezoidal if m > n), and U is upper
308
+ # triangular (upper trapezoidal if m < n).
309
+ #
310
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
311
+ # @param permute_l [Bool] (optional) If true, perform the matrix product of P and L.
312
+ # @return [[p,l,u]] if permute_l == false
313
+ # @return [[pl,u]] if permute_l == true
314
+ #
315
+ # - **p** [Numo::NArray] -- The permutation matrix P.
316
+ # - **l** [Numo::NArray] -- The factor L.
317
+ # - **u** [Numo::NArray] -- The factor U.
318
+
319
+ def lu(a, permute_l: false)
320
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
321
+ m, n = a.shape
322
+ k = [m, n].min
323
+ lu, ip = lu_fact(a)
324
+ l = lu.tril.tap { |mat| mat[mat.diag_indices(0)] = 1.0 }[true, 0...k]
325
+ u = lu.triu[0...k, 0...n]
326
+ p = Numo::DFloat.eye(m).tap do |mat|
327
+ ip.to_a.each_with_index { |i, j| mat[true, [i - 1, j]] = mat[true, [j, i - 1]].dup }
328
+ end
329
+ permute_l ? [p.dot(l), u] : [p, l, u]
330
+ end
298
331
 
299
332
  # Computes an LU factorization of a M-by-N matrix A
300
333
  # using partial pivoting with row interchanges.
@@ -308,26 +341,14 @@ module Numo; module Linalg
308
341
  # triangular (upper trapezoidal if m < n).
309
342
  #
310
343
  # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
311
- # @param driver [String or Symbol] choose LAPACK diriver from
312
- # 'gen','sym','her'. (optional, default='gen')
313
- # @param uplo [String or Symbol] optional, default='U'. Access upper
314
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
315
344
  # @return [[lu, ipiv]]
316
345
  # - **lu** [Numo::NArray] -- The factors L and U from the factorization
317
346
  # `A = P*L*U`; the unit diagonal elements of L are not stored.
318
347
  # - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
319
348
  # row i of the matrix was interchanged with row IPIV(i).
320
349
 
321
- def lu_fact(a, driver:"gen", uplo:"U")
322
- case driver.to_s
323
- when /^gen?(trf)?$/i
324
- Lapack.call(:getrf, a)[0..1]
325
- when /^(sym?|her?)(trf)?$/i
326
- func = driver[0..2].downcase+"trf"
327
- Lapack.call(func, a, uplo:uplo)[0..1]
328
- else
329
- raise ArgumentError, "invalid driver: #{driver}"
330
- end
350
+ def lu_fact(a)
351
+ Lapack.call(:getrf, a)[0..1]
331
352
  end
332
353
 
333
354
  # Computes the inverse of a matrix using the LU factorization
@@ -345,22 +366,10 @@ module Numo; module Linalg
345
366
  # @param ipiv [Numo::NArray] The pivot indices from
346
367
  # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
347
368
  # interchanged with row IPIV(i).
348
- # @param driver [String or Symbol] choose LAPACK diriver from
349
- # 'gen','sym','her'. (optional, default='gen')
350
- # @param uplo [String or Symbol] optional, default='U'. Access upper
351
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
352
369
  # @return [Numo::NArray] the inverse of the original matrix A.
353
370
 
354
- def lu_inv(lu, ipiv, driver:"gen", uplo:"U")
355
- case driver.to_s
356
- when /^gen?(tri)?$/i
357
- Lapack.call(:getri, lu, ipiv)[0]
358
- when /^(sym?|her?)(tri)?$/i
359
- func = driver[0..2].downcase+"tri"
360
- Lapack.call(func, lu, ipiv, uplo:uplo)[0]
361
- else
362
- raise ArgumentError, "invalid driver: #{driver}"
363
- end
371
+ def lu_inv(lu, ipiv)
372
+ Lapack.call(:getri, lu, ipiv)[0]
364
373
  end
365
374
 
366
375
  # Solves a system of linear equations
@@ -377,31 +386,100 @@ module Numo; module Linalg
377
386
  # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
378
387
  # interchanged with row IPIV(i).
379
388
  # @param b [Numo::NArray] the right hand side matrix B.
380
- # @param driver [String or Symbol] choose LAPACK diriver from
381
- # 'gen','sym','her'. (optional, default='gen')
382
- # @param uplo [String or Symbol] optional, default='U'. Access upper
383
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
384
389
  # @param trans [String or Symbol]
385
- # Specifies the form of the system of equations
386
- # (omitted if not driver:"gen"):
387
- #
390
+ # Specifies the form of the system of equations:
388
391
  # - If 'N': `A * X = B` (No transpose).
389
392
  # - If 'T': `A*\*T* X = B` (Transpose).
390
393
  # - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
391
394
  # @return [Numo::NArray] the solution matrix X.
392
395
 
393
- def lu_solve(lu, ipiv, b, driver:"gen", uplo:"U", trans:"N")
394
- case driver.to_s
395
- when /^gen?(trs)?$/i
396
- Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
397
- when /^(sym?|her?)(trs)?$/i
398
- func = driver[0..2].downcase+"trs"
399
- Lapack.call(func, lu, ipiv, b, uplo:uplo)[0]
400
- else
401
- raise ArgumentError, "invalid driver: #{driver}"
396
+ def lu_solve(lu, ipiv, b, trans:"N")
397
+ Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
398
+ end
399
+
400
+ # Computes the LDLt or Bunch-Kaufman factorization of a symmetric/Hermitian matrix A.
401
+ # The factorization has the form
402
+ #
403
+ # A = U*D*U**T or A = L*D*L**T
404
+ #
405
+ # where U (or L) is a product of permutation and unit upper (lower) triangular matrices
406
+ # and D is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
407
+ #
408
+ # @param a [Numo::NArray] m-by-m matrix A (>= 2-dimensinal NArray)
409
+ # @param uplo [String or Symbol] optional, default='U'. Access upper or ('U') lower ('L') triangle.
410
+ # @param hermitian [Bool] optional, default=true. If true, hermitian matrix is assumed.
411
+ # (omitted when real-value matrix is given)
412
+ #
413
+ # @return [[lu,d,perm]]
414
+ #
415
+ # - **lu** [Numo::NArray] -- The permutated upper (lower) triangular matrix U (L).
416
+ # - **d** [Numo::NArray] -- The block diagonal matrix D.
417
+ # - **perm** [Numo::NArray] -- The row-permutation index for changing lu into triangular form.
418
+
419
+ def ldl(a, uplo: 'U', hermitian: true)
420
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
421
+ raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
422
+
423
+ is_complex = blas_char(a) =~ /c|z/
424
+ func = is_complex && hermitian ? 'hetrf' : 'sytrf'
425
+ lud, ipiv, = Lapack.call(func.to_sym, a, uplo: uplo)
426
+
427
+ lu = (uplo == 'U' ? lud.triu : lud.tril).tap { |mat| mat[mat.diag_indices(0)] = 1.0 }
428
+ d = lud[lud.diag_indices(0)].diag
429
+
430
+ m = a.shape[0]
431
+ n = m - 1
432
+ changed_2x2 = false
433
+ perm = Numo::Int32.new(m).seq
434
+ m.times do |t|
435
+ i = uplo == 'U' ? t : n - t
436
+ j = uplo == 'U' ? i - 1 : i + 1;
437
+ r = uplo == 'U' ? 0..i : i..n;
438
+ if ipiv[i] > 0
439
+ k = ipiv[i] - 1
440
+ lu[[k, i], r] = lu[[i, k], r].dup
441
+ perm[[k, i]] = perm[[i, k]].dup
442
+ elsif j.between?(0, n) && ipiv[i] == ipiv[j] && !changed_2x2
443
+ k = ipiv[i].abs - 1
444
+ d[j, i] = lud[j, i]
445
+ d[i, j] = is_complex && hermitian ? lud[j, i].conj : lud[j, i]
446
+ lu[j, i] = 0.0
447
+ lu[[k, j], r] = lu[[j, k], r].dup
448
+ perm[[k, j]] = perm[[j, k]].dup
449
+ changed_2x2 = true
450
+ next
451
+ end
452
+ changed_2x2 = false
402
453
  end
454
+
455
+ [lu, d, perm.sort_index]
403
456
  end
404
457
 
458
+ # Computes the Cholesky factorization of a symmetric/Hermitian
459
+ # positive definite matrix A. The factorization has the form
460
+ #
461
+ # A = U**H * U, if UPLO = 'U', or
462
+ # A = L * L**H, if UPLO = 'L',
463
+ #
464
+ # where U is an upper triangular matrix and L is a lower triangular matrix.
465
+ # @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
466
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
467
+ # or ('U') lower ('L') triangle.
468
+ # @return [Numo::NArray] The factor U or L.
469
+
470
+ def cholesky(a, uplo: 'U')
471
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
472
+ raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
473
+ factor = Lapack.call(:potrf, a, uplo: uplo)[0]
474
+ if uplo == 'U'
475
+ factor.triu
476
+ else
477
+ # TODO: Use the tril method if the verision of Numo::NArray
478
+ # in the runtime dependency list becomes 0.9.1.3 or higher.
479
+ m, = a.shape
480
+ factor * Numo::DFloat.ones(m, m).triu.transpose
481
+ end
482
+ end
405
483
 
406
484
  # Computes the Cholesky factorization of a symmetric/Hermitian
407
485
  # positive definite matrix A. The factorization has the form
@@ -409,16 +487,16 @@ module Numo; module Linalg
409
487
  # A = U**H * U, if UPLO = 'U', or
410
488
  # A = L * L**H, if UPLO = 'L',
411
489
  #
412
- # where U is an upper triangular matrix and L is lower triangular
490
+ # where U is an upper triangular matrix and L is a lower triangular matrix.
413
491
  # @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
414
492
  # @param uplo [String or Symbol] optional, default='U'. Access upper
415
493
  # or ('U') lower ('L') triangle.
416
- # @return [Numo::NArray] the factor U or L.
494
+ # @return [Numo::NArray] The matrix which has the Cholesky factor in upper or lower triangular part.
495
+ # Remain part consists of random values.
417
496
 
418
497
  def cho_fact(a, uplo:'U')
419
498
  Lapack.call(:potrf, a, uplo:uplo)[0]
420
499
  end
421
- #alias cholesky cho_fact
422
500
 
423
501
  # Computes the inverse of a symmetric/Hermitian
424
502
  # positive definite matrix A using the Cholesky factorization
@@ -479,27 +557,41 @@ module Numo; module Linalg
479
557
  [w,vl,vr] #.compact
480
558
  end
481
559
 
482
- # Computes the eigenvalues and, optionally, the left and/or right
483
- # eigenvectors for a square symmetric/hermitian matrix A.
560
+ # Obtains the eigenvalues and, optionally, the eigenvectors
561
+ # by solving an ordinary or generalized eigenvalue problem
562
+ # for a square symmetric / Hermitian matrix.
484
563
  #
485
- # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
486
- # @param values_only [Bool] (optional) If false, eigenvectors are computed.
564
+ # @param a [Numo::NArray] square symmetric matrix (>= 2-dimensinal NArray)
565
+ # @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
566
+ # If nil, identity matrix is assumed.
567
+ # @param vals_only [Bool] (optional) If false, eigenvectors are computed.
568
+ # @param vals_range [Range] (optional)
569
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
570
+ # If nil or 0...N (N is the size of the matrix a), all eigenvalues and eigenvectors are returned.
487
571
  # @param uplo [String or Symbol] (optional, default='U')
488
572
  # Access upper ('U') or lower ('L') triangle.
573
+ # @param turbo [Bool] (optional) If true, divide and conquer algorithm is used.
489
574
  # @return [[w,v]]
490
575
  # - **w** [Numo::NArray] -- The eigenvalues.
491
576
  # - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
492
577
 
493
- def eigh(a, vals_only:false, uplo:false, turbo:false)
578
+ def eigh(a, b=nil, vals_only:false, vals_range:nil, uplo:'U', turbo:false)
494
579
  jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
495
- case blas_char(a)
496
- when /c|z/
497
- func = turbo ? :hegv : :heev
580
+ b = a.class.eye(a.shape[0]) if b.nil?
581
+ func = blas_char(a, b) =~ /c|z/ ? 'hegv' : 'sygv'
582
+ if vals_range.nil?
583
+ func << 'd' if turbo
584
+ v, u_, w, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz)
585
+ v = nil if vals_only
586
+ [w, v]
498
587
  else
499
- func = turbo ? :sygv : :syev
588
+ func << 'x'
589
+ il = vals_range.first(1)[0]
590
+ iu = vals_range.last(1)[0]
591
+ a_, b_, w, v, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz, range: 'I', il: il + 1, iu: iu + 1)
592
+ v = nil if vals_only
593
+ [w, v]
500
594
  end
501
- w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
502
- [w,v] #.compact
503
595
  end
504
596
 
505
597
  # Computes the eigenvalues only for a square nonsymmetric matrix A.
@@ -519,23 +611,22 @@ module Numo; module Linalg
519
611
  w
520
612
  end
521
613
 
522
- # Computes the eigenvalues for a square symmetric/hermitian matrix A.
614
+ # Obtains the eigenvalues by solving an ordinary or generalized eigenvalue problem
615
+ # for a square symmetric / Hermitian matrix.
523
616
  #
524
617
  # @param a [Numo::NArray] square symmetric/hermitian matrix
525
618
  # (>= 2-dimensinal NArray)
619
+ # @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
620
+ # If nil, identity matrix is assumed.
621
+ # @param vals_range [Range] (optional)
622
+ # The range of indices of the eigenvalues (in ascending order) to be returned.
623
+ # If nil or 0...N (N is the size of the matrix a), all eigenvalues are returned.
526
624
  # @param uplo [String or Symbol] (optional, default='U')
527
625
  # Access upper ('U') or lower ('L') triangle.
528
626
  # @return [Numo::NArray] eigenvalues
529
627
 
530
- def eigvalsh(a, uplo:false, turbo:false)
531
- jobz = 'N' # jobz: Compute eigenvalues and eigenvectors.
532
- case blas_char(a)
533
- when /c|z/
534
- func = turbo ? :hegv : :heev
535
- else
536
- func = turbo ? :sygv : :syev
537
- end
538
- Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
628
+ def eigvalsh(a, b=nil, vals_range:nil, uplo:'U', turbo:false)
629
+ eigh(a, b, vals_only: true, vals_range: vals_range, uplo: uplo, turbo: turbo).first
539
630
  end
540
631
 
541
632
 
@@ -801,7 +892,7 @@ module Numo; module Linalg
801
892
  # returns lu, x, ipiv, info
802
893
  Lapack.call(:gesv, a, b)[1]
803
894
  when /^(sym?|her?|pos?)(sv)?$/i
804
- func = driver[0..2].downcase+"sv"
895
+ func = driver[0..1].downcase+"sv"
805
896
  Lapack.call(func, a, b, uplo:uplo)[1]
806
897
  else
807
898
  raise ArgumentError, "invalid driver: #{driver}"
@@ -834,8 +925,8 @@ module Numo; module Linalg
834
925
  solve(a, b, driver:d, uplo:uplo)
835
926
  when /(ge|sy|he)tr[fi]$/
836
927
  d = $1
837
- lu, piv = lu_fact(a, driver:d, uplo:uplo)
838
- lu_inv(lu, piv, driver:d, uplo:uplo)
928
+ lu, piv = lu_fact(a)
929
+ lu_inv(lu, piv)
839
930
  when /potr[fi]$/
840
931
  lu = cho_fact(a, uplo:uplo)
841
932
  cho_inv(lu, uplo:uplo)