numo-linalg 0.1.1 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +13 -9
  3. data/Rakefile +17 -0
  4. data/ext/numo/linalg/blas/blas.c +20 -6
  5. data/ext/numo/linalg/blas/depend.erb +2 -2
  6. data/ext/numo/linalg/blas/extconf.rb +15 -42
  7. data/ext/numo/linalg/blas/numo_blas.h +6 -0
  8. data/ext/numo/linalg/blas/tmpl/mv.c +3 -2
  9. data/ext/numo/linalg/lapack/depend.erb +2 -2
  10. data/ext/numo/linalg/lapack/extconf.rb +14 -19
  11. data/ext/numo/linalg/lapack/gen/spec.rb +5 -0
  12. data/ext/numo/linalg/lapack/lapack.c +49 -6
  13. data/ext/numo/linalg/lapack/numo_lapack.h +3 -0
  14. data/ext/numo/linalg/lapack/tmpl/gqr.c +1 -1
  15. data/ext/numo/linalg/lapack/tmpl/sygvx.c +130 -0
  16. data/ext/numo/linalg/mkmf_linalg.rb +84 -0
  17. data/lib/numo/linalg/autoloader.rb +18 -3
  18. data/lib/numo/linalg/function.rb +317 -92
  19. data/lib/numo/linalg/loader.rb +70 -77
  20. data/lib/numo/linalg/version.rb +1 -1
  21. data/numo-linalg.gemspec +5 -3
  22. data/spec/linalg/autoloader_spec.rb +10 -0
  23. data/spec/linalg/function/cho_fact_spec.rb +31 -0
  24. data/spec/linalg/function/cho_inv_spec.rb +39 -0
  25. data/spec/linalg/function/cho_solve_spec.rb +66 -0
  26. data/spec/linalg/function/cholesky_spec.rb +43 -0
  27. data/spec/linalg/function/cond_spec.rb +57 -0
  28. data/spec/linalg/function/det_spec.rb +21 -0
  29. data/spec/linalg/function/dot_spec.rb +84 -0
  30. data/spec/linalg/function/eig_spec.rb +53 -0
  31. data/spec/linalg/function/eigh_spec.rb +81 -0
  32. data/spec/linalg/function/eigvals_spec.rb +27 -0
  33. data/spec/linalg/function/eigvalsh_spec.rb +60 -0
  34. data/spec/linalg/function/expm_spec.rb +36 -0
  35. data/spec/linalg/function/inv_spec.rb +57 -0
  36. data/spec/linalg/function/ldl_spec.rb +51 -0
  37. data/spec/linalg/function/lstsq_spec.rb +80 -0
  38. data/spec/linalg/function/lu_fact_spec.rb +34 -0
  39. data/spec/linalg/function/lu_inv_spec.rb +21 -0
  40. data/spec/linalg/function/lu_solve_spec.rb +40 -0
  41. data/spec/linalg/function/lu_spec.rb +46 -0
  42. data/spec/linalg/function/matmul_spec.rb +41 -0
  43. data/spec/linalg/function/matrix_power_spec.rb +31 -0
  44. data/spec/linalg/function/matrix_rank_spec.rb +33 -0
  45. data/spec/linalg/function/norm_spec.rb +81 -0
  46. data/spec/linalg/function/null_space_spec.rb +41 -0
  47. data/spec/linalg/function/orth_spec.rb +43 -0
  48. data/spec/linalg/function/pinv_spec.rb +48 -0
  49. data/spec/linalg/function/qr_spec.rb +82 -0
  50. data/spec/linalg/function/slogdet_spec.rb +21 -0
  51. data/spec/linalg/function/solve_spec.rb +98 -0
  52. data/spec/linalg/function/svd_spec.rb +88 -0
  53. data/spec/linalg/function/svdvals_spec.rb +40 -0
  54. data/spec/spec_helper.rb +55 -0
  55. metadata +107 -14
  56. data/spec/lapack_spec.rb +0 -13
@@ -0,0 +1,130 @@
1
+ #define args_t <%=func_name%>_args_t
2
+ #define func_p <%=func_name%>_p
3
+
4
+ typedef struct {
5
+ int order;
6
+ int itype;
7
+ char jobz;
8
+ char uplo;
9
+ char range;
10
+ int il;
11
+ int iu;
12
+ } args_t;
13
+
14
+ static <%=func_name%>_t func_p = 0;
15
+
16
+ static void
17
+ <%=c_iter%>(na_loop_t * const lp)
18
+ {
19
+ dtype *a, *b, *z;
20
+ rtype *w;
21
+ int *ifail;
22
+ int *info;
23
+ int m, n, lda, ldb, ldz;
24
+ rtype vl = 0, vu = 0;
25
+ rtype abstol = 0;
26
+
27
+ args_t *g;
28
+
29
+ a = (dtype*)NDL_PTR(lp, 0);
30
+ b = (dtype*)NDL_PTR(lp, 1);
31
+ w = (rtype*)NDL_PTR(lp, 2);
32
+ z = (dtype*)NDL_PTR(lp, 3);
33
+ ifail = (int*)NDL_PTR(lp, 4);
34
+ info = (int*)NDL_PTR(lp, 5);
35
+ g = (args_t*)(lp->opt_ptr);
36
+
37
+ n = NDL_SHAPE(lp, 0)[1];
38
+ lda = NDL_STEP(lp, 0) / sizeof(dtype);
39
+ ldb = NDL_STEP(lp, 1) / sizeof(dtype);
40
+ ldz = NDL_SHAPE(lp, 3)[1];
41
+
42
+ *info = (*func_p)( g->order, g->itype, g->jobz, g->range, g->uplo, n, a, lda, b, ldb,
43
+ vl, vu, g->il, g->iu, abstol, &m, w, z, ldz, ifail );
44
+
45
+ CHECK_ERROR(*info);
46
+ }
47
+
48
+ /*
49
+ <%
50
+ params = [
51
+ mat("a",:inplace),
52
+ mat("b",:inplace),
53
+ "@param [Integer] itype Specifies the problem type to be solved. If 1: A*x = (lambda)*B*x, If 2: A*B*x = (lambda)*x, If 3: B*A*x = (lambda)*x.",
54
+ jobe("jobz"),
55
+ opt("uplo"),
56
+ opt("order"),
57
+ "@param [String or Symbol] range If 'A': Compute all eigenvalues, if 'I': Compute eigenvalues with indices il to iu (default='A')",
58
+ "@param [Integer] il Specifies the index of the smallest eigenvalue in ascending order to be returned. If range = 'A', il is not referenced.",
59
+ "@param [Integer] iu Specifies the index of the largest eigenvalue in ascending order to be returned. Constraint: 1<=il<=iu<=N. If range = 'A', iu is not referenced.",
60
+ ].select{|x| x}.join("\n ")
61
+ return_name="a, b, w, z, ifail, info"
62
+ %>
63
+ @overload <%=name%>(a, b, [itype:1, jobz:'V', uplo:'U', order:'R', range:'I', il: 1, il: 2])
64
+ <%=params%>
65
+ @return [[<%=return_name%>]]
66
+ Array<<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,Integer>
67
+ <%=outparam(return_name)%>
68
+
69
+ <%=description%>
70
+ */
71
+
72
+ static VALUE
73
+ <%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
74
+ {
75
+ VALUE a, b, ans;
76
+ int n, nb, m;
77
+ narray_t *na1, *na2;
78
+ size_t w_shape[1];
79
+ size_t z_shape[2];
80
+ size_t ifail_shape[1];
81
+
82
+ ndfunc_arg_in_t ain[2] = {{OVERWRITE, 2}, {OVERWRITE, 2}};
83
+ ndfunc_arg_out_t aout[4] = {{cRT, 1, w_shape}, {cT, 2, z_shape}, {cI, 1, ifail_shape}, {cInt, 0}};
84
+ ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP | NDF_EXTRACT, 2, 4, ain, aout};
85
+
86
+ args_t g;
87
+ VALUE opts[7] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
88
+ VALUE kw_hash = Qnil;
89
+ ID kw_table[7] = {id_order, id_jobz, id_uplo, id_itype, id_range, id_il, id_iu};
90
+
91
+ CHECK_FUNC(func_p,"<%=func_name%>");
92
+
93
+ rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
94
+ rb_get_kwargs(kw_hash, kw_table, 0, 7, opts);
95
+ g.order = option_order(opts[0]);
96
+ g.jobz = option_job(opts[1], 'V', 'N');
97
+ g.uplo = option_uplo(opts[2]);
98
+ g.itype = NUM2INT(option_value(opts[3], INT2FIX(1)));
99
+ g.range = option_range(opts[4], 'A', 'I');
100
+ g.il = NUM2INT(option_value(opts[5], INT2FIX(1)));
101
+ g.iu = NUM2INT(option_value(opts[6], INT2FIX(1)));
102
+
103
+ COPY_OR_CAST_TO(a, cT);
104
+ GetNArray(a, na1);
105
+ CHECK_DIM_GE(na1, 2);
106
+
107
+ COPY_OR_CAST_TO(b, cT);
108
+ GetNArray(b, na2);
109
+ CHECK_DIM_GE(na2, 2);
110
+ CHECK_SQUARE("matrix a", na1);
111
+ n = COL_SIZE(na1);
112
+ CHECK_SQUARE("matrix b", na2);
113
+ nb = COL_SIZE(na2);
114
+ if (n != nb) {
115
+ rb_raise(nary_eShapeError, "matrix a and b must have same size");
116
+ }
117
+
118
+ m = g.range == 'I' ? g.iu - g.il + 1 : n;
119
+ w_shape[0] = m;
120
+ z_shape[0] = n;
121
+ z_shape[1] = m;
122
+ ifail_shape[0] = m;
123
+
124
+ ans = na_ndloop3(&ndf, &g, 2, a, b);
125
+
126
+ return rb_ary_new3(6, a, b, RARRAY_AREF(ans, 0), RARRAY_AREF(ans, 1), RARRAY_AREF(ans, 2), RARRAY_AREF(ans, 3));
127
+ }
128
+
129
+ #undef args_t
130
+ #undef func_p
@@ -0,0 +1,84 @@
1
+ require 'mkmf'
2
+
3
+ def create_site_conf
4
+ ldirs = [
5
+ dir_config("mkl")[1],
6
+ dir_config("openblas")[1],
7
+ dir_config("atlas")[1],
8
+ dir_config("blas")[1],
9
+ dir_config("lapack")[1],
10
+ ]
11
+ bked = with_config("backend")
12
+
13
+ require 'fiddle'
14
+
15
+ message "creating lib/site_conf.rb\n"
16
+
17
+ FileUtils.mkdir_p "lib"
18
+
19
+ ext = detect_library_extension
20
+
21
+ open("lib/site_conf.rb","w"){|f| f.write "
22
+ module Numo
23
+ module Linalg
24
+
25
+ BACKEND = #{bked.inspect}
26
+ MKL_LIBPATH = #{ldirs[0].inspect}
27
+ OPENBLAS_LIBPATH = #{ldirs[1].inspect}
28
+ ATLAS_LIBPATH = #{ldirs[2].inspect}
29
+ BLAS_LIBPATH = #{ldirs[3].inspect}
30
+ LAPACK_LIBPATH = #{ldirs[4].inspect}
31
+
32
+ module Loader
33
+ EXT = '#{ext}'
34
+ end
35
+
36
+ end
37
+ end"
38
+ }
39
+ end
40
+
41
+ def detect_library_extension
42
+ case RbConfig::CONFIG['host_os']
43
+ when /mswin|msys|mingw|cygwin/
44
+ 'dll'
45
+ when /darwin|mac os/
46
+ 'dylib'
47
+ else
48
+ 'so'
49
+ end
50
+ end
51
+
52
+ require 'numo/narray'
53
+
54
+ def find_narray_h
55
+ $LOAD_PATH.each do |x|
56
+ if File.exist? File.join(x,'numo/numo/narray.h')
57
+ $INCFLAGS = "-I#{x}/numo " + $INCFLAGS
58
+ break
59
+ end
60
+ end
61
+ end
62
+
63
+ def find_libnarray_a
64
+ $LOAD_PATH.each do |x|
65
+ if File.exist? File.join(x,'numo/libnarray.a')
66
+ $LDFLAGS = "-L#{x}/numo " + $LDFLAGS
67
+ break
68
+ end
69
+ end
70
+ end
71
+
72
+ def create_depend(base_dir)
73
+ require 'erb'
74
+ message "creating depend\n"
75
+ dep_path = "#{base_dir}/depend"
76
+ File.open(dep_path, "w") do |dep|
77
+ dep_erb_path = "#{base_dir}/depend.erb"
78
+ File.open(dep_erb_path, "r") do |dep_erb|
79
+ erb = ERB.new(dep_erb.read)
80
+ erb.filename = dep_erb_path
81
+ dep.print(erb.result)
82
+ end
83
+ end
84
+ end
@@ -19,15 +19,23 @@ module Numo
19
19
  # @return [String] name of loaded backend library (mkl/openblas/lapack)
20
20
  def load_library
21
21
  mkl_dirs = ['/opt/intel/lib', '/opt/intel/lib64', '/opt/intel/mkl/lib', '/opt/intel/mkl/lib64']
22
- openblas_dirs = ['/opt/openblas/lib', '/opt/openblas/lib64', '/usr/local/opt/openblas/lib']
22
+ openblas_dirs = ['/opt/OpenBLAS/lib', '/opt/OpenBLAS/lib64', '/opt/openblas/lib', '/opt/openblas/lib64',
23
+ '/usr/local/opt/openblas/lib']
23
24
  atlas_dirs = ['/opt/atlas/lib', '/opt/atlas/lib64',
24
25
  '/usr/lib/atlas', '/usr/lib64/atlas', '/usr/local/opt/atlas/lib']
25
26
  lapacke_dirs = ['/opt/lapack/lib', '/opt/lapack/lib64', '/opt/local/lib/lapack',
26
27
  '/usr/local/opt/lapack/lib']
27
28
  opt_dirs = ['/opt/local/lib', '/opt/local/lib64', '/opt/lib', '/opt/lib64']
28
29
  base_dirs = ['/usr/local/lib', '/usr/local/lib64', '/usr/lib', '/usr/lib64']
30
+ base_dirs.concat(Dir["/usr/lib/#{RbConfig::CONFIG['host_cpu']}-*"])
29
31
  base_dirs.unshift(*ENV['LD_LIBRARY_PATH'].split(':')) unless ENV['LD_LIBRARY_PATH'].nil?
30
32
 
33
+ select_dirs(base_dirs)
34
+ select_dirs(opt_dirs)
35
+ select_dirs(lapacke_dirs)
36
+ select_dirs(atlas_dirs)
37
+ select_dirs(mkl_dirs)
38
+
31
39
  mkl_libs = find_mkl_libs([*base_dirs, *opt_dirs, *mkl_dirs])
32
40
  openblas_libs = find_openblas_libs([*base_dirs, *opt_dirs, *openblas_dirs])
33
41
  atlas_libs = find_atlas_libs([*base_dirs, *opt_dirs, *atlas_dirs, *lapacke_dirs])
@@ -66,11 +74,18 @@ module Numo
66
74
  end
67
75
  end
68
76
 
77
+ def select_dirs(dirs)
78
+ dirs.select!{|d| Dir.exist?(d)}
79
+ end
80
+
69
81
  def find_libs(lib_names, lib_dirs)
70
82
  lib_ext = detect_library_extension
71
83
  lib_arr = lib_names.map do |l|
72
- [l.to_sym, lib_dirs.map { |d| "#{d}/lib#{l}.#{lib_ext}" }
73
- .keep_if { |f| File.exist?(f) }.first]
84
+ x = nil
85
+ lib_dirs.each do |d|
86
+ break if x = Dir.glob("#{d}/lib#{l}{,64}.#{lib_ext}{,.*}").last
87
+ end
88
+ [l.to_sym, x]
74
89
  end
75
90
  Hash[*lib_arr.flatten]
76
91
  end
@@ -12,12 +12,19 @@ module Numo; module Linalg
12
12
  # defined from data-types of arguments.
13
13
  # @param [Symbol] func function name without BLAS char.
14
14
  # @param args arguments passed to Blas function.
15
+ # @param kwargs keyword arguments passed to Blas function.
15
16
  # @example
16
17
  # c = Numo::Linalg::Blas.call(:gemm, a, b)
17
- def self.call(func,*args)
18
+ def self.call(func, *args, **kwargs)
18
19
  fn = (Linalg.blas_char(*args) + func.to_s).to_sym
19
20
  fn = FIXNAME[fn] || fn
20
- send(fn,*args)
21
+ if kwargs.empty?
22
+ # This conditional branch is necessary to prevent ArgumentError
23
+ # that occurs in Ruby 2.6 or earlier.
24
+ send(fn, *args)
25
+ else
26
+ send(fn, *args, **kwargs)
27
+ end
21
28
  end
22
29
 
23
30
  end
@@ -34,12 +41,19 @@ module Numo; module Linalg
34
41
  # defined from data-types of arguments.
35
42
  # @param [Symbol,String] func function name without BLAS char.
36
43
  # @param args arguments passed to Lapack function.
44
+ # @param kwargs keyword arguments passed to Lapack function.
37
45
  # @example
38
46
  # s = Numo::Linalg::Lapack.call(:gesv, a)
39
- def self.call(func,*args)
47
+ def self.call(func, *args, **kwargs)
40
48
  fn = (Linalg.blas_char(*args) + func.to_s).to_sym
41
49
  fn = FIXNAME[fn] || fn
42
- send(fn,*args)
50
+ if kwargs.empty?
51
+ # This conditional branch is necessary to prevent ArgumentError
52
+ # that occurs in Ruby 2.6 or earlier.
53
+ send(fn, *args)
54
+ else
55
+ send(fn, *args, **kwargs)
56
+ end
43
57
  end
44
58
 
45
59
  end
@@ -65,7 +79,7 @@ module Numo; module Linalg
65
79
  NArray.array_type(a)
66
80
  end
67
81
  if k && k < NArray
68
- t = k::UPCAST[t]
82
+ t = k::UPCAST[t] || t::UPCAST[k]
69
83
  end
70
84
  end
71
85
  BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
@@ -86,16 +100,61 @@ module Numo; module Linalg
86
100
  when 1
87
101
  case b.ndim
88
102
  when 1
89
- Blas.call(:dot, a, b)
103
+ func = blas_char(a, b) =~ /c|z/ ? :dotu : :dot
104
+ Blas.call(func, a, b)
90
105
  else
91
- Blas.call(:gemv, b, a, trans:'t')
106
+ if b.contiguous?
107
+ trans = 't'
108
+ else
109
+ if b.fortran_contiguous?
110
+ trans = 'n'
111
+ b = b.transpose
112
+ else
113
+ trans = 't'
114
+ b = b.dup
115
+ end
116
+ end
117
+ Blas.call(:gemv, b, a, trans:trans)
92
118
  end
93
119
  else
94
120
  case b.ndim
95
121
  when 1
96
- Blas.call(:gemv, a, b)
122
+ if a.contiguous?
123
+ trans = 'n'
124
+ else
125
+ if a.fortran_contiguous?
126
+ trans = 't'
127
+ a = a.transpose
128
+ else
129
+ trans = 'n'
130
+ a = a.dup
131
+ end
132
+ end
133
+ Blas.call(:gemv, a, b, trans:trans)
97
134
  else
98
- Blas.call(:gemm, a, b)
135
+ if a.contiguous?
136
+ transa = 'n'
137
+ else
138
+ if a.fortran_contiguous?
139
+ transa = 't'
140
+ a = a.transpose
141
+ else
142
+ transa = 'n'
143
+ a = a.dup
144
+ end
145
+ end
146
+ if b.contiguous?
147
+ transb = 'n'
148
+ else
149
+ if b.fortran_contiguous?
150
+ transb='t'
151
+ b = b.transpose
152
+ else
153
+ transb='n'
154
+ b = b.dup
155
+ end
156
+ end
157
+ Blas.call(:gemm, a, b, transa:transa, transb:transb)
99
158
  end
100
159
  end
101
160
  end
@@ -194,10 +253,10 @@ module Numo; module Linalg
194
253
  #
195
254
  # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
196
255
  # @param mode [String]
197
- # - "reduce" -- returns both Q and R,
198
- # - "r" -- returns only R,
199
- # - "economy" -- returns both Q and R but computed in economy-size,
200
- # - "raw" -- returns QR and TAU used in LAPACK.
256
+ # - "reduce" -- returns both Q and R,
257
+ # - "r" -- returns only R,
258
+ # - "economic" -- returns both Q and R but computed in economy-size,
259
+ # - "raw" -- returns QR and TAU used in LAPACK.
201
260
  # @return [r] if mode:"r"
202
261
  # @return [[q,r]] if mode:"reduce" or "economic"
203
262
  # @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
@@ -295,6 +354,74 @@ module Numo; module Linalg
295
354
  end
296
355
  end
297
356
 
357
+ # Computes an orthonormal basis for the range of matrix A.
358
+ #
359
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensional NArray).
360
+ # @param rcond [Float] (optional)
361
+ # rcond is used to determine the effective rank of A.
362
+ # Singular values `s[i] <= rcond * s.max` are treated as zero.
363
+ # If rcond < 0, machine precision is used instead.
364
+ # @return [Numo::NArray] The orthonormal basis for the range of matrix A.
365
+
366
+ def orth(a, rcond: -1)
367
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
368
+ s, u, = svd(a)
369
+ tol = s.max * (rcond.nil? || rcond < 0 ? a.class::EPSILON * a.shape.max : rcond)
370
+ k = (s > tol).count
371
+ u[true, 0...k]
372
+ end
373
+
374
+ # Computes an orthonormal basis for the null space of matrix A.
375
+ #
376
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensional NArray).
377
+ # @param rcond [Float] (optional)
378
+ # rcond is used to determine the effective rank of A.
379
+ # Singular values `s[i] <= rcond * s.max` are treated as zero.
380
+ # If rcond < 0, machine precision is used instead.
381
+ # @return [Numo::NArray] The orthonormal basis for the null space of matrix A.
382
+
383
+ def null_space(a, rcond: -1)
384
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
385
+ s, _u, vh = svd(a)
386
+ tol = s.max * (rcond.nil? || rcond < 0 ? a.class::EPSILON * a.shape.max : rcond)
387
+ k = (s > tol).count
388
+ return a.class.new if k == vh.shape[0]
389
+ r = vh[k..-1, true].transpose.dup
390
+ blas_char(vh) =~ /c|z/ ? r.conj : r
391
+ end
392
+
393
+ # Computes an LU factorization of a M-by-N matrix A
394
+ # using partial pivoting with row interchanges.
395
+ #
396
+ # The factorization has the form
397
+ #
398
+ # A = P * L * U
399
+ #
400
+ # where P is a permutation matrix, L is lower triangular with unit
401
+ # diagonal elements (lower trapezoidal if m > n), and U is upper
402
+ # triangular (upper trapezoidal if m < n).
403
+ #
404
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
405
+ # @param permute_l [Bool] (optional) If true, perform the matrix product of P and L.
406
+ # @return [[p,l,u]] if permute_l == false
407
+ # @return [[pl,u]] if permute_l == true
408
+ #
409
+ # - **p** [Numo::NArray] -- The permutation matrix P.
410
+ # - **l** [Numo::NArray] -- The factor L.
411
+ # - **u** [Numo::NArray] -- The factor U.
412
+
413
+ def lu(a, permute_l: false)
414
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
415
+ m, n = a.shape
416
+ k = [m, n].min
417
+ lu, ip = lu_fact(a)
418
+ l = lu.tril.tap { |mat| mat[mat.diag_indices(0)] = 1.0 }[true, 0...k]
419
+ u = lu.triu[0...k, 0...n]
420
+ p = Numo::DFloat.eye(m).tap do |mat|
421
+ ip.to_a.each_with_index { |i, j| mat[true, [i - 1, j]] = mat[true, [j, i - 1]].dup }
422
+ end
423
+ permute_l ? [p.dot(l), u] : [p, l, u]
424
+ end
298
425
 
299
426
  # Computes an LU factorization of a M-by-N matrix A
300
427
  # using partial pivoting with row interchanges.
@@ -308,26 +435,14 @@ module Numo; module Linalg
308
435
  # triangular (upper trapezoidal if m < n).
309
436
  #
310
437
  # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
311
- # @param driver [String or Symbol] choose LAPACK diriver from
312
- # 'gen','sym','her'. (optional, default='gen')
313
- # @param uplo [String or Symbol] optional, default='U'. Access upper
314
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
315
438
  # @return [[lu, ipiv]]
316
439
  # - **lu** [Numo::NArray] -- The factors L and U from the factorization
317
440
  # `A = P*L*U`; the unit diagonal elements of L are not stored.
318
441
  # - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
319
442
  # row i of the matrix was interchanged with row IPIV(i).
320
443
 
321
- def lu_fact(a, driver:"gen", uplo:"U")
322
- case driver.to_s
323
- when /^gen?(trf)?$/i
324
- Lapack.call(:getrf, a)[0..1]
325
- when /^(sym?|her?)(trf)?$/i
326
- func = driver[0..2].downcase+"trf"
327
- Lapack.call(func, a, uplo:uplo)[0..1]
328
- else
329
- raise ArgumentError, "invalid driver: #{driver}"
330
- end
444
+ def lu_fact(a)
445
+ Lapack.call(:getrf, a)[0..1]
331
446
  end
332
447
 
333
448
  # Computes the inverse of a matrix using the LU factorization
@@ -345,22 +460,10 @@ module Numo; module Linalg
345
460
  # @param ipiv [Numo::NArray] The pivot indices from
346
461
  # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
347
462
  # interchanged with row IPIV(i).
348
- # @param driver [String or Symbol] choose LAPACK diriver from
349
- # 'gen','sym','her'. (optional, default='gen')
350
- # @param uplo [String or Symbol] optional, default='U'. Access upper
351
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
352
463
  # @return [Numo::NArray] the inverse of the original matrix A.
353
464
 
354
- def lu_inv(lu, ipiv, driver:"gen", uplo:"U")
355
- case driver.to_s
356
- when /^gen?(tri)?$/i
357
- Lapack.call(:getri, lu, ipiv)[0]
358
- when /^(sym?|her?)(tri)?$/i
359
- func = driver[0..2].downcase+"tri"
360
- Lapack.call(func, lu, ipiv, uplo:uplo)[0]
361
- else
362
- raise ArgumentError, "invalid driver: #{driver}"
363
- end
465
+ def lu_inv(lu, ipiv)
466
+ Lapack.call(:getri, lu, ipiv)[0]
364
467
  end
365
468
 
366
469
  # Solves a system of linear equations
@@ -377,31 +480,100 @@ module Numo; module Linalg
377
480
  # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
378
481
  # interchanged with row IPIV(i).
379
482
  # @param b [Numo::NArray] the right hand side matrix B.
380
- # @param driver [String or Symbol] choose LAPACK diriver from
381
- # 'gen','sym','her'. (optional, default='gen')
382
- # @param uplo [String or Symbol] optional, default='U'. Access upper
383
- # or ('U') lower ('L') triangle. (omitted when driver:"gen")
384
483
  # @param trans [String or Symbol]
385
- # Specifies the form of the system of equations
386
- # (omitted if not driver:"gen"):
387
- #
484
+ # Specifies the form of the system of equations:
388
485
  # - If 'N': `A * X = B` (No transpose).
389
486
  # - If 'T': `A*\*T* X = B` (Transpose).
390
487
  # - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
391
488
  # @return [Numo::NArray] the solution matrix X.
392
489
 
393
- def lu_solve(lu, ipiv, b, driver:"gen", uplo:"U", trans:"N")
394
- case driver.to_s
395
- when /^gen?(trs)?$/i
396
- Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
397
- when /^(sym?|her?)(trs)?$/i
398
- func = driver[0..2].downcase+"trs"
399
- Lapack.call(func, lu, ipiv, b, uplo:uplo)[0]
400
- else
401
- raise ArgumentError, "invalid driver: #{driver}"
490
+ def lu_solve(lu, ipiv, b, trans:"N")
491
+ Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
492
+ end
493
+
494
+ # Computes the LDLt or Bunch-Kaufman factorization of a symmetric/Hermitian matrix A.
495
+ # The factorization has the form
496
+ #
497
+ # A = U*D*U**T or A = L*D*L**T
498
+ #
499
+ # where U (or L) is a product of permutation and unit upper (lower) triangular matrices
500
+ # and D is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
501
+ #
502
+ # @param a [Numo::NArray] m-by-m matrix A (>= 2-dimensinal NArray)
503
+ # @param uplo [String or Symbol] optional, default='U'. Access upper or ('U') lower ('L') triangle.
504
+ # @param hermitian [Bool] optional, default=true. If true, hermitian matrix is assumed.
505
+ # (omitted when real-value matrix is given)
506
+ #
507
+ # @return [[lu,d,perm]]
508
+ #
509
+ # - **lu** [Numo::NArray] -- The permutated upper (lower) triangular matrix U (L).
510
+ # - **d** [Numo::NArray] -- The block diagonal matrix D.
511
+ # - **perm** [Numo::NArray] -- The row-permutation index for changing lu into triangular form.
512
+
513
+ def ldl(a, uplo: 'U', hermitian: true)
514
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
515
+ raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
516
+
517
+ is_complex = blas_char(a) =~ /c|z/
518
+ func = is_complex && hermitian ? 'hetrf' : 'sytrf'
519
+ lud, ipiv, = Lapack.call(func.to_sym, a, uplo: uplo)
520
+
521
+ lu = (uplo == 'U' ? lud.triu : lud.tril).tap { |mat| mat[mat.diag_indices(0)] = 1.0 }
522
+ d = lud[lud.diag_indices(0)].diag
523
+
524
+ m = a.shape[0]
525
+ n = m - 1
526
+ changed_2x2 = false
527
+ perm = Numo::Int32.new(m).seq
528
+ m.times do |t|
529
+ i = uplo == 'U' ? t : n - t
530
+ j = uplo == 'U' ? i - 1 : i + 1;
531
+ r = uplo == 'U' ? 0..i : i..n;
532
+ if ipiv[i] > 0
533
+ k = ipiv[i] - 1
534
+ lu[[k, i], r] = lu[[i, k], r].dup
535
+ perm[[k, i]] = perm[[i, k]].dup
536
+ elsif j.between?(0, n) && ipiv[i] == ipiv[j] && !changed_2x2
537
+ k = ipiv[i].abs - 1
538
+ d[j, i] = lud[j, i]
539
+ d[i, j] = is_complex && hermitian ? lud[j, i].conj : lud[j, i]
540
+ lu[j, i] = 0.0
541
+ lu[[k, j], r] = lu[[j, k], r].dup
542
+ perm[[k, j]] = perm[[j, k]].dup
543
+ changed_2x2 = true
544
+ next
545
+ end
546
+ changed_2x2 = false
402
547
  end
548
+
549
+ [lu, d, perm.sort_index]
403
550
  end
404
551
 
552
+ # Computes the Cholesky factorization of a symmetric/Hermitian
553
+ # positive definite matrix A. The factorization has the form
554
+ #
555
+ # A = U**H * U, if UPLO = 'U', or
556
+ # A = L * L**H, if UPLO = 'L',
557
+ #
558
+ # where U is an upper triangular matrix and L is a lower triangular matrix.
559
+ # @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
560
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
561
+ # or ('U') lower ('L') triangle.
562
+ # @return [Numo::NArray] The factor U or L.
563
+
564
+ def cholesky(a, uplo: 'U')
565
+ raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
566
+ raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
567
+ factor = Lapack.call(:potrf, a, uplo: uplo)[0]
568
+ if uplo == 'U'
569
+ factor.triu
570
+ else
571
+ # TODO: Use the tril method if the verision of Numo::NArray
572
+ # in the runtime dependency list becomes 0.9.1.3 or higher.
573
+ m, = a.shape
574
+ factor * Numo::DFloat.ones(m, m).triu.transpose
575
+ end
576
+ end
405
577
 
406
578
  # Computes the Cholesky factorization of a symmetric/Hermitian
407
579
  # positive definite matrix A. The factorization has the form
@@ -409,16 +581,16 @@ module Numo; module Linalg
409
581
  # A = U**H * U, if UPLO = 'U', or
410
582
  # A = L * L**H, if UPLO = 'L',
411
583
  #
412
- # where U is an upper triangular matrix and L is lower triangular
584
+ # where U is an upper triangular matrix and L is a lower triangular matrix.
413
585
  # @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
414
586
  # @param uplo [String or Symbol] optional, default='U'. Access upper
415
587
  # or ('U') lower ('L') triangle.
416
- # @return [Numo::NArray] the factor U or L.
588
+ # @return [Numo::NArray] The matrix which has the Cholesky factor in upper or lower triangular part.
589
+ # Remain part consists of random values.
417
590
 
418
591
  def cho_fact(a, uplo:'U')
419
592
  Lapack.call(:potrf, a, uplo:uplo)[0]
420
593
  end
421
- #alias cholesky cho_fact
422
594
 
423
595
  # Computes the inverse of a symmetric/Hermitian
424
596
  # positive definite matrix A using the Cholesky factorization
@@ -479,27 +651,41 @@ module Numo; module Linalg
479
651
  [w,vl,vr] #.compact
480
652
  end
481
653
 
482
- # Computes the eigenvalues and, optionally, the left and/or right
483
- # eigenvectors for a square symmetric/hermitian matrix A.
654
+ # Obtains the eigenvalues and, optionally, the eigenvectors
655
+ # by solving an ordinary or generalized eigenvalue problem
656
+ # for a square symmetric / Hermitian matrix.
484
657
  #
485
- # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
486
- # @param values_only [Bool] (optional) If false, eigenvectors are computed.
658
+ # @param a [Numo::NArray] square symmetric matrix (>= 2-dimensinal NArray)
659
+ # @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
660
+ # If nil, identity matrix is assumed.
661
+ # @param vals_only [Bool] (optional) If false, eigenvectors are computed.
662
+ # @param vals_range [Range] (optional)
663
+ # The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
664
+ # If nil or 0...N (N is the size of the matrix a), all eigenvalues and eigenvectors are returned.
487
665
  # @param uplo [String or Symbol] (optional, default='U')
488
666
  # Access upper ('U') or lower ('L') triangle.
667
+ # @param turbo [Bool] (optional) If true, divide and conquer algorithm is used.
489
668
  # @return [[w,v]]
490
669
  # - **w** [Numo::NArray] -- The eigenvalues.
491
670
  # - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
492
671
 
493
- def eigh(a, vals_only:false, uplo:false, turbo:false)
672
+ def eigh(a, b=nil, vals_only:false, vals_range:nil, uplo:'U', turbo:false)
494
673
  jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
495
- case blas_char(a)
496
- when /c|z/
497
- func = turbo ? :hegv : :heev
674
+ b = a.class.eye(a.shape[0]) if b.nil?
675
+ func = blas_char(a, b) =~ /c|z/ ? 'hegv' : 'sygv'
676
+ if vals_range.nil?
677
+ func << 'd' if turbo
678
+ v, u_, w, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz)
679
+ v = nil if vals_only
680
+ [w, v]
498
681
  else
499
- func = turbo ? :sygv : :syev
682
+ func << 'x'
683
+ il = vals_range.first(1)[0]
684
+ iu = vals_range.last(1)[0]
685
+ a_, b_, w, v, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz, range: 'I', il: il + 1, iu: iu + 1)
686
+ v = nil if vals_only
687
+ [w, v]
500
688
  end
501
- w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
502
- [w,v] #.compact
503
689
  end
504
690
 
505
691
  # Computes the eigenvalues only for a square nonsymmetric matrix A.
@@ -519,23 +705,22 @@ module Numo; module Linalg
519
705
  w
520
706
  end
521
707
 
522
- # Computes the eigenvalues for a square symmetric/hermitian matrix A.
708
+ # Obtains the eigenvalues by solving an ordinary or generalized eigenvalue problem
709
+ # for a square symmetric / Hermitian matrix.
523
710
  #
524
711
  # @param a [Numo::NArray] square symmetric/hermitian matrix
525
712
  # (>= 2-dimensinal NArray)
713
+ # @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
714
+ # If nil, identity matrix is assumed.
715
+ # @param vals_range [Range] (optional)
716
+ # The range of indices of the eigenvalues (in ascending order) to be returned.
717
+ # If nil or 0...N (N is the size of the matrix a), all eigenvalues are returned.
526
718
  # @param uplo [String or Symbol] (optional, default='U')
527
719
  # Access upper ('U') or lower ('L') triangle.
528
720
  # @return [Numo::NArray] eigenvalues
529
721
 
530
- def eigvalsh(a, uplo:false, turbo:false)
531
- jobz = 'N' # jobz: Compute eigenvalues and eigenvectors.
532
- case blas_char(a)
533
- when /c|z/
534
- func = turbo ? :hegv : :heev
535
- else
536
- func = turbo ? :sygv : :syev
537
- end
538
- Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
722
+ def eigvalsh(a, b=nil, vals_range:nil, uplo:'U', turbo:false)
723
+ eigh(a, b, vals_only: true, vals_range: vals_range, uplo: uplo, turbo: turbo).first
539
724
  end
540
725
 
541
726
 
@@ -801,7 +986,7 @@ module Numo; module Linalg
801
986
  # returns lu, x, ipiv, info
802
987
  Lapack.call(:gesv, a, b)[1]
803
988
  when /^(sym?|her?|pos?)(sv)?$/i
804
- func = driver[0..2].downcase+"sv"
989
+ func = driver[0..1].downcase+"sv"
805
990
  Lapack.call(func, a, b, uplo:uplo)[1]
806
991
  else
807
992
  raise ArgumentError, "invalid driver: #{driver}"
@@ -834,8 +1019,8 @@ module Numo; module Linalg
834
1019
  solve(a, b, driver:d, uplo:uplo)
835
1020
  when /(ge|sy|he)tr[fi]$/
836
1021
  d = $1
837
- lu, piv = lu_fact(a, driver:d, uplo:uplo)
838
- lu_inv(lu, piv, driver:d, uplo:uplo)
1022
+ lu, piv = lu_fact(a)
1023
+ lu_inv(lu, piv)
839
1024
  when /potr[fi]$/
840
1025
  lu = cho_fact(a, uplo:uplo)
841
1026
  cho_inv(lu, uplo:uplo)
@@ -913,14 +1098,11 @@ module Numo; module Linalg
913
1098
  when n
914
1099
  resids = (x[n..-1,true].abs**2).sum(axis:0)
915
1100
  when NArray
916
- if true
917
- resids = (x[false,n..-1,true].abs**2).sum(axis:-2)
918
- else
919
- resids = x[false,0,true].new_zeros
920
- mask = rank.eq(n)
921
- # NArray does not suppurt this yet.
922
- resids[mask,true] = (x[mask,n..-1,true].abs**2).sum(axis:-2)
923
- end
1101
+ resids = (x[false,n..-1,true].abs**2).sum(axis:-2)
1102
+ ## NArray does not suppurt this yet.
1103
+ # resids = x[false,0,true].new_zeros
1104
+ # mask = rank.eq(n)
1105
+ # resids[mask,true] = (x[mask,n..-1,true].abs**2).sum(axis:-2)
924
1106
  end
925
1107
  end
926
1108
  x = x[false,0...n,true]
@@ -992,6 +1174,49 @@ module Numo; module Linalg
992
1174
  end
993
1175
  end
994
1176
 
1177
+ # Compute the matrix exponential using Pade approximation method.
1178
+ #
1179
+ # @param a [Numo::NArray] square matrix (>= 2-dimensinal NArray)
1180
+ # @param ord [Integer] order of approximation
1181
+ # @return [Numo::NArray]
1182
+ # @example
1183
+ # a = Numo::Linalg.expm(Numo::DFloat.zeros([2,2]))
1184
+ # => Numo::DFloat#shape=[2,2]
1185
+ # [[1, 0],
1186
+ # [0, 1]]
1187
+ # b = Numo::Linalg.expm(Numo::DFloat[[1, 2], [-1, 3]] * Complex::I)
1188
+ # => Numo::DComplex#shape=[2,2]
1189
+ # [[0.426459+1.89218i, -2.13721-0.978113i],
1190
+ # [1.06861+0.489056i, -1.71076+0.914063i]]
1191
+
1192
+ def expm(a, ord=8)
1193
+ raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
1194
+
1195
+ inf_norm = norm(a, 'inf')
1196
+ n_squarings = inf_norm.zero? ? 0 : [0, Math.log2(inf_norm).ceil.to_i].max
1197
+ a = a / (2**n_squarings)
1198
+
1199
+ sz_mat = a.shape[0]
1200
+ c = 1
1201
+ s = -1
1202
+ x = Numo::DFloat.eye(sz_mat)
1203
+ n = Numo::DFloat.eye(sz_mat)
1204
+ d = Numo::DFloat.eye(sz_mat)
1205
+
1206
+ (1..ord).each do |k|
1207
+ c *= (ord - k + 1).fdiv((2 * ord - k + 1) * k)
1208
+ x = a.dot(x)
1209
+ cx = c * x
1210
+ n += cx
1211
+ d += s * cx
1212
+ s *= -1
1213
+ end
1214
+
1215
+ res = solve(d, n)
1216
+ n_squarings.times { res = res.dot(res) }
1217
+ res
1218
+ end
1219
+
995
1220
  # @!visibility private
996
1221
  def _make_complex_eigvecs(w, vin) # :nodoc:
997
1222
  v = w.class.cast(vin)