numo-linalg 0.1.1 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +13 -9
- data/Rakefile +17 -0
- data/ext/numo/linalg/blas/blas.c +20 -6
- data/ext/numo/linalg/blas/depend.erb +2 -2
- data/ext/numo/linalg/blas/extconf.rb +15 -42
- data/ext/numo/linalg/blas/numo_blas.h +6 -0
- data/ext/numo/linalg/blas/tmpl/mv.c +3 -2
- data/ext/numo/linalg/lapack/depend.erb +2 -2
- data/ext/numo/linalg/lapack/extconf.rb +14 -19
- data/ext/numo/linalg/lapack/gen/spec.rb +5 -0
- data/ext/numo/linalg/lapack/lapack.c +49 -6
- data/ext/numo/linalg/lapack/numo_lapack.h +3 -0
- data/ext/numo/linalg/lapack/tmpl/gqr.c +1 -1
- data/ext/numo/linalg/lapack/tmpl/sygvx.c +130 -0
- data/ext/numo/linalg/mkmf_linalg.rb +84 -0
- data/lib/numo/linalg/autoloader.rb +18 -3
- data/lib/numo/linalg/function.rb +317 -92
- data/lib/numo/linalg/loader.rb +70 -77
- data/lib/numo/linalg/version.rb +1 -1
- data/numo-linalg.gemspec +5 -3
- data/spec/linalg/autoloader_spec.rb +10 -0
- data/spec/linalg/function/cho_fact_spec.rb +31 -0
- data/spec/linalg/function/cho_inv_spec.rb +39 -0
- data/spec/linalg/function/cho_solve_spec.rb +66 -0
- data/spec/linalg/function/cholesky_spec.rb +43 -0
- data/spec/linalg/function/cond_spec.rb +57 -0
- data/spec/linalg/function/det_spec.rb +21 -0
- data/spec/linalg/function/dot_spec.rb +84 -0
- data/spec/linalg/function/eig_spec.rb +53 -0
- data/spec/linalg/function/eigh_spec.rb +81 -0
- data/spec/linalg/function/eigvals_spec.rb +27 -0
- data/spec/linalg/function/eigvalsh_spec.rb +60 -0
- data/spec/linalg/function/expm_spec.rb +36 -0
- data/spec/linalg/function/inv_spec.rb +57 -0
- data/spec/linalg/function/ldl_spec.rb +51 -0
- data/spec/linalg/function/lstsq_spec.rb +80 -0
- data/spec/linalg/function/lu_fact_spec.rb +34 -0
- data/spec/linalg/function/lu_inv_spec.rb +21 -0
- data/spec/linalg/function/lu_solve_spec.rb +40 -0
- data/spec/linalg/function/lu_spec.rb +46 -0
- data/spec/linalg/function/matmul_spec.rb +41 -0
- data/spec/linalg/function/matrix_power_spec.rb +31 -0
- data/spec/linalg/function/matrix_rank_spec.rb +33 -0
- data/spec/linalg/function/norm_spec.rb +81 -0
- data/spec/linalg/function/null_space_spec.rb +41 -0
- data/spec/linalg/function/orth_spec.rb +43 -0
- data/spec/linalg/function/pinv_spec.rb +48 -0
- data/spec/linalg/function/qr_spec.rb +82 -0
- data/spec/linalg/function/slogdet_spec.rb +21 -0
- data/spec/linalg/function/solve_spec.rb +98 -0
- data/spec/linalg/function/svd_spec.rb +88 -0
- data/spec/linalg/function/svdvals_spec.rb +40 -0
- data/spec/spec_helper.rb +55 -0
- metadata +107 -14
- data/spec/lapack_spec.rb +0 -13
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
#define args_t <%=func_name%>_args_t
|
|
2
|
+
#define func_p <%=func_name%>_p
|
|
3
|
+
|
|
4
|
+
typedef struct {
|
|
5
|
+
int order;
|
|
6
|
+
int itype;
|
|
7
|
+
char jobz;
|
|
8
|
+
char uplo;
|
|
9
|
+
char range;
|
|
10
|
+
int il;
|
|
11
|
+
int iu;
|
|
12
|
+
} args_t;
|
|
13
|
+
|
|
14
|
+
static <%=func_name%>_t func_p = 0;
|
|
15
|
+
|
|
16
|
+
static void
|
|
17
|
+
<%=c_iter%>(na_loop_t * const lp)
|
|
18
|
+
{
|
|
19
|
+
dtype *a, *b, *z;
|
|
20
|
+
rtype *w;
|
|
21
|
+
int *ifail;
|
|
22
|
+
int *info;
|
|
23
|
+
int m, n, lda, ldb, ldz;
|
|
24
|
+
rtype vl = 0, vu = 0;
|
|
25
|
+
rtype abstol = 0;
|
|
26
|
+
|
|
27
|
+
args_t *g;
|
|
28
|
+
|
|
29
|
+
a = (dtype*)NDL_PTR(lp, 0);
|
|
30
|
+
b = (dtype*)NDL_PTR(lp, 1);
|
|
31
|
+
w = (rtype*)NDL_PTR(lp, 2);
|
|
32
|
+
z = (dtype*)NDL_PTR(lp, 3);
|
|
33
|
+
ifail = (int*)NDL_PTR(lp, 4);
|
|
34
|
+
info = (int*)NDL_PTR(lp, 5);
|
|
35
|
+
g = (args_t*)(lp->opt_ptr);
|
|
36
|
+
|
|
37
|
+
n = NDL_SHAPE(lp, 0)[1];
|
|
38
|
+
lda = NDL_STEP(lp, 0) / sizeof(dtype);
|
|
39
|
+
ldb = NDL_STEP(lp, 1) / sizeof(dtype);
|
|
40
|
+
ldz = NDL_SHAPE(lp, 3)[1];
|
|
41
|
+
|
|
42
|
+
*info = (*func_p)( g->order, g->itype, g->jobz, g->range, g->uplo, n, a, lda, b, ldb,
|
|
43
|
+
vl, vu, g->il, g->iu, abstol, &m, w, z, ldz, ifail );
|
|
44
|
+
|
|
45
|
+
CHECK_ERROR(*info);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/*
|
|
49
|
+
<%
|
|
50
|
+
params = [
|
|
51
|
+
mat("a",:inplace),
|
|
52
|
+
mat("b",:inplace),
|
|
53
|
+
"@param [Integer] itype Specifies the problem type to be solved. If 1: A*x = (lambda)*B*x, If 2: A*B*x = (lambda)*x, If 3: B*A*x = (lambda)*x.",
|
|
54
|
+
jobe("jobz"),
|
|
55
|
+
opt("uplo"),
|
|
56
|
+
opt("order"),
|
|
57
|
+
"@param [String or Symbol] range If 'A': Compute all eigenvalues, if 'I': Compute eigenvalues with indices il to iu (default='A')",
|
|
58
|
+
"@param [Integer] il Specifies the index of the smallest eigenvalue in ascending order to be returned. If range = 'A', il is not referenced.",
|
|
59
|
+
"@param [Integer] iu Specifies the index of the largest eigenvalue in ascending order to be returned. Constraint: 1<=il<=iu<=N. If range = 'A', iu is not referenced.",
|
|
60
|
+
].select{|x| x}.join("\n ")
|
|
61
|
+
return_name="a, b, w, z, ifail, info"
|
|
62
|
+
%>
|
|
63
|
+
@overload <%=name%>(a, b, [itype:1, jobz:'V', uplo:'U', order:'R', range:'I', il: 1, il: 2])
|
|
64
|
+
<%=params%>
|
|
65
|
+
@return [[<%=return_name%>]]
|
|
66
|
+
Array<<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,Integer>
|
|
67
|
+
<%=outparam(return_name)%>
|
|
68
|
+
|
|
69
|
+
<%=description%>
|
|
70
|
+
*/
|
|
71
|
+
|
|
72
|
+
static VALUE
|
|
73
|
+
<%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
|
|
74
|
+
{
|
|
75
|
+
VALUE a, b, ans;
|
|
76
|
+
int n, nb, m;
|
|
77
|
+
narray_t *na1, *na2;
|
|
78
|
+
size_t w_shape[1];
|
|
79
|
+
size_t z_shape[2];
|
|
80
|
+
size_t ifail_shape[1];
|
|
81
|
+
|
|
82
|
+
ndfunc_arg_in_t ain[2] = {{OVERWRITE, 2}, {OVERWRITE, 2}};
|
|
83
|
+
ndfunc_arg_out_t aout[4] = {{cRT, 1, w_shape}, {cT, 2, z_shape}, {cI, 1, ifail_shape}, {cInt, 0}};
|
|
84
|
+
ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP | NDF_EXTRACT, 2, 4, ain, aout};
|
|
85
|
+
|
|
86
|
+
args_t g;
|
|
87
|
+
VALUE opts[7] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
|
88
|
+
VALUE kw_hash = Qnil;
|
|
89
|
+
ID kw_table[7] = {id_order, id_jobz, id_uplo, id_itype, id_range, id_il, id_iu};
|
|
90
|
+
|
|
91
|
+
CHECK_FUNC(func_p,"<%=func_name%>");
|
|
92
|
+
|
|
93
|
+
rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
|
|
94
|
+
rb_get_kwargs(kw_hash, kw_table, 0, 7, opts);
|
|
95
|
+
g.order = option_order(opts[0]);
|
|
96
|
+
g.jobz = option_job(opts[1], 'V', 'N');
|
|
97
|
+
g.uplo = option_uplo(opts[2]);
|
|
98
|
+
g.itype = NUM2INT(option_value(opts[3], INT2FIX(1)));
|
|
99
|
+
g.range = option_range(opts[4], 'A', 'I');
|
|
100
|
+
g.il = NUM2INT(option_value(opts[5], INT2FIX(1)));
|
|
101
|
+
g.iu = NUM2INT(option_value(opts[6], INT2FIX(1)));
|
|
102
|
+
|
|
103
|
+
COPY_OR_CAST_TO(a, cT);
|
|
104
|
+
GetNArray(a, na1);
|
|
105
|
+
CHECK_DIM_GE(na1, 2);
|
|
106
|
+
|
|
107
|
+
COPY_OR_CAST_TO(b, cT);
|
|
108
|
+
GetNArray(b, na2);
|
|
109
|
+
CHECK_DIM_GE(na2, 2);
|
|
110
|
+
CHECK_SQUARE("matrix a", na1);
|
|
111
|
+
n = COL_SIZE(na1);
|
|
112
|
+
CHECK_SQUARE("matrix b", na2);
|
|
113
|
+
nb = COL_SIZE(na2);
|
|
114
|
+
if (n != nb) {
|
|
115
|
+
rb_raise(nary_eShapeError, "matrix a and b must have same size");
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
m = g.range == 'I' ? g.iu - g.il + 1 : n;
|
|
119
|
+
w_shape[0] = m;
|
|
120
|
+
z_shape[0] = n;
|
|
121
|
+
z_shape[1] = m;
|
|
122
|
+
ifail_shape[0] = m;
|
|
123
|
+
|
|
124
|
+
ans = na_ndloop3(&ndf, &g, 2, a, b);
|
|
125
|
+
|
|
126
|
+
return rb_ary_new3(6, a, b, RARRAY_AREF(ans, 0), RARRAY_AREF(ans, 1), RARRAY_AREF(ans, 2), RARRAY_AREF(ans, 3));
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
#undef args_t
|
|
130
|
+
#undef func_p
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
require 'mkmf'
|
|
2
|
+
|
|
3
|
+
def create_site_conf
|
|
4
|
+
ldirs = [
|
|
5
|
+
dir_config("mkl")[1],
|
|
6
|
+
dir_config("openblas")[1],
|
|
7
|
+
dir_config("atlas")[1],
|
|
8
|
+
dir_config("blas")[1],
|
|
9
|
+
dir_config("lapack")[1],
|
|
10
|
+
]
|
|
11
|
+
bked = with_config("backend")
|
|
12
|
+
|
|
13
|
+
require 'fiddle'
|
|
14
|
+
|
|
15
|
+
message "creating lib/site_conf.rb\n"
|
|
16
|
+
|
|
17
|
+
FileUtils.mkdir_p "lib"
|
|
18
|
+
|
|
19
|
+
ext = detect_library_extension
|
|
20
|
+
|
|
21
|
+
open("lib/site_conf.rb","w"){|f| f.write "
|
|
22
|
+
module Numo
|
|
23
|
+
module Linalg
|
|
24
|
+
|
|
25
|
+
BACKEND = #{bked.inspect}
|
|
26
|
+
MKL_LIBPATH = #{ldirs[0].inspect}
|
|
27
|
+
OPENBLAS_LIBPATH = #{ldirs[1].inspect}
|
|
28
|
+
ATLAS_LIBPATH = #{ldirs[2].inspect}
|
|
29
|
+
BLAS_LIBPATH = #{ldirs[3].inspect}
|
|
30
|
+
LAPACK_LIBPATH = #{ldirs[4].inspect}
|
|
31
|
+
|
|
32
|
+
module Loader
|
|
33
|
+
EXT = '#{ext}'
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
end
|
|
37
|
+
end"
|
|
38
|
+
}
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def detect_library_extension
|
|
42
|
+
case RbConfig::CONFIG['host_os']
|
|
43
|
+
when /mswin|msys|mingw|cygwin/
|
|
44
|
+
'dll'
|
|
45
|
+
when /darwin|mac os/
|
|
46
|
+
'dylib'
|
|
47
|
+
else
|
|
48
|
+
'so'
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
require 'numo/narray'
|
|
53
|
+
|
|
54
|
+
def find_narray_h
|
|
55
|
+
$LOAD_PATH.each do |x|
|
|
56
|
+
if File.exist? File.join(x,'numo/numo/narray.h')
|
|
57
|
+
$INCFLAGS = "-I#{x}/numo " + $INCFLAGS
|
|
58
|
+
break
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
def find_libnarray_a
|
|
64
|
+
$LOAD_PATH.each do |x|
|
|
65
|
+
if File.exist? File.join(x,'numo/libnarray.a')
|
|
66
|
+
$LDFLAGS = "-L#{x}/numo " + $LDFLAGS
|
|
67
|
+
break
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
def create_depend(base_dir)
|
|
73
|
+
require 'erb'
|
|
74
|
+
message "creating depend\n"
|
|
75
|
+
dep_path = "#{base_dir}/depend"
|
|
76
|
+
File.open(dep_path, "w") do |dep|
|
|
77
|
+
dep_erb_path = "#{base_dir}/depend.erb"
|
|
78
|
+
File.open(dep_erb_path, "r") do |dep_erb|
|
|
79
|
+
erb = ERB.new(dep_erb.read)
|
|
80
|
+
erb.filename = dep_erb_path
|
|
81
|
+
dep.print(erb.result)
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
end
|
|
@@ -19,15 +19,23 @@ module Numo
|
|
|
19
19
|
# @return [String] name of loaded backend library (mkl/openblas/lapack)
|
|
20
20
|
def load_library
|
|
21
21
|
mkl_dirs = ['/opt/intel/lib', '/opt/intel/lib64', '/opt/intel/mkl/lib', '/opt/intel/mkl/lib64']
|
|
22
|
-
openblas_dirs = ['/opt/
|
|
22
|
+
openblas_dirs = ['/opt/OpenBLAS/lib', '/opt/OpenBLAS/lib64', '/opt/openblas/lib', '/opt/openblas/lib64',
|
|
23
|
+
'/usr/local/opt/openblas/lib']
|
|
23
24
|
atlas_dirs = ['/opt/atlas/lib', '/opt/atlas/lib64',
|
|
24
25
|
'/usr/lib/atlas', '/usr/lib64/atlas', '/usr/local/opt/atlas/lib']
|
|
25
26
|
lapacke_dirs = ['/opt/lapack/lib', '/opt/lapack/lib64', '/opt/local/lib/lapack',
|
|
26
27
|
'/usr/local/opt/lapack/lib']
|
|
27
28
|
opt_dirs = ['/opt/local/lib', '/opt/local/lib64', '/opt/lib', '/opt/lib64']
|
|
28
29
|
base_dirs = ['/usr/local/lib', '/usr/local/lib64', '/usr/lib', '/usr/lib64']
|
|
30
|
+
base_dirs.concat(Dir["/usr/lib/#{RbConfig::CONFIG['host_cpu']}-*"])
|
|
29
31
|
base_dirs.unshift(*ENV['LD_LIBRARY_PATH'].split(':')) unless ENV['LD_LIBRARY_PATH'].nil?
|
|
30
32
|
|
|
33
|
+
select_dirs(base_dirs)
|
|
34
|
+
select_dirs(opt_dirs)
|
|
35
|
+
select_dirs(lapacke_dirs)
|
|
36
|
+
select_dirs(atlas_dirs)
|
|
37
|
+
select_dirs(mkl_dirs)
|
|
38
|
+
|
|
31
39
|
mkl_libs = find_mkl_libs([*base_dirs, *opt_dirs, *mkl_dirs])
|
|
32
40
|
openblas_libs = find_openblas_libs([*base_dirs, *opt_dirs, *openblas_dirs])
|
|
33
41
|
atlas_libs = find_atlas_libs([*base_dirs, *opt_dirs, *atlas_dirs, *lapacke_dirs])
|
|
@@ -66,11 +74,18 @@ module Numo
|
|
|
66
74
|
end
|
|
67
75
|
end
|
|
68
76
|
|
|
77
|
+
def select_dirs(dirs)
|
|
78
|
+
dirs.select!{|d| Dir.exist?(d)}
|
|
79
|
+
end
|
|
80
|
+
|
|
69
81
|
def find_libs(lib_names, lib_dirs)
|
|
70
82
|
lib_ext = detect_library_extension
|
|
71
83
|
lib_arr = lib_names.map do |l|
|
|
72
|
-
|
|
73
|
-
|
|
84
|
+
x = nil
|
|
85
|
+
lib_dirs.each do |d|
|
|
86
|
+
break if x = Dir.glob("#{d}/lib#{l}{,64}.#{lib_ext}{,.*}").last
|
|
87
|
+
end
|
|
88
|
+
[l.to_sym, x]
|
|
74
89
|
end
|
|
75
90
|
Hash[*lib_arr.flatten]
|
|
76
91
|
end
|
data/lib/numo/linalg/function.rb
CHANGED
|
@@ -12,12 +12,19 @@ module Numo; module Linalg
|
|
|
12
12
|
# defined from data-types of arguments.
|
|
13
13
|
# @param [Symbol] func function name without BLAS char.
|
|
14
14
|
# @param args arguments passed to Blas function.
|
|
15
|
+
# @param kwargs keyword arguments passed to Blas function.
|
|
15
16
|
# @example
|
|
16
17
|
# c = Numo::Linalg::Blas.call(:gemm, a, b)
|
|
17
|
-
def self.call(func
|
|
18
|
+
def self.call(func, *args, **kwargs)
|
|
18
19
|
fn = (Linalg.blas_char(*args) + func.to_s).to_sym
|
|
19
20
|
fn = FIXNAME[fn] || fn
|
|
20
|
-
|
|
21
|
+
if kwargs.empty?
|
|
22
|
+
# This conditional branch is necessary to prevent ArgumentError
|
|
23
|
+
# that occurs in Ruby 2.6 or earlier.
|
|
24
|
+
send(fn, *args)
|
|
25
|
+
else
|
|
26
|
+
send(fn, *args, **kwargs)
|
|
27
|
+
end
|
|
21
28
|
end
|
|
22
29
|
|
|
23
30
|
end
|
|
@@ -34,12 +41,19 @@ module Numo; module Linalg
|
|
|
34
41
|
# defined from data-types of arguments.
|
|
35
42
|
# @param [Symbol,String] func function name without BLAS char.
|
|
36
43
|
# @param args arguments passed to Lapack function.
|
|
44
|
+
# @param kwargs keyword arguments passed to Lapack function.
|
|
37
45
|
# @example
|
|
38
46
|
# s = Numo::Linalg::Lapack.call(:gesv, a)
|
|
39
|
-
def self.call(func
|
|
47
|
+
def self.call(func, *args, **kwargs)
|
|
40
48
|
fn = (Linalg.blas_char(*args) + func.to_s).to_sym
|
|
41
49
|
fn = FIXNAME[fn] || fn
|
|
42
|
-
|
|
50
|
+
if kwargs.empty?
|
|
51
|
+
# This conditional branch is necessary to prevent ArgumentError
|
|
52
|
+
# that occurs in Ruby 2.6 or earlier.
|
|
53
|
+
send(fn, *args)
|
|
54
|
+
else
|
|
55
|
+
send(fn, *args, **kwargs)
|
|
56
|
+
end
|
|
43
57
|
end
|
|
44
58
|
|
|
45
59
|
end
|
|
@@ -65,7 +79,7 @@ module Numo; module Linalg
|
|
|
65
79
|
NArray.array_type(a)
|
|
66
80
|
end
|
|
67
81
|
if k && k < NArray
|
|
68
|
-
t = k::UPCAST[t]
|
|
82
|
+
t = k::UPCAST[t] || t::UPCAST[k]
|
|
69
83
|
end
|
|
70
84
|
end
|
|
71
85
|
BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
|
|
@@ -86,16 +100,61 @@ module Numo; module Linalg
|
|
|
86
100
|
when 1
|
|
87
101
|
case b.ndim
|
|
88
102
|
when 1
|
|
89
|
-
|
|
103
|
+
func = blas_char(a, b) =~ /c|z/ ? :dotu : :dot
|
|
104
|
+
Blas.call(func, a, b)
|
|
90
105
|
else
|
|
91
|
-
|
|
106
|
+
if b.contiguous?
|
|
107
|
+
trans = 't'
|
|
108
|
+
else
|
|
109
|
+
if b.fortran_contiguous?
|
|
110
|
+
trans = 'n'
|
|
111
|
+
b = b.transpose
|
|
112
|
+
else
|
|
113
|
+
trans = 't'
|
|
114
|
+
b = b.dup
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
Blas.call(:gemv, b, a, trans:trans)
|
|
92
118
|
end
|
|
93
119
|
else
|
|
94
120
|
case b.ndim
|
|
95
121
|
when 1
|
|
96
|
-
|
|
122
|
+
if a.contiguous?
|
|
123
|
+
trans = 'n'
|
|
124
|
+
else
|
|
125
|
+
if a.fortran_contiguous?
|
|
126
|
+
trans = 't'
|
|
127
|
+
a = a.transpose
|
|
128
|
+
else
|
|
129
|
+
trans = 'n'
|
|
130
|
+
a = a.dup
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
Blas.call(:gemv, a, b, trans:trans)
|
|
97
134
|
else
|
|
98
|
-
|
|
135
|
+
if a.contiguous?
|
|
136
|
+
transa = 'n'
|
|
137
|
+
else
|
|
138
|
+
if a.fortran_contiguous?
|
|
139
|
+
transa = 't'
|
|
140
|
+
a = a.transpose
|
|
141
|
+
else
|
|
142
|
+
transa = 'n'
|
|
143
|
+
a = a.dup
|
|
144
|
+
end
|
|
145
|
+
end
|
|
146
|
+
if b.contiguous?
|
|
147
|
+
transb = 'n'
|
|
148
|
+
else
|
|
149
|
+
if b.fortran_contiguous?
|
|
150
|
+
transb='t'
|
|
151
|
+
b = b.transpose
|
|
152
|
+
else
|
|
153
|
+
transb='n'
|
|
154
|
+
b = b.dup
|
|
155
|
+
end
|
|
156
|
+
end
|
|
157
|
+
Blas.call(:gemm, a, b, transa:transa, transb:transb)
|
|
99
158
|
end
|
|
100
159
|
end
|
|
101
160
|
end
|
|
@@ -194,10 +253,10 @@ module Numo; module Linalg
|
|
|
194
253
|
#
|
|
195
254
|
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
|
196
255
|
# @param mode [String]
|
|
197
|
-
# - "reduce"
|
|
198
|
-
# - "r"
|
|
199
|
-
# - "
|
|
200
|
-
# - "raw"
|
|
256
|
+
# - "reduce" -- returns both Q and R,
|
|
257
|
+
# - "r" -- returns only R,
|
|
258
|
+
# - "economic" -- returns both Q and R but computed in economy-size,
|
|
259
|
+
# - "raw" -- returns QR and TAU used in LAPACK.
|
|
201
260
|
# @return [r] if mode:"r"
|
|
202
261
|
# @return [[q,r]] if mode:"reduce" or "economic"
|
|
203
262
|
# @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
|
|
@@ -295,6 +354,74 @@ module Numo; module Linalg
|
|
|
295
354
|
end
|
|
296
355
|
end
|
|
297
356
|
|
|
357
|
+
# Computes an orthonormal basis for the range of matrix A.
|
|
358
|
+
#
|
|
359
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensional NArray).
|
|
360
|
+
# @param rcond [Float] (optional)
|
|
361
|
+
# rcond is used to determine the effective rank of A.
|
|
362
|
+
# Singular values `s[i] <= rcond * s.max` are treated as zero.
|
|
363
|
+
# If rcond < 0, machine precision is used instead.
|
|
364
|
+
# @return [Numo::NArray] The orthonormal basis for the range of matrix A.
|
|
365
|
+
|
|
366
|
+
def orth(a, rcond: -1)
|
|
367
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
|
368
|
+
s, u, = svd(a)
|
|
369
|
+
tol = s.max * (rcond.nil? || rcond < 0 ? a.class::EPSILON * a.shape.max : rcond)
|
|
370
|
+
k = (s > tol).count
|
|
371
|
+
u[true, 0...k]
|
|
372
|
+
end
|
|
373
|
+
|
|
374
|
+
# Computes an orthonormal basis for the null space of matrix A.
|
|
375
|
+
#
|
|
376
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensional NArray).
|
|
377
|
+
# @param rcond [Float] (optional)
|
|
378
|
+
# rcond is used to determine the effective rank of A.
|
|
379
|
+
# Singular values `s[i] <= rcond * s.max` are treated as zero.
|
|
380
|
+
# If rcond < 0, machine precision is used instead.
|
|
381
|
+
# @return [Numo::NArray] The orthonormal basis for the null space of matrix A.
|
|
382
|
+
|
|
383
|
+
def null_space(a, rcond: -1)
|
|
384
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
|
385
|
+
s, _u, vh = svd(a)
|
|
386
|
+
tol = s.max * (rcond.nil? || rcond < 0 ? a.class::EPSILON * a.shape.max : rcond)
|
|
387
|
+
k = (s > tol).count
|
|
388
|
+
return a.class.new if k == vh.shape[0]
|
|
389
|
+
r = vh[k..-1, true].transpose.dup
|
|
390
|
+
blas_char(vh) =~ /c|z/ ? r.conj : r
|
|
391
|
+
end
|
|
392
|
+
|
|
393
|
+
# Computes an LU factorization of a M-by-N matrix A
|
|
394
|
+
# using partial pivoting with row interchanges.
|
|
395
|
+
#
|
|
396
|
+
# The factorization has the form
|
|
397
|
+
#
|
|
398
|
+
# A = P * L * U
|
|
399
|
+
#
|
|
400
|
+
# where P is a permutation matrix, L is lower triangular with unit
|
|
401
|
+
# diagonal elements (lower trapezoidal if m > n), and U is upper
|
|
402
|
+
# triangular (upper trapezoidal if m < n).
|
|
403
|
+
#
|
|
404
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
|
405
|
+
# @param permute_l [Bool] (optional) If true, perform the matrix product of P and L.
|
|
406
|
+
# @return [[p,l,u]] if permute_l == false
|
|
407
|
+
# @return [[pl,u]] if permute_l == true
|
|
408
|
+
#
|
|
409
|
+
# - **p** [Numo::NArray] -- The permutation matrix P.
|
|
410
|
+
# - **l** [Numo::NArray] -- The factor L.
|
|
411
|
+
# - **u** [Numo::NArray] -- The factor U.
|
|
412
|
+
|
|
413
|
+
def lu(a, permute_l: false)
|
|
414
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
|
415
|
+
m, n = a.shape
|
|
416
|
+
k = [m, n].min
|
|
417
|
+
lu, ip = lu_fact(a)
|
|
418
|
+
l = lu.tril.tap { |mat| mat[mat.diag_indices(0)] = 1.0 }[true, 0...k]
|
|
419
|
+
u = lu.triu[0...k, 0...n]
|
|
420
|
+
p = Numo::DFloat.eye(m).tap do |mat|
|
|
421
|
+
ip.to_a.each_with_index { |i, j| mat[true, [i - 1, j]] = mat[true, [j, i - 1]].dup }
|
|
422
|
+
end
|
|
423
|
+
permute_l ? [p.dot(l), u] : [p, l, u]
|
|
424
|
+
end
|
|
298
425
|
|
|
299
426
|
# Computes an LU factorization of a M-by-N matrix A
|
|
300
427
|
# using partial pivoting with row interchanges.
|
|
@@ -308,26 +435,14 @@ module Numo; module Linalg
|
|
|
308
435
|
# triangular (upper trapezoidal if m < n).
|
|
309
436
|
#
|
|
310
437
|
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
|
311
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
|
312
|
-
# 'gen','sym','her'. (optional, default='gen')
|
|
313
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
|
314
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
|
315
438
|
# @return [[lu, ipiv]]
|
|
316
439
|
# - **lu** [Numo::NArray] -- The factors L and U from the factorization
|
|
317
440
|
# `A = P*L*U`; the unit diagonal elements of L are not stored.
|
|
318
441
|
# - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
|
|
319
442
|
# row i of the matrix was interchanged with row IPIV(i).
|
|
320
443
|
|
|
321
|
-
def lu_fact(a
|
|
322
|
-
|
|
323
|
-
when /^gen?(trf)?$/i
|
|
324
|
-
Lapack.call(:getrf, a)[0..1]
|
|
325
|
-
when /^(sym?|her?)(trf)?$/i
|
|
326
|
-
func = driver[0..2].downcase+"trf"
|
|
327
|
-
Lapack.call(func, a, uplo:uplo)[0..1]
|
|
328
|
-
else
|
|
329
|
-
raise ArgumentError, "invalid driver: #{driver}"
|
|
330
|
-
end
|
|
444
|
+
def lu_fact(a)
|
|
445
|
+
Lapack.call(:getrf, a)[0..1]
|
|
331
446
|
end
|
|
332
447
|
|
|
333
448
|
# Computes the inverse of a matrix using the LU factorization
|
|
@@ -345,22 +460,10 @@ module Numo; module Linalg
|
|
|
345
460
|
# @param ipiv [Numo::NArray] The pivot indices from
|
|
346
461
|
# Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
|
|
347
462
|
# interchanged with row IPIV(i).
|
|
348
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
|
349
|
-
# 'gen','sym','her'. (optional, default='gen')
|
|
350
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
|
351
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
|
352
463
|
# @return [Numo::NArray] the inverse of the original matrix A.
|
|
353
464
|
|
|
354
|
-
def lu_inv(lu, ipiv
|
|
355
|
-
|
|
356
|
-
when /^gen?(tri)?$/i
|
|
357
|
-
Lapack.call(:getri, lu, ipiv)[0]
|
|
358
|
-
when /^(sym?|her?)(tri)?$/i
|
|
359
|
-
func = driver[0..2].downcase+"tri"
|
|
360
|
-
Lapack.call(func, lu, ipiv, uplo:uplo)[0]
|
|
361
|
-
else
|
|
362
|
-
raise ArgumentError, "invalid driver: #{driver}"
|
|
363
|
-
end
|
|
465
|
+
def lu_inv(lu, ipiv)
|
|
466
|
+
Lapack.call(:getri, lu, ipiv)[0]
|
|
364
467
|
end
|
|
365
468
|
|
|
366
469
|
# Solves a system of linear equations
|
|
@@ -377,31 +480,100 @@ module Numo; module Linalg
|
|
|
377
480
|
# Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
|
|
378
481
|
# interchanged with row IPIV(i).
|
|
379
482
|
# @param b [Numo::NArray] the right hand side matrix B.
|
|
380
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
|
381
|
-
# 'gen','sym','her'. (optional, default='gen')
|
|
382
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
|
383
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
|
384
483
|
# @param trans [String or Symbol]
|
|
385
|
-
# Specifies the form of the system of equations
|
|
386
|
-
# (omitted if not driver:"gen"):
|
|
387
|
-
#
|
|
484
|
+
# Specifies the form of the system of equations:
|
|
388
485
|
# - If 'N': `A * X = B` (No transpose).
|
|
389
486
|
# - If 'T': `A*\*T* X = B` (Transpose).
|
|
390
487
|
# - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
|
|
391
488
|
# @return [Numo::NArray] the solution matrix X.
|
|
392
489
|
|
|
393
|
-
def lu_solve(lu, ipiv, b,
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
490
|
+
def lu_solve(lu, ipiv, b, trans:"N")
|
|
491
|
+
Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
|
|
492
|
+
end
|
|
493
|
+
|
|
494
|
+
# Computes the LDLt or Bunch-Kaufman factorization of a symmetric/Hermitian matrix A.
|
|
495
|
+
# The factorization has the form
|
|
496
|
+
#
|
|
497
|
+
# A = U*D*U**T or A = L*D*L**T
|
|
498
|
+
#
|
|
499
|
+
# where U (or L) is a product of permutation and unit upper (lower) triangular matrices
|
|
500
|
+
# and D is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
|
|
501
|
+
#
|
|
502
|
+
# @param a [Numo::NArray] m-by-m matrix A (>= 2-dimensinal NArray)
|
|
503
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper or ('U') lower ('L') triangle.
|
|
504
|
+
# @param hermitian [Bool] optional, default=true. If true, hermitian matrix is assumed.
|
|
505
|
+
# (omitted when real-value matrix is given)
|
|
506
|
+
#
|
|
507
|
+
# @return [[lu,d,perm]]
|
|
508
|
+
#
|
|
509
|
+
# - **lu** [Numo::NArray] -- The permutated upper (lower) triangular matrix U (L).
|
|
510
|
+
# - **d** [Numo::NArray] -- The block diagonal matrix D.
|
|
511
|
+
# - **perm** [Numo::NArray] -- The row-permutation index for changing lu into triangular form.
|
|
512
|
+
|
|
513
|
+
def ldl(a, uplo: 'U', hermitian: true)
|
|
514
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
|
515
|
+
raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
|
|
516
|
+
|
|
517
|
+
is_complex = blas_char(a) =~ /c|z/
|
|
518
|
+
func = is_complex && hermitian ? 'hetrf' : 'sytrf'
|
|
519
|
+
lud, ipiv, = Lapack.call(func.to_sym, a, uplo: uplo)
|
|
520
|
+
|
|
521
|
+
lu = (uplo == 'U' ? lud.triu : lud.tril).tap { |mat| mat[mat.diag_indices(0)] = 1.0 }
|
|
522
|
+
d = lud[lud.diag_indices(0)].diag
|
|
523
|
+
|
|
524
|
+
m = a.shape[0]
|
|
525
|
+
n = m - 1
|
|
526
|
+
changed_2x2 = false
|
|
527
|
+
perm = Numo::Int32.new(m).seq
|
|
528
|
+
m.times do |t|
|
|
529
|
+
i = uplo == 'U' ? t : n - t
|
|
530
|
+
j = uplo == 'U' ? i - 1 : i + 1;
|
|
531
|
+
r = uplo == 'U' ? 0..i : i..n;
|
|
532
|
+
if ipiv[i] > 0
|
|
533
|
+
k = ipiv[i] - 1
|
|
534
|
+
lu[[k, i], r] = lu[[i, k], r].dup
|
|
535
|
+
perm[[k, i]] = perm[[i, k]].dup
|
|
536
|
+
elsif j.between?(0, n) && ipiv[i] == ipiv[j] && !changed_2x2
|
|
537
|
+
k = ipiv[i].abs - 1
|
|
538
|
+
d[j, i] = lud[j, i]
|
|
539
|
+
d[i, j] = is_complex && hermitian ? lud[j, i].conj : lud[j, i]
|
|
540
|
+
lu[j, i] = 0.0
|
|
541
|
+
lu[[k, j], r] = lu[[j, k], r].dup
|
|
542
|
+
perm[[k, j]] = perm[[j, k]].dup
|
|
543
|
+
changed_2x2 = true
|
|
544
|
+
next
|
|
545
|
+
end
|
|
546
|
+
changed_2x2 = false
|
|
402
547
|
end
|
|
548
|
+
|
|
549
|
+
[lu, d, perm.sort_index]
|
|
403
550
|
end
|
|
404
551
|
|
|
552
|
+
# Computes the Cholesky factorization of a symmetric/Hermitian
|
|
553
|
+
# positive definite matrix A. The factorization has the form
|
|
554
|
+
#
|
|
555
|
+
# A = U**H * U, if UPLO = 'U', or
|
|
556
|
+
# A = L * L**H, if UPLO = 'L',
|
|
557
|
+
#
|
|
558
|
+
# where U is an upper triangular matrix and L is a lower triangular matrix.
|
|
559
|
+
# @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
|
|
560
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
|
561
|
+
# or ('U') lower ('L') triangle.
|
|
562
|
+
# @return [Numo::NArray] The factor U or L.
|
|
563
|
+
|
|
564
|
+
def cholesky(a, uplo: 'U')
|
|
565
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
|
566
|
+
raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
|
|
567
|
+
factor = Lapack.call(:potrf, a, uplo: uplo)[0]
|
|
568
|
+
if uplo == 'U'
|
|
569
|
+
factor.triu
|
|
570
|
+
else
|
|
571
|
+
# TODO: Use the tril method if the verision of Numo::NArray
|
|
572
|
+
# in the runtime dependency list becomes 0.9.1.3 or higher.
|
|
573
|
+
m, = a.shape
|
|
574
|
+
factor * Numo::DFloat.ones(m, m).triu.transpose
|
|
575
|
+
end
|
|
576
|
+
end
|
|
405
577
|
|
|
406
578
|
# Computes the Cholesky factorization of a symmetric/Hermitian
|
|
407
579
|
# positive definite matrix A. The factorization has the form
|
|
@@ -409,16 +581,16 @@ module Numo; module Linalg
|
|
|
409
581
|
# A = U**H * U, if UPLO = 'U', or
|
|
410
582
|
# A = L * L**H, if UPLO = 'L',
|
|
411
583
|
#
|
|
412
|
-
# where U is an upper triangular matrix and L is lower triangular
|
|
584
|
+
# where U is an upper triangular matrix and L is a lower triangular matrix.
|
|
413
585
|
# @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
|
|
414
586
|
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
|
415
587
|
# or ('U') lower ('L') triangle.
|
|
416
|
-
# @return [Numo::NArray] the factor
|
|
588
|
+
# @return [Numo::NArray] The matrix which has the Cholesky factor in upper or lower triangular part.
|
|
589
|
+
# Remain part consists of random values.
|
|
417
590
|
|
|
418
591
|
def cho_fact(a, uplo:'U')
|
|
419
592
|
Lapack.call(:potrf, a, uplo:uplo)[0]
|
|
420
593
|
end
|
|
421
|
-
#alias cholesky cho_fact
|
|
422
594
|
|
|
423
595
|
# Computes the inverse of a symmetric/Hermitian
|
|
424
596
|
# positive definite matrix A using the Cholesky factorization
|
|
@@ -479,27 +651,41 @@ module Numo; module Linalg
|
|
|
479
651
|
[w,vl,vr] #.compact
|
|
480
652
|
end
|
|
481
653
|
|
|
482
|
-
#
|
|
483
|
-
#
|
|
654
|
+
# Obtains the eigenvalues and, optionally, the eigenvectors
|
|
655
|
+
# by solving an ordinary or generalized eigenvalue problem
|
|
656
|
+
# for a square symmetric / Hermitian matrix.
|
|
484
657
|
#
|
|
485
|
-
# @param a [Numo::NArray] square
|
|
486
|
-
# @param
|
|
658
|
+
# @param a [Numo::NArray] square symmetric matrix (>= 2-dimensinal NArray)
|
|
659
|
+
# @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
|
|
660
|
+
# If nil, identity matrix is assumed.
|
|
661
|
+
# @param vals_only [Bool] (optional) If false, eigenvectors are computed.
|
|
662
|
+
# @param vals_range [Range] (optional)
|
|
663
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
|
664
|
+
# If nil or 0...N (N is the size of the matrix a), all eigenvalues and eigenvectors are returned.
|
|
487
665
|
# @param uplo [String or Symbol] (optional, default='U')
|
|
488
666
|
# Access upper ('U') or lower ('L') triangle.
|
|
667
|
+
# @param turbo [Bool] (optional) If true, divide and conquer algorithm is used.
|
|
489
668
|
# @return [[w,v]]
|
|
490
669
|
# - **w** [Numo::NArray] -- The eigenvalues.
|
|
491
670
|
# - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
|
|
492
671
|
|
|
493
|
-
def eigh(a, vals_only:false, uplo:
|
|
672
|
+
def eigh(a, b=nil, vals_only:false, vals_range:nil, uplo:'U', turbo:false)
|
|
494
673
|
jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
674
|
+
b = a.class.eye(a.shape[0]) if b.nil?
|
|
675
|
+
func = blas_char(a, b) =~ /c|z/ ? 'hegv' : 'sygv'
|
|
676
|
+
if vals_range.nil?
|
|
677
|
+
func << 'd' if turbo
|
|
678
|
+
v, u_, w, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz)
|
|
679
|
+
v = nil if vals_only
|
|
680
|
+
[w, v]
|
|
498
681
|
else
|
|
499
|
-
func
|
|
682
|
+
func << 'x'
|
|
683
|
+
il = vals_range.first(1)[0]
|
|
684
|
+
iu = vals_range.last(1)[0]
|
|
685
|
+
a_, b_, w, v, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz, range: 'I', il: il + 1, iu: iu + 1)
|
|
686
|
+
v = nil if vals_only
|
|
687
|
+
[w, v]
|
|
500
688
|
end
|
|
501
|
-
w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
|
|
502
|
-
[w,v] #.compact
|
|
503
689
|
end
|
|
504
690
|
|
|
505
691
|
# Computes the eigenvalues only for a square nonsymmetric matrix A.
|
|
@@ -519,23 +705,22 @@ module Numo; module Linalg
|
|
|
519
705
|
w
|
|
520
706
|
end
|
|
521
707
|
|
|
522
|
-
#
|
|
708
|
+
# Obtains the eigenvalues by solving an ordinary or generalized eigenvalue problem
|
|
709
|
+
# for a square symmetric / Hermitian matrix.
|
|
523
710
|
#
|
|
524
711
|
# @param a [Numo::NArray] square symmetric/hermitian matrix
|
|
525
712
|
# (>= 2-dimensinal NArray)
|
|
713
|
+
# @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
|
|
714
|
+
# If nil, identity matrix is assumed.
|
|
715
|
+
# @param vals_range [Range] (optional)
|
|
716
|
+
# The range of indices of the eigenvalues (in ascending order) to be returned.
|
|
717
|
+
# If nil or 0...N (N is the size of the matrix a), all eigenvalues are returned.
|
|
526
718
|
# @param uplo [String or Symbol] (optional, default='U')
|
|
527
719
|
# Access upper ('U') or lower ('L') triangle.
|
|
528
720
|
# @return [Numo::NArray] eigenvalues
|
|
529
721
|
|
|
530
|
-
def eigvalsh(a, uplo:
|
|
531
|
-
|
|
532
|
-
case blas_char(a)
|
|
533
|
-
when /c|z/
|
|
534
|
-
func = turbo ? :hegv : :heev
|
|
535
|
-
else
|
|
536
|
-
func = turbo ? :sygv : :syev
|
|
537
|
-
end
|
|
538
|
-
Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
|
|
722
|
+
def eigvalsh(a, b=nil, vals_range:nil, uplo:'U', turbo:false)
|
|
723
|
+
eigh(a, b, vals_only: true, vals_range: vals_range, uplo: uplo, turbo: turbo).first
|
|
539
724
|
end
|
|
540
725
|
|
|
541
726
|
|
|
@@ -801,7 +986,7 @@ module Numo; module Linalg
|
|
|
801
986
|
# returns lu, x, ipiv, info
|
|
802
987
|
Lapack.call(:gesv, a, b)[1]
|
|
803
988
|
when /^(sym?|her?|pos?)(sv)?$/i
|
|
804
|
-
func = driver[0..
|
|
989
|
+
func = driver[0..1].downcase+"sv"
|
|
805
990
|
Lapack.call(func, a, b, uplo:uplo)[1]
|
|
806
991
|
else
|
|
807
992
|
raise ArgumentError, "invalid driver: #{driver}"
|
|
@@ -834,8 +1019,8 @@ module Numo; module Linalg
|
|
|
834
1019
|
solve(a, b, driver:d, uplo:uplo)
|
|
835
1020
|
when /(ge|sy|he)tr[fi]$/
|
|
836
1021
|
d = $1
|
|
837
|
-
lu, piv = lu_fact(a
|
|
838
|
-
lu_inv(lu, piv
|
|
1022
|
+
lu, piv = lu_fact(a)
|
|
1023
|
+
lu_inv(lu, piv)
|
|
839
1024
|
when /potr[fi]$/
|
|
840
1025
|
lu = cho_fact(a, uplo:uplo)
|
|
841
1026
|
cho_inv(lu, uplo:uplo)
|
|
@@ -913,14 +1098,11 @@ module Numo; module Linalg
|
|
|
913
1098
|
when n
|
|
914
1099
|
resids = (x[n..-1,true].abs**2).sum(axis:0)
|
|
915
1100
|
when NArray
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
# NArray does not suppurt this yet.
|
|
922
|
-
resids[mask,true] = (x[mask,n..-1,true].abs**2).sum(axis:-2)
|
|
923
|
-
end
|
|
1101
|
+
resids = (x[false,n..-1,true].abs**2).sum(axis:-2)
|
|
1102
|
+
## NArray does not suppurt this yet.
|
|
1103
|
+
# resids = x[false,0,true].new_zeros
|
|
1104
|
+
# mask = rank.eq(n)
|
|
1105
|
+
# resids[mask,true] = (x[mask,n..-1,true].abs**2).sum(axis:-2)
|
|
924
1106
|
end
|
|
925
1107
|
end
|
|
926
1108
|
x = x[false,0...n,true]
|
|
@@ -992,6 +1174,49 @@ module Numo; module Linalg
|
|
|
992
1174
|
end
|
|
993
1175
|
end
|
|
994
1176
|
|
|
1177
|
+
# Compute the matrix exponential using Pade approximation method.
|
|
1178
|
+
#
|
|
1179
|
+
# @param a [Numo::NArray] square matrix (>= 2-dimensinal NArray)
|
|
1180
|
+
# @param ord [Integer] order of approximation
|
|
1181
|
+
# @return [Numo::NArray]
|
|
1182
|
+
# @example
|
|
1183
|
+
# a = Numo::Linalg.expm(Numo::DFloat.zeros([2,2]))
|
|
1184
|
+
# => Numo::DFloat#shape=[2,2]
|
|
1185
|
+
# [[1, 0],
|
|
1186
|
+
# [0, 1]]
|
|
1187
|
+
# b = Numo::Linalg.expm(Numo::DFloat[[1, 2], [-1, 3]] * Complex::I)
|
|
1188
|
+
# => Numo::DComplex#shape=[2,2]
|
|
1189
|
+
# [[0.426459+1.89218i, -2.13721-0.978113i],
|
|
1190
|
+
# [1.06861+0.489056i, -1.71076+0.914063i]]
|
|
1191
|
+
|
|
1192
|
+
def expm(a, ord=8)
|
|
1193
|
+
raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
|
|
1194
|
+
|
|
1195
|
+
inf_norm = norm(a, 'inf')
|
|
1196
|
+
n_squarings = inf_norm.zero? ? 0 : [0, Math.log2(inf_norm).ceil.to_i].max
|
|
1197
|
+
a = a / (2**n_squarings)
|
|
1198
|
+
|
|
1199
|
+
sz_mat = a.shape[0]
|
|
1200
|
+
c = 1
|
|
1201
|
+
s = -1
|
|
1202
|
+
x = Numo::DFloat.eye(sz_mat)
|
|
1203
|
+
n = Numo::DFloat.eye(sz_mat)
|
|
1204
|
+
d = Numo::DFloat.eye(sz_mat)
|
|
1205
|
+
|
|
1206
|
+
(1..ord).each do |k|
|
|
1207
|
+
c *= (ord - k + 1).fdiv((2 * ord - k + 1) * k)
|
|
1208
|
+
x = a.dot(x)
|
|
1209
|
+
cx = c * x
|
|
1210
|
+
n += cx
|
|
1211
|
+
d += s * cx
|
|
1212
|
+
s *= -1
|
|
1213
|
+
end
|
|
1214
|
+
|
|
1215
|
+
res = solve(d, n)
|
|
1216
|
+
n_squarings.times { res = res.dot(res) }
|
|
1217
|
+
res
|
|
1218
|
+
end
|
|
1219
|
+
|
|
995
1220
|
# @!visibility private
|
|
996
1221
|
def _make_complex_eigvecs(w, vin) # :nodoc:
|
|
997
1222
|
v = w.class.cast(vin)
|