numo-linalg 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +7 -3
- data/ext/numo/linalg/blas/extconf.rb +1 -2
- data/ext/numo/linalg/blas/numo_blas.h +6 -0
- data/ext/numo/linalg/blas/tmpl/mv.c +3 -2
- data/ext/numo/linalg/lapack/gen/spec.rb +5 -0
- data/ext/numo/linalg/lapack/lapack.c +29 -0
- data/ext/numo/linalg/lapack/numo_lapack.h +3 -0
- data/ext/numo/linalg/lapack/tmpl/gqr.c +1 -1
- data/ext/numo/linalg/lapack/tmpl/sygvx.c +130 -0
- data/ext/numo/linalg/mkmf_linalg.rb +2 -19
- data/lib/numo/linalg/function.rb +168 -77
- data/lib/numo/linalg/loader.rb +6 -14
- data/lib/numo/linalg/version.rb +1 -1
- data/numo-linalg.gemspec +2 -1
- data/spec/linalg/autoloader_spec.rb +27 -0
- data/spec/linalg/function/cho_fact_spec.rb +31 -0
- data/spec/linalg/function/cho_inv_spec.rb +39 -0
- data/spec/linalg/function/cho_solve_spec.rb +66 -0
- data/spec/linalg/function/cholesky_spec.rb +43 -0
- data/spec/linalg/function/cond_spec.rb +57 -0
- data/spec/linalg/function/det_spec.rb +21 -0
- data/spec/linalg/function/dot_spec.rb +84 -0
- data/spec/linalg/function/eig_spec.rb +53 -0
- data/spec/linalg/function/eigh_spec.rb +81 -0
- data/spec/linalg/function/eigvals_spec.rb +27 -0
- data/spec/linalg/function/eigvalsh_spec.rb +60 -0
- data/spec/linalg/function/inv_spec.rb +57 -0
- data/spec/linalg/function/ldl_spec.rb +51 -0
- data/spec/linalg/function/lstsq_spec.rb +80 -0
- data/spec/linalg/function/lu_fact_spec.rb +34 -0
- data/spec/linalg/function/lu_inv_spec.rb +21 -0
- data/spec/linalg/function/lu_solve_spec.rb +40 -0
- data/spec/linalg/function/lu_spec.rb +46 -0
- data/spec/linalg/function/matmul_spec.rb +41 -0
- data/spec/linalg/function/matrix_power_spec.rb +31 -0
- data/spec/linalg/function/matrix_rank_spec.rb +33 -0
- data/spec/linalg/function/norm_spec.rb +81 -0
- data/spec/linalg/function/pinv_spec.rb +48 -0
- data/spec/linalg/function/qr_spec.rb +82 -0
- data/spec/linalg/function/slogdet_spec.rb +21 -0
- data/spec/linalg/function/solve_spec.rb +98 -0
- data/spec/linalg/function/svd_spec.rb +88 -0
- data/spec/linalg/function/svdvals_spec.rb +40 -0
- data/spec/spec_helper.rb +55 -0
- metadata +79 -6
- data/spec/lapack_spec.rb +0 -13
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ef1ecaaf8a71faee9a7cef26bc35fae50230224856ab5dd6186cee93e5cca69d
|
4
|
+
data.tar.gz: f2e5e5944ad4832ef4ad4db05f6e94142fefb780e19282d94ece2a055cf447a0
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3fcf1506aa63ccbba57208e358102ae0dc537cd80b530e4b359b07d01a7cb1da11261dfd6d7028224175ae0f38014ec0373b99fc339caa2b952ad48c0fb26fd7
|
7
|
+
data.tar.gz: c7a9eb8b65719571120a6a122721e1eaadc240164e182a5095ff45018453a90f4cc48f99ef219305827beace8839bdebb0d27d35dfb5f64ec64f54d5406a8910
|
data/README.md
CHANGED
@@ -18,7 +18,7 @@ This desgin allows you to change backend libraries without re-compiling.
|
|
18
18
|
* Matrix and vector products
|
19
19
|
* dot, matmul
|
20
20
|
* Decomposition
|
21
|
-
* lu\_fact, lu\_inv, lu\_solve, cho\_fact, cho\_inv, cho\_solve
|
21
|
+
* lu, lu\_fact, lu\_inv, lu\_solve, ldl, cholesky, cho\_fact, cho\_inv, cho\_solve,
|
22
22
|
qr, svd, svdvals
|
23
23
|
* Matrix eigenvalues
|
24
24
|
* eig, eigh, eigvals, eigvalsh
|
@@ -78,8 +78,12 @@ require "numo/linalg"
|
|
78
78
|
|
79
79
|
## Authors
|
80
80
|
|
81
|
-
* Masahiro
|
82
|
-
* Makoto
|
81
|
+
* Masahiro Tanaka
|
82
|
+
* Makoto Kishimoto
|
83
|
+
* Atsushi Tatsuma
|
84
|
+
|
85
|
+
## Acknowledgement
|
86
|
+
|
83
87
|
* This work is partly supported by 2016 Ruby Association Grant.
|
84
88
|
|
85
89
|
## ToDo
|
@@ -1,5 +1,4 @@
|
|
1
1
|
require 'mkmf'
|
2
|
-
require 'numo/narray'
|
3
2
|
require_relative '../mkmf_linalg'
|
4
3
|
|
5
4
|
srcs = %w(
|
@@ -21,7 +20,7 @@ if !have_header("numo/narray.h")
|
|
21
20
|
exit(1)
|
22
21
|
end
|
23
22
|
|
24
|
-
if RUBY_PLATFORM =~ /cygwin|mingw/
|
23
|
+
if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
|
25
24
|
find_libnarray_a
|
26
25
|
unless have_library("narray","nary_new")
|
27
26
|
puts "libnarray.a not found"
|
@@ -31,6 +31,12 @@ extern void numo_cblas_check_func(void **func, const char *name);
|
|
31
31
|
#define SWAP_IFROW(order,a,b,tmp) \
|
32
32
|
{ if ((order)==CblasRowMajor) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
|
33
33
|
|
34
|
+
#define SWAP_IFNOTRANS(trans,a,b,tmp) \
|
35
|
+
{ if ((trans)==CblasNoTrans) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
|
36
|
+
|
37
|
+
#define SWAP_IFTRANS(trans,a,b,tmp) \
|
38
|
+
{ if ((trans)!=CblasNoTrans) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
|
39
|
+
|
34
40
|
#define SWAP_IFCOLTR(order,trans,a,b,tmp) \
|
35
41
|
{ if (((order)==CblasRowMajor && (trans)!=CblasNoTrans) || \
|
36
42
|
((order)!=CblasRowMajor && (trans)==CblasNoTrans)) \
|
@@ -79,7 +79,7 @@ static void
|
|
79
79
|
opt("beta"),
|
80
80
|
!is_ge && opt("side"),
|
81
81
|
!is_ge && opt("uplo"),
|
82
|
-
is_ge || is_tr && opt("trans"),
|
82
|
+
(is_ge || is_tr) && opt("trans"),
|
83
83
|
opt("order")
|
84
84
|
].select{|x| x}.join("\n ")
|
85
85
|
%>
|
@@ -143,9 +143,10 @@ static VALUE
|
|
143
143
|
CHECK_DIM_GE(na2,1);
|
144
144
|
nx = COL_SIZE(na2);
|
145
145
|
#if GE
|
146
|
-
|
146
|
+
SWAP_IFCOL(g.order, ma, na, tmp);
|
147
147
|
g.m = ma;
|
148
148
|
g.n = na;
|
149
|
+
SWAP_IFTRANS(g.trans, ma, na, tmp);
|
149
150
|
#else
|
150
151
|
CHECK_SQUARE("a",na1);
|
151
152
|
#endif
|
@@ -6,6 +6,9 @@ def_id "jobz"
|
|
6
6
|
def_id "jobvl"
|
7
7
|
def_id "jobvr"
|
8
8
|
def_id "trans"
|
9
|
+
def_id "range"
|
10
|
+
def_id "il"
|
11
|
+
def_id "iu"
|
9
12
|
def_id "rcond"
|
10
13
|
def_id "itype"
|
11
14
|
def_id "norm"
|
@@ -58,11 +61,13 @@ when /c|z/
|
|
58
61
|
decl "?heevd", "syev"
|
59
62
|
decl "?hegv", "sygv"
|
60
63
|
decl "?hegvd", "sygv"
|
64
|
+
decl "?hegvx", "sygvx"
|
61
65
|
else
|
62
66
|
decl "?syev"
|
63
67
|
decl "?syevd", "syev"
|
64
68
|
decl "?sygv"
|
65
69
|
decl "?sygvd", "sygv"
|
70
|
+
decl "?sygvx"
|
66
71
|
end
|
67
72
|
|
68
73
|
# factorize
|
@@ -124,6 +124,35 @@ numo_lapacke_option_job(VALUE job, char true_char, char false_char)
|
|
124
124
|
return 0;
|
125
125
|
}
|
126
126
|
|
127
|
+
char
|
128
|
+
numo_lapacke_option_range(VALUE job, char true_char, char false_char)
|
129
|
+
{
|
130
|
+
char *ptr, c;
|
131
|
+
|
132
|
+
switch(TYPE(job)) {
|
133
|
+
case T_NIL:
|
134
|
+
case T_UNDEF:
|
135
|
+
case T_TRUE:
|
136
|
+
return true_char;
|
137
|
+
case T_FALSE:
|
138
|
+
return false_char;
|
139
|
+
case T_SYMBOL:
|
140
|
+
job = rb_sym2str(job);
|
141
|
+
case T_STRING:
|
142
|
+
ptr = RSTRING_PTR(job);
|
143
|
+
if (RSTRING_LEN(job) > 0) {
|
144
|
+
c = ptr[0];
|
145
|
+
if (c >= 'a' && c <= 'z') {
|
146
|
+
c -= 'a'-'A';
|
147
|
+
}
|
148
|
+
return c;
|
149
|
+
}
|
150
|
+
break;
|
151
|
+
}
|
152
|
+
rb_raise(rb_eArgError,"invalid value for JOB option");
|
153
|
+
return 0;
|
154
|
+
}
|
155
|
+
|
127
156
|
char
|
128
157
|
numo_lapacke_option_trans(VALUE trans)
|
129
158
|
{
|
@@ -15,6 +15,9 @@ extern int numo_lapacke_option_order(VALUE order);
|
|
15
15
|
#define option_job numo_lapacke_option_job
|
16
16
|
extern char numo_lapacke_option_job(VALUE job, char true_char, char false_char);
|
17
17
|
|
18
|
+
#define option_range numo_lapacke_option_range
|
19
|
+
extern char numo_lapacke_option_range(VALUE range, char true_char, char false_char);
|
20
|
+
|
18
21
|
#define option_trans numo_lapacke_option_trans
|
19
22
|
extern char numo_lapacke_option_trans(VALUE trans);
|
20
23
|
|
@@ -30,7 +30,7 @@ static void
|
|
30
30
|
SWAP_IFCOL(g->order,m,n);
|
31
31
|
lda = NDL_STEP(lp,0) / sizeof(dtype);
|
32
32
|
|
33
|
-
printf("order=%d m=%d n=%d k=%d lda=%d \n",g->order,m,n,k,lda);
|
33
|
+
//printf("order=%d m=%d n=%d k=%d lda=%d \n",g->order,m,n,k,lda);
|
34
34
|
|
35
35
|
*info = (*func_p)(g->order, m, n, k, a, lda, tau);
|
36
36
|
CHECK_ERROR(*info);
|
@@ -0,0 +1,130 @@
|
|
1
|
+
#define args_t <%=func_name%>_args_t
|
2
|
+
#define func_p <%=func_name%>_p
|
3
|
+
|
4
|
+
typedef struct {
|
5
|
+
int order;
|
6
|
+
int itype;
|
7
|
+
char jobz;
|
8
|
+
char uplo;
|
9
|
+
char range;
|
10
|
+
int il;
|
11
|
+
int iu;
|
12
|
+
} args_t;
|
13
|
+
|
14
|
+
static <%=func_name%>_t func_p = 0;
|
15
|
+
|
16
|
+
static void
|
17
|
+
<%=c_iter%>(na_loop_t * const lp)
|
18
|
+
{
|
19
|
+
dtype *a, *b, *z;
|
20
|
+
rtype *w;
|
21
|
+
int *ifail;
|
22
|
+
int *info;
|
23
|
+
int m, n, lda, ldb, ldz;
|
24
|
+
rtype vl = 0, vu = 0;
|
25
|
+
rtype abstol = 0;
|
26
|
+
|
27
|
+
args_t *g;
|
28
|
+
|
29
|
+
a = (dtype*)NDL_PTR(lp, 0);
|
30
|
+
b = (dtype*)NDL_PTR(lp, 1);
|
31
|
+
w = (rtype*)NDL_PTR(lp, 2);
|
32
|
+
z = (dtype*)NDL_PTR(lp, 3);
|
33
|
+
ifail = (int*)NDL_PTR(lp, 4);
|
34
|
+
info = (int*)NDL_PTR(lp, 5);
|
35
|
+
g = (args_t*)(lp->opt_ptr);
|
36
|
+
|
37
|
+
n = NDL_SHAPE(lp, 0)[1];
|
38
|
+
lda = NDL_STEP(lp, 0) / sizeof(dtype);
|
39
|
+
ldb = NDL_STEP(lp, 1) / sizeof(dtype);
|
40
|
+
ldz = NDL_SHAPE(lp, 3)[1];
|
41
|
+
|
42
|
+
*info = (*func_p)( g->order, g->itype, g->jobz, g->range, g->uplo, n, a, lda, b, ldb,
|
43
|
+
vl, vu, g->il, g->iu, abstol, &m, w, z, ldz, ifail );
|
44
|
+
|
45
|
+
CHECK_ERROR(*info);
|
46
|
+
}
|
47
|
+
|
48
|
+
/*
|
49
|
+
<%
|
50
|
+
params = [
|
51
|
+
mat("a",:inplace),
|
52
|
+
mat("b",:inplace),
|
53
|
+
"@param [Integer] itype Specifies the problem type to be solved. If 1: A*x = (lambda)*B*x, If 2: A*B*x = (lambda)*x, If 3: B*A*x = (lambda)*x.",
|
54
|
+
jobe("jobz"),
|
55
|
+
opt("uplo"),
|
56
|
+
opt("order"),
|
57
|
+
"@param [String or Symbol] range If 'A': Compute all eigenvalues, if 'I': Compute eigenvalues with indices il to iu (default='A')",
|
58
|
+
"@param [Integer] il Specifies the index of the smallest eigenvalue in ascending order to be returned. If range = 'A', il is not referenced.",
|
59
|
+
"@param [Integer] iu Specifies the index of the largest eigenvalue in ascending order to be returned. Constraint: 1<=il<=iu<=N. If range = 'A', iu is not referenced.",
|
60
|
+
].select{|x| x}.join("\n ")
|
61
|
+
return_name="a, b, w, z, ifail, info"
|
62
|
+
%>
|
63
|
+
@overload <%=name%>(a, b, [itype:1, jobz:'V', uplo:'U', order:'R', range:'I', il: 1, il: 2])
|
64
|
+
<%=params%>
|
65
|
+
@return [[<%=return_name%>]]
|
66
|
+
Array<<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,<%=real_class_name%>,Integer>
|
67
|
+
<%=outparam(return_name)%>
|
68
|
+
|
69
|
+
<%=description%>
|
70
|
+
*/
|
71
|
+
|
72
|
+
static VALUE
|
73
|
+
<%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
|
74
|
+
{
|
75
|
+
VALUE a, b, ans;
|
76
|
+
int n, nb, m;
|
77
|
+
narray_t *na1, *na2;
|
78
|
+
size_t w_shape[1];
|
79
|
+
size_t z_shape[2];
|
80
|
+
size_t ifail_shape[1];
|
81
|
+
|
82
|
+
ndfunc_arg_in_t ain[2] = {{OVERWRITE, 2}, {OVERWRITE, 2}};
|
83
|
+
ndfunc_arg_out_t aout[4] = {{cRT, 1, w_shape}, {cT, 2, z_shape}, {cI, 1, ifail_shape}, {cInt, 0}};
|
84
|
+
ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP | NDF_EXTRACT, 2, 4, ain, aout};
|
85
|
+
|
86
|
+
args_t g;
|
87
|
+
VALUE opts[7] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
88
|
+
VALUE kw_hash = Qnil;
|
89
|
+
ID kw_table[7] = {id_order, id_jobz, id_uplo, id_itype, id_range, id_il, id_iu};
|
90
|
+
|
91
|
+
CHECK_FUNC(func_p,"<%=func_name%>");
|
92
|
+
|
93
|
+
rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
|
94
|
+
rb_get_kwargs(kw_hash, kw_table, 0, 7, opts);
|
95
|
+
g.order = option_order(opts[0]);
|
96
|
+
g.jobz = option_job(opts[1], 'V', 'N');
|
97
|
+
g.uplo = option_uplo(opts[2]);
|
98
|
+
g.itype = NUM2INT(option_value(opts[3], INT2FIX(1)));
|
99
|
+
g.range = option_range(opts[4], 'A', 'I');
|
100
|
+
g.il = NUM2INT(option_value(opts[5], INT2FIX(1)));
|
101
|
+
g.iu = NUM2INT(option_value(opts[6], INT2FIX(1)));
|
102
|
+
|
103
|
+
COPY_OR_CAST_TO(a, cT);
|
104
|
+
GetNArray(a, na1);
|
105
|
+
CHECK_DIM_GE(na1, 2);
|
106
|
+
|
107
|
+
COPY_OR_CAST_TO(b, cT);
|
108
|
+
GetNArray(b, na2);
|
109
|
+
CHECK_DIM_GE(na2, 2);
|
110
|
+
CHECK_SQUARE("matrix a", na1);
|
111
|
+
n = COL_SIZE(na1);
|
112
|
+
CHECK_SQUARE("matrix b", na2);
|
113
|
+
nb = COL_SIZE(na2);
|
114
|
+
if (n != nb) {
|
115
|
+
rb_raise(nary_eShapeError, "matrix a and b must have same size");
|
116
|
+
}
|
117
|
+
|
118
|
+
m = g.range == 'I' ? g.iu - g.il + 1 : n;
|
119
|
+
w_shape[0] = m;
|
120
|
+
z_shape[0] = n;
|
121
|
+
z_shape[1] = m;
|
122
|
+
ifail_shape[0] = m;
|
123
|
+
|
124
|
+
ans = na_ndloop3(&ndf, &g, 2, a, b);
|
125
|
+
|
126
|
+
return rb_ary_new3(6, a, b, RARRAY_AREF(ans, 0), RARRAY_AREF(ans, 1), RARRAY_AREF(ans, 2), RARRAY_AREF(ans, 3));
|
127
|
+
}
|
128
|
+
|
129
|
+
#undef args_t
|
130
|
+
#undef func_p
|
@@ -17,24 +17,6 @@ def create_site_conf
|
|
17
17
|
FileUtils.mkdir_p "lib"
|
18
18
|
|
19
19
|
ext = detect_library_extension
|
20
|
-
need_version = false
|
21
|
-
if ext == 'so'
|
22
|
-
begin
|
23
|
-
Fiddle.dlopen "libm.so"
|
24
|
-
rescue
|
25
|
-
(5..7).each do |i|
|
26
|
-
begin
|
27
|
-
Fiddle.dlopen "libm.so.#{i}"
|
28
|
-
need_version = true
|
29
|
-
break
|
30
|
-
rescue
|
31
|
-
end
|
32
|
-
end
|
33
|
-
if !need_version
|
34
|
-
raise "failed to check whether dynamically linked shared object needs version suffix"
|
35
|
-
end
|
36
|
-
end
|
37
|
-
end
|
38
20
|
|
39
21
|
open("lib/site_conf.rb","w"){|f| f.write "
|
40
22
|
module Numo
|
@@ -49,7 +31,6 @@ module Numo
|
|
49
31
|
|
50
32
|
module Loader
|
51
33
|
EXT = '#{ext}'
|
52
|
-
NEED_VERSION_SUFFIX = #{need_version}
|
53
34
|
end
|
54
35
|
|
55
36
|
end
|
@@ -68,6 +49,8 @@ def detect_library_extension
|
|
68
49
|
end
|
69
50
|
end
|
70
51
|
|
52
|
+
require 'numo/narray'
|
53
|
+
|
71
54
|
def find_narray_h
|
72
55
|
$LOAD_PATH.each do |x|
|
73
56
|
if File.exist? File.join(x,'numo/numo/narray.h')
|
data/lib/numo/linalg/function.rb
CHANGED
@@ -65,7 +65,7 @@ module Numo; module Linalg
|
|
65
65
|
NArray.array_type(a)
|
66
66
|
end
|
67
67
|
if k && k < NArray
|
68
|
-
t = k::UPCAST[t]
|
68
|
+
t = k::UPCAST[t] || t::UPCAST[k]
|
69
69
|
end
|
70
70
|
end
|
71
71
|
BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
|
@@ -86,7 +86,8 @@ module Numo; module Linalg
|
|
86
86
|
when 1
|
87
87
|
case b.ndim
|
88
88
|
when 1
|
89
|
-
|
89
|
+
func = blas_char(a, b) =~ /c|z/ ? :dotu : :dot
|
90
|
+
Blas.call(func, a, b)
|
90
91
|
else
|
91
92
|
Blas.call(:gemv, b, a, trans:'t')
|
92
93
|
end
|
@@ -194,10 +195,10 @@ module Numo; module Linalg
|
|
194
195
|
#
|
195
196
|
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
196
197
|
# @param mode [String]
|
197
|
-
# - "reduce"
|
198
|
-
# - "r"
|
199
|
-
# - "
|
200
|
-
# - "raw"
|
198
|
+
# - "reduce" -- returns both Q and R,
|
199
|
+
# - "r" -- returns only R,
|
200
|
+
# - "economic" -- returns both Q and R but computed in economy-size,
|
201
|
+
# - "raw" -- returns QR and TAU used in LAPACK.
|
201
202
|
# @return [r] if mode:"r"
|
202
203
|
# @return [[q,r]] if mode:"reduce" or "economic"
|
203
204
|
# @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
|
@@ -295,6 +296,38 @@ module Numo; module Linalg
|
|
295
296
|
end
|
296
297
|
end
|
297
298
|
|
299
|
+
# Computes an LU factorization of a M-by-N matrix A
|
300
|
+
# using partial pivoting with row interchanges.
|
301
|
+
#
|
302
|
+
# The factorization has the form
|
303
|
+
#
|
304
|
+
# A = P * L * U
|
305
|
+
#
|
306
|
+
# where P is a permutation matrix, L is lower triangular with unit
|
307
|
+
# diagonal elements (lower trapezoidal if m > n), and U is upper
|
308
|
+
# triangular (upper trapezoidal if m < n).
|
309
|
+
#
|
310
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
311
|
+
# @param permute_l [Bool] (optional) If true, perform the matrix product of P and L.
|
312
|
+
# @return [[p,l,u]] if permute_l == false
|
313
|
+
# @return [[pl,u]] if permute_l == true
|
314
|
+
#
|
315
|
+
# - **p** [Numo::NArray] -- The permutation matrix P.
|
316
|
+
# - **l** [Numo::NArray] -- The factor L.
|
317
|
+
# - **u** [Numo::NArray] -- The factor U.
|
318
|
+
|
319
|
+
def lu(a, permute_l: false)
|
320
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
321
|
+
m, n = a.shape
|
322
|
+
k = [m, n].min
|
323
|
+
lu, ip = lu_fact(a)
|
324
|
+
l = lu.tril.tap { |mat| mat[mat.diag_indices(0)] = 1.0 }[true, 0...k]
|
325
|
+
u = lu.triu[0...k, 0...n]
|
326
|
+
p = Numo::DFloat.eye(m).tap do |mat|
|
327
|
+
ip.to_a.each_with_index { |i, j| mat[true, [i - 1, j]] = mat[true, [j, i - 1]].dup }
|
328
|
+
end
|
329
|
+
permute_l ? [p.dot(l), u] : [p, l, u]
|
330
|
+
end
|
298
331
|
|
299
332
|
# Computes an LU factorization of a M-by-N matrix A
|
300
333
|
# using partial pivoting with row interchanges.
|
@@ -308,26 +341,14 @@ module Numo; module Linalg
|
|
308
341
|
# triangular (upper trapezoidal if m < n).
|
309
342
|
#
|
310
343
|
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
311
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
312
|
-
# 'gen','sym','her'. (optional, default='gen')
|
313
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
314
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
315
344
|
# @return [[lu, ipiv]]
|
316
345
|
# - **lu** [Numo::NArray] -- The factors L and U from the factorization
|
317
346
|
# `A = P*L*U`; the unit diagonal elements of L are not stored.
|
318
347
|
# - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
|
319
348
|
# row i of the matrix was interchanged with row IPIV(i).
|
320
349
|
|
321
|
-
def lu_fact(a
|
322
|
-
|
323
|
-
when /^gen?(trf)?$/i
|
324
|
-
Lapack.call(:getrf, a)[0..1]
|
325
|
-
when /^(sym?|her?)(trf)?$/i
|
326
|
-
func = driver[0..2].downcase+"trf"
|
327
|
-
Lapack.call(func, a, uplo:uplo)[0..1]
|
328
|
-
else
|
329
|
-
raise ArgumentError, "invalid driver: #{driver}"
|
330
|
-
end
|
350
|
+
def lu_fact(a)
|
351
|
+
Lapack.call(:getrf, a)[0..1]
|
331
352
|
end
|
332
353
|
|
333
354
|
# Computes the inverse of a matrix using the LU factorization
|
@@ -345,22 +366,10 @@ module Numo; module Linalg
|
|
345
366
|
# @param ipiv [Numo::NArray] The pivot indices from
|
346
367
|
# Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
|
347
368
|
# interchanged with row IPIV(i).
|
348
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
349
|
-
# 'gen','sym','her'. (optional, default='gen')
|
350
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
351
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
352
369
|
# @return [Numo::NArray] the inverse of the original matrix A.
|
353
370
|
|
354
|
-
def lu_inv(lu, ipiv
|
355
|
-
|
356
|
-
when /^gen?(tri)?$/i
|
357
|
-
Lapack.call(:getri, lu, ipiv)[0]
|
358
|
-
when /^(sym?|her?)(tri)?$/i
|
359
|
-
func = driver[0..2].downcase+"tri"
|
360
|
-
Lapack.call(func, lu, ipiv, uplo:uplo)[0]
|
361
|
-
else
|
362
|
-
raise ArgumentError, "invalid driver: #{driver}"
|
363
|
-
end
|
371
|
+
def lu_inv(lu, ipiv)
|
372
|
+
Lapack.call(:getri, lu, ipiv)[0]
|
364
373
|
end
|
365
374
|
|
366
375
|
# Solves a system of linear equations
|
@@ -377,31 +386,100 @@ module Numo; module Linalg
|
|
377
386
|
# Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
|
378
387
|
# interchanged with row IPIV(i).
|
379
388
|
# @param b [Numo::NArray] the right hand side matrix B.
|
380
|
-
# @param driver [String or Symbol] choose LAPACK diriver from
|
381
|
-
# 'gen','sym','her'. (optional, default='gen')
|
382
|
-
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
383
|
-
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
384
389
|
# @param trans [String or Symbol]
|
385
|
-
# Specifies the form of the system of equations
|
386
|
-
# (omitted if not driver:"gen"):
|
387
|
-
#
|
390
|
+
# Specifies the form of the system of equations:
|
388
391
|
# - If 'N': `A * X = B` (No transpose).
|
389
392
|
# - If 'T': `A*\*T* X = B` (Transpose).
|
390
393
|
# - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
|
391
394
|
# @return [Numo::NArray] the solution matrix X.
|
392
395
|
|
393
|
-
def lu_solve(lu, ipiv, b,
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
396
|
+
def lu_solve(lu, ipiv, b, trans:"N")
|
397
|
+
Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
|
398
|
+
end
|
399
|
+
|
400
|
+
# Computes the LDLt or Bunch-Kaufman factorization of a symmetric/Hermitian matrix A.
|
401
|
+
# The factorization has the form
|
402
|
+
#
|
403
|
+
# A = U*D*U**T or A = L*D*L**T
|
404
|
+
#
|
405
|
+
# where U (or L) is a product of permutation and unit upper (lower) triangular matrices
|
406
|
+
# and D is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
|
407
|
+
#
|
408
|
+
# @param a [Numo::NArray] m-by-m matrix A (>= 2-dimensinal NArray)
|
409
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper or ('U') lower ('L') triangle.
|
410
|
+
# @param hermitian [Bool] optional, default=true. If true, hermitian matrix is assumed.
|
411
|
+
# (omitted when real-value matrix is given)
|
412
|
+
#
|
413
|
+
# @return [[lu,d,perm]]
|
414
|
+
#
|
415
|
+
# - **lu** [Numo::NArray] -- The permutated upper (lower) triangular matrix U (L).
|
416
|
+
# - **d** [Numo::NArray] -- The block diagonal matrix D.
|
417
|
+
# - **perm** [Numo::NArray] -- The row-permutation index for changing lu into triangular form.
|
418
|
+
|
419
|
+
def ldl(a, uplo: 'U', hermitian: true)
|
420
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
421
|
+
raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
|
422
|
+
|
423
|
+
is_complex = blas_char(a) =~ /c|z/
|
424
|
+
func = is_complex && hermitian ? 'hetrf' : 'sytrf'
|
425
|
+
lud, ipiv, = Lapack.call(func.to_sym, a, uplo: uplo)
|
426
|
+
|
427
|
+
lu = (uplo == 'U' ? lud.triu : lud.tril).tap { |mat| mat[mat.diag_indices(0)] = 1.0 }
|
428
|
+
d = lud[lud.diag_indices(0)].diag
|
429
|
+
|
430
|
+
m = a.shape[0]
|
431
|
+
n = m - 1
|
432
|
+
changed_2x2 = false
|
433
|
+
perm = Numo::Int32.new(m).seq
|
434
|
+
m.times do |t|
|
435
|
+
i = uplo == 'U' ? t : n - t
|
436
|
+
j = uplo == 'U' ? i - 1 : i + 1;
|
437
|
+
r = uplo == 'U' ? 0..i : i..n;
|
438
|
+
if ipiv[i] > 0
|
439
|
+
k = ipiv[i] - 1
|
440
|
+
lu[[k, i], r] = lu[[i, k], r].dup
|
441
|
+
perm[[k, i]] = perm[[i, k]].dup
|
442
|
+
elsif j.between?(0, n) && ipiv[i] == ipiv[j] && !changed_2x2
|
443
|
+
k = ipiv[i].abs - 1
|
444
|
+
d[j, i] = lud[j, i]
|
445
|
+
d[i, j] = is_complex && hermitian ? lud[j, i].conj : lud[j, i]
|
446
|
+
lu[j, i] = 0.0
|
447
|
+
lu[[k, j], r] = lu[[j, k], r].dup
|
448
|
+
perm[[k, j]] = perm[[j, k]].dup
|
449
|
+
changed_2x2 = true
|
450
|
+
next
|
451
|
+
end
|
452
|
+
changed_2x2 = false
|
402
453
|
end
|
454
|
+
|
455
|
+
[lu, d, perm.sort_index]
|
403
456
|
end
|
404
457
|
|
458
|
+
# Computes the Cholesky factorization of a symmetric/Hermitian
|
459
|
+
# positive definite matrix A. The factorization has the form
|
460
|
+
#
|
461
|
+
# A = U**H * U, if UPLO = 'U', or
|
462
|
+
# A = L * L**H, if UPLO = 'L',
|
463
|
+
#
|
464
|
+
# where U is an upper triangular matrix and L is a lower triangular matrix.
|
465
|
+
# @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
|
466
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
467
|
+
# or ('U') lower ('L') triangle.
|
468
|
+
# @return [Numo::NArray] The factor U or L.
|
469
|
+
|
470
|
+
def cholesky(a, uplo: 'U')
|
471
|
+
raise NArray::ShapeError, '2-d array is required' if a.ndim < 2
|
472
|
+
raise NArray::ShapeError, 'matrix a is not square matrix' if a.shape[0] != a.shape[1]
|
473
|
+
factor = Lapack.call(:potrf, a, uplo: uplo)[0]
|
474
|
+
if uplo == 'U'
|
475
|
+
factor.triu
|
476
|
+
else
|
477
|
+
# TODO: Use the tril method if the verision of Numo::NArray
|
478
|
+
# in the runtime dependency list becomes 0.9.1.3 or higher.
|
479
|
+
m, = a.shape
|
480
|
+
factor * Numo::DFloat.ones(m, m).triu.transpose
|
481
|
+
end
|
482
|
+
end
|
405
483
|
|
406
484
|
# Computes the Cholesky factorization of a symmetric/Hermitian
|
407
485
|
# positive definite matrix A. The factorization has the form
|
@@ -409,16 +487,16 @@ module Numo; module Linalg
|
|
409
487
|
# A = U**H * U, if UPLO = 'U', or
|
410
488
|
# A = L * L**H, if UPLO = 'L',
|
411
489
|
#
|
412
|
-
# where U is an upper triangular matrix and L is lower triangular
|
490
|
+
# where U is an upper triangular matrix and L is a lower triangular matrix.
|
413
491
|
# @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
|
414
492
|
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
415
493
|
# or ('U') lower ('L') triangle.
|
416
|
-
# @return [Numo::NArray] the factor
|
494
|
+
# @return [Numo::NArray] The matrix which has the Cholesky factor in upper or lower triangular part.
|
495
|
+
# Remain part consists of random values.
|
417
496
|
|
418
497
|
def cho_fact(a, uplo:'U')
|
419
498
|
Lapack.call(:potrf, a, uplo:uplo)[0]
|
420
499
|
end
|
421
|
-
#alias cholesky cho_fact
|
422
500
|
|
423
501
|
# Computes the inverse of a symmetric/Hermitian
|
424
502
|
# positive definite matrix A using the Cholesky factorization
|
@@ -479,27 +557,41 @@ module Numo; module Linalg
|
|
479
557
|
[w,vl,vr] #.compact
|
480
558
|
end
|
481
559
|
|
482
|
-
#
|
483
|
-
#
|
560
|
+
# Obtains the eigenvalues and, optionally, the eigenvectors
|
561
|
+
# by solving an ordinary or generalized eigenvalue problem
|
562
|
+
# for a square symmetric / Hermitian matrix.
|
484
563
|
#
|
485
|
-
# @param a [Numo::NArray] square
|
486
|
-
# @param
|
564
|
+
# @param a [Numo::NArray] square symmetric matrix (>= 2-dimensinal NArray)
|
565
|
+
# @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
|
566
|
+
# If nil, identity matrix is assumed.
|
567
|
+
# @param vals_only [Bool] (optional) If false, eigenvectors are computed.
|
568
|
+
# @param vals_range [Range] (optional)
|
569
|
+
# The range of indices of the eigenvalues (in ascending order) and corresponding eigenvectors to be returned.
|
570
|
+
# If nil or 0...N (N is the size of the matrix a), all eigenvalues and eigenvectors are returned.
|
487
571
|
# @param uplo [String or Symbol] (optional, default='U')
|
488
572
|
# Access upper ('U') or lower ('L') triangle.
|
573
|
+
# @param turbo [Bool] (optional) If true, divide and conquer algorithm is used.
|
489
574
|
# @return [[w,v]]
|
490
575
|
# - **w** [Numo::NArray] -- The eigenvalues.
|
491
576
|
# - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
|
492
577
|
|
493
|
-
def eigh(a, vals_only:false, uplo:
|
578
|
+
def eigh(a, b=nil, vals_only:false, vals_range:nil, uplo:'U', turbo:false)
|
494
579
|
jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
|
495
|
-
|
496
|
-
|
497
|
-
|
580
|
+
b = a.class.eye(a.shape[0]) if b.nil?
|
581
|
+
func = blas_char(a, b) =~ /c|z/ ? 'hegv' : 'sygv'
|
582
|
+
if vals_range.nil?
|
583
|
+
func << 'd' if turbo
|
584
|
+
v, u_, w, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz)
|
585
|
+
v = nil if vals_only
|
586
|
+
[w, v]
|
498
587
|
else
|
499
|
-
func
|
588
|
+
func << 'x'
|
589
|
+
il = vals_range.first(1)[0]
|
590
|
+
iu = vals_range.last(1)[0]
|
591
|
+
a_, b_, w, v, = Lapack.call(func.to_sym, a, b, uplo: uplo, jobz: jobz, range: 'I', il: il + 1, iu: iu + 1)
|
592
|
+
v = nil if vals_only
|
593
|
+
[w, v]
|
500
594
|
end
|
501
|
-
w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
|
502
|
-
[w,v] #.compact
|
503
595
|
end
|
504
596
|
|
505
597
|
# Computes the eigenvalues only for a square nonsymmetric matrix A.
|
@@ -519,23 +611,22 @@ module Numo; module Linalg
|
|
519
611
|
w
|
520
612
|
end
|
521
613
|
|
522
|
-
#
|
614
|
+
# Obtains the eigenvalues by solving an ordinary or generalized eigenvalue problem
|
615
|
+
# for a square symmetric / Hermitian matrix.
|
523
616
|
#
|
524
617
|
# @param a [Numo::NArray] square symmetric/hermitian matrix
|
525
618
|
# (>= 2-dimensinal NArray)
|
619
|
+
# @param b [Numo::NArray] (optional) square symmetric matrix (>= 2-dimensinal NArray)
|
620
|
+
# If nil, identity matrix is assumed.
|
621
|
+
# @param vals_range [Range] (optional)
|
622
|
+
# The range of indices of the eigenvalues (in ascending order) to be returned.
|
623
|
+
# If nil or 0...N (N is the size of the matrix a), all eigenvalues are returned.
|
526
624
|
# @param uplo [String or Symbol] (optional, default='U')
|
527
625
|
# Access upper ('U') or lower ('L') triangle.
|
528
626
|
# @return [Numo::NArray] eigenvalues
|
529
627
|
|
530
|
-
def eigvalsh(a, uplo:
|
531
|
-
|
532
|
-
case blas_char(a)
|
533
|
-
when /c|z/
|
534
|
-
func = turbo ? :hegv : :heev
|
535
|
-
else
|
536
|
-
func = turbo ? :sygv : :syev
|
537
|
-
end
|
538
|
-
Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
|
628
|
+
def eigvalsh(a, b=nil, vals_range:nil, uplo:'U', turbo:false)
|
629
|
+
eigh(a, b, vals_only: true, vals_range: vals_range, uplo: uplo, turbo: turbo).first
|
539
630
|
end
|
540
631
|
|
541
632
|
|
@@ -801,7 +892,7 @@ module Numo; module Linalg
|
|
801
892
|
# returns lu, x, ipiv, info
|
802
893
|
Lapack.call(:gesv, a, b)[1]
|
803
894
|
when /^(sym?|her?|pos?)(sv)?$/i
|
804
|
-
func = driver[0..
|
895
|
+
func = driver[0..1].downcase+"sv"
|
805
896
|
Lapack.call(func, a, b, uplo:uplo)[1]
|
806
897
|
else
|
807
898
|
raise ArgumentError, "invalid driver: #{driver}"
|
@@ -834,8 +925,8 @@ module Numo; module Linalg
|
|
834
925
|
solve(a, b, driver:d, uplo:uplo)
|
835
926
|
when /(ge|sy|he)tr[fi]$/
|
836
927
|
d = $1
|
837
|
-
lu, piv = lu_fact(a
|
838
|
-
lu_inv(lu, piv
|
928
|
+
lu, piv = lu_fact(a)
|
929
|
+
lu_inv(lu, piv)
|
839
930
|
when /potr[fi]$/
|
840
931
|
lu = cho_fact(a, uplo:uplo)
|
841
932
|
cho_inv(lu, uplo:uplo)
|