numo-linalg 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +7 -3
- data/ext/numo/linalg/blas/extconf.rb +1 -2
- data/ext/numo/linalg/blas/numo_blas.h +6 -0
- data/ext/numo/linalg/blas/tmpl/mv.c +3 -2
- data/ext/numo/linalg/lapack/gen/spec.rb +5 -0
- data/ext/numo/linalg/lapack/lapack.c +29 -0
- data/ext/numo/linalg/lapack/numo_lapack.h +3 -0
- data/ext/numo/linalg/lapack/tmpl/gqr.c +1 -1
- data/ext/numo/linalg/lapack/tmpl/sygvx.c +130 -0
- data/ext/numo/linalg/mkmf_linalg.rb +2 -19
- data/lib/numo/linalg/function.rb +168 -77
- data/lib/numo/linalg/loader.rb +6 -14
- data/lib/numo/linalg/version.rb +1 -1
- data/numo-linalg.gemspec +2 -1
- data/spec/linalg/autoloader_spec.rb +27 -0
- data/spec/linalg/function/cho_fact_spec.rb +31 -0
- data/spec/linalg/function/cho_inv_spec.rb +39 -0
- data/spec/linalg/function/cho_solve_spec.rb +66 -0
- data/spec/linalg/function/cholesky_spec.rb +43 -0
- data/spec/linalg/function/cond_spec.rb +57 -0
- data/spec/linalg/function/det_spec.rb +21 -0
- data/spec/linalg/function/dot_spec.rb +84 -0
- data/spec/linalg/function/eig_spec.rb +53 -0
- data/spec/linalg/function/eigh_spec.rb +81 -0
- data/spec/linalg/function/eigvals_spec.rb +27 -0
- data/spec/linalg/function/eigvalsh_spec.rb +60 -0
- data/spec/linalg/function/inv_spec.rb +57 -0
- data/spec/linalg/function/ldl_spec.rb +51 -0
- data/spec/linalg/function/lstsq_spec.rb +80 -0
- data/spec/linalg/function/lu_fact_spec.rb +34 -0
- data/spec/linalg/function/lu_inv_spec.rb +21 -0
- data/spec/linalg/function/lu_solve_spec.rb +40 -0
- data/spec/linalg/function/lu_spec.rb +46 -0
- data/spec/linalg/function/matmul_spec.rb +41 -0
- data/spec/linalg/function/matrix_power_spec.rb +31 -0
- data/spec/linalg/function/matrix_rank_spec.rb +33 -0
- data/spec/linalg/function/norm_spec.rb +81 -0
- data/spec/linalg/function/pinv_spec.rb +48 -0
- data/spec/linalg/function/qr_spec.rb +82 -0
- data/spec/linalg/function/slogdet_spec.rb +21 -0
- data/spec/linalg/function/solve_spec.rb +98 -0
- data/spec/linalg/function/svd_spec.rb +88 -0
- data/spec/linalg/function/svdvals_spec.rb +40 -0
- data/spec/spec_helper.rb +55 -0
- metadata +79 -6
- data/spec/lapack_spec.rb +0 -13
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ef1ecaaf8a71faee9a7cef26bc35fae50230224856ab5dd6186cee93e5cca69d
|
4
|
+
data.tar.gz: f2e5e5944ad4832ef4ad4db05f6e94142fefb780e19282d94ece2a055cf447a0
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3fcf1506aa63ccbba57208e358102ae0dc537cd80b530e4b359b07d01a7cb1da11261dfd6d7028224175ae0f38014ec0373b99fc339caa2b952ad48c0fb26fd7
|
7
|
+
data.tar.gz: c7a9eb8b65719571120a6a122721e1eaadc240164e182a5095ff45018453a90f4cc48f99ef219305827beace8839bdebb0d27d35dfb5f64ec64f54d5406a8910
|
data/README.md
CHANGED
@@ -18,7 +18,7 @@ This desgin allows you to change backend libraries without re-compiling.
|
|
18
18
|
* Matrix and vector products
|
19
19
|
* dot, matmul
|
20
20
|
* Decomposition
|
21
|
-
* lu\_fact, lu\_inv, lu\_solve, cho\_fact, cho\_inv, cho\_solve
|
21
|
+
* lu, lu\_fact, lu\_inv, lu\_solve, ldl, cholesky, cho\_fact, cho\_inv, cho\_solve,
|
22
22
|
qr, svd, svdvals
|
23
23
|
* Matrix eigenvalues
|
24
24
|
* eig, eigh, eigvals, eigvalsh
|
@@ -78,8 +78,12 @@ require "numo/linalg"
|
|
78
78
|
|
79
79
|
## Authors
|
80
80
|
|
81
|
-
* Masahiro
|
82
|
-
* Makoto
|
81
|
+
* Masahiro Tanaka
|
82
|
+
* Makoto Kishimoto
|
83
|
+
* Atsushi Tatsuma
|
84
|
+
|
85
|
+
## Acknowledgement
|
86
|
+
|
83
87
|
* This work is partly supported by 2016 Ruby Association Grant.
|
84
88
|
|
85
89
|
## ToDo
|
@@ -1,5 +1,4 @@
|
|
1
1
|
require 'mkmf'
|
2
|
-
require 'numo/narray'
|
3
2
|
require_relative '../mkmf_linalg'
|
4
3
|
|
5
4
|
srcs = %w(
|
@@ -21,7 +20,7 @@ if !have_header("numo/narray.h")
|
|
21
20
|
exit(1)
|
22
21
|
end
|
23
22
|
|
24
|
-
if RUBY_PLATFORM =~ /cygwin|mingw/
|
23
|
+
if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
|
25
24
|
find_libnarray_a
|
26
25
|
unless have_library("narray","nary_new")
|
27
26
|
puts "libnarray.a not found"
|
@@ -31,6 +31,12 @@ extern void numo_cblas_check_func(void **func, const char *name);
|
|
31
31
|
#define SWAP_IFROW(order,a,b,tmp) \
|
32
32
|
{ if ((order)==CblasRowMajor) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
|
33
33
|
|
34
|
+
#define SWAP_IFNOTRANS(trans,a,b,tmp) \
|
35
|
+
{ if ((trans)==CblasNoTrans) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
|
36
|
+
|
37
|
+
#define SWAP_IFTRANS(trans,a,b,tmp) \
|
38
|
+
{ if ((trans)!=CblasNoTrans) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
|
39
|
+
|
34
40
|
#define SWAP_IFCOLTR(order,trans,a,b,tmp) \
|
35
41
|
{ if (((order)==CblasRowMajor && (trans)!=CblasNoTrans) || \
|
36
42
|
((order)!=CblasRowMajor && (trans)==CblasNoTrans)) \
|
@@ -79,7 +79,7 @@ static void
|
|
79
79
|
opt("beta"),
|
80
80
|
!is_ge && opt("side"),
|
81
81
|
!is_ge && opt("uplo"),
|
82
|
-
is_ge || is_tr && opt("trans"),
|
82
|
+
(is_ge || is_tr) && opt("trans"),
|
83
83
|
opt("order")
|
84
84
|
].select{|x| x}.join("\n ")
|
85
85
|
%>
|
@@ -143,9 +143,10 @@ static VALUE
|
|
143
143
|
CHECK_DIM_GE(na2,1);
|
144
144
|
nx = COL_SIZE(na2);
|
145
145
|
#if GE
|
146
|
-
|
146
|
+
SWAP_IFCOL(g.order, ma, na, tmp);
|
147
147
|
g.m = ma;
|
148
148
|
g.n = na;
|
149
|
+
SWAP_IFTRANS(g.trans, ma, na, tmp);
|
149
150
|
#else
|
150
151
|
CHECK_SQUARE("a",na1);
|
151
152
|
#endif
|
@@ -6,6 +6,9 @@ def_id "jobz"
|
|
6
6
|
def_id "jobvl"
|
7
7
|
def_id "jobvr"
|
8
8
|
def_id "trans"
|
9
|
+
def_id "range"
|
10
|
+
def_id "il"
|
11
|
+
def_id "iu"
|
9
12
|
def_id "rcond"
|
10
13
|
def_id "itype"
|
11
14
|
def_id "norm"
|
@@ -58,11 +61,13 @@ when /c|z/
|
|
58
61
|
decl "?heevd", "syev"
|
59
62
|
decl "?hegv", "sygv"
|
60
63
|
decl "?hegvd", "sygv"
|
64
|
+
decl "?hegvx", "sygvx"
|
61
65
|
else
|
62
66
|
decl "?syev"
|
63
67
|
decl "?syevd", "syev"
|
64
68
|
decl "?sygv"
|
65
69
|
decl "?sygvd", "sygv"
|
70
|
+
decl "?sygvx"
|
66
71
|
end
|
67
72
|
|
68
73
|
# factorize
|
@@ -124,6 +124,35 @@ numo_lapacke_option_job(VALUE job, char true_char, char false_char)
|
|
124
124
|
return 0;
|
125
125
|
}
|
126
126
|
|
127
|
+
char
|
128
|
+
numo_lapacke_option_range(VALUE job, char true_char, char false_char)
|
129
|
+
{
|
130
|
+
char *ptr, c;
|
131
|
+
|
132
|
+
switch(TYPE(job)) {
|
133
|
+
case T_NIL:
|
134
|
+
case T_UNDEF:
|
135
|
+
case T_TRUE:
|
136
|
+
return true_char;
|
137
|
+
case T_FALSE:
|
138
|
+
return false_char;
|
139
|
+
case T_SYMBOL:
|
140
|
+
job = rb_sym2str(job);
|
141
|
+
case T_STRING:
|
142
|
+
ptr = RSTRING_PTR(job);
|
143
|
+
if (RSTRING_LEN(job) > 0) {
|
144
|
+
c = ptr[0];
|
145
|
+
if (c >= 'a' && c <= 'z') {
|
146
|
+
c -= 'a'-'A';
|
147
|
+
}
|
148
|
+
return c;
|
149
|
+
}
|
150
|
+
break;
|
151
|
+
}
|
152
|
+
rb_raise(rb_eArgError,"invalid value for JOB option");
|
153
|
+
return 0;
|
154
|
+
}
|
155
|
+
|
127
156
|
char
|
128
157
|
numo_lapacke_option_trans(VALUE trans)
|
129
158
|
{
|
@@ -15,6 +15,9 @@ extern int numo_lapacke_option_order(VALUE order);
|
|
15
15
|
#define option_job numo_lapacke_option_job
|
16
16
|
extern char numo_lapacke_option_job(VALUE job, char true_char, char false_char);
|
17
17
|
|
18
|
+
#define option_range numo_lapacke_option_range
|
19
|
+
extern char numo_lapacke_option_range(VALUE range, char true_char, char false_char);
|
20
|
+
|
18
21
|
#define option_trans numo_lapacke_option_trans
|
19
22
|
extern char numo_lapacke_option_trans(VALUE trans);
|
20
23
|
|
@@ -30,7 +30,7 @@ static void
|
|
30
30
|
SWAP_IFCOL(g->order,m,n);
|
31
31
|
lda = NDL_STEP(lp,0) / sizeof(dtype);
|
32
32
|
|
33
|
-
printf("order=%d m=%d n=%d k=%d lda=%d \n",g->order,m,n,k,lda);
|
33
|
+
//printf("order=%d m=%d n=%d k=%d lda=%d \n",g->order,m,n,k,lda);
|
34
34
|
|
35
35
|
*info = (*func_p)(g->order, m, n, k, a, lda, tau);
|
36
36
|
CHECK_ERROR(*info);
|
@@ -0,0 +1,130 @@
|
|
1
|
+
#define args_t <%=func_name%>_args_t
|
2
|
+
#define func_p <%=func_name%>_p
|
3
|
+
|
4
|
+
typedef struct {
|
5
|
+
int order;
|
6
|
+
int itype;
|
7
|
+
char jobz;
|
8
|
+
char uplo;
|
9
|
+
char range;
|
10
|
+
int il;
|
11
|
+
int iu;
|
12
|
+
} args_t;
|
13
|
+
|
14
|
+
static <%=func_name%>_t func_p = 0;
|
15
|
+
|
16
|
+
static void
|
17
|
+
<%=c_iter%>(na_loop_t * const lp)
|
18
|
+
{
|
19
|
+
dtype *a, *b, *z;
|
20
|
+
rtype *w;
|
21
|
+
int *ifail;
|
22
|
+
int *info;
|
23
|
+
int m, n, lda, ldb, ldz;
|
24
|
+
rtype vl = 0, vu = 0;
|
25
|
+
rtype abstol = 0;
|
26
|
+
|
27
|
+
args_t *g;
|
28
|
+
|
29
|
+
a = (dtype*)NDL_PTR(lp, 0);
|
30
|
+
b = (dtype*)NDL_PTR(lp, 1);
|
31
|
+
w = (rtype*)NDL_PTR(lp, 2);
|
32
|
+
z = (dtype*)NDL_PTR(lp, 3);
|
33
|
+
ifail = (int*)NDL_PTR(lp, 4);
|
34
|
+
info = (int*)NDL_PTR(lp, 5);
|
35
|
+
g = (args_t*)(lp->opt_ptr);
|
36
|
+
|
37
|
+
n = NDL_SHAPE(lp, 0)[1];
|
38
|
+
lda = NDL_STEP(lp, 0) / sizeof(dtype);
|
39
|
+
ldb = NDL_STEP(lp, 1) / sizeof(dtype);
|
40
|
+
ldz = NDL_SHAPE(lp, 3)[1];
|
41
|
+
|
42
|
+
*info = (*func_p)( g->order, g->itype, g->jobz, g->range, g->uplo, n, a, lda, b, ldb,
|
43
|
+
vl, vu, g->il, g->iu, abstol, &m, w, z, ldz, ifail );
|
44
|
+
|
45
|
+
CHECK_ERROR(*info);
|
46
|
+
}
|
47
|
+
|
48
|
+
/*
|
49
|
+
<%
|
50
|
+
params = [
|
51
|
+
mat("a",:inplace),
|
52
|
+
mat("b",:inplace),
|
53
|
+
"@param [Integer] itype Specifies the problem type to be solved. If 1: A*x = (lambda)*B*x, If 2: A*B*x = (lambda)*x, If 3: B*A*x = (lambda)*x.",
|
54
|
+
jobe("jobz"),
|
55
|
+
opt("uplo"),
|
56
|
+
opt("order"),
|
57
|
+
"@param [String or Symbol] range If 'A': Compute all eigenvalues, if 'I': Compute eigenvalues with indices il to iu (default='A')",
|
58
|
+
"@param [Integer] il Specifies the index of the smallest eigenvalue in ascending order to be returned. If range = 'A', il is not referenced.",
|
59
|
+
"@param [Integer] iu Specifies the index of the largest eigenvalue in ascending order to be returned. Constraint: 1<=il<=iu<=N. If range = 'A', iu is not referenced.",
|
60
|
+
].select{|x| x}.join("\n ")
|
61
|
+
return_name="a, b, w, z, ifail, info"
|
62
|
+
%>
|
63
|
+
@overload <%=name%>(a, b, [itype:1, jobz:'V', uplo:'U', order:'R', range:'I', il: 1, il: 2])
|
64
|
+
<%=params%>
|
65
|
+
@return [[<%=return_name%>]]
|
66
|
+
Array<<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,Integer>
|
67
|
+
<%=outparam(return_name)%>
|
68
|
+
|
69
|
+
<%=description%>
|
70
|
+
*/
|
71
|
+
|
72
|
+
static VALUE
|
73
|
+
<%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
|
74
|
+
{
|
75
|
+
VALUE a, b, ans;
|
76
|
+
int n, nb, m;
|
77
|
+
narray_t *na1, *na2;
|
78
|
+
size_t w_shape[1];
|
79
|
+
size_t z_shape[2];
|
80
|
+
size_t ifail_shape[1];
|
81
|
+
|
82
|
+
ndfunc_arg_in_t ain[2] = {{OVERWRITE, 2}, {OVERWRITE, 2}};
|
83
|
+
ndfunc_arg_out_t aout[4] = {{cRT, 1, w_shape}, {cT, 2, z_shape}, {cI, 1, ifail_shape}, {cInt, 0}};
|
84
|
+
ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP | NDF_EXTRACT, 2, 4, ain, aout};
|
85
|
+
|
86
|
+
args_t g;
|
87
|
+
VALUE opts[7] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
88
|
+
VALUE kw_hash = Qnil;
|
89
|
+
ID kw_table[7] = {id_order, id_jobz, id_uplo, id_itype, id_range, id_il, id_iu};
|
90
|
+
|
91
|
+
CHECK_FUNC(func_p,"<%=func_name%>");
|
92
|
+
|
93
|
+
rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
|
94
|
+
rb_get_kwargs(kw_hash, kw_table, 0, 7, opts);
|
95
|
+
g.order = option_order(opts[0]);
|
96
|
+
g.jobz = option_job(opts[1], 'V', 'N');
|
97
|
+
g.uplo = option_uplo(opts[2]);
|
98
|
+
g.itype = NUM2INT(option_value(opts[3], INT2FIX(1)));
|
99
|
+
g.range = option_range(opts[4], 'A', 'I');
|
100
|
+
g.il = NUM2INT(option_value(opts[5], INT2FIX(1)));
|
101
|
+
g.iu = NUM2INT(option_value(opts[6], INT2FIX(1)));
|
102
|
+
|
103
|
+
COPY_OR_CAST_TO(a, cT);
|
104
|
+
GetNArray(a, na1);
|
105
|
+
CHECK_DIM_GE(na1, 2);
|
106
|
+
|
107
|
+
COPY_OR_CAST_TO(b, cT);
|
108
|
+
GetNArray(b, na2);
|
109
|
+
CHECK_DIM_GE(na2, 2);
|
110
|
+
CHECK_SQUARE("matrix a", na1);
|
111
|
+
n = COL_SIZE(na1);
|
112
|
+
CHECK_SQUARE("matrix b", na2);
|
113
|
+
nb = COL_SIZE(na2);
|
114
|
+
if (n != nb) {
|
115
|
+
rb_raise(nary_eShapeError, "matrix a and b must have same size");
|
116
|
+
}
|
117
|
+
|
118
|
+
m = g.range == 'I' ? g.iu - g.il + 1 : n;
|
119
|
+
w_shape[0] = m;
|
120
|
+
z_shape[0] = n;
|
121
|
+
z_shape[1] = m;
|
122
|
+
ifail_shape[0] = m;
|
123
|
+
|
124
|
+
ans = na_ndloop3(&ndf, &g, 2, a, b);
|
125
|
+
|
126
|
+
return rb_ary_new3(6, a, b, RARRAY_AREF(ans, 0), RARRAY_AREF(ans, 1), RARRAY_AREF(ans, 2), RARRAY_AREF(ans, 3));
|
127
|
+
}
|
128
|
+
|
129
|
+
#undef args_t
|
130
|
+
#undef func_p
|
@@ -17,24 +17,6 @@ def create_site_conf
|
|
17
17
|
FileUtils.mkdir_p "lib"
|
18
18
|
|
19
19
|
ext = detect_library_extension
|
20
|
-
need_version = false
|
21
|
-
if ext == 'so'
|
22
|
-
begin
|
23
|
-
Fiddle.dlopen "libm.so"
|
24
|
-
rescue
|
25
|
-
(5..7).each do |i|
|
26
|
-
begin
|
27
|
-
Fiddle.dlopen "libm.so.#{i}"
|
28
|
-
need_version = true
|
29
|
-
break
|
30
|
-
rescue
|
31
|
-
end
|
32
|
-
end
|
33
|
-
if !need_version
|
34
|
-
raise "failed to check whether dynamically linked shared object needs version suffix"
|
35
|
-
end
|
36
|
-
end
|
37
|
-
end
|
38
20
|
|
39
21
|
open("lib/site_conf.rb","w"){|f| f.write "
|
40
22
|
module Numo
|
@@ -49,7 +31,6 @@ module Numo
|
|
49
31
|
|
50
32
|
module Loader
|
51
33
|
EXT = '#{ext}'
|
52
|
-
NEED_VERSION_SUFFIX = #{need_version}
|
53
34
|
end
|
54
35
|
|
55
36
|
end
|
@@ -68,6 +49,8 @@ def detect_library_extension
|
|
68
49
|
end
|
69
50
|
end
|
70
51
|
|
52
|
+
require 'numo/narray'
|
53
|
+
|
71
54
|
def find_narray_h
|
72
55
|
$LOAD_PATH.each do |x|
|
73
56
|
if File.exist? File.join(x,'numo/numo/narray.h')
|
data/lib/numo/linalg/function.rb
CHANGED
@@ -65,7 +65,7 @@ module Numo; module Linalg
|
|
65
65
|
NArray.array_type(a)
|
66
66
|
end
|
67
67
|
if k && k < NArray
|
68
|
-
t = k::UPCAST[t]
|
68
|
+
t = k::UPCAST[t] || t::UPCAST[k]
|
69
69
|
end
|
70
70
|
end
|
71
71
|
BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
|
@@ -86,7 +86,8 @@ module Numo; module Linalg
|
|
86
86
|
when 1
|
87
87
|
case b.ndim
|
88
88
|
when 1
|
89
|
-
|
89
|
+
func = blas_char(a, b) =~ /c|z/ ? :dotu : :dot
|
90
|
+
Blas.call(func, a, b)
|
90
91
|
else
|
91
92
|
Blas.call(:gemv, b, a, trans:'t')
|
92
93
|
end
|
@@ -194,10 +195,10 @@ module Numo; module Linalg
|
|
194
195
|
#
|
195
196
|
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
196
197
|
# @param mode [String]
|
197
|
-
# - "reduce"
|
198
|
-
# - "r"
|
199
|
-
# - "
|
200
|
-
# - "raw"
|
198
|
+
# - "reduce" -- returns both Q and R,
|
199
|
+
# - "r" -- returns only R,
|
200
|
+
# - "economic" -- returns both Q and R but computed in economy-size,
|
201
|
+
# - "raw" -- returns QR and TAU used in LAPACK.
|
201
202
|
# @return [r] if mode:"r"
|
202
203
|
# @return [[q,r]] if mode:"reduce" or "economic"
|
203
204
|
# @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
|
@@ -295,6 +296,38 @@ module Numo; module Linalg
|
|
295
296
|
end
|
296
297
|
end
|
297
298
|
|
299
|
+
# Computes an LU factorization of a M-by-N matrix A
|
300
|
+
# using partial pivoting with row interchanges.
|
301
|
+
#
|
302
|
+
# The factorization has the form
|
303
|
+
#
|
304
|
+
# A = P * L * U
|
305
|
+
#
|
306
|
+
# where P is a permutation matrix, L is lower triangular with unit
|
307
|
+
# diagonal elements (lower trapezoidal if m > n), and U is upper
|
308
|
+
# triangular (upper trapezoidal if m < n).
|
309
|
+
#
|
310
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
311
|
+
# @param permute_l [Bool] (optional) If true, perform the matrix product of P and L.
|
312
|
+
# @return [[p,l,u]] if permute_l == false
|
313
|
+
# @return [[pl,u]] if permute_l == true
|
314
|
+
#
|
315
|
+
# - **p** [Numo::NArray] -- The permutation matrix P.
|
316
|
+
# - **l** [Numo::NArray] -- The factor L.
|
317
|
+
# - **u** [Numo::NArray] -- The factor U.
|
318
|
+
|
319
|
+
def lu(a, permute_l: false)
|
320
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
321
|
+
m, n = a.shape
|
322
|
+
k = [m, n].min
|
323
|
+
lu, ip = lu_fact(a)
|
324
|
+
l = lu.tril.tap { |mat| mat[mat.diag_indices(0)] = 1.0 }[true, 0...k]
|
325
|
+
u = lu.triu[0...k, 0...n]
|
326
|
+
p = Numo::DFloat.eye(m).tap do |mat|
|
327
|
+
ip.to_a.each_with_index { |i, j| mat[true, [i - 1, j]] = mat[true, [j, i - 1]].dup }
|
328
|
+
end
|
329
|
+
permute_l ? [p.dot(l), u] : [p, l, u]
|
330
|
+
end
|
298
331
|
|
299
332
|
# Computes an LU factorization of a M-by-N matrix A
|
300
333
|
# using partial pivoting with row interchanges.
|
@@ -308,26 +341,14 @@ module Numo; module Linalg
|
|
308
341
|
# triangular (upper trapezoidal if m < n).
|
309
342
|
#
|
310
343
|
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
311
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
312
|
-
# 'gen','sym','her'. (optional, default='gen')
|
313
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
314
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
315
344
|
# @return [[lu, ipiv]]
|
316
345
|
# - **lu** [Numo::NArray] -- The factors L and U from the factorization
|
317
346
|
# `A = P*L*U`; the unit diagonal elements of L are not stored.
|
318
347
|
# - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
|
319
348
|
# row i of the matrix was interchanged with row IPIV(i).
|
320
349
|
|
321
|
-
def lu_fact(a
|
322
|
-
|
323
|
-
when /^gen?(trf)?$/i
|
324
|
-
Lapack.call(:getrf, a)[0..1]
|
325
|
-
when /^(sym?|her?)(trf)?$/i
|
326
|
-
func = driver[0..2].downcase+"trf"
|
327
|
-
Lapack.call(func, a, uplo:uplo)[0..1]
|
328
|
-
else
|
329
|
-
raise ArgumentError, "invalid driver: #{driver}"
|
330
|
-
end
|
350
|
+
def lu_fact(a)
|
351
|
+
Lapack.call(:getrf, a)[0..1]
|
331
352
|
end
|
332
353
|
|
333
354
|
# Computes the inverse of a matrix using the LU factorization
|
@@ -345,22 +366,10 @@ module Numo; module Linalg
|
|
345
366
|
# @param ipiv [Numo::NArray] The pivot indices from
|
346
367
|
# Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
|
347
368
|
# interchanged with row IPIV(i).
|
348
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
349
|
-
# 'gen','sym','her'. (optional, default='gen')
|
350
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
351
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
352
369
|
# @return [Numo::NArray] the inverse of the original matrix A.
|
353
370
|
|
354
|
-
def lu_inv(lu, ipiv
|
355
|
-
|
356
|
-
when /^gen?(tri)?$/i
|
357
|
-
Lapack.call(:getri, lu, ipiv)[0]
|
358
|
-
when /^(sym?|her?)(tri)?$/i
|
359
|
-
func = driver[0..2].downcase+"tri"
|
360
|
-
Lapack.call(func, lu, ipiv, uplo:uplo)[0]
|
361
|
-
else
|
362
|
-
raise ArgumentError, "invalid driver: #{driver}"
|
363
|
-
end
|
371
|
+
def lu_inv(lu, ipiv)
|
372
|
+
Lapack.call(:getri, lu, ipiv)[0]
|
364
373
|
end
|
365
374
|
|
366
375
|
# Solves a system of linear equations
|
@@ -377,31 +386,100 @@ module Numo; module Linalg
|
|
377
386
|
# Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
|
378
387
|
# interchanged with row IPIV(i).
|
379
388
|
# @param b [Numo::NArray] the right hand side matrix B.
|
380
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
381
|
-
# 'gen','sym','her'. (optional, default='gen')
|
382
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
383
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
384
389
|
# @param trans [String or Symbol]
|
385
|
-
# Specifies the form of the system of equations
|
386
|
-
# (omitted if not driver:"gen"):
|
387
|
-
#
|
390
|
+
# Specifies the form of the system of equations:
|
388
391
|
# - If 'N': `A * X = B` (No transpose).
|
389
392
|
# - If 'T': `A*\*T* X = B` (Transpose).
|
390
393
|
# - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
|
391
394
|
# @return [Numo::NArray] the solution matrix X.
|
392
395
|
|
393
|
-
def lu_solve(lu, ipiv, b,
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
396
|
+
def lu_solve(lu, ipiv, b, trans:"N")
|
397
|
+
Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
|
398
|
+
end
|
399
|
+
|
400
|
+
# Computes the LDLt or Bunch-Kaufman factorization of a symmetric/Hermitian matrix A.
|
401
|
+
# The factorization has the form
|
402
|
+
#
|
403
|
+
# A = U*D*U**T or A = L*D*L**T
|
404
|
+
#
|
405
|
+
# where U (or L) is a product of permutation and unit upper (lower) triangular matrices
|
406
|
+
# and D is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
|
407
|
+
#
|
408
|
+
# @param a [Numo::NArray] m-by-m matrix A (>= 2-dimensinal NArray)
|
409
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper or ('U') lower ('L') triangle.
|
410
|
+
# @param hermitian [Bool] optional, default=true. If true, hermitian matrix is assumed.
|
411
|
+
# (omitted when real-value matrix is given)
|
412
|
+
#
|
413
|
+
# @return [[lu,d,perm]]
|
414
|
+
#
|
415
|
+
# - **lu** [Numo::NArray] -- The permutated upper (lower) triangular matrix U (L).
|
416
|
+
# - **d** [Numo::NArray] -- The block diagonal matrix D.
|
417
|
+
# - **perm** [Numo::NArray] -- The row-permutation index for changing lu into triangular form.
|
418
|
+
|
419
|
+
def ldl(a, uplo: 'U', hermitian: true)
|
420
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
421
|
+
raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
|
422
|
+
|
423
|
+
is_complex = blas_char(a) =~ /c|z/
|
424
|
+
func = is_complex && hermitian ? 'hetrf' : 'sytrf'
|
425
|
+
lud, ipiv, = Lapack.call(func.to_sym, a, uplo: uplo)
|
426
|
+
|
427
|
+
lu = (uplo == 'U' ? lud.triu : lud.tril).tap { |mat| mat[mat.diag_indices(0)] = 1.0 }
|
428
|
+
d = lud[lud.diag_indices(0)].diag
|
429
|
+
|
430
|
+
m = a.shape[0]
|
431
|
+
n = m - 1
|
432
|
+
changed_2x2 = false
|
433
|
+
perm = Numo::Int32.new(m).seq
|
434
|
+
m.times do |t|
|
435
|
+
i = uplo == 'U' ? t : n - t
|
436
|
+
j = uplo == 'U' ? i - 1 : i + 1;
|
437
|
+
r = uplo == 'U' ? 0..i : i..n;
|
438
|
+
if ipiv[i] > 0
|
439
|
+
k = ipiv[i] - 1
|
440
|
+
lu[[k, i], r] = lu[[i, k], r].dup
|
441
|
+
perm[[k, i]] = perm[[i, k]].dup
|
442
|
+
elsif j.between?(0, n) && ipiv[i] == ipiv[j] && !changed_2x2
|
443
|
+
k = ipiv[i].abs - 1
|
444
|
+
d[j, i] = lud[j, i]
|
445
|
+
d[i, j] = is_complex && hermitian ? lud[j, i].conj : lud[j, i]
|
446
|
+
lu[j, i] = 0.0
|
447
|
+
lu[[k, j], r] = lu[[j, k], r].dup
|
448
|
+
perm[[k, j]] = perm[[j, k]].dup
|
449
|
+
changed_2x2 = true
|
450
|
+
next
|
451
|
+
end
|
452
|
+
changed_2x2 = false
|
402
453
|
end
|
454
|
+
|
455
|
+
[lu, d, perm.sort_index]
|
403
456
|
end
|
404
457
|
|
458
|
+
# Computes the Cholesky factorization of a symmetric/Hermitian
|
459
|
+
# positive definite matrix A. The factorization has the form
|
460
|
+
#
|
461
|
+
# A = U**H * U, if UPLO = 'U', or
|
462
|
+
# A = L * L**H, if UPLO = 'L',
|
463
|
+
#
|
464
|
+
# where U is an upper triangular matrix and L is a lower triangular matrix.
|
465
|
+
# @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
|
466
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
467
|
+
# or ('U') lower ('L') triangle.
|
468
|
+
# @return [Numo::NArray] The factor U or L.
|
469
|
+
|
470
|
+
def cholesky(a, uplo: 'U')
|
471
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
472
|
+
raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
|
473
|
+
factor = Lapack.call(:potrf, a, uplo: uplo)[0]
|
474
|
+
if uplo == 'U'
|
475
|
+
factor.triu
|
476
|
+
else
|
477
|
+
# TODO: Use the tril method if the verision of Numo::NArray
|
478
|
+
# in the runtime dependency list becomes 0.9.1.3 or higher.
|
479
|
+
m, = a.shape
|
480
|
+
factor * Numo::DFloat.ones(m, m).triu.transpose
|
481
|
+
end
|
482
|
+
end
|
405
483
|
|
406
484
|
# Computes the Cholesky factorization of a symmetric/Hermitian
|
407
485
|
# positive definite matrix A. The factorization has the form
|
@@ -409,16 +487,16 @@ module Numo; module Linalg
|
|
409
487
|
# A = U**H * U, if UPLO = 'U', or
|
410
488
|
# A = L * L**H, if UPLO = 'L',
|
411
489
|
#
|
412
|
-
# where U is an upper triangular matrix and L is lower triangular
|
490
|
+
# where U is an upper triangular matrix and L is a lower triangular matrix.
|
413
491
|
# @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
|
414
492
|
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
415
493
|
# or ('U') lower ('L') triangle.
|
416
|
-
# @return [Numo::NArray] the factor
|
494
|
+
# @return [Numo::NArray] The matrix which has the Cholesky factor in upper or lower triangular part.
|
495
|
+
# Remain part consists of random values.
|
417
496
|
|
418
497
|
def cho_fact(a, uplo:'U')
|
419
498
|
Lapack.call(:potrf, a, uplo:uplo)[0]
|
420
499
|
end
|
421
|
-
#alias cholesky cho_fact
|
422
500
|
|
423
501
|
# Computes the inverse of a symmetric/Hermitian
|
424
502
|
# positive definite matrix A using the Cholesky factorization
|
@@ -479,27 +557,41 @@ module Numo; module Linalg
|
|
479
557
|
[w,vl,vr] #.compact
|
480
558
|
end
|
481
559
|
|
482
|
-
#
|
483
|
-
#
|
560
|
+
# Obtains the eigenvalues and, optionally, the eigenvectors
|
561
|
+
# by solving an ordinary or generalized eigenvalue problem
|
562
|
+
# for a square symmetric / Hermitian matrix.
|
484
563
|
#
|
485
|
-
# @param a [Numo::NArray] square
|
486
|
-
# @param
|
564
|
+
# @param a [Numo::NArray] square symmetric matrix (>= 2-dimensinal NArray)
|
565
|
+
# @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
|
566
|
+
# If nil, identity matrix is assumed.
|
567
|
+
# @param vals_only [Bool] (optional) If false, eigenvectors are computed.
|
568
|
+
# @param vals_range [Range] (optional)
|
569
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
570
|
+
# If nil or 0...N (N is the size of the matrix a), all eigenvalues and eigenvectors are returned.
|
487
571
|
# @param uplo [String or Symbol] (optional, default='U')
|
488
572
|
# Access upper ('U') or lower ('L') triangle.
|
573
|
+
# @param turbo [Bool] (optional) If true, divide and conquer algorithm is used.
|
489
574
|
# @return [[w,v]]
|
490
575
|
# - **w** [Numo::NArray] -- The eigenvalues.
|
491
576
|
# - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
|
492
577
|
|
493
|
-
def eigh(a, vals_only:false, uplo:
|
578
|
+
def eigh(a, b=nil, vals_only:false, vals_range:nil, uplo:'U', turbo:false)
|
494
579
|
jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
|
495
|
-
|
496
|
-
|
497
|
-
|
580
|
+
b = a.class.eye(a.shape[0]) if b.nil?
|
581
|
+
func = blas_char(a, b) =~ /c|z/ ? 'hegv' : 'sygv'
|
582
|
+
if vals_range.nil?
|
583
|
+
func << 'd' if turbo
|
584
|
+
v, u_, w, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz)
|
585
|
+
v = nil if vals_only
|
586
|
+
[w, v]
|
498
587
|
else
|
499
|
-
func
|
588
|
+
func << 'x'
|
589
|
+
il = vals_range.first(1)[0]
|
590
|
+
iu = vals_range.last(1)[0]
|
591
|
+
a_, b_, w, v, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz, range: 'I', il: il + 1, iu: iu + 1)
|
592
|
+
v = nil if vals_only
|
593
|
+
[w, v]
|
500
594
|
end
|
501
|
-
w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
|
502
|
-
[w,v] #.compact
|
503
595
|
end
|
504
596
|
|
505
597
|
# Computes the eigenvalues only for a square nonsymmetric matrix A.
|
@@ -519,23 +611,22 @@ module Numo; module Linalg
|
|
519
611
|
w
|
520
612
|
end
|
521
613
|
|
522
|
-
#
|
614
|
+
# Obtains the eigenvalues by solving an ordinary or generalized eigenvalue problem
|
615
|
+
# for a square symmetric / Hermitian matrix.
|
523
616
|
#
|
524
617
|
# @param a [Numo::NArray] square symmetric/hermitian matrix
|
525
618
|
# (>= 2-dimensinal NArray)
|
619
|
+
# @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
|
620
|
+
# If nil, identity matrix is assumed.
|
621
|
+
# @param vals_range [Range] (optional)
|
622
|
+
# The range of indices of the eigenvalues (in ascending order) to be returned.
|
623
|
+
# If nil or 0...N (N is the size of the matrix a), all eigenvalues are returned.
|
526
624
|
# @param uplo [String or Symbol] (optional, default='U')
|
527
625
|
# Access upper ('U') or lower ('L') triangle.
|
528
626
|
# @return [Numo::NArray] eigenvalues
|
529
627
|
|
530
|
-
def eigvalsh(a, uplo:
|
531
|
-
|
532
|
-
case blas_char(a)
|
533
|
-
when /c|z/
|
534
|
-
func = turbo ? :hegv : :heev
|
535
|
-
else
|
536
|
-
func = turbo ? :sygv : :syev
|
537
|
-
end
|
538
|
-
Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
|
628
|
+
def eigvalsh(a, b=nil, vals_range:nil, uplo:'U', turbo:false)
|
629
|
+
eigh(a, b, vals_only: true, vals_range: vals_range, uplo: uplo, turbo: turbo).first
|
539
630
|
end
|
540
631
|
|
541
632
|
|
@@ -801,7 +892,7 @@ module Numo; module Linalg
|
|
801
892
|
# returns lu, x, ipiv, info
|
802
893
|
Lapack.call(:gesv, a, b)[1]
|
803
894
|
when /^(sym?|her?|pos?)(sv)?$/i
|
804
|
-
func = driver[0..
|
895
|
+
func = driver[0..1].downcase+"sv"
|
805
896
|
Lapack.call(func, a, b, uplo:uplo)[1]
|
806
897
|
else
|
807
898
|
raise ArgumentError, "invalid driver: #{driver}"
|
@@ -834,8 +925,8 @@ module Numo; module Linalg
|
|
834
925
|
solve(a, b, driver:d, uplo:uplo)
|
835
926
|
when /(ge|sy|he)tr[fi]$/
|
836
927
|
d = $1
|
837
|
-
lu, piv = lu_fact(a
|
838
|
-
lu_inv(lu, piv
|
928
|
+
lu, piv = lu_fact(a)
|
929
|
+
lu_inv(lu, piv)
|
839
930
|
when /potr[fi]$/
|
840
931
|
lu = cho_fact(a, uplo:uplo)
|
841
932
|
cho_inv(lu, uplo:uplo)
|