numo-linalg 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Gemfile +4 -0
- data/README.md +80 -0
- data/Rakefile +18 -0
- data/ext/numo/linalg/blas/blas.c +352 -0
- data/ext/numo/linalg/blas/cblas.h +575 -0
- data/ext/numo/linalg/blas/cblas_t.h +563 -0
- data/ext/numo/linalg/blas/depend.erb +23 -0
- data/ext/numo/linalg/blas/extconf.rb +67 -0
- data/ext/numo/linalg/blas/gen/cogen.rb +72 -0
- data/ext/numo/linalg/blas/gen/decl.rb +203 -0
- data/ext/numo/linalg/blas/gen/desc.rb +8138 -0
- data/ext/numo/linalg/blas/gen/erbpp2.rb +339 -0
- data/ext/numo/linalg/blas/gen/replace_cblas_h.rb +27 -0
- data/ext/numo/linalg/blas/gen/spec.rb +93 -0
- data/ext/numo/linalg/blas/numo_blas.h +41 -0
- data/ext/numo/linalg/blas/tmpl/axpy.c +75 -0
- data/ext/numo/linalg/blas/tmpl/copy.c +57 -0
- data/ext/numo/linalg/blas/tmpl/def_c.c +3 -0
- data/ext/numo/linalg/blas/tmpl/def_d.c +3 -0
- data/ext/numo/linalg/blas/tmpl/def_s.c +3 -0
- data/ext/numo/linalg/blas/tmpl/def_z.c +3 -0
- data/ext/numo/linalg/blas/tmpl/dot.c +68 -0
- data/ext/numo/linalg/blas/tmpl/ger.c +114 -0
- data/ext/numo/linalg/blas/tmpl/init_class.c +20 -0
- data/ext/numo/linalg/blas/tmpl/init_module.c +12 -0
- data/ext/numo/linalg/blas/tmpl/lib.c +40 -0
- data/ext/numo/linalg/blas/tmpl/mm.c +214 -0
- data/ext/numo/linalg/blas/tmpl/module.c +9 -0
- data/ext/numo/linalg/blas/tmpl/mv.c +194 -0
- data/ext/numo/linalg/blas/tmpl/nrm2.c +79 -0
- data/ext/numo/linalg/blas/tmpl/rot.c +65 -0
- data/ext/numo/linalg/blas/tmpl/rotm.c +82 -0
- data/ext/numo/linalg/blas/tmpl/scal.c +69 -0
- data/ext/numo/linalg/blas/tmpl/sdsdot.c +77 -0
- data/ext/numo/linalg/blas/tmpl/set_prefix.c +16 -0
- data/ext/numo/linalg/blas/tmpl/swap.c +57 -0
- data/ext/numo/linalg/blas/tmpl/syr.c +102 -0
- data/ext/numo/linalg/blas/tmpl/syr2.c +110 -0
- data/ext/numo/linalg/blas/tmpl/syr2k.c +129 -0
- data/ext/numo/linalg/blas/tmpl/syrk.c +132 -0
- data/ext/numo/linalg/lapack/depend.erb +23 -0
- data/ext/numo/linalg/lapack/extconf.rb +45 -0
- data/ext/numo/linalg/lapack/gen/cogen.rb +74 -0
- data/ext/numo/linalg/lapack/gen/desc.rb +151278 -0
- data/ext/numo/linalg/lapack/gen/replace_lapacke_h.rb +32 -0
- data/ext/numo/linalg/lapack/gen/spec.rb +104 -0
- data/ext/numo/linalg/lapack/lapack.c +387 -0
- data/ext/numo/linalg/lapack/lapacke.h +16425 -0
- data/ext/numo/linalg/lapack/lapacke_config.h +119 -0
- data/ext/numo/linalg/lapack/lapacke_mangling.h +17 -0
- data/ext/numo/linalg/lapack/lapacke_t.h +10550 -0
- data/ext/numo/linalg/lapack/numo_lapack.h +42 -0
- data/ext/numo/linalg/lapack/tmpl/def_c.c +3 -0
- data/ext/numo/linalg/lapack/tmpl/def_d.c +7 -0
- data/ext/numo/linalg/lapack/tmpl/def_s.c +7 -0
- data/ext/numo/linalg/lapack/tmpl/def_z.c +3 -0
- data/ext/numo/linalg/lapack/tmpl/fact.c +179 -0
- data/ext/numo/linalg/lapack/tmpl/geev.c +123 -0
- data/ext/numo/linalg/lapack/tmpl/gels.c +232 -0
- data/ext/numo/linalg/lapack/tmpl/gesv.c +149 -0
- data/ext/numo/linalg/lapack/tmpl/gesvd.c +189 -0
- data/ext/numo/linalg/lapack/tmpl/ggev.c +138 -0
- data/ext/numo/linalg/lapack/tmpl/gqr.c +121 -0
- data/ext/numo/linalg/lapack/tmpl/init_class.c +20 -0
- data/ext/numo/linalg/lapack/tmpl/init_module.c +12 -0
- data/ext/numo/linalg/lapack/tmpl/lange.c +79 -0
- data/ext/numo/linalg/lapack/tmpl/lib.c +40 -0
- data/ext/numo/linalg/lapack/tmpl/module.c +9 -0
- data/ext/numo/linalg/lapack/tmpl/syev.c +91 -0
- data/ext/numo/linalg/lapack/tmpl/sygv.c +104 -0
- data/ext/numo/linalg/lapack/tmpl/trf.c +276 -0
- data/ext/numo/linalg/numo_linalg.h +115 -0
- data/lib/numo/linalg.rb +3 -0
- data/lib/numo/linalg/function.rb +1008 -0
- data/lib/numo/linalg/linalg.rb +7 -0
- data/lib/numo/linalg/loader.rb +174 -0
- data/lib/numo/linalg/use/atlas.rb +3 -0
- data/lib/numo/linalg/use/lapack.rb +3 -0
- data/lib/numo/linalg/use/mkl.rb +3 -0
- data/lib/numo/linalg/use/openblas.rb +3 -0
- data/lib/numo/linalg/version.rb +5 -0
- data/numo-linalg.gemspec +26 -0
- data/spec/lapack_spec.rb +13 -0
- metadata +172 -0
@@ -0,0 +1,276 @@
|
|
1
|
+
<%
|
2
|
+
has_rhs = (/trs$/ =~ name)
|
3
|
+
has_trans = (/^.(g|l|t).trs$/ =~ name)
|
4
|
+
has_uplo = (/^.(g|pt)/ !~ name)
|
5
|
+
has_ipiv = (/p[bfopt]tr.$/ !~ name)
|
6
|
+
ipiv_out = (has_ipiv && /trf$/ =~ name)
|
7
|
+
ipiv_in = (has_ipiv && /tr[is]$/ =~ name)
|
8
|
+
is_sym = (has_uplo || /getr[is]/=~name)
|
9
|
+
%>
|
10
|
+
#define RHS <%= has_rhs ? "1":"0" %>
|
11
|
+
#define TRANS <%= has_trans ? "1":"0" %>
|
12
|
+
#define UPLO <%= has_uplo ? "1":"0" %>
|
13
|
+
#define IPIV <%= has_ipiv ? "1":"0" %>
|
14
|
+
#define IPIV_OUT <%= ipiv_out ? "1":"0" %>
|
15
|
+
#define IPIV_IN <%= ipiv_in ? "1":"0" %>
|
16
|
+
#define SYM <%= is_sym ? "1":"0" %>
|
17
|
+
#define args_t <%=func_name%>_args_t
|
18
|
+
#define func_p <%=func_name%>_p
|
19
|
+
|
20
|
+
typedef struct {
|
21
|
+
int order;
|
22
|
+
char uplo;
|
23
|
+
char trans;
|
24
|
+
} args_t;
|
25
|
+
|
26
|
+
static <%=func_name%>_t func_p = 0;
|
27
|
+
|
28
|
+
static void
|
29
|
+
<%=c_iter%>(na_loop_t * const lp)
|
30
|
+
{
|
31
|
+
dtype *a;
|
32
|
+
#if RHS
|
33
|
+
dtype *b;
|
34
|
+
int nb, nrhs, ldb;
|
35
|
+
#endif
|
36
|
+
#if IPIV
|
37
|
+
int *pv;
|
38
|
+
#endif
|
39
|
+
int *info;
|
40
|
+
int m, n, lda;
|
41
|
+
args_t *g;
|
42
|
+
|
43
|
+
a = (dtype*)NDL_PTR(lp,0);
|
44
|
+
#if IPIV
|
45
|
+
pv = (int*)NDL_PTR(lp,1);
|
46
|
+
#endif
|
47
|
+
#if RHS
|
48
|
+
b = (dtype*)NDL_PTR(lp,1+IPIV);
|
49
|
+
#endif
|
50
|
+
info = (int*)NDL_PTR(lp,1+IPIV+RHS);
|
51
|
+
g = (args_t*)(lp->opt_ptr);
|
52
|
+
|
53
|
+
n = NDL_SHAPE(lp,0)[0];
|
54
|
+
m = NDL_SHAPE(lp,0)[1];
|
55
|
+
lda = NDL_STEP(lp,0) / sizeof(dtype);
|
56
|
+
|
57
|
+
#if RHS
|
58
|
+
// same as gels.c
|
59
|
+
if (lp->args[1+IPIV].ndim == 1) {
|
60
|
+
nrhs = 1;
|
61
|
+
nb = NDL_SHAPE(lp,1+IPIV)[0];
|
62
|
+
ldb = (g->order==LAPACK_COL_MAJOR) ? nb : 1;
|
63
|
+
} else {
|
64
|
+
nb = NDL_SHAPE(lp,1+IPIV)[0];
|
65
|
+
nrhs = NDL_SHAPE(lp,1+IPIV)[1];
|
66
|
+
ldb = nrhs;
|
67
|
+
{ int tmp; SWAP_IFCOL(g->order,nb,nrhs); }
|
68
|
+
}
|
69
|
+
//printf("order=%d m=%d n=%d nb=%d nrhs=%d lda=%d ldb=%d\n",g->order,m,n,nb,nrhs,lda,ldb);
|
70
|
+
#else
|
71
|
+
//printf("order=%d m=%d n=%d lda=%d \n",g->order,m,n,lda);
|
72
|
+
#endif
|
73
|
+
|
74
|
+
#if SYM
|
75
|
+
n = min_(m,n);
|
76
|
+
#else
|
77
|
+
{ int tmp; SWAP_IFCOL(g->order,m,n); }
|
78
|
+
#endif
|
79
|
+
|
80
|
+
<%
|
81
|
+
func_args = [ "g->order",
|
82
|
+
has_uplo && "g->uplo",
|
83
|
+
has_trans && "g->trans",
|
84
|
+
"n",
|
85
|
+
has_rhs ? "nrhs" : (!is_sym && "m"),
|
86
|
+
"a, lda",
|
87
|
+
has_ipiv && "pv",
|
88
|
+
has_rhs && "b, ldb",
|
89
|
+
].select{|x| x}.join(", ")
|
90
|
+
%>
|
91
|
+
*info = (*func_p)(<%=func_args%>);
|
92
|
+
CHECK_ERROR(*info);
|
93
|
+
}
|
94
|
+
|
95
|
+
/*<%
|
96
|
+
args_v = [
|
97
|
+
"a",
|
98
|
+
ipiv_in && "ipiv",
|
99
|
+
has_rhs && "b",
|
100
|
+
].select{|x| x}.join(", ")
|
101
|
+
|
102
|
+
args_opt = [
|
103
|
+
has_uplo && "uplo:'U'",
|
104
|
+
has_trans && "trans:'N'",
|
105
|
+
"order:'R'",
|
106
|
+
].select{|x| x}.join(", ")
|
107
|
+
|
108
|
+
trf = name.sub(/.$/,"f")
|
109
|
+
|
110
|
+
params = [
|
111
|
+
has_rhs ? "@param a [#{class_name}] LU matrix computed by "+trf :
|
112
|
+
mat("a",:inplace),
|
113
|
+
ipiv_in && "@param ipiv [Numo::Int] pivot computed by "+trf,
|
114
|
+
has_rhs && mat("b",:inplace),
|
115
|
+
has_uplo && opt("uplo"),
|
116
|
+
has_trans && opt("trans"),
|
117
|
+
opt("order"),
|
118
|
+
].select{|x| x}.join("\n ")
|
119
|
+
|
120
|
+
return_type = [
|
121
|
+
class_name,
|
122
|
+
ipiv_out && "Numo::Int",
|
123
|
+
"Integer"
|
124
|
+
].select{|x| x}.join(", ")
|
125
|
+
|
126
|
+
return_name = [
|
127
|
+
has_rhs ? "b" : "a",
|
128
|
+
ipiv_out && "ipiv",
|
129
|
+
"info"
|
130
|
+
].select{|x| x}.join(", ")
|
131
|
+
%>
|
132
|
+
@overload <%=name%>(<%=args_v%>, [<%=args_opt%>])
|
133
|
+
<%=params%>
|
134
|
+
@return [[<%=return_name%>]] Array<<%=return_type%>>
|
135
|
+
<%=outparam(return_name)%>
|
136
|
+
|
137
|
+
<%=description%>
|
138
|
+
|
139
|
+
*/
|
140
|
+
static VALUE
|
141
|
+
<%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
|
142
|
+
{
|
143
|
+
<% %>
|
144
|
+
VALUE a, ans;
|
145
|
+
#if IPIV_IN
|
146
|
+
VALUE ipiv;
|
147
|
+
#endif
|
148
|
+
#if RHS
|
149
|
+
VALUE b;
|
150
|
+
size_t n, nb, nrhs;
|
151
|
+
narray_t *na2;
|
152
|
+
#endif
|
153
|
+
narray_t *na1;
|
154
|
+
<%
|
155
|
+
aout = [
|
156
|
+
ipiv_out && "{cInt,1,shape_piv}",
|
157
|
+
"{cInt,0}",
|
158
|
+
].select{|x| x}.join(",")
|
159
|
+
%>
|
160
|
+
#if IPIV_OUT
|
161
|
+
size_t shape_piv[1];
|
162
|
+
#endif
|
163
|
+
#if IPIV_IN
|
164
|
+
# if RHS
|
165
|
+
ndfunc_arg_in_t ain[3] = {{cT,2},{cInt,1},{OVERWRITE,2}};
|
166
|
+
# else
|
167
|
+
ndfunc_arg_in_t ain[2] = {{OVERWRITE,2},{cInt,1}};
|
168
|
+
# endif
|
169
|
+
#else
|
170
|
+
# if RHS
|
171
|
+
ndfunc_arg_in_t ain[2] = {{cT,2},{OVERWRITE,2}};
|
172
|
+
# else
|
173
|
+
ndfunc_arg_in_t ain[1] = {{OVERWRITE,2}};
|
174
|
+
# endif
|
175
|
+
#endif
|
176
|
+
ndfunc_arg_out_t aout[1+IPIV_OUT] = {<%=aout%>};
|
177
|
+
ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP|NDF_EXTRACT,
|
178
|
+
1+IPIV_IN+RHS, IPIV_OUT+1, ain,aout};
|
179
|
+
|
180
|
+
args_t g = {0,0};
|
181
|
+
VALUE opts[2] = {Qundef,Qundef};
|
182
|
+
VALUE kw_hash = Qnil;
|
183
|
+
ID kw_table[2] = {id_order,id_uplo};
|
184
|
+
|
185
|
+
CHECK_FUNC(func_p,"<%=func_name%>");
|
186
|
+
|
187
|
+
#if IPIV_IN
|
188
|
+
# if RHS
|
189
|
+
rb_scan_args(argc, argv, "3:", &a, &ipiv, &b, &kw_hash);
|
190
|
+
# else
|
191
|
+
rb_scan_args(argc, argv, "2:", &a, &ipiv, &kw_hash);
|
192
|
+
# endif
|
193
|
+
#else
|
194
|
+
# if RHS
|
195
|
+
rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
|
196
|
+
# else
|
197
|
+
rb_scan_args(argc, argv, "1:", &a, &kw_hash);
|
198
|
+
# endif
|
199
|
+
#endif
|
200
|
+
#if TRANS
|
201
|
+
kw_table[1] = id_trans;
|
202
|
+
rb_get_kwargs(kw_hash, kw_table, 0, 2, opts);
|
203
|
+
g.trans = option_trans(opts[1]);
|
204
|
+
#elif UPLO
|
205
|
+
rb_get_kwargs(kw_hash, kw_table, 0, 2, opts);
|
206
|
+
g.uplo = option_uplo(opts[1]);
|
207
|
+
#else
|
208
|
+
rb_get_kwargs(kw_hash, kw_table, 0, 1, opts);
|
209
|
+
#endif
|
210
|
+
g.order = option_order(opts[0]);
|
211
|
+
|
212
|
+
#if !RHS
|
213
|
+
COPY_OR_CAST_TO(a,cT);
|
214
|
+
#endif
|
215
|
+
GetNArray(a, na1);
|
216
|
+
CHECK_DIM_GE(na1, 2);
|
217
|
+
#if IPIV_OUT
|
218
|
+
shape_piv[0] = min_(ROW_SIZE(na1),COL_SIZE(na1));
|
219
|
+
#endif
|
220
|
+
|
221
|
+
#if RHS
|
222
|
+
COPY_OR_CAST_TO(b,cT);
|
223
|
+
GetNArray(b, na2);
|
224
|
+
CHECK_DIM_GE(na2, 1);
|
225
|
+
n = COL_SIZE(na1);
|
226
|
+
#if SYM
|
227
|
+
n = min_(n,ROW_SIZE(na1));
|
228
|
+
#endif
|
229
|
+
// same as gesv.c
|
230
|
+
if (NA_NDIM(na2) == 1) {
|
231
|
+
ain[1+IPIV_IN].dim = 1;
|
232
|
+
nb = COL_SIZE(na2);
|
233
|
+
nrhs = 1;
|
234
|
+
} else {
|
235
|
+
nb = ROW_SIZE(na2);
|
236
|
+
nrhs = COL_SIZE(na2);
|
237
|
+
{ int tmp; SWAP_IFCOL(g.order,nb,nrhs); }
|
238
|
+
}
|
239
|
+
if (n != nb) {
|
240
|
+
rb_raise(nary_eShapeError, "matrix dimension mismatch: "
|
241
|
+
"a.col(or a.row)=%"SZF"u b.row=%"SZF"u", n, nb);
|
242
|
+
}
|
243
|
+
#endif
|
244
|
+
|
245
|
+
#if IPIV_IN
|
246
|
+
# if RHS
|
247
|
+
ans = na_ndloop3(&ndf, &g, 3, a, ipiv, b);
|
248
|
+
return rb_assoc_new(b, ans);
|
249
|
+
# else
|
250
|
+
ans = na_ndloop3(&ndf, &g, 2, a, ipiv);
|
251
|
+
return rb_assoc_new(a, ans);
|
252
|
+
# endif
|
253
|
+
#else
|
254
|
+
# if RHS
|
255
|
+
ans = na_ndloop3(&ndf, &g, 2, a, b);
|
256
|
+
return rb_assoc_new(b, ans);
|
257
|
+
# else
|
258
|
+
ans = na_ndloop3(&ndf, &g, 1, a);
|
259
|
+
# if IPIV_OUT
|
260
|
+
return rb_ary_unshift(ans, a);
|
261
|
+
# else
|
262
|
+
return rb_assoc_new(a, ans);
|
263
|
+
# endif
|
264
|
+
# endif
|
265
|
+
#endif
|
266
|
+
}
|
267
|
+
|
268
|
+
#undef args_t
|
269
|
+
#undef func_p
|
270
|
+
#undef RHS
|
271
|
+
#undef TRANS
|
272
|
+
#undef UPLO
|
273
|
+
#undef IPIV
|
274
|
+
#undef IPIV_OUT
|
275
|
+
#undef IPIV_IN
|
276
|
+
#undef SYM
|
@@ -0,0 +1,115 @@
|
|
1
|
+
#if defined __clang__
|
2
|
+
# define UNUSED(name) __unused name
|
3
|
+
#else
|
4
|
+
# define UNUSED(name) name
|
5
|
+
#endif
|
6
|
+
|
7
|
+
#if SIZEOF_INT == 4
|
8
|
+
#define cI numo_cInt32
|
9
|
+
#define cUI numo_cUInt32
|
10
|
+
#elif SIZEOF_INT==8
|
11
|
+
#define cI numo_cInt64
|
12
|
+
#define cUI numo_cUInt64
|
13
|
+
#endif
|
14
|
+
|
15
|
+
#if SIZEOF_SIZE_T == 4
|
16
|
+
#define cSZ numo_cUInt32
|
17
|
+
#define cSSZ numo_cInt32
|
18
|
+
#elif SIZEOF_SIZE_T == 8
|
19
|
+
#define cSZ numo_cUInt64
|
20
|
+
#define cSSZ numo_cInt64
|
21
|
+
#endif
|
22
|
+
|
23
|
+
#define cDF numo_cDFloat
|
24
|
+
#define cDC numo_cDComplex
|
25
|
+
#define cSF numo_cSFloat
|
26
|
+
#define cSC numo_cSComplex
|
27
|
+
#define cInt cI
|
28
|
+
#define cUInt cUI
|
29
|
+
|
30
|
+
extern VALUE na_expand_dims(VALUE self, VALUE vdim);
|
31
|
+
|
32
|
+
#define max_(m,n) (((m)>(n)) ? (m):(n))
|
33
|
+
#define min_(m,n) (((m)<(n)) ? (m):(n))
|
34
|
+
|
35
|
+
#define ROW_SIZE(na) ((na)->shape[(na)->ndim-2])
|
36
|
+
#define COL_SIZE(na) ((na)->shape[(na)->ndim-1])
|
37
|
+
|
38
|
+
#define CHECK_NARRAY_TYPE(x,t) \
|
39
|
+
if (CLASS_OF(x)!=(t)) { \
|
40
|
+
rb_raise(rb_eTypeError,"invalid NArray type (class)"); \
|
41
|
+
}
|
42
|
+
|
43
|
+
// Error Class ??
|
44
|
+
#define CHECK_DIM_GE(na,nd) \
|
45
|
+
if ((na)->ndim<(nd)) { \
|
46
|
+
rb_raise(nary_eShapeError, \
|
47
|
+
"n-dimension=%d, but >=%d is expected", \
|
48
|
+
(na)->ndim, (nd)); \
|
49
|
+
}
|
50
|
+
|
51
|
+
#define CHECK_DIM_EQ(na1,nd) \
|
52
|
+
if ((na1)->ndim != (nd)) { \
|
53
|
+
rb_raise(nary_eShapeError, \
|
54
|
+
"dimention mismatch: %d != %d", \
|
55
|
+
(na1)->ndim, (nd)); \
|
56
|
+
}
|
57
|
+
|
58
|
+
#define CHECK_SQUARE(name,na) \
|
59
|
+
if ((na)->shape[(na)->ndim-1] != (na)->shape[(na)->ndim-2]) { \
|
60
|
+
rb_raise(nary_eShapeError,"%s is not square matrix",name); \
|
61
|
+
}
|
62
|
+
|
63
|
+
#define CHECK_SIZE_GE(na,sz) \
|
64
|
+
if ((na)->size < (size_t)(sz)) { \
|
65
|
+
rb_raise(nary_eShapeError, \
|
66
|
+
"NArray size must be >= %"SZF"u",(size_t)(sz));\
|
67
|
+
}
|
68
|
+
|
69
|
+
#define CHECK_NON_EMPTY(na) \
|
70
|
+
if ((na)->size==0) { \
|
71
|
+
rb_raise(nary_eShapeError,"empty NArray"); \
|
72
|
+
}
|
73
|
+
|
74
|
+
#define CHECK_SIZE_EQ(n,m) \
|
75
|
+
if ((n)!=(m)) { \
|
76
|
+
rb_raise(nary_eShapeError, \
|
77
|
+
"size mismatch: %"SZF"d != %"SZF"d", \
|
78
|
+
(size_t)(n),(size_t)(m)); \
|
79
|
+
}
|
80
|
+
|
81
|
+
#define CHECK_SAME_SHAPE(na1,na2) \
|
82
|
+
{ int i; \
|
83
|
+
CHECK_DIM_EQ(na1,na2->ndim); \
|
84
|
+
for (i=0; i<na1->ndim; i++) { \
|
85
|
+
CHECK_SIZE_EQ(na1->shape[i],na2->shape[i]); \
|
86
|
+
} \
|
87
|
+
}
|
88
|
+
|
89
|
+
#define CHECK_INT_EQ(sm,m,sn,n) \
|
90
|
+
if ((m) != (n)) { \
|
91
|
+
rb_raise(nary_eShapeError, \
|
92
|
+
"%s must be == %s: %s=%d %s=%d", \
|
93
|
+
sm,sn,sm,m,sn,n); \
|
94
|
+
}
|
95
|
+
|
96
|
+
// Error Class ??
|
97
|
+
#define CHECK_LEADING_GE(sld,ld,sn,n) \
|
98
|
+
if ((ld) < (n)) { \
|
99
|
+
rb_raise(nary_eShapeError, \
|
100
|
+
"%s must be >= max(%s,1): %s=%d %s=%d", \
|
101
|
+
sld,sn,sld,ld,sn,n); \
|
102
|
+
}
|
103
|
+
|
104
|
+
#define COPY_OR_CAST_TO(a,T) \
|
105
|
+
{ \
|
106
|
+
if (CLASS_OF(a) == (T)) { \
|
107
|
+
if (!TEST_INPLACE(a)) { \
|
108
|
+
a = na_copy(a); \
|
109
|
+
} \
|
110
|
+
} else { \
|
111
|
+
a = rb_funcall(T,rb_intern("cast"),1,a); \
|
112
|
+
} \
|
113
|
+
}
|
114
|
+
|
115
|
+
#define swap(a,b) {tmp=a;a=b;b=tmp;}
|
data/lib/numo/linalg.rb
ADDED
@@ -0,0 +1,1008 @@
|
|
1
|
+
module Numo; module Linalg
|
2
|
+
|
3
|
+
module Blas
|
4
|
+
|
5
|
+
FIXNAME =
|
6
|
+
{
|
7
|
+
cnrm2: :csnrm2,
|
8
|
+
znrm2: :dznrm2,
|
9
|
+
}
|
10
|
+
|
11
|
+
# Call BLAS function prefixed with BLAS char ([sdcz])
|
12
|
+
# defined from data-types of arguments.
|
13
|
+
# @param [Symbol] func function name without BLAS char.
|
14
|
+
# @param args arguments passed to Blas function.
|
15
|
+
# @example
|
16
|
+
# c = Numo::Linalg::Blas.call(:gemm, a, b)
|
17
|
+
def self.call(func,*args)
|
18
|
+
fn = (Linalg.blas_char(*args) + func.to_s).to_sym
|
19
|
+
fn = FIXNAME[fn] || fn
|
20
|
+
send(fn,*args)
|
21
|
+
end
|
22
|
+
|
23
|
+
end
|
24
|
+
|
25
|
+
module Lapack
|
26
|
+
|
27
|
+
FIXNAME =
|
28
|
+
{
|
29
|
+
corgqr: :cungqr,
|
30
|
+
zorgqr: :zungqr,
|
31
|
+
}
|
32
|
+
|
33
|
+
# Call LAPACK function prefixed with BLAS char ([sdcz])
|
34
|
+
# defined from data-types of arguments.
|
35
|
+
# @param [Symbol,String] func function name without BLAS char.
|
36
|
+
# @param args arguments passed to Lapack function.
|
37
|
+
# @example
|
38
|
+
# s = Numo::Linalg::Lapack.call(:gesv, a)
|
39
|
+
def self.call(func,*args)
|
40
|
+
fn = (Linalg.blas_char(*args) + func.to_s).to_sym
|
41
|
+
fn = FIXNAME[fn] || fn
|
42
|
+
send(fn,*args)
|
43
|
+
end
|
44
|
+
|
45
|
+
end
|
46
|
+
|
47
|
+
BLAS_CHAR =
|
48
|
+
{
|
49
|
+
SFloat => "s",
|
50
|
+
DFloat => "d",
|
51
|
+
SComplex => "c",
|
52
|
+
DComplex => "z",
|
53
|
+
}
|
54
|
+
|
55
|
+
module_function
|
56
|
+
|
57
|
+
def blas_char(*args)
|
58
|
+
t = Float
|
59
|
+
args.each do |a|
|
60
|
+
k =
|
61
|
+
case a
|
62
|
+
when NArray
|
63
|
+
a.class
|
64
|
+
when Array
|
65
|
+
NArray.array_type(a)
|
66
|
+
end
|
67
|
+
if k && k < NArray
|
68
|
+
t = k::UPCAST[t]
|
69
|
+
end
|
70
|
+
end
|
71
|
+
BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
|
72
|
+
end
|
73
|
+
|
74
|
+
# module methods
|
75
|
+
|
76
|
+
## Matrix and vector products
|
77
|
+
|
78
|
+
# Dot product.
|
79
|
+
# @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
|
80
|
+
# @param b [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
|
81
|
+
# @return [Numo::NArray] result of dot product
|
82
|
+
def dot(a, b)
|
83
|
+
a = NArray.asarray(a)
|
84
|
+
b = NArray.asarray(b)
|
85
|
+
case a.ndim
|
86
|
+
when 1
|
87
|
+
case b.ndim
|
88
|
+
when 1
|
89
|
+
Blas.call(:dot, a, b)
|
90
|
+
else
|
91
|
+
Blas.call(:gemv, b, a, trans:'t')
|
92
|
+
end
|
93
|
+
else
|
94
|
+
case b.ndim
|
95
|
+
when 1
|
96
|
+
Blas.call(:gemv, a, b)
|
97
|
+
else
|
98
|
+
Blas.call(:gemm, a, b)
|
99
|
+
end
|
100
|
+
end
|
101
|
+
end
|
102
|
+
|
103
|
+
# Matrix product.
|
104
|
+
# @param a [Numo::NArray] matrix (>= 2-dimensinal NArray)
|
105
|
+
# @param b [Numo::NArray] matrix (>= 2-dimensinal NArray)
|
106
|
+
# @return [Numo::NArray] result of matrix product
|
107
|
+
def matmul(a, b)
|
108
|
+
Blas.call(:gemm, a, b)
|
109
|
+
end
|
110
|
+
|
111
|
+
# Compute a square matrix `a` to the power `n`.
|
112
|
+
#
|
113
|
+
# * If n > 0: return `a**n`.
|
114
|
+
# * If n == 0: return identity matrix.
|
115
|
+
# * If n < 0: return `(a*\*-1)*\*n.abs`.
|
116
|
+
#
|
117
|
+
# @param a [Numo::NArray] square matrix (>= 2-dimensinal NArray).
|
118
|
+
# @param n [Integer] the exponent.
|
119
|
+
# @example
|
120
|
+
# i = Numo::DFloat[[0, 1], [-1, 0]]
|
121
|
+
# => Numo::DFloat#shape=[2,2]
|
122
|
+
# [[0, 1],
|
123
|
+
# [-1, 0]]
|
124
|
+
# Numo::Linalg.matrix_power(i,3)
|
125
|
+
# => Numo::DFloat#shape=[2,2]
|
126
|
+
# [[0, -1],
|
127
|
+
# [1, 0]]
|
128
|
+
# Numo::Linalg.matrix_power(i,0)
|
129
|
+
# => Numo::DFloat#shape=[2,2]
|
130
|
+
# [[1, 0],
|
131
|
+
# [0, 1]]
|
132
|
+
# Numo::Linalg.matrix_power(i,-3)
|
133
|
+
# => Numo::DFloat#shape=[2,2]
|
134
|
+
# [[0, 1],
|
135
|
+
# [-1, 0]]
|
136
|
+
#
|
137
|
+
# q = Numo::DFloat.zeros(4,4)
|
138
|
+
# q[0..1,0..1] = -i
|
139
|
+
# q[2..3,2..3] = i
|
140
|
+
# q
|
141
|
+
# => Numo::DFloat#shape=[4,4]
|
142
|
+
# [[-0, -1, 0, 0],
|
143
|
+
# [1, -0, 0, 0],
|
144
|
+
# [0, 0, 0, 1],
|
145
|
+
# [0, 0, -1, 0]]
|
146
|
+
# Numo::Linalg.matrix_power(q,2)
|
147
|
+
# => Numo::DFloat#shape=[4,4]
|
148
|
+
# [[-1, 0, 0, 0],
|
149
|
+
# [0, -1, 0, 0],
|
150
|
+
# [0, 0, -1, 0],
|
151
|
+
# [0, 0, 0, -1]]
|
152
|
+
|
153
|
+
def matrix_power(a, n)
|
154
|
+
a = NArray.asarray(a)
|
155
|
+
m,k = a.shape[-2..-1]
|
156
|
+
unless m==k
|
157
|
+
raise NArray::ShapeError, "input must be a square array"
|
158
|
+
end
|
159
|
+
unless Integer===n
|
160
|
+
raise ArgumentError, "exponent must be an integer"
|
161
|
+
end
|
162
|
+
if n == 0
|
163
|
+
return a.class.eye(m)
|
164
|
+
elsif n < 0
|
165
|
+
a = inv(a)
|
166
|
+
n = n.abs
|
167
|
+
end
|
168
|
+
if n <= 3
|
169
|
+
r = a
|
170
|
+
(n-1).times do
|
171
|
+
r = matmul(r,a)
|
172
|
+
end
|
173
|
+
else
|
174
|
+
while (n & 1) == 0
|
175
|
+
a = matmul(a,a)
|
176
|
+
n >>= 1
|
177
|
+
end
|
178
|
+
r = a
|
179
|
+
while n != 0
|
180
|
+
a = matmul(a,a)
|
181
|
+
n >>= 1
|
182
|
+
if (n & 1) != 0
|
183
|
+
r = matmul(r,a)
|
184
|
+
end
|
185
|
+
end
|
186
|
+
end
|
187
|
+
r
|
188
|
+
end
|
189
|
+
|
190
|
+
|
191
|
+
## factorization
|
192
|
+
|
193
|
+
# Computes a QR factorization of a complex M-by-N matrix A: A = Q \* R.
|
194
|
+
#
|
195
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
196
|
+
# @param mode [String]
|
197
|
+
# - "reduce" -- returns both Q and R,
|
198
|
+
# - "r" -- returns only R,
|
199
|
+
# - "economy" -- returns both Q and R but computed in economy-size,
|
200
|
+
# - "raw" -- returns QR and TAU used in LAPACK.
|
201
|
+
# @return [r] if mode:"r"
|
202
|
+
# @return [[q,r]] if mode:"reduce" or "economic"
|
203
|
+
# @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
|
204
|
+
|
205
|
+
def qr(a, mode:"reduce")
|
206
|
+
qr,tau, = Lapack.call(:geqrf, a)
|
207
|
+
*shp,m,n = qr.shape
|
208
|
+
r = (m >= n && %w[economic raw].include?(mode)) ?
|
209
|
+
qr[false, 0...n, true].triu : qr.triu
|
210
|
+
mode = mode.to_s.downcase
|
211
|
+
case mode
|
212
|
+
when "r"
|
213
|
+
return r
|
214
|
+
when "raw"
|
215
|
+
return [qr,tau]
|
216
|
+
when "reduce","economic"
|
217
|
+
# skip
|
218
|
+
else
|
219
|
+
raise ArgumentError, "invalid mode:#{mode}"
|
220
|
+
end
|
221
|
+
if m < n
|
222
|
+
q, = Lapack.call(:orgqr, qr[false, 0...m], tau)
|
223
|
+
elsif mode == "economic"
|
224
|
+
q, = Lapack.call(:orgqr, qr, tau)
|
225
|
+
else
|
226
|
+
qqr = qr.class.zeros(*(shp+[m,m]))
|
227
|
+
qqr[false,0...n] = qr
|
228
|
+
q, = Lapack.call(:orgqr, qqr, tau)
|
229
|
+
end
|
230
|
+
return [q,r]
|
231
|
+
end
|
232
|
+
|
233
|
+
|
234
|
+
# Computes the Singular Value Decomposition (SVD) of a M-by-N matrix A,
|
235
|
+
# and the left and/or right singular vectors. The SVD is written
|
236
|
+
#
|
237
|
+
# A = U * SIGMA * transpose(V)
|
238
|
+
#
|
239
|
+
# where SIGMA is an M-by-N matrix which is zero except for its
|
240
|
+
# min(m,n) diagonal elements, U is an M-by-M orthogonal matrix, and
|
241
|
+
# V is an N-by-N orthogonal matrix. The diagonal elements of SIGMA
|
242
|
+
# are the singular values of A; they are real and non-negative, and
|
243
|
+
# are returned in descending order. The first min(m,n) columns of U
|
244
|
+
# and V are the left and right singular vectors of A. Note that the
|
245
|
+
# routine returns V**T, not V.
|
246
|
+
#
|
247
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
248
|
+
# @param driver [String or Symbol] choose LAPACK solver from 'svd',
|
249
|
+
# 'sdd'. (optional, default='svd')
|
250
|
+
# @param job [String or Symbol]
|
251
|
+
# - 'A': all M columns of U and all N rows of V\*\*T are returned in
|
252
|
+
# the arrays U and VT.
|
253
|
+
# - 'S': the first min(M,N) columns of U and the first min(M,N)
|
254
|
+
# rows of V\*\*T are returned in the arrays U and VT.
|
255
|
+
# - 'N': no columns of U or rows of V\*\*T are computed.
|
256
|
+
# @return [[sigma,u,vt]] SVD result. Array<Numo::NArray>
|
257
|
+
|
258
|
+
def svd(a, driver:'svd', job:'A')
|
259
|
+
unless /^[ASN]/i =~ job
|
260
|
+
raise ArgumentError, "invalid job: #{job.inspect}"
|
261
|
+
end
|
262
|
+
case driver.to_s
|
263
|
+
when /^(ge)?sdd$/i, "turbo"
|
264
|
+
Lapack.call(:gesdd, a, jobz:job)[0..2]
|
265
|
+
when /^(ge)?svd$/i
|
266
|
+
Lapack.call(:gesvd, a, jobu:job, jobvt:job)[0..2]
|
267
|
+
else
|
268
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
269
|
+
end
|
270
|
+
end
|
271
|
+
|
272
|
+
# Computes the Singular Values of a M-by-N matrix A.
|
273
|
+
# The SVD is written
|
274
|
+
#
|
275
|
+
# A = U * SIGMA * transpose(V)
|
276
|
+
#
|
277
|
+
# where SIGMA is an M-by-N matrix which is zero except for its
|
278
|
+
# min(m,n) diagonal elements. The diagonal elements of SIGMA
|
279
|
+
# are the singular values of A; they are real and non-negative, and
|
280
|
+
# are returned in descending order.
|
281
|
+
#
|
282
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
283
|
+
# @param driver [String or Symbol] choose LAPACK solver from 'svd',
|
284
|
+
# 'sdd'. (optional, default='svd')
|
285
|
+
# @return [Numo::NArray] returns SIGMA (singular values).
|
286
|
+
|
287
|
+
def svdvals(a, driver:'svd')
|
288
|
+
case driver.to_s
|
289
|
+
when /^(ge)?sdd$/i, "turbo"
|
290
|
+
Lapack.call(:gesdd, a, jobz:'N')[0]
|
291
|
+
when /^(ge)?svd$/i
|
292
|
+
Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')[0]
|
293
|
+
else
|
294
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
295
|
+
end
|
296
|
+
end
|
297
|
+
|
298
|
+
|
299
|
+
# Computes an LU factorization of a M-by-N matrix A
|
300
|
+
# using partial pivoting with row interchanges.
|
301
|
+
#
|
302
|
+
# The factorization has the form
|
303
|
+
#
|
304
|
+
# A = P * L * U
|
305
|
+
#
|
306
|
+
# where P is a permutation matrix, L is lower triangular with unit
|
307
|
+
# diagonal elements (lower trapezoidal if m > n), and U is upper
|
308
|
+
# triangular (upper trapezoidal if m < n).
|
309
|
+
#
|
310
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
311
|
+
# @param driver [String or Symbol] choose LAPACK diriver from
|
312
|
+
# 'gen','sym','her'. (optional, default='gen')
|
313
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
314
|
+
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
315
|
+
# @return [[lu, ipiv]]
|
316
|
+
# - **lu** [Numo::NArray] -- The factors L and U from the factorization
|
317
|
+
# `A = P*L*U`; the unit diagonal elements of L are not stored.
|
318
|
+
# - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
|
319
|
+
# row i of the matrix was interchanged with row IPIV(i).
|
320
|
+
|
321
|
+
def lu_fact(a, driver:"gen", uplo:"U")
|
322
|
+
case driver.to_s
|
323
|
+
when /^gen?(trf)?$/i
|
324
|
+
Lapack.call(:getrf, a)[0..1]
|
325
|
+
when /^(sym?|her?)(trf)?$/i
|
326
|
+
func = driver[0..2].downcase+"trf"
|
327
|
+
Lapack.call(func, a, uplo:uplo)[0..1]
|
328
|
+
else
|
329
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
330
|
+
end
|
331
|
+
end
|
332
|
+
|
333
|
+
# Computes the inverse of a matrix using the LU factorization
|
334
|
+
# computed by Numo::Linalg.lu_fact.
|
335
|
+
#
|
336
|
+
# This method inverts U and then computes inv(A) by solving the system
|
337
|
+
#
|
338
|
+
# inv(A)*L = inv(U)
|
339
|
+
#
|
340
|
+
# for inv(A).
|
341
|
+
#
|
342
|
+
# @param lu [Numo::NArray] matrix containing the factors L and U
|
343
|
+
# from the factorization `A = P*L*U` as computed by
|
344
|
+
# Numo::Linalg.lu_fact.
|
345
|
+
# @param ipiv [Numo::NArray] The pivot indices from
|
346
|
+
# Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
|
347
|
+
# interchanged with row IPIV(i).
|
348
|
+
# @param driver [String or Symbol] choose LAPACK diriver from
|
349
|
+
# 'gen','sym','her'. (optional, default='gen')
|
350
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
351
|
+
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
352
|
+
# @return [Numo::NArray] the inverse of the original matrix A.
|
353
|
+
|
354
|
+
def lu_inv(lu, ipiv, driver:"gen", uplo:"U")
|
355
|
+
case driver.to_s
|
356
|
+
when /^gen?(tri)?$/i
|
357
|
+
Lapack.call(:getri, lu, ipiv)[0]
|
358
|
+
when /^(sym?|her?)(tri)?$/i
|
359
|
+
func = driver[0..2].downcase+"tri"
|
360
|
+
Lapack.call(func, lu, ipiv, uplo:uplo)[0]
|
361
|
+
else
|
362
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
363
|
+
end
|
364
|
+
end
|
365
|
+
|
366
|
+
# Solves a system of linear equations
|
367
|
+
#
|
368
|
+
# A * X = B or A**T * X = B
|
369
|
+
#
|
370
|
+
# with a N-by-N matrix A using the LU factorization computed by
|
371
|
+
# Numo::Linalg.lu_fact
|
372
|
+
#
|
373
|
+
# @param lu [Numo::NArray] matrix containing the factors L and U
|
374
|
+
# from the factorization `A = P*L*U` as computed by
|
375
|
+
# Numo::Linalg.lu_fact.
|
376
|
+
# @param ipiv [Numo::NArray] The pivot indices from
|
377
|
+
# Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
|
378
|
+
# interchanged with row IPIV(i).
|
379
|
+
# @param b [Numo::NArray] the right hand side matrix B.
|
380
|
+
# @param driver [String or Symbol] choose LAPACK diriver from
|
381
|
+
# 'gen','sym','her'. (optional, default='gen')
|
382
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
383
|
+
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
384
|
+
# @param trans [String or Symbol]
|
385
|
+
# Specifies the form of the system of equations
|
386
|
+
# (omitted if not driver:"gen"):
|
387
|
+
#
|
388
|
+
# - If 'N': `A * X = B` (No transpose).
|
389
|
+
# - If 'T': `A*\*T* X = B` (Transpose).
|
390
|
+
# - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
|
391
|
+
# @return [Numo::NArray] the solution matrix X.
|
392
|
+
|
393
|
+
def lu_solve(lu, ipiv, b, driver:"gen", uplo:"U", trans:"N")
|
394
|
+
case driver.to_s
|
395
|
+
when /^gen?(trs)?$/i
|
396
|
+
Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
|
397
|
+
when /^(sym?|her?)(trs)?$/i
|
398
|
+
func = driver[0..2].downcase+"trs"
|
399
|
+
Lapack.call(func, lu, ipiv, b, uplo:uplo)[0]
|
400
|
+
else
|
401
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
402
|
+
end
|
403
|
+
end
|
404
|
+
|
405
|
+
|
406
|
+
# Computes the Cholesky factorization of a symmetric/Hermitian
|
407
|
+
# positive definite matrix A. The factorization has the form
|
408
|
+
#
|
409
|
+
# A = U**H * U, if UPLO = 'U', or
|
410
|
+
# A = L * L**H, if UPLO = 'L',
|
411
|
+
#
|
412
|
+
# where U is an upper triangular matrix and L is lower triangular
|
413
|
+
# @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
|
414
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
415
|
+
# or ('U') lower ('L') triangle.
|
416
|
+
# @return [Numo::NArray] the factor U or L.
|
417
|
+
|
418
|
+
def cho_fact(a, uplo:'U')
|
419
|
+
Lapack.call(:potrf, a, uplo:uplo)[0]
|
420
|
+
end
|
421
|
+
#alias cholesky cho_fact
|
422
|
+
|
423
|
+
# Computes the inverse of a symmetric/Hermitian
|
424
|
+
# positive definite matrix A using the Cholesky factorization
|
425
|
+
# `A = U**T*U` or `A = L*L**T` computed by Linalg.cho_fact.
|
426
|
+
#
|
427
|
+
# @param a [Numo::NArray] the triangular factor U or L from the
|
428
|
+
# Cholesky factorization, as computed by Linalg.cho_fact.
|
429
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
430
|
+
# or ('U') lower ('L') triangle.
|
431
|
+
# @return [Numo::NArray] the upper or lower triangle of the
|
432
|
+
# (symmetric) inverse of A.
|
433
|
+
|
434
|
+
def cho_inv(a, uplo:'U')
|
435
|
+
Lapack.call(:potri, a, uplo:uplo)[0]
|
436
|
+
end
|
437
|
+
|
438
|
+
# Solves a system of linear equations
|
439
|
+
# A*X = B
|
440
|
+
# with a symmetric/Hermitian positive definite matrix A
|
441
|
+
# using the Cholesky factorization
|
442
|
+
# `A = U**T*U` or `A = L*L**T` computed by Linalg.cho_fact.
|
443
|
+
# @param a [Numo::NArray] the triangular factor U or L from the
|
444
|
+
# Cholesky factorization, as computed by Linalg.cho_fact.
|
445
|
+
# @param b [Numo::NArray] the right hand side matrix B.
|
446
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
447
|
+
# or ('U') lower ('L') triangle.
|
448
|
+
# @return [Numo::NArray] the solution matrix X.
|
449
|
+
|
450
|
+
def cho_solve(a, b, uplo:'U')
|
451
|
+
Lapack.call(:potrs, a, b, uplo:uplo)[0]
|
452
|
+
end
|
453
|
+
|
454
|
+
|
455
|
+
## Matrix eigenvalues
|
456
|
+
|
457
|
+
# Computes the eigenvalues and, optionally, the left and/or right
|
458
|
+
# eigenvectors for a square nonsymmetric matrix A.
|
459
|
+
#
|
460
|
+
# @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
|
461
|
+
# @param left [Bool] (optional) If true, left eigenvectors are computed.
|
462
|
+
# @param right [Bool] (optional) If true, right eigenvectors are computed.
|
463
|
+
# @return [[w,vl,vr]]
|
464
|
+
# - **w** [Numo::NArray] -- The eigenvalues.
|
465
|
+
# - **vl** [Numo::NArray] -- The left eigenvectors if left is true, otherwise nil.
|
466
|
+
# - **vr** [Numo::NArray] -- The right eigenvectors if right is true, otherwise nil.
|
467
|
+
|
468
|
+
def eig(a, left:false, right:true)
|
469
|
+
jobvl, jobvr = left, right
|
470
|
+
case blas_char(a)
|
471
|
+
when /c|z/
|
472
|
+
w, vl, vr, info = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
|
473
|
+
else
|
474
|
+
wr, wi, vl, vr, info = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
|
475
|
+
w = wr + wi * Complex::I
|
476
|
+
vl = _make_complex_eigvecs(w,vl) if left
|
477
|
+
vr = _make_complex_eigvecs(w,vr) if right
|
478
|
+
end
|
479
|
+
[w,vl,vr] #.compact
|
480
|
+
end
|
481
|
+
|
482
|
+
# Computes the eigenvalues and, optionally, the left and/or right
|
483
|
+
# eigenvectors for a square symmetric/hermitian matrix A.
|
484
|
+
#
|
485
|
+
# @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
|
486
|
+
# @param values_only [Bool] (optional) If false, eigenvectors are computed.
|
487
|
+
# @param uplo [String or Symbol] (optional, default='U')
|
488
|
+
# Access upper ('U') or lower ('L') triangle.
|
489
|
+
# @return [[w,v]]
|
490
|
+
# - **w** [Numo::NArray] -- The eigenvalues.
|
491
|
+
# - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
|
492
|
+
|
493
|
+
def eigh(a, vals_only:false, uplo:false, turbo:false)
|
494
|
+
jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
|
495
|
+
case blas_char(a)
|
496
|
+
when /c|z/
|
497
|
+
func = turbo ? :hegv : :heev
|
498
|
+
else
|
499
|
+
func = turbo ? :sygv : :syev
|
500
|
+
end
|
501
|
+
w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
|
502
|
+
[w,v] #.compact
|
503
|
+
end
|
504
|
+
|
505
|
+
# Computes the eigenvalues only for a square nonsymmetric matrix A.
|
506
|
+
#
|
507
|
+
# @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
|
508
|
+
# @return [Numo::NArray] eigenvalues
|
509
|
+
|
510
|
+
def eigvals(a)
|
511
|
+
jobvl, jobvr = 'N','N'
|
512
|
+
case blas_char(a)
|
513
|
+
when /c|z/
|
514
|
+
w, = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
|
515
|
+
else
|
516
|
+
wr, wi, = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
|
517
|
+
w = wr + wi * Complex::I
|
518
|
+
end
|
519
|
+
w
|
520
|
+
end
|
521
|
+
|
522
|
+
# Computes the eigenvalues for a square symmetric/hermitian matrix A.
|
523
|
+
#
|
524
|
+
# @param a [Numo::NArray] square symmetric/hermitian matrix
|
525
|
+
# (>= 2-dimensinal NArray)
|
526
|
+
# @param uplo [String or Symbol] (optional, default='U')
|
527
|
+
# Access upper ('U') or lower ('L') triangle.
|
528
|
+
# @return [Numo::NArray] eigenvalues
|
529
|
+
|
530
|
+
def eigvalsh(a, uplo:false, turbo:false)
|
531
|
+
jobz = 'N' # jobz: Compute eigenvalues and eigenvectors.
|
532
|
+
case blas_char(a)
|
533
|
+
when /c|z/
|
534
|
+
func = turbo ? :hegv : :heev
|
535
|
+
else
|
536
|
+
func = turbo ? :sygv : :syev
|
537
|
+
end
|
538
|
+
Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
|
539
|
+
end
|
540
|
+
|
541
|
+
|
542
|
+
## Norms and other numbers
|
543
|
+
|
544
|
+
# Compute matrix or vector norm.
|
545
|
+
#
|
546
|
+
# | ord | matrix norm | vector norm |
|
547
|
+
# | ----- | ---------------------- | --------------------------- |
|
548
|
+
# | nil | Frobenius norm | 2-norm |
|
549
|
+
# | 'fro' | Frobenius norm | - |
|
550
|
+
# | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
|
551
|
+
# | 0 | - | (x.ne 0).sum |
|
552
|
+
# | 1 | x.abs.sum(axis:-2).max | same as below |
|
553
|
+
# | 2 | 2-norm (max sing_vals) | same as below |
|
554
|
+
# | other | - | (x.abs**ord).sum**(1.0/ord) |
|
555
|
+
#
|
556
|
+
# @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
|
557
|
+
# @param ord [String or Symbol] Order of the norm .
|
558
|
+
# @param axis [Integer or Array] Applied axes (optional).
|
559
|
+
# @param keepdims [Bool] If true, the applied axes are left in
|
560
|
+
# result with size one (optional).
|
561
|
+
# @return [Numo::NArray] norm result
|
562
|
+
|
563
|
+
def norm(a, ord=nil, axis:nil, keepdims:false)
|
564
|
+
a = Numo::NArray.asarray(a)
|
565
|
+
|
566
|
+
# check axis
|
567
|
+
if axis
|
568
|
+
case axis
|
569
|
+
when Integer
|
570
|
+
axis = [axis]
|
571
|
+
when Array
|
572
|
+
if axis.size < 1 || axis.size > 2
|
573
|
+
raise ArgmentError, "axis option should be 1- or 2-element array"
|
574
|
+
end
|
575
|
+
else
|
576
|
+
raise ArgumentError, "invalid option for axis: #{axis}"
|
577
|
+
end
|
578
|
+
# swap axes
|
579
|
+
if a.ndim > 1
|
580
|
+
idx = (0...a.ndim).to_a
|
581
|
+
tmp = []
|
582
|
+
axis.each do |i|
|
583
|
+
x = idx[i]
|
584
|
+
if x.nil?
|
585
|
+
raise ArgmentError, "axis contains same dimension"
|
586
|
+
end
|
587
|
+
tmp << x
|
588
|
+
idx[i] = nil
|
589
|
+
end
|
590
|
+
idx.compact!
|
591
|
+
idx.concat(tmp)
|
592
|
+
a = a.transpose(*idx)
|
593
|
+
end
|
594
|
+
else
|
595
|
+
case a.ndim
|
596
|
+
when 0
|
597
|
+
raise ArgumentError, "zero-dimensional array"
|
598
|
+
when 1
|
599
|
+
axis = [-1]
|
600
|
+
else
|
601
|
+
axis = [-2,-1]
|
602
|
+
end
|
603
|
+
end
|
604
|
+
|
605
|
+
# calculate norm
|
606
|
+
case axis.size
|
607
|
+
|
608
|
+
when 1 # vector
|
609
|
+
k = keepdims
|
610
|
+
ord ||= 2 # default
|
611
|
+
case ord.to_s
|
612
|
+
when "0"
|
613
|
+
r = a.class.cast(a.ne(0)).sum(axis:-1, keepdims:k)
|
614
|
+
when "1"
|
615
|
+
r = a.abs.sum(axis:-1, keepdims:k)
|
616
|
+
when "2"
|
617
|
+
r = Blas.call(:nrm2, a, keepdims:k)
|
618
|
+
when /^-?\d+$/
|
619
|
+
o = ord.to_i
|
620
|
+
r = (a.abs**o).sum(axis:-1, keepdims:k)**(1.0/o)
|
621
|
+
when /^inf(inity)?$/i
|
622
|
+
r = a.abs.max(axis:-1, keepdims:k)
|
623
|
+
when /^-inf(inity)?$/i
|
624
|
+
r = a.abs.min(axis:-1, keepdims:k)
|
625
|
+
else
|
626
|
+
raise ArgumentError, "ord (#{ord}) is invalid for vector norm"
|
627
|
+
end
|
628
|
+
|
629
|
+
when 2 # matrix
|
630
|
+
if keepdims
|
631
|
+
fixdims = [true] * a.ndim
|
632
|
+
axis.each do |i|
|
633
|
+
if i < -a.ndim || i >= a.ndim
|
634
|
+
raise ArgmentError, "axis (%d) is out of range", i
|
635
|
+
end
|
636
|
+
fixdims[i] = :new
|
637
|
+
end
|
638
|
+
end
|
639
|
+
ord ||= "fro" # default
|
640
|
+
case ord.to_s
|
641
|
+
when "1"
|
642
|
+
r, = Lapack.call(:lange, a, '1')
|
643
|
+
when "-1"
|
644
|
+
r = a.abs.sum(axis:-2).min(axis:-1)
|
645
|
+
when "2"
|
646
|
+
svd, = Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')
|
647
|
+
r = svd.max(axis:-1)
|
648
|
+
when "-2"
|
649
|
+
svd, = Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')
|
650
|
+
r = svd.min(axis:-1)
|
651
|
+
when /^f(ro)?$/i
|
652
|
+
r, = Lapack.call(:lange, a, 'F')
|
653
|
+
when /^inf(inity)?$/i
|
654
|
+
r, = Lapack.call(:lange, a, 'I')
|
655
|
+
when /^-inf(inity)?$/i
|
656
|
+
r = a.abs.sum(axis:-1).min(axis:-1)
|
657
|
+
else
|
658
|
+
raise ArgumentError, "ord (#{ord}) is invalid for matrix norm"
|
659
|
+
end
|
660
|
+
if keepdims
|
661
|
+
if NArray===r
|
662
|
+
r = r[*fixdims]
|
663
|
+
else
|
664
|
+
r = a.class.new(1,1).store(r)
|
665
|
+
end
|
666
|
+
end
|
667
|
+
end
|
668
|
+
return r
|
669
|
+
end
|
670
|
+
|
671
|
+
# Compute the condition number of a matrix
|
672
|
+
# using the norm with one of the following order.
|
673
|
+
#
|
674
|
+
# | ord | matrix norm |
|
675
|
+
# | ----- | ---------------------- |
|
676
|
+
# | nil | 2-norm using SVD |
|
677
|
+
# | 'fro' | Frobenius norm |
|
678
|
+
# | 'inf' | x.abs.sum(axis:-1).max |
|
679
|
+
# | 1 | x.abs.sum(axis:-2).max |
|
680
|
+
# | 2 | 2-norm (max sing_vals) |
|
681
|
+
#
|
682
|
+
# @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
|
683
|
+
# @param ord [String or Symbol] Order of the norm.
|
684
|
+
# @return [Numo::NArray] cond result
|
685
|
+
# @example
|
686
|
+
# a = Numo::DFloat[[1, 0, -1], [0, 1, 0], [1, 0, 1]]
|
687
|
+
# => Numo::DFloat#shape=[3,3]
|
688
|
+
# [[1, 0, -1],
|
689
|
+
# [0, 1, 0],
|
690
|
+
# [1, 0, 1]]
|
691
|
+
# LA = Numo::Linalg
|
692
|
+
# LA.cond(a)
|
693
|
+
# => 1.4142135623730951
|
694
|
+
# LA.cond(a, 'fro')
|
695
|
+
# => 3.1622776601683795
|
696
|
+
# LA.cond(a, 'inf')
|
697
|
+
# => 2.0
|
698
|
+
# LA.cond(a, '-inf')
|
699
|
+
# => 1.0
|
700
|
+
# LA.cond(a, 1)
|
701
|
+
# => 2.0
|
702
|
+
# LA.cond(a, -1)
|
703
|
+
# => 1.0
|
704
|
+
# LA.cond(a, 2)
|
705
|
+
# => 1.4142135623730951
|
706
|
+
# LA.cond(a, -2)
|
707
|
+
# => 0.7071067811865475
|
708
|
+
# (LA.svdvals(a)).min*(LA.svdvals(LA.inv(a))).min
|
709
|
+
# => 0.7071067811865475
|
710
|
+
|
711
|
+
def cond(a,ord=nil)
|
712
|
+
if ord.nil?
|
713
|
+
s = svdvals(a)
|
714
|
+
s[false, 0]/s[false, -1]
|
715
|
+
else
|
716
|
+
norm(a, ord, axis:[-2,-1]) * norm(inv(a), ord, axis:[-2,-1])
|
717
|
+
end
|
718
|
+
end
|
719
|
+
|
720
|
+
# Determinant of a matrix
|
721
|
+
#
|
722
|
+
# @param a [Numo::NArray] matrix (>= 2-dimensional NArray)
|
723
|
+
# @return [Float or Complex or Numo::NArray]
|
724
|
+
|
725
|
+
def det(a)
|
726
|
+
lu, piv, = Lapack.call(:getrf, a)
|
727
|
+
idx = piv.new_narray.store(piv.class.new(piv.shape[-1]).seq(1))
|
728
|
+
m = piv.eq(idx).count_false(axis:-1) % 2
|
729
|
+
sign = m * -2 + 1
|
730
|
+
lu.diagonal.prod(axis:-1) * sign
|
731
|
+
end
|
732
|
+
|
733
|
+
# Natural logarithm of the determinant of a matrix
|
734
|
+
#
|
735
|
+
# @param a [Numo::NArray] matrix (>= 2-dimensional NArray)
|
736
|
+
# @return [[sign,logdet]]
|
737
|
+
# - **sign** -- A number representing the sign of the determinant.
|
738
|
+
# - **logdet** -- The natural log of the absolute value of the determinant.
|
739
|
+
|
740
|
+
def slogdet(a)
|
741
|
+
lu, piv, = Lapack.call(:getrf, a)
|
742
|
+
idx = piv.new_narray.store(piv.class.new(piv.shape[-1]).seq(1))
|
743
|
+
m = piv.eq(idx).count_false(axis:-1) % 2
|
744
|
+
sign = m * -2 + 1
|
745
|
+
|
746
|
+
lud = lu.diagonal
|
747
|
+
if (lud.eq 0).any?
|
748
|
+
return 0, (-Float::INFINITY)
|
749
|
+
end
|
750
|
+
lud_abs = lud.abs
|
751
|
+
sign *= (lud/lud_abs).prod
|
752
|
+
[sign, NMath.log(lud_abs).sum(axis:-1)]
|
753
|
+
end
|
754
|
+
|
755
|
+
# Compute matrix rank of array using SVD
|
756
|
+
# *Rank* is the number of singular values greater than *tol*.
|
757
|
+
#
|
758
|
+
# @param m [Numo::NArray] matrix (>= 2-dimensional NArray)
|
759
|
+
# @param tol [Float] threshold below which singular values are
|
760
|
+
# considered to be zero. If *tol* is nil,
|
761
|
+
# `tol = sing_vals.max() * m.shape.max * EPSILON`.
|
762
|
+
# @param driver [String or Symbol] choose LAPACK solver from 'svd',
|
763
|
+
# 'sdd'. (optional, default='svd')
|
764
|
+
|
765
|
+
def matrix_rank(m, tol:nil, driver:'svd')
|
766
|
+
m = Numo::NArray.asarray(m)
|
767
|
+
if m.ndim < 2
|
768
|
+
m.ne(0).any? ? 1 : 0
|
769
|
+
else
|
770
|
+
case driver.to_s
|
771
|
+
when /^(ge)?sdd$/, "turbo"
|
772
|
+
s = Lapack.call(:gesdd, m, jobz:'N')[0]
|
773
|
+
when /^(ge)?svd$/
|
774
|
+
s = Lapack.call(:gesvd, m, jobu:'N', jobvt:'N')[0]
|
775
|
+
else
|
776
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
777
|
+
end
|
778
|
+
tol ||= s.max(axis:-1, keepdims:true) *
|
779
|
+
(m.shape[-2..-1].max * s.class::EPSILON)
|
780
|
+
(s > tol).count(axis:-1)
|
781
|
+
end
|
782
|
+
end
|
783
|
+
|
784
|
+
|
785
|
+
## Solving equations and inverting matrices
|
786
|
+
|
787
|
+
# Solves linear equation `a * x = b` for `x`
|
788
|
+
# from square matrix `a`
|
789
|
+
# @param a [Numo::NArray] n-by-n square matrix (>= 2-dimensinal NArray)
|
790
|
+
# @param b [Numo::NArray] n-by-nrhs right-hand-side matrix (>=
|
791
|
+
# 1-dimensinal NArray)
|
792
|
+
# @param driver [String or Symbol] choose LAPACK diriver from
|
793
|
+
# 'gen','sym','her' or 'pos'. (optional, default='gen')
|
794
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
795
|
+
# or ('U') lower ('L') triangle. (omitted when driver:"gen")
|
796
|
+
# @return [Numo::NArray] The solusion matrix/vector X.
|
797
|
+
|
798
|
+
def solve(a, b, driver:"gen", uplo:'U')
|
799
|
+
case driver.to_s
|
800
|
+
when /^gen?(sv)?$/i
|
801
|
+
# returns lu, x, ipiv, info
|
802
|
+
Lapack.call(:gesv, a, b)[1]
|
803
|
+
when /^(sym?|her?|pos?)(sv)?$/i
|
804
|
+
func = driver[0..2].downcase+"sv"
|
805
|
+
Lapack.call(func, a, b, uplo:uplo)[1]
|
806
|
+
else
|
807
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
808
|
+
end
|
809
|
+
end
|
810
|
+
|
811
|
+
# Inverse matrix from square matrix `a`
|
812
|
+
# @param a [Numo::NArray] n-by-n square matrix (>= 2-dimensinal NArray)
|
813
|
+
# @param driver [String or Symbol] choose LAPACK diriver
|
814
|
+
# ('ge'|'sy'|'he'|'po') + ("sv"|"trf")
|
815
|
+
# (optional, default='getrf')
|
816
|
+
# @param uplo [String or Symbol] optional, default='U'. Access upper
|
817
|
+
# or ('U') lower ('L') triangle. (omitted when driver:"ge")
|
818
|
+
# @return [Numo::NArray] The inverse matrix.
|
819
|
+
# @example
|
820
|
+
# Numo::Linalg.inv(a,driver:'getrf')
|
821
|
+
# => Numo::DFloat#shape=[2,2]
|
822
|
+
# [[-2, 1],
|
823
|
+
# [1.5, -0.5]]
|
824
|
+
# a.dot(Numo::Linalg.inv(a,driver:'getrf'))
|
825
|
+
# => Numo::DFloat#shape=[2,2]
|
826
|
+
# [[1, 0],
|
827
|
+
# [8.88178e-16, 1]]
|
828
|
+
|
829
|
+
def inv(a, driver:"getrf", uplo:'U')
|
830
|
+
case driver
|
831
|
+
when /(ge|sy|he|po)sv$/
|
832
|
+
d = $1
|
833
|
+
b = a.new_zeros.eye
|
834
|
+
solve(a, b, driver:d, uplo:uplo)
|
835
|
+
when /(ge|sy|he)tr[fi]$/
|
836
|
+
d = $1
|
837
|
+
lu, piv = lu_fact(a, driver:d, uplo:uplo)
|
838
|
+
lu_inv(lu, piv, driver:d, uplo:uplo)
|
839
|
+
when /potr[fi]$/
|
840
|
+
lu = cho_fact(a, uplo:uplo)
|
841
|
+
cho_inv(lu, uplo:uplo)
|
842
|
+
else
|
843
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
844
|
+
end
|
845
|
+
end
|
846
|
+
|
847
|
+
# Computes the minimum-norm solution to a linear least squares
|
848
|
+
# problem:
|
849
|
+
#
|
850
|
+
# minimize 2-norm(| b - A*x |)
|
851
|
+
#
|
852
|
+
# using the singular value decomposition (SVD) of A.
|
853
|
+
# A is an M-by-N matrix which may be rank-deficient.
|
854
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
855
|
+
# @param b [Numo::NArray] m-by-nrhs right-hand-side matrix b
|
856
|
+
# (>= 1-dimensinal NArray)
|
857
|
+
# @param driver [String or Symbol] choose LAPACK driver from
|
858
|
+
# 'lsd','lss','lsy' (optional, default='lsd')
|
859
|
+
# @param rcond [Float] (optional, default=-1)
|
860
|
+
# RCOND is used to determine the effective rank of A.
|
861
|
+
# Singular values `S(i) <= RCOND*S(1)` are treated as zero.
|
862
|
+
# If RCOND < 0, machine precision is used instead.
|
863
|
+
# @return [[x, resids, rank, s]]
|
864
|
+
# - **x** -- The solution matrix/vector X.
|
865
|
+
# - **resids** -- Sums of residues, squared 2-norm for each column in
|
866
|
+
# `b - a x`. If matrix_rank(a) < N or > M, or 'gelsy' is used,
|
867
|
+
# this is an empty array.
|
868
|
+
# - **rank** -- The effective rank of A, i.e.,
|
869
|
+
# the number of singular values which are greater than RCOND*S(1).
|
870
|
+
# - **s** -- The singular values of A in decreasing order.
|
871
|
+
# Returns nil if 'gelsy' is used.
|
872
|
+
|
873
|
+
def lstsq(a, b, driver:'lsd', rcond:-1)
|
874
|
+
a = NArray.asarray(a)
|
875
|
+
b = NArray.asarray(b)
|
876
|
+
b_orig = nil
|
877
|
+
if b.shape.size==1
|
878
|
+
b_orig = b
|
879
|
+
b = b_orig[true,:new]
|
880
|
+
end
|
881
|
+
m = a.shape[-2]
|
882
|
+
n = a.shape[-1]
|
883
|
+
#nrhs = b.shape[-1]
|
884
|
+
if m != b.shape[-2]
|
885
|
+
raise NArray::ShapeError, "size mismatch: A-row and B-row"
|
886
|
+
end
|
887
|
+
if m < n # need to extend b matrix
|
888
|
+
shp = b.shape
|
889
|
+
shp[-2] = n
|
890
|
+
b2 = b.class.zeros(*shp)
|
891
|
+
b2[false,0...m,true] = b
|
892
|
+
b = b2
|
893
|
+
end
|
894
|
+
case driver.to_s
|
895
|
+
when /^(ge)?lsd$/i
|
896
|
+
# x, s, rank, info
|
897
|
+
x, s, rank, = Lapack.call(:gelsd, a, b, rcond:rcond)
|
898
|
+
when /^(ge)?lss$/i
|
899
|
+
# v, x, s, rank, info
|
900
|
+
_, x, s, rank, = Lapack.call(:gelss, a, b, rcond:rcond)
|
901
|
+
when /^(ge)?lsy$/i
|
902
|
+
jpvt = Int32.zeros(*a[false,0,true].shape)
|
903
|
+
# v, x, jpvt, rank, info
|
904
|
+
_, x, _, rank, = Lapack.call(:gelsy, a, b, jpvt, rcond:rcond)
|
905
|
+
s = nil
|
906
|
+
else
|
907
|
+
raise ArgumentError, "invalid driver: #{driver}"
|
908
|
+
end
|
909
|
+
resids = nil
|
910
|
+
if m > n
|
911
|
+
if /ls(d|s)$/i =~ driver
|
912
|
+
case rank
|
913
|
+
when n
|
914
|
+
resids = (x[n..-1,true].abs**2).sum(axis:0)
|
915
|
+
when NArray
|
916
|
+
if true
|
917
|
+
resids = (x[false,n..-1,true].abs**2).sum(axis:-2)
|
918
|
+
else
|
919
|
+
resids = x[false,0,true].new_zeros
|
920
|
+
mask = rank.eq(n)
|
921
|
+
# NArray does not suppurt this yet.
|
922
|
+
resids[mask,true] = (x[mask,n..-1,true].abs**2).sum(axis:-2)
|
923
|
+
end
|
924
|
+
end
|
925
|
+
end
|
926
|
+
x = x[false,0...n,true]
|
927
|
+
end
|
928
|
+
if b_orig && b_orig.shape.size==1
|
929
|
+
x = x[true,0]
|
930
|
+
resids &&= resids[false,0]
|
931
|
+
end
|
932
|
+
[x, resids, rank, s]
|
933
|
+
end
|
934
|
+
|
935
|
+
# Compute the (Moore-Penrose) pseudo-inverse of a matrix
|
936
|
+
# using svd or lstsq.
|
937
|
+
#
|
938
|
+
# @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
|
939
|
+
# @param driver [String or Symbol] choose LAPACK driver from
|
940
|
+
# SVD ('svd', 'sdd') or Least square ('lsd','lss','lsy')
|
941
|
+
# (optional, default='svd')
|
942
|
+
# @param rcond [Float] (optional, default=-1)
|
943
|
+
# RCOND is used to determine the effective rank of A.
|
944
|
+
# Singular values `S(i) <= RCOND*S(1)` are treated as zero.
|
945
|
+
# If RCOND < 0, machine precision is used instead.
|
946
|
+
# @return [Numo::NArray]
|
947
|
+
# @example
|
948
|
+
# a = Numo::DFloat.new(5,3).rand_norm
|
949
|
+
# => Numo::DFloat#shape=[5,3]
|
950
|
+
# [[-0.581255, -0.168354, 0.586895],
|
951
|
+
# [-0.595142, -0.802802, -0.326106],
|
952
|
+
# [0.282922, 1.68427, 0.918499],
|
953
|
+
# [-0.0485384, -0.464453, -0.992194],
|
954
|
+
# [0.413794, -0.60717, -0.699695]]
|
955
|
+
# b = Numo::Linalg.pinv(a,driver:"svd")
|
956
|
+
# => Numo::DFloat(view)#shape=[3,5]
|
957
|
+
# [[-0.360863, -0.813125, -0.353367, -0.891963, 0.877253],
|
958
|
+
# [-0.227645, 0.162939, 0.696655, 0.787685, -0.469346],
|
959
|
+
# [0.408671, -0.308323, -0.337807, -1.13833, 0.228051]]
|
960
|
+
# (a-a.dot(b.dot(a))).abs.max
|
961
|
+
# => 5.551115123125783e-16
|
962
|
+
|
963
|
+
def pinv(a, driver:"svd", rcond:nil)
|
964
|
+
a = NArray.asarray(a)
|
965
|
+
if a.ndim < 2
|
966
|
+
raise NArray::ShapeError, "2-d array is required"
|
967
|
+
end
|
968
|
+
case driver
|
969
|
+
when /^(ge)?s[dv]d$/
|
970
|
+
s, u, vh = svd(a, driver:driver, job:'S')
|
971
|
+
if rcond.nil? || rcond < 0
|
972
|
+
rcond = ((SFloat===s) ? 1e3 : 1e6) * s.class::EPSILON
|
973
|
+
elsif ! Numeric === rcond
|
974
|
+
raise ArgumentError, "rcond must be Numeric"
|
975
|
+
end
|
976
|
+
cond = (s > rcond * s.max(axis:-1, keepdims:true))
|
977
|
+
if cond.all?
|
978
|
+
r = s.reciprocal
|
979
|
+
else
|
980
|
+
r = s.new_zeros
|
981
|
+
r[cond] = s[cond].reciprocal
|
982
|
+
end
|
983
|
+
u *= r[false,:new,true]
|
984
|
+
dot(u,vh).conj.swapaxes(-2,-1)
|
985
|
+
when /^(ge)?ls[dsy]$/
|
986
|
+
b = a.class.eye(a.shape[-2])
|
987
|
+
x, = lstsq(a, b, driver:driver, rcond:rcond)
|
988
|
+
x
|
989
|
+
else
|
990
|
+
raise ArgumentError, "#{driver.inspect} is not one of drivers: "+
|
991
|
+
"svd, sdd, lsd, lss, lsy"
|
992
|
+
end
|
993
|
+
end
|
994
|
+
|
995
|
+
private
|
996
|
+
|
997
|
+
# @!visibility private
|
998
|
+
def _make_complex_eigvecs(w, vin) # :nodoc:
|
999
|
+
v = w.class.cast(vin)
|
1000
|
+
# broadcast to vin.shape
|
1001
|
+
m = (w.imag > 0 | Bit.zeros(*vin.shape)).where
|
1002
|
+
v[m].imag = vin[m+1]
|
1003
|
+
v[m+1] = v[m].conj
|
1004
|
+
v
|
1005
|
+
end
|
1006
|
+
|
1007
|
+
end
|
1008
|
+
end
|