numo-linalg 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/Gemfile +4 -0
  3. data/README.md +80 -0
  4. data/Rakefile +18 -0
  5. data/ext/numo/linalg/blas/blas.c +352 -0
  6. data/ext/numo/linalg/blas/cblas.h +575 -0
  7. data/ext/numo/linalg/blas/cblas_t.h +563 -0
  8. data/ext/numo/linalg/blas/depend.erb +23 -0
  9. data/ext/numo/linalg/blas/extconf.rb +67 -0
  10. data/ext/numo/linalg/blas/gen/cogen.rb +72 -0
  11. data/ext/numo/linalg/blas/gen/decl.rb +203 -0
  12. data/ext/numo/linalg/blas/gen/desc.rb +8138 -0
  13. data/ext/numo/linalg/blas/gen/erbpp2.rb +339 -0
  14. data/ext/numo/linalg/blas/gen/replace_cblas_h.rb +27 -0
  15. data/ext/numo/linalg/blas/gen/spec.rb +93 -0
  16. data/ext/numo/linalg/blas/numo_blas.h +41 -0
  17. data/ext/numo/linalg/blas/tmpl/axpy.c +75 -0
  18. data/ext/numo/linalg/blas/tmpl/copy.c +57 -0
  19. data/ext/numo/linalg/blas/tmpl/def_c.c +3 -0
  20. data/ext/numo/linalg/blas/tmpl/def_d.c +3 -0
  21. data/ext/numo/linalg/blas/tmpl/def_s.c +3 -0
  22. data/ext/numo/linalg/blas/tmpl/def_z.c +3 -0
  23. data/ext/numo/linalg/blas/tmpl/dot.c +68 -0
  24. data/ext/numo/linalg/blas/tmpl/ger.c +114 -0
  25. data/ext/numo/linalg/blas/tmpl/init_class.c +20 -0
  26. data/ext/numo/linalg/blas/tmpl/init_module.c +12 -0
  27. data/ext/numo/linalg/blas/tmpl/lib.c +40 -0
  28. data/ext/numo/linalg/blas/tmpl/mm.c +214 -0
  29. data/ext/numo/linalg/blas/tmpl/module.c +9 -0
  30. data/ext/numo/linalg/blas/tmpl/mv.c +194 -0
  31. data/ext/numo/linalg/blas/tmpl/nrm2.c +79 -0
  32. data/ext/numo/linalg/blas/tmpl/rot.c +65 -0
  33. data/ext/numo/linalg/blas/tmpl/rotm.c +82 -0
  34. data/ext/numo/linalg/blas/tmpl/scal.c +69 -0
  35. data/ext/numo/linalg/blas/tmpl/sdsdot.c +77 -0
  36. data/ext/numo/linalg/blas/tmpl/set_prefix.c +16 -0
  37. data/ext/numo/linalg/blas/tmpl/swap.c +57 -0
  38. data/ext/numo/linalg/blas/tmpl/syr.c +102 -0
  39. data/ext/numo/linalg/blas/tmpl/syr2.c +110 -0
  40. data/ext/numo/linalg/blas/tmpl/syr2k.c +129 -0
  41. data/ext/numo/linalg/blas/tmpl/syrk.c +132 -0
  42. data/ext/numo/linalg/lapack/depend.erb +23 -0
  43. data/ext/numo/linalg/lapack/extconf.rb +45 -0
  44. data/ext/numo/linalg/lapack/gen/cogen.rb +74 -0
  45. data/ext/numo/linalg/lapack/gen/desc.rb +151278 -0
  46. data/ext/numo/linalg/lapack/gen/replace_lapacke_h.rb +32 -0
  47. data/ext/numo/linalg/lapack/gen/spec.rb +104 -0
  48. data/ext/numo/linalg/lapack/lapack.c +387 -0
  49. data/ext/numo/linalg/lapack/lapacke.h +16425 -0
  50. data/ext/numo/linalg/lapack/lapacke_config.h +119 -0
  51. data/ext/numo/linalg/lapack/lapacke_mangling.h +17 -0
  52. data/ext/numo/linalg/lapack/lapacke_t.h +10550 -0
  53. data/ext/numo/linalg/lapack/numo_lapack.h +42 -0
  54. data/ext/numo/linalg/lapack/tmpl/def_c.c +3 -0
  55. data/ext/numo/linalg/lapack/tmpl/def_d.c +7 -0
  56. data/ext/numo/linalg/lapack/tmpl/def_s.c +7 -0
  57. data/ext/numo/linalg/lapack/tmpl/def_z.c +3 -0
  58. data/ext/numo/linalg/lapack/tmpl/fact.c +179 -0
  59. data/ext/numo/linalg/lapack/tmpl/geev.c +123 -0
  60. data/ext/numo/linalg/lapack/tmpl/gels.c +232 -0
  61. data/ext/numo/linalg/lapack/tmpl/gesv.c +149 -0
  62. data/ext/numo/linalg/lapack/tmpl/gesvd.c +189 -0
  63. data/ext/numo/linalg/lapack/tmpl/ggev.c +138 -0
  64. data/ext/numo/linalg/lapack/tmpl/gqr.c +121 -0
  65. data/ext/numo/linalg/lapack/tmpl/init_class.c +20 -0
  66. data/ext/numo/linalg/lapack/tmpl/init_module.c +12 -0
  67. data/ext/numo/linalg/lapack/tmpl/lange.c +79 -0
  68. data/ext/numo/linalg/lapack/tmpl/lib.c +40 -0
  69. data/ext/numo/linalg/lapack/tmpl/module.c +9 -0
  70. data/ext/numo/linalg/lapack/tmpl/syev.c +91 -0
  71. data/ext/numo/linalg/lapack/tmpl/sygv.c +104 -0
  72. data/ext/numo/linalg/lapack/tmpl/trf.c +276 -0
  73. data/ext/numo/linalg/numo_linalg.h +115 -0
  74. data/lib/numo/linalg.rb +3 -0
  75. data/lib/numo/linalg/function.rb +1008 -0
  76. data/lib/numo/linalg/linalg.rb +7 -0
  77. data/lib/numo/linalg/loader.rb +174 -0
  78. data/lib/numo/linalg/use/atlas.rb +3 -0
  79. data/lib/numo/linalg/use/lapack.rb +3 -0
  80. data/lib/numo/linalg/use/mkl.rb +3 -0
  81. data/lib/numo/linalg/use/openblas.rb +3 -0
  82. data/lib/numo/linalg/version.rb +5 -0
  83. data/numo-linalg.gemspec +26 -0
  84. data/spec/lapack_spec.rb +13 -0
  85. metadata +172 -0
@@ -0,0 +1,276 @@
1
+ <%
2
+ has_rhs = (/trs$/ =~ name)
3
+ has_trans = (/^.(g|l|t).trs$/ =~ name)
4
+ has_uplo = (/^.(g|pt)/ !~ name)
5
+ has_ipiv = (/p[bfopt]tr.$/ !~ name)
6
+ ipiv_out = (has_ipiv && /trf$/ =~ name)
7
+ ipiv_in = (has_ipiv && /tr[is]$/ =~ name)
8
+ is_sym = (has_uplo || /getr[is]/=~name)
9
+ %>
10
+ #define RHS <%= has_rhs ? "1":"0" %>
11
+ #define TRANS <%= has_trans ? "1":"0" %>
12
+ #define UPLO <%= has_uplo ? "1":"0" %>
13
+ #define IPIV <%= has_ipiv ? "1":"0" %>
14
+ #define IPIV_OUT <%= ipiv_out ? "1":"0" %>
15
+ #define IPIV_IN <%= ipiv_in ? "1":"0" %>
16
+ #define SYM <%= is_sym ? "1":"0" %>
17
+ #define args_t <%=func_name%>_args_t
18
+ #define func_p <%=func_name%>_p
19
+
20
+ typedef struct {
21
+ int order;
22
+ char uplo;
23
+ char trans;
24
+ } args_t;
25
+
26
+ static <%=func_name%>_t func_p = 0;
27
+
28
+ static void
29
+ <%=c_iter%>(na_loop_t * const lp)
30
+ {
31
+ dtype *a;
32
+ #if RHS
33
+ dtype *b;
34
+ int nb, nrhs, ldb;
35
+ #endif
36
+ #if IPIV
37
+ int *pv;
38
+ #endif
39
+ int *info;
40
+ int m, n, lda;
41
+ args_t *g;
42
+
43
+ a = (dtype*)NDL_PTR(lp,0);
44
+ #if IPIV
45
+ pv = (int*)NDL_PTR(lp,1);
46
+ #endif
47
+ #if RHS
48
+ b = (dtype*)NDL_PTR(lp,1+IPIV);
49
+ #endif
50
+ info = (int*)NDL_PTR(lp,1+IPIV+RHS);
51
+ g = (args_t*)(lp->opt_ptr);
52
+
53
+ n = NDL_SHAPE(lp,0)[0];
54
+ m = NDL_SHAPE(lp,0)[1];
55
+ lda = NDL_STEP(lp,0) / sizeof(dtype);
56
+
57
+ #if RHS
58
+ // same as gels.c
59
+ if (lp->args[1+IPIV].ndim == 1) {
60
+ nrhs = 1;
61
+ nb = NDL_SHAPE(lp,1+IPIV)[0];
62
+ ldb = (g->order==LAPACK_COL_MAJOR) ? nb : 1;
63
+ } else {
64
+ nb = NDL_SHAPE(lp,1+IPIV)[0];
65
+ nrhs = NDL_SHAPE(lp,1+IPIV)[1];
66
+ ldb = nrhs;
67
+ { int tmp; SWAP_IFCOL(g->order,nb,nrhs); }
68
+ }
69
+ //printf("order=%d m=%d n=%d nb=%d nrhs=%d lda=%d ldb=%d\n",g->order,m,n,nb,nrhs,lda,ldb);
70
+ #else
71
+ //printf("order=%d m=%d n=%d lda=%d \n",g->order,m,n,lda);
72
+ #endif
73
+
74
+ #if SYM
75
+ n = min_(m,n);
76
+ #else
77
+ { int tmp; SWAP_IFCOL(g->order,m,n); }
78
+ #endif
79
+
80
+ <%
81
+ func_args = [ "g->order",
82
+ has_uplo && "g->uplo",
83
+ has_trans && "g->trans",
84
+ "n",
85
+ has_rhs ? "nrhs" : (!is_sym && "m"),
86
+ "a, lda",
87
+ has_ipiv && "pv",
88
+ has_rhs && "b, ldb",
89
+ ].select{|x| x}.join(", ")
90
+ %>
91
+ *info = (*func_p)(<%=func_args%>);
92
+ CHECK_ERROR(*info);
93
+ }
94
+
95
+ /*<%
96
+ args_v = [
97
+ "a",
98
+ ipiv_in && "ipiv",
99
+ has_rhs && "b",
100
+ ].select{|x| x}.join(", ")
101
+
102
+ args_opt = [
103
+ has_uplo && "uplo:'U'",
104
+ has_trans && "trans:'N'",
105
+ "order:'R'",
106
+ ].select{|x| x}.join(", ")
107
+
108
+ trf = name.sub(/.$/,"f")
109
+
110
+ params = [
111
+ has_rhs ? "@param a [#{class_name}] LU matrix computed by "+trf :
112
+ mat("a",:inplace),
113
+ ipiv_in && "@param ipiv [Numo::Int] pivot computed by "+trf,
114
+ has_rhs && mat("b",:inplace),
115
+ has_uplo && opt("uplo"),
116
+ has_trans && opt("trans"),
117
+ opt("order"),
118
+ ].select{|x| x}.join("\n ")
119
+
120
+ return_type = [
121
+ class_name,
122
+ ipiv_out && "Numo::Int",
123
+ "Integer"
124
+ ].select{|x| x}.join(", ")
125
+
126
+ return_name = [
127
+ has_rhs ? "b" : "a",
128
+ ipiv_out && "ipiv",
129
+ "info"
130
+ ].select{|x| x}.join(", ")
131
+ %>
132
+ @overload <%=name%>(<%=args_v%>, [<%=args_opt%>])
133
+ <%=params%>
134
+ @return [[<%=return_name%>]] Array<<%=return_type%>>
135
+ <%=outparam(return_name)%>
136
+
137
+ <%=description%>
138
+
139
+ */
140
+ static VALUE
141
+ <%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
142
+ {
143
+ <% %>
144
+ VALUE a, ans;
145
+ #if IPIV_IN
146
+ VALUE ipiv;
147
+ #endif
148
+ #if RHS
149
+ VALUE b;
150
+ size_t n, nb, nrhs;
151
+ narray_t *na2;
152
+ #endif
153
+ narray_t *na1;
154
+ <%
155
+ aout = [
156
+ ipiv_out && "{cInt,1,shape_piv}",
157
+ "{cInt,0}",
158
+ ].select{|x| x}.join(",")
159
+ %>
160
+ #if IPIV_OUT
161
+ size_t shape_piv[1];
162
+ #endif
163
+ #if IPIV_IN
164
+ # if RHS
165
+ ndfunc_arg_in_t ain[3] = {{cT,2},{cInt,1},{OVERWRITE,2}};
166
+ # else
167
+ ndfunc_arg_in_t ain[2] = {{OVERWRITE,2},{cInt,1}};
168
+ # endif
169
+ #else
170
+ # if RHS
171
+ ndfunc_arg_in_t ain[2] = {{cT,2},{OVERWRITE,2}};
172
+ # else
173
+ ndfunc_arg_in_t ain[1] = {{OVERWRITE,2}};
174
+ # endif
175
+ #endif
176
+ ndfunc_arg_out_t aout[1+IPIV_OUT] = {<%=aout%>};
177
+ ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP|NDF_EXTRACT,
178
+ 1+IPIV_IN+RHS, IPIV_OUT+1, ain,aout};
179
+
180
+ args_t g = {0,0};
181
+ VALUE opts[2] = {Qundef,Qundef};
182
+ VALUE kw_hash = Qnil;
183
+ ID kw_table[2] = {id_order,id_uplo};
184
+
185
+ CHECK_FUNC(func_p,"<%=func_name%>");
186
+
187
+ #if IPIV_IN
188
+ # if RHS
189
+ rb_scan_args(argc, argv, "3:", &a, &ipiv, &b, &kw_hash);
190
+ # else
191
+ rb_scan_args(argc, argv, "2:", &a, &ipiv, &kw_hash);
192
+ # endif
193
+ #else
194
+ # if RHS
195
+ rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
196
+ # else
197
+ rb_scan_args(argc, argv, "1:", &a, &kw_hash);
198
+ # endif
199
+ #endif
200
+ #if TRANS
201
+ kw_table[1] = id_trans;
202
+ rb_get_kwargs(kw_hash, kw_table, 0, 2, opts);
203
+ g.trans = option_trans(opts[1]);
204
+ #elif UPLO
205
+ rb_get_kwargs(kw_hash, kw_table, 0, 2, opts);
206
+ g.uplo = option_uplo(opts[1]);
207
+ #else
208
+ rb_get_kwargs(kw_hash, kw_table, 0, 1, opts);
209
+ #endif
210
+ g.order = option_order(opts[0]);
211
+
212
+ #if !RHS
213
+ COPY_OR_CAST_TO(a,cT);
214
+ #endif
215
+ GetNArray(a, na1);
216
+ CHECK_DIM_GE(na1, 2);
217
+ #if IPIV_OUT
218
+ shape_piv[0] = min_(ROW_SIZE(na1),COL_SIZE(na1));
219
+ #endif
220
+
221
+ #if RHS
222
+ COPY_OR_CAST_TO(b,cT);
223
+ GetNArray(b, na2);
224
+ CHECK_DIM_GE(na2, 1);
225
+ n = COL_SIZE(na1);
226
+ #if SYM
227
+ n = min_(n,ROW_SIZE(na1));
228
+ #endif
229
+ // same as gesv.c
230
+ if (NA_NDIM(na2) == 1) {
231
+ ain[1+IPIV_IN].dim = 1;
232
+ nb = COL_SIZE(na2);
233
+ nrhs = 1;
234
+ } else {
235
+ nb = ROW_SIZE(na2);
236
+ nrhs = COL_SIZE(na2);
237
+ { int tmp; SWAP_IFCOL(g.order,nb,nrhs); }
238
+ }
239
+ if (n != nb) {
240
+ rb_raise(nary_eShapeError, "matrix dimension mismatch: "
241
+ "a.col(or a.row)=%"SZF"u b.row=%"SZF"u", n, nb);
242
+ }
243
+ #endif
244
+
245
+ #if IPIV_IN
246
+ # if RHS
247
+ ans = na_ndloop3(&ndf, &g, 3, a, ipiv, b);
248
+ return rb_assoc_new(b, ans);
249
+ # else
250
+ ans = na_ndloop3(&ndf, &g, 2, a, ipiv);
251
+ return rb_assoc_new(a, ans);
252
+ # endif
253
+ #else
254
+ # if RHS
255
+ ans = na_ndloop3(&ndf, &g, 2, a, b);
256
+ return rb_assoc_new(b, ans);
257
+ # else
258
+ ans = na_ndloop3(&ndf, &g, 1, a);
259
+ # if IPIV_OUT
260
+ return rb_ary_unshift(ans, a);
261
+ # else
262
+ return rb_assoc_new(a, ans);
263
+ # endif
264
+ # endif
265
+ #endif
266
+ }
267
+
268
+ #undef args_t
269
+ #undef func_p
270
+ #undef RHS
271
+ #undef TRANS
272
+ #undef UPLO
273
+ #undef IPIV
274
+ #undef IPIV_OUT
275
+ #undef IPIV_IN
276
+ #undef SYM
@@ -0,0 +1,115 @@
1
+ #if defined __clang__
2
+ # define UNUSED(name) __unused name
3
+ #else
4
+ # define UNUSED(name) name
5
+ #endif
6
+
7
+ #if SIZEOF_INT == 4
8
+ #define cI numo_cInt32
9
+ #define cUI numo_cUInt32
10
+ #elif SIZEOF_INT==8
11
+ #define cI numo_cInt64
12
+ #define cUI numo_cUInt64
13
+ #endif
14
+
15
+ #if SIZEOF_SIZE_T == 4
16
+ #define cSZ numo_cUInt32
17
+ #define cSSZ numo_cInt32
18
+ #elif SIZEOF_SIZE_T == 8
19
+ #define cSZ numo_cUInt64
20
+ #define cSSZ numo_cInt64
21
+ #endif
22
+
23
+ #define cDF numo_cDFloat
24
+ #define cDC numo_cDComplex
25
+ #define cSF numo_cSFloat
26
+ #define cSC numo_cSComplex
27
+ #define cInt cI
28
+ #define cUInt cUI
29
+
30
+ extern VALUE na_expand_dims(VALUE self, VALUE vdim);
31
+
32
+ #define max_(m,n) (((m)>(n)) ? (m):(n))
33
+ #define min_(m,n) (((m)<(n)) ? (m):(n))
34
+
35
+ #define ROW_SIZE(na) ((na)->shape[(na)->ndim-2])
36
+ #define COL_SIZE(na) ((na)->shape[(na)->ndim-1])
37
+
38
+ #define CHECK_NARRAY_TYPE(x,t) \
39
+ if (CLASS_OF(x)!=(t)) { \
40
+ rb_raise(rb_eTypeError,"invalid NArray type (class)"); \
41
+ }
42
+
43
+ // Error Class ??
44
+ #define CHECK_DIM_GE(na,nd) \
45
+ if ((na)->ndim<(nd)) { \
46
+ rb_raise(nary_eShapeError, \
47
+ "n-dimension=%d, but >=%d is expected", \
48
+ (na)->ndim, (nd)); \
49
+ }
50
+
51
+ #define CHECK_DIM_EQ(na1,nd) \
52
+ if ((na1)->ndim != (nd)) { \
53
+ rb_raise(nary_eShapeError, \
54
+ "dimention mismatch: %d != %d", \
55
+ (na1)->ndim, (nd)); \
56
+ }
57
+
58
+ #define CHECK_SQUARE(name,na) \
59
+ if ((na)->shape[(na)->ndim-1] != (na)->shape[(na)->ndim-2]) { \
60
+ rb_raise(nary_eShapeError,"%s is not square matrix",name); \
61
+ }
62
+
63
+ #define CHECK_SIZE_GE(na,sz) \
64
+ if ((na)->size < (size_t)(sz)) { \
65
+ rb_raise(nary_eShapeError, \
66
+ "NArray size must be >= %"SZF"u",(size_t)(sz));\
67
+ }
68
+
69
+ #define CHECK_NON_EMPTY(na) \
70
+ if ((na)->size==0) { \
71
+ rb_raise(nary_eShapeError,"empty NArray"); \
72
+ }
73
+
74
+ #define CHECK_SIZE_EQ(n,m) \
75
+ if ((n)!=(m)) { \
76
+ rb_raise(nary_eShapeError, \
77
+ "size mismatch: %"SZF"d != %"SZF"d", \
78
+ (size_t)(n),(size_t)(m)); \
79
+ }
80
+
81
+ #define CHECK_SAME_SHAPE(na1,na2) \
82
+ { int i; \
83
+ CHECK_DIM_EQ(na1,na2->ndim); \
84
+ for (i=0; i<na1->ndim; i++) { \
85
+ CHECK_SIZE_EQ(na1->shape[i],na2->shape[i]); \
86
+ } \
87
+ }
88
+
89
+ #define CHECK_INT_EQ(sm,m,sn,n) \
90
+ if ((m) != (n)) { \
91
+ rb_raise(nary_eShapeError, \
92
+ "%s must be == %s: %s=%d %s=%d", \
93
+ sm,sn,sm,m,sn,n); \
94
+ }
95
+
96
+ // Error Class ??
97
+ #define CHECK_LEADING_GE(sld,ld,sn,n) \
98
+ if ((ld) < (n)) { \
99
+ rb_raise(nary_eShapeError, \
100
+ "%s must be >= max(%s,1): %s=%d %s=%d", \
101
+ sld,sn,sld,ld,sn,n); \
102
+ }
103
+
104
+ #define COPY_OR_CAST_TO(a,T) \
105
+ { \
106
+ if (CLASS_OF(a) == (T)) { \
107
+ if (!TEST_INPLACE(a)) { \
108
+ a = na_copy(a); \
109
+ } \
110
+ } else { \
111
+ a = rb_funcall(T,rb_intern("cast"),1,a); \
112
+ } \
113
+ }
114
+
115
+ #define swap(a,b) {tmp=a;a=b;b=tmp;}
@@ -0,0 +1,3 @@
1
+ require "numo/linalg/linalg"
2
+
3
+ Numo::Linalg::Loader.load_library
@@ -0,0 +1,1008 @@
1
+ module Numo; module Linalg
2
+
3
+ module Blas
4
+
5
+ FIXNAME =
6
+ {
7
+ cnrm2: :csnrm2,
8
+ znrm2: :dznrm2,
9
+ }
10
+
11
+ # Call BLAS function prefixed with BLAS char ([sdcz])
12
+ # defined from data-types of arguments.
13
+ # @param [Symbol] func function name without BLAS char.
14
+ # @param args arguments passed to Blas function.
15
+ # @example
16
+ # c = Numo::Linalg::Blas.call(:gemm, a, b)
17
+ def self.call(func,*args)
18
+ fn = (Linalg.blas_char(*args) + func.to_s).to_sym
19
+ fn = FIXNAME[fn] || fn
20
+ send(fn,*args)
21
+ end
22
+
23
+ end
24
+
25
+ module Lapack
26
+
27
+ FIXNAME =
28
+ {
29
+ corgqr: :cungqr,
30
+ zorgqr: :zungqr,
31
+ }
32
+
33
+ # Call LAPACK function prefixed with BLAS char ([sdcz])
34
+ # defined from data-types of arguments.
35
+ # @param [Symbol,String] func function name without BLAS char.
36
+ # @param args arguments passed to Lapack function.
37
+ # @example
38
+ # s = Numo::Linalg::Lapack.call(:gesv, a)
39
+ def self.call(func,*args)
40
+ fn = (Linalg.blas_char(*args) + func.to_s).to_sym
41
+ fn = FIXNAME[fn] || fn
42
+ send(fn,*args)
43
+ end
44
+
45
+ end
46
+
47
+ BLAS_CHAR =
48
+ {
49
+ SFloat => "s",
50
+ DFloat => "d",
51
+ SComplex => "c",
52
+ DComplex => "z",
53
+ }
54
+
55
+ module_function
56
+
57
+ def blas_char(*args)
58
+ t = Float
59
+ args.each do |a|
60
+ k =
61
+ case a
62
+ when NArray
63
+ a.class
64
+ when Array
65
+ NArray.array_type(a)
66
+ end
67
+ if k && k < NArray
68
+ t = k::UPCAST[t]
69
+ end
70
+ end
71
+ BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
72
+ end
73
+
74
+ # module methods
75
+
76
+ ## Matrix and vector products
77
+
78
+ # Dot product.
79
+ # @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
80
+ # @param b [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
81
+ # @return [Numo::NArray] result of dot product
82
+ def dot(a, b)
83
+ a = NArray.asarray(a)
84
+ b = NArray.asarray(b)
85
+ case a.ndim
86
+ when 1
87
+ case b.ndim
88
+ when 1
89
+ Blas.call(:dot, a, b)
90
+ else
91
+ Blas.call(:gemv, b, a, trans:'t')
92
+ end
93
+ else
94
+ case b.ndim
95
+ when 1
96
+ Blas.call(:gemv, a, b)
97
+ else
98
+ Blas.call(:gemm, a, b)
99
+ end
100
+ end
101
+ end
102
+
103
+ # Matrix product.
104
+ # @param a [Numo::NArray] matrix (>= 2-dimensinal NArray)
105
+ # @param b [Numo::NArray] matrix (>= 2-dimensinal NArray)
106
+ # @return [Numo::NArray] result of matrix product
107
+ def matmul(a, b)
108
+ Blas.call(:gemm, a, b)
109
+ end
110
+
111
+ # Compute a square matrix `a` to the power `n`.
112
+ #
113
+ # * If n > 0: return `a**n`.
114
+ # * If n == 0: return identity matrix.
115
+ # * If n < 0: return `(a*\*-1)*\*n.abs`.
116
+ #
117
+ # @param a [Numo::NArray] square matrix (>= 2-dimensinal NArray).
118
+ # @param n [Integer] the exponent.
119
+ # @example
120
+ # i = Numo::DFloat[[0, 1], [-1, 0]]
121
+ # => Numo::DFloat#shape=[2,2]
122
+ # [[0, 1],
123
+ # [-1, 0]]
124
+ # Numo::Linalg.matrix_power(i,3)
125
+ # => Numo::DFloat#shape=[2,2]
126
+ # [[0, -1],
127
+ # [1, 0]]
128
+ # Numo::Linalg.matrix_power(i,0)
129
+ # => Numo::DFloat#shape=[2,2]
130
+ # [[1, 0],
131
+ # [0, 1]]
132
+ # Numo::Linalg.matrix_power(i,-3)
133
+ # => Numo::DFloat#shape=[2,2]
134
+ # [[0, 1],
135
+ # [-1, 0]]
136
+ #
137
+ # q = Numo::DFloat.zeros(4,4)
138
+ # q[0..1,0..1] = -i
139
+ # q[2..3,2..3] = i
140
+ # q
141
+ # => Numo::DFloat#shape=[4,4]
142
+ # [[-0, -1, 0, 0],
143
+ # [1, -0, 0, 0],
144
+ # [0, 0, 0, 1],
145
+ # [0, 0, -1, 0]]
146
+ # Numo::Linalg.matrix_power(q,2)
147
+ # => Numo::DFloat#shape=[4,4]
148
+ # [[-1, 0, 0, 0],
149
+ # [0, -1, 0, 0],
150
+ # [0, 0, -1, 0],
151
+ # [0, 0, 0, -1]]
152
+
153
+ def matrix_power(a, n)
154
+ a = NArray.asarray(a)
155
+ m,k = a.shape[-2..-1]
156
+ unless m==k
157
+ raise NArray::ShapeError, "input must be a square array"
158
+ end
159
+ unless Integer===n
160
+ raise ArgumentError, "exponent must be an integer"
161
+ end
162
+ if n == 0
163
+ return a.class.eye(m)
164
+ elsif n < 0
165
+ a = inv(a)
166
+ n = n.abs
167
+ end
168
+ if n <= 3
169
+ r = a
170
+ (n-1).times do
171
+ r = matmul(r,a)
172
+ end
173
+ else
174
+ while (n & 1) == 0
175
+ a = matmul(a,a)
176
+ n >>= 1
177
+ end
178
+ r = a
179
+ while n != 0
180
+ a = matmul(a,a)
181
+ n >>= 1
182
+ if (n & 1) != 0
183
+ r = matmul(r,a)
184
+ end
185
+ end
186
+ end
187
+ r
188
+ end
189
+
190
+
191
+ ## factorization
192
+
193
+ # Computes a QR factorization of a complex M-by-N matrix A: A = Q \* R.
194
+ #
195
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
196
+ # @param mode [String]
197
+ # - "reduce" -- returns both Q and R,
198
+ # - "r" -- returns only R,
199
+ # - "economy" -- returns both Q and R but computed in economy-size,
200
+ # - "raw" -- returns QR and TAU used in LAPACK.
201
+ # @return [r] if mode:"r"
202
+ # @return [[q,r]] if mode:"reduce" or "economic"
203
+ # @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
204
+
205
+ def qr(a, mode:"reduce")
206
+ qr,tau, = Lapack.call(:geqrf, a)
207
+ *shp,m,n = qr.shape
208
+ r = (m >= n && %w[economic raw].include?(mode)) ?
209
+ qr[false, 0...n, true].triu : qr.triu
210
+ mode = mode.to_s.downcase
211
+ case mode
212
+ when "r"
213
+ return r
214
+ when "raw"
215
+ return [qr,tau]
216
+ when "reduce","economic"
217
+ # skip
218
+ else
219
+ raise ArgumentError, "invalid mode:#{mode}"
220
+ end
221
+ if m < n
222
+ q, = Lapack.call(:orgqr, qr[false, 0...m], tau)
223
+ elsif mode == "economic"
224
+ q, = Lapack.call(:orgqr, qr, tau)
225
+ else
226
+ qqr = qr.class.zeros(*(shp+[m,m]))
227
+ qqr[false,0...n] = qr
228
+ q, = Lapack.call(:orgqr, qqr, tau)
229
+ end
230
+ return [q,r]
231
+ end
232
+
233
+
234
+ # Computes the Singular Value Decomposition (SVD) of a M-by-N matrix A,
235
+ # and the left and/or right singular vectors. The SVD is written
236
+ #
237
+ # A = U * SIGMA * transpose(V)
238
+ #
239
+ # where SIGMA is an M-by-N matrix which is zero except for its
240
+ # min(m,n) diagonal elements, U is an M-by-M orthogonal matrix, and
241
+ # V is an N-by-N orthogonal matrix. The diagonal elements of SIGMA
242
+ # are the singular values of A; they are real and non-negative, and
243
+ # are returned in descending order. The first min(m,n) columns of U
244
+ # and V are the left and right singular vectors of A. Note that the
245
+ # routine returns V**T, not V.
246
+ #
247
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
248
+ # @param driver [String or Symbol] choose LAPACK solver from 'svd',
249
+ # 'sdd'. (optional, default='svd')
250
+ # @param job [String or Symbol]
251
+ # - 'A': all M columns of U and all N rows of V\*\*T are returned in
252
+ # the arrays U and VT.
253
+ # - 'S': the first min(M,N) columns of U and the first min(M,N)
254
+ # rows of V\*\*T are returned in the arrays U and VT.
255
+ # - 'N': no columns of U or rows of V\*\*T are computed.
256
+ # @return [[sigma,u,vt]] SVD result. Array<Numo::NArray>
257
+
258
+ def svd(a, driver:'svd', job:'A')
259
+ unless /^[ASN]/i =~ job
260
+ raise ArgumentError, "invalid job: #{job.inspect}"
261
+ end
262
+ case driver.to_s
263
+ when /^(ge)?sdd$/i, "turbo"
264
+ Lapack.call(:gesdd, a, jobz:job)[0..2]
265
+ when /^(ge)?svd$/i
266
+ Lapack.call(:gesvd, a, jobu:job, jobvt:job)[0..2]
267
+ else
268
+ raise ArgumentError, "invalid driver: #{driver}"
269
+ end
270
+ end
271
+
272
+ # Computes the Singular Values of a M-by-N matrix A.
273
+ # The SVD is written
274
+ #
275
+ # A = U * SIGMA * transpose(V)
276
+ #
277
+ # where SIGMA is an M-by-N matrix which is zero except for its
278
+ # min(m,n) diagonal elements. The diagonal elements of SIGMA
279
+ # are the singular values of A; they are real and non-negative, and
280
+ # are returned in descending order.
281
+ #
282
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
283
+ # @param driver [String or Symbol] choose LAPACK solver from 'svd',
284
+ # 'sdd'. (optional, default='svd')
285
+ # @return [Numo::NArray] returns SIGMA (singular values).
286
+
287
+ def svdvals(a, driver:'svd')
288
+ case driver.to_s
289
+ when /^(ge)?sdd$/i, "turbo"
290
+ Lapack.call(:gesdd, a, jobz:'N')[0]
291
+ when /^(ge)?svd$/i
292
+ Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')[0]
293
+ else
294
+ raise ArgumentError, "invalid driver: #{driver}"
295
+ end
296
+ end
297
+
298
+
299
+ # Computes an LU factorization of a M-by-N matrix A
300
+ # using partial pivoting with row interchanges.
301
+ #
302
+ # The factorization has the form
303
+ #
304
+ # A = P * L * U
305
+ #
306
+ # where P is a permutation matrix, L is lower triangular with unit
307
+ # diagonal elements (lower trapezoidal if m > n), and U is upper
308
+ # triangular (upper trapezoidal if m < n).
309
+ #
310
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
311
+ # @param driver [String or Symbol] choose LAPACK diriver from
312
+ # 'gen','sym','her'. (optional, default='gen')
313
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
314
+ # or ('U') lower ('L') triangle. (omitted when driver:"gen")
315
+ # @return [[lu, ipiv]]
316
+ # - **lu** [Numo::NArray] -- The factors L and U from the factorization
317
+ # `A = P*L*U`; the unit diagonal elements of L are not stored.
318
+ # - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
319
+ # row i of the matrix was interchanged with row IPIV(i).
320
+
321
+ def lu_fact(a, driver:"gen", uplo:"U")
322
+ case driver.to_s
323
+ when /^gen?(trf)?$/i
324
+ Lapack.call(:getrf, a)[0..1]
325
+ when /^(sym?|her?)(trf)?$/i
326
+ func = driver[0..2].downcase+"trf"
327
+ Lapack.call(func, a, uplo:uplo)[0..1]
328
+ else
329
+ raise ArgumentError, "invalid driver: #{driver}"
330
+ end
331
+ end
332
+
333
+ # Computes the inverse of a matrix using the LU factorization
334
+ # computed by Numo::Linalg.lu_fact.
335
+ #
336
+ # This method inverts U and then computes inv(A) by solving the system
337
+ #
338
+ # inv(A)*L = inv(U)
339
+ #
340
+ # for inv(A).
341
+ #
342
+ # @param lu [Numo::NArray] matrix containing the factors L and U
343
+ # from the factorization `A = P*L*U` as computed by
344
+ # Numo::Linalg.lu_fact.
345
+ # @param ipiv [Numo::NArray] The pivot indices from
346
+ # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
347
+ # interchanged with row IPIV(i).
348
+ # @param driver [String or Symbol] choose LAPACK diriver from
349
+ # 'gen','sym','her'. (optional, default='gen')
350
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
351
+ # or ('U') lower ('L') triangle. (omitted when driver:"gen")
352
+ # @return [Numo::NArray] the inverse of the original matrix A.
353
+
354
+ def lu_inv(lu, ipiv, driver:"gen", uplo:"U")
355
+ case driver.to_s
356
+ when /^gen?(tri)?$/i
357
+ Lapack.call(:getri, lu, ipiv)[0]
358
+ when /^(sym?|her?)(tri)?$/i
359
+ func = driver[0..2].downcase+"tri"
360
+ Lapack.call(func, lu, ipiv, uplo:uplo)[0]
361
+ else
362
+ raise ArgumentError, "invalid driver: #{driver}"
363
+ end
364
+ end
365
+
366
+ # Solves a system of linear equations
367
+ #
368
+ # A * X = B or A**T * X = B
369
+ #
370
+ # with a N-by-N matrix A using the LU factorization computed by
371
+ # Numo::Linalg.lu_fact
372
+ #
373
+ # @param lu [Numo::NArray] matrix containing the factors L and U
374
+ # from the factorization `A = P*L*U` as computed by
375
+ # Numo::Linalg.lu_fact.
376
+ # @param ipiv [Numo::NArray] The pivot indices from
377
+ # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
378
+ # interchanged with row IPIV(i).
379
+ # @param b [Numo::NArray] the right hand side matrix B.
380
+ # @param driver [String or Symbol] choose LAPACK diriver from
381
+ # 'gen','sym','her'. (optional, default='gen')
382
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
383
+ # or ('U') lower ('L') triangle. (omitted when driver:"gen")
384
+ # @param trans [String or Symbol]
385
+ # Specifies the form of the system of equations
386
+ # (omitted if not driver:"gen"):
387
+ #
388
+ # - If 'N': `A * X = B` (No transpose).
389
+ # - If 'T': `A*\*T* X = B` (Transpose).
390
+ # - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
391
+ # @return [Numo::NArray] the solution matrix X.
392
+
393
+ def lu_solve(lu, ipiv, b, driver:"gen", uplo:"U", trans:"N")
394
+ case driver.to_s
395
+ when /^gen?(trs)?$/i
396
+ Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
397
+ when /^(sym?|her?)(trs)?$/i
398
+ func = driver[0..2].downcase+"trs"
399
+ Lapack.call(func, lu, ipiv, b, uplo:uplo)[0]
400
+ else
401
+ raise ArgumentError, "invalid driver: #{driver}"
402
+ end
403
+ end
404
+
405
+
406
+ # Computes the Cholesky factorization of a symmetric/Hermitian
407
+ # positive definite matrix A. The factorization has the form
408
+ #
409
+ # A = U**H * U, if UPLO = 'U', or
410
+ # A = L * L**H, if UPLO = 'L',
411
+ #
412
+ # where U is an upper triangular matrix and L is lower triangular
413
+ # @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
414
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
415
+ # or ('U') lower ('L') triangle.
416
+ # @return [Numo::NArray] the factor U or L.
417
+
418
+ def cho_fact(a, uplo:'U')
419
+ Lapack.call(:potrf, a, uplo:uplo)[0]
420
+ end
421
+ #alias cholesky cho_fact
422
+
423
+ # Computes the inverse of a symmetric/Hermitian
424
+ # positive definite matrix A using the Cholesky factorization
425
+ # `A = U**T*U` or `A = L*L**T` computed by Linalg.cho_fact.
426
+ #
427
+ # @param a [Numo::NArray] the triangular factor U or L from the
428
+ # Cholesky factorization, as computed by Linalg.cho_fact.
429
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
430
+ # or ('U') lower ('L') triangle.
431
+ # @return [Numo::NArray] the upper or lower triangle of the
432
+ # (symmetric) inverse of A.
433
+
434
+ def cho_inv(a, uplo:'U')
435
+ Lapack.call(:potri, a, uplo:uplo)[0]
436
+ end
437
+
438
+ # Solves a system of linear equations
439
+ # A*X = B
440
+ # with a symmetric/Hermitian positive definite matrix A
441
+ # using the Cholesky factorization
442
+ # `A = U**T*U` or `A = L*L**T` computed by Linalg.cho_fact.
443
+ # @param a [Numo::NArray] the triangular factor U or L from the
444
+ # Cholesky factorization, as computed by Linalg.cho_fact.
445
+ # @param b [Numo::NArray] the right hand side matrix B.
446
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
447
+ # or ('U') lower ('L') triangle.
448
+ # @return [Numo::NArray] the solution matrix X.
449
+
450
+ def cho_solve(a, b, uplo:'U')
451
+ Lapack.call(:potrs, a, b, uplo:uplo)[0]
452
+ end
453
+
454
+
455
+ ## Matrix eigenvalues
456
+
457
+ # Computes the eigenvalues and, optionally, the left and/or right
458
+ # eigenvectors for a square nonsymmetric matrix A.
459
+ #
460
+ # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
461
+ # @param left [Bool] (optional) If true, left eigenvectors are computed.
462
+ # @param right [Bool] (optional) If true, right eigenvectors are computed.
463
+ # @return [[w,vl,vr]]
464
+ # - **w** [Numo::NArray] -- The eigenvalues.
465
+ # - **vl** [Numo::NArray] -- The left eigenvectors if left is true, otherwise nil.
466
+ # - **vr** [Numo::NArray] -- The right eigenvectors if right is true, otherwise nil.
467
+
468
+ def eig(a, left:false, right:true)
469
+ jobvl, jobvr = left, right
470
+ case blas_char(a)
471
+ when /c|z/
472
+ w, vl, vr, info = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
473
+ else
474
+ wr, wi, vl, vr, info = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
475
+ w = wr + wi * Complex::I
476
+ vl = _make_complex_eigvecs(w,vl) if left
477
+ vr = _make_complex_eigvecs(w,vr) if right
478
+ end
479
+ [w,vl,vr] #.compact
480
+ end
481
+
482
+ # Computes the eigenvalues and, optionally, the left and/or right
483
+ # eigenvectors for a square symmetric/hermitian matrix A.
484
+ #
485
+ # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
486
+ # @param values_only [Bool] (optional) If false, eigenvectors are computed.
487
+ # @param uplo [String or Symbol] (optional, default='U')
488
+ # Access upper ('U') or lower ('L') triangle.
489
+ # @return [[w,v]]
490
+ # - **w** [Numo::NArray] -- The eigenvalues.
491
+ # - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
492
+
493
+ def eigh(a, vals_only:false, uplo:false, turbo:false)
494
+ jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
495
+ case blas_char(a)
496
+ when /c|z/
497
+ func = turbo ? :hegv : :heev
498
+ else
499
+ func = turbo ? :sygv : :syev
500
+ end
501
+ w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
502
+ [w,v] #.compact
503
+ end
504
+
505
+ # Computes the eigenvalues only for a square nonsymmetric matrix A.
506
+ #
507
+ # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
508
+ # @return [Numo::NArray] eigenvalues
509
+
510
+ def eigvals(a)
511
+ jobvl, jobvr = 'N','N'
512
+ case blas_char(a)
513
+ when /c|z/
514
+ w, = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
515
+ else
516
+ wr, wi, = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
517
+ w = wr + wi * Complex::I
518
+ end
519
+ w
520
+ end
521
+
522
+ # Computes the eigenvalues for a square symmetric/hermitian matrix A.
523
+ #
524
+ # @param a [Numo::NArray] square symmetric/hermitian matrix
525
+ # (>= 2-dimensinal NArray)
526
+ # @param uplo [String or Symbol] (optional, default='U')
527
+ # Access upper ('U') or lower ('L') triangle.
528
+ # @return [Numo::NArray] eigenvalues
529
+
530
+ def eigvalsh(a, uplo:false, turbo:false)
531
+ jobz = 'N' # jobz: Compute eigenvalues and eigenvectors.
532
+ case blas_char(a)
533
+ when /c|z/
534
+ func = turbo ? :hegv : :heev
535
+ else
536
+ func = turbo ? :sygv : :syev
537
+ end
538
+ Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
539
+ end
540
+
541
+
542
+ ## Norms and other numbers
543
+
544
+ # Compute matrix or vector norm.
545
+ #
546
+ # | ord | matrix norm | vector norm |
547
+ # | ----- | ---------------------- | --------------------------- |
548
+ # | nil | Frobenius norm | 2-norm |
549
+ # | 'fro' | Frobenius norm | - |
550
+ # | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
551
+ # | 0 | - | (x.ne 0).sum |
552
+ # | 1 | x.abs.sum(axis:-2).max | same as below |
553
+ # | 2 | 2-norm (max sing_vals) | same as below |
554
+ # | other | - | (x.abs**ord).sum**(1.0/ord) |
555
+ #
556
+ # @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
557
+ # @param ord [String or Symbol] Order of the norm .
558
+ # @param axis [Integer or Array] Applied axes (optional).
559
+ # @param keepdims [Bool] If true, the applied axes are left in
560
+ # result with size one (optional).
561
+ # @return [Numo::NArray] norm result
562
+
563
+ def norm(a, ord=nil, axis:nil, keepdims:false)
564
+ a = Numo::NArray.asarray(a)
565
+
566
+ # check axis
567
+ if axis
568
+ case axis
569
+ when Integer
570
+ axis = [axis]
571
+ when Array
572
+ if axis.size < 1 || axis.size > 2
573
+ raise ArgmentError, "axis option should be 1- or 2-element array"
574
+ end
575
+ else
576
+ raise ArgumentError, "invalid option for axis: #{axis}"
577
+ end
578
+ # swap axes
579
+ if a.ndim > 1
580
+ idx = (0...a.ndim).to_a
581
+ tmp = []
582
+ axis.each do |i|
583
+ x = idx[i]
584
+ if x.nil?
585
+ raise ArgmentError, "axis contains same dimension"
586
+ end
587
+ tmp << x
588
+ idx[i] = nil
589
+ end
590
+ idx.compact!
591
+ idx.concat(tmp)
592
+ a = a.transpose(*idx)
593
+ end
594
+ else
595
+ case a.ndim
596
+ when 0
597
+ raise ArgumentError, "zero-dimensional array"
598
+ when 1
599
+ axis = [-1]
600
+ else
601
+ axis = [-2,-1]
602
+ end
603
+ end
604
+
605
+ # calculate norm
606
+ case axis.size
607
+
608
+ when 1 # vector
609
+ k = keepdims
610
+ ord ||= 2 # default
611
+ case ord.to_s
612
+ when "0"
613
+ r = a.class.cast(a.ne(0)).sum(axis:-1, keepdims:k)
614
+ when "1"
615
+ r = a.abs.sum(axis:-1, keepdims:k)
616
+ when "2"
617
+ r = Blas.call(:nrm2, a, keepdims:k)
618
+ when /^-?\d+$/
619
+ o = ord.to_i
620
+ r = (a.abs**o).sum(axis:-1, keepdims:k)**(1.0/o)
621
+ when /^inf(inity)?$/i
622
+ r = a.abs.max(axis:-1, keepdims:k)
623
+ when /^-inf(inity)?$/i
624
+ r = a.abs.min(axis:-1, keepdims:k)
625
+ else
626
+ raise ArgumentError, "ord (#{ord}) is invalid for vector norm"
627
+ end
628
+
629
+ when 2 # matrix
630
+ if keepdims
631
+ fixdims = [true] * a.ndim
632
+ axis.each do |i|
633
+ if i < -a.ndim || i >= a.ndim
634
+ raise ArgmentError, "axis (%d) is out of range", i
635
+ end
636
+ fixdims[i] = :new
637
+ end
638
+ end
639
+ ord ||= "fro" # default
640
+ case ord.to_s
641
+ when "1"
642
+ r, = Lapack.call(:lange, a, '1')
643
+ when "-1"
644
+ r = a.abs.sum(axis:-2).min(axis:-1)
645
+ when "2"
646
+ svd, = Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')
647
+ r = svd.max(axis:-1)
648
+ when "-2"
649
+ svd, = Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')
650
+ r = svd.min(axis:-1)
651
+ when /^f(ro)?$/i
652
+ r, = Lapack.call(:lange, a, 'F')
653
+ when /^inf(inity)?$/i
654
+ r, = Lapack.call(:lange, a, 'I')
655
+ when /^-inf(inity)?$/i
656
+ r = a.abs.sum(axis:-1).min(axis:-1)
657
+ else
658
+ raise ArgumentError, "ord (#{ord}) is invalid for matrix norm"
659
+ end
660
+ if keepdims
661
+ if NArray===r
662
+ r = r[*fixdims]
663
+ else
664
+ r = a.class.new(1,1).store(r)
665
+ end
666
+ end
667
+ end
668
+ return r
669
+ end
670
+
671
+ # Compute the condition number of a matrix
672
+ # using the norm with one of the following order.
673
+ #
674
+ # | ord | matrix norm |
675
+ # | ----- | ---------------------- |
676
+ # | nil | 2-norm using SVD |
677
+ # | 'fro' | Frobenius norm |
678
+ # | 'inf' | x.abs.sum(axis:-1).max |
679
+ # | 1 | x.abs.sum(axis:-2).max |
680
+ # | 2 | 2-norm (max sing_vals) |
681
+ #
682
+ # @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
683
+ # @param ord [String or Symbol] Order of the norm.
684
+ # @return [Numo::NArray] cond result
685
+ # @example
686
+ # a = Numo::DFloat[[1, 0, -1], [0, 1, 0], [1, 0, 1]]
687
+ # => Numo::DFloat#shape=[3,3]
688
+ # [[1, 0, -1],
689
+ # [0, 1, 0],
690
+ # [1, 0, 1]]
691
+ # LA = Numo::Linalg
692
+ # LA.cond(a)
693
+ # => 1.4142135623730951
694
+ # LA.cond(a, 'fro')
695
+ # => 3.1622776601683795
696
+ # LA.cond(a, 'inf')
697
+ # => 2.0
698
+ # LA.cond(a, '-inf')
699
+ # => 1.0
700
+ # LA.cond(a, 1)
701
+ # => 2.0
702
+ # LA.cond(a, -1)
703
+ # => 1.0
704
+ # LA.cond(a, 2)
705
+ # => 1.4142135623730951
706
+ # LA.cond(a, -2)
707
+ # => 0.7071067811865475
708
+ # (LA.svdvals(a)).min*(LA.svdvals(LA.inv(a))).min
709
+ # => 0.7071067811865475
710
+
711
+ def cond(a,ord=nil)
712
+ if ord.nil?
713
+ s = svdvals(a)
714
+ s[false, 0]/s[false, -1]
715
+ else
716
+ norm(a, ord, axis:[-2,-1]) * norm(inv(a), ord, axis:[-2,-1])
717
+ end
718
+ end
719
+
720
+ # Determinant of a matrix
721
+ #
722
+ # @param a [Numo::NArray] matrix (>= 2-dimensional NArray)
723
+ # @return [Float or Complex or Numo::NArray]
724
+
725
+ def det(a)
726
+ lu, piv, = Lapack.call(:getrf, a)
727
+ idx = piv.new_narray.store(piv.class.new(piv.shape[-1]).seq(1))
728
+ m = piv.eq(idx).count_false(axis:-1) % 2
729
+ sign = m * -2 + 1
730
+ lu.diagonal.prod(axis:-1) * sign
731
+ end
732
+
733
+ # Natural logarithm of the determinant of a matrix
734
+ #
735
+ # @param a [Numo::NArray] matrix (>= 2-dimensional NArray)
736
+ # @return [[sign,logdet]]
737
+ # - **sign** -- A number representing the sign of the determinant.
738
+ # - **logdet** -- The natural log of the absolute value of the determinant.
739
+
740
+ def slogdet(a)
741
+ lu, piv, = Lapack.call(:getrf, a)
742
+ idx = piv.new_narray.store(piv.class.new(piv.shape[-1]).seq(1))
743
+ m = piv.eq(idx).count_false(axis:-1) % 2
744
+ sign = m * -2 + 1
745
+
746
+ lud = lu.diagonal
747
+ if (lud.eq 0).any?
748
+ return 0, (-Float::INFINITY)
749
+ end
750
+ lud_abs = lud.abs
751
+ sign *= (lud/lud_abs).prod
752
+ [sign, NMath.log(lud_abs).sum(axis:-1)]
753
+ end
754
+
755
+ # Compute matrix rank of array using SVD
756
+ # *Rank* is the number of singular values greater than *tol*.
757
+ #
758
+ # @param m [Numo::NArray] matrix (>= 2-dimensional NArray)
759
+ # @param tol [Float] threshold below which singular values are
760
+ # considered to be zero. If *tol* is nil,
761
+ # `tol = sing_vals.max() * m.shape.max * EPSILON`.
762
+ # @param driver [String or Symbol] choose LAPACK solver from 'svd',
763
+ # 'sdd'. (optional, default='svd')
764
+
765
+ def matrix_rank(m, tol:nil, driver:'svd')
766
+ m = Numo::NArray.asarray(m)
767
+ if m.ndim < 2
768
+ m.ne(0).any? ? 1 : 0
769
+ else
770
+ case driver.to_s
771
+ when /^(ge)?sdd$/, "turbo"
772
+ s = Lapack.call(:gesdd, m, jobz:'N')[0]
773
+ when /^(ge)?svd$/
774
+ s = Lapack.call(:gesvd, m, jobu:'N', jobvt:'N')[0]
775
+ else
776
+ raise ArgumentError, "invalid driver: #{driver}"
777
+ end
778
+ tol ||= s.max(axis:-1, keepdims:true) *
779
+ (m.shape[-2..-1].max * s.class::EPSILON)
780
+ (s > tol).count(axis:-1)
781
+ end
782
+ end
783
+
784
+
785
+ ## Solving equations and inverting matrices
786
+
787
+ # Solves linear equation `a * x = b` for `x`
788
+ # from square matrix `a`
789
+ # @param a [Numo::NArray] n-by-n square matrix (>= 2-dimensinal NArray)
790
+ # @param b [Numo::NArray] n-by-nrhs right-hand-side matrix (>=
791
+ # 1-dimensinal NArray)
792
+ # @param driver [String or Symbol] choose LAPACK diriver from
793
+ # 'gen','sym','her' or 'pos'. (optional, default='gen')
794
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
795
+ # or ('U') lower ('L') triangle. (omitted when driver:"gen")
796
+ # @return [Numo::NArray] The solusion matrix/vector X.
797
+
798
+ def solve(a, b, driver:"gen", uplo:'U')
799
+ case driver.to_s
800
+ when /^gen?(sv)?$/i
801
+ # returns lu, x, ipiv, info
802
+ Lapack.call(:gesv, a, b)[1]
803
+ when /^(sym?|her?|pos?)(sv)?$/i
804
+ func = driver[0..2].downcase+"sv"
805
+ Lapack.call(func, a, b, uplo:uplo)[1]
806
+ else
807
+ raise ArgumentError, "invalid driver: #{driver}"
808
+ end
809
+ end
810
+
811
+ # Inverse matrix from square matrix `a`
812
+ # @param a [Numo::NArray] n-by-n square matrix (>= 2-dimensinal NArray)
813
+ # @param driver [String or Symbol] choose LAPACK diriver
814
+ # ('ge'|'sy'|'he'|'po') + ("sv"|"trf")
815
+ # (optional, default='getrf')
816
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
817
+ # or ('U') lower ('L') triangle. (omitted when driver:"ge")
818
+ # @return [Numo::NArray] The inverse matrix.
819
+ # @example
820
+ # Numo::Linalg.inv(a,driver:'getrf')
821
+ # => Numo::DFloat#shape=[2,2]
822
+ # [[-2, 1],
823
+ # [1.5, -0.5]]
824
+ # a.dot(Numo::Linalg.inv(a,driver:'getrf'))
825
+ # => Numo::DFloat#shape=[2,2]
826
+ # [[1, 0],
827
+ # [8.88178e-16, 1]]
828
+
829
+ def inv(a, driver:"getrf", uplo:'U')
830
+ case driver
831
+ when /(ge|sy|he|po)sv$/
832
+ d = $1
833
+ b = a.new_zeros.eye
834
+ solve(a, b, driver:d, uplo:uplo)
835
+ when /(ge|sy|he)tr[fi]$/
836
+ d = $1
837
+ lu, piv = lu_fact(a, driver:d, uplo:uplo)
838
+ lu_inv(lu, piv, driver:d, uplo:uplo)
839
+ when /potr[fi]$/
840
+ lu = cho_fact(a, uplo:uplo)
841
+ cho_inv(lu, uplo:uplo)
842
+ else
843
+ raise ArgumentError, "invalid driver: #{driver}"
844
+ end
845
+ end
846
+
847
+ # Computes the minimum-norm solution to a linear least squares
848
+ # problem:
849
+ #
850
+ # minimize 2-norm(| b - A*x |)
851
+ #
852
+ # using the singular value decomposition (SVD) of A.
853
+ # A is an M-by-N matrix which may be rank-deficient.
854
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
855
+ # @param b [Numo::NArray] m-by-nrhs right-hand-side matrix b
856
+ # (>= 1-dimensinal NArray)
857
+ # @param driver [String or Symbol] choose LAPACK driver from
858
+ # 'lsd','lss','lsy' (optional, default='lsd')
859
+ # @param rcond [Float] (optional, default=-1)
860
+ # RCOND is used to determine the effective rank of A.
861
+ # Singular values `S(i) <= RCOND*S(1)` are treated as zero.
862
+ # If RCOND < 0, machine precision is used instead.
863
+ # @return [[x, resids, rank, s]]
864
+ # - **x** -- The solution matrix/vector X.
865
+ # - **resids** -- Sums of residues, squared 2-norm for each column in
866
+ # `b - a x`. If matrix_rank(a) < N or > M, or 'gelsy' is used,
867
+ # this is an empty array.
868
+ # - **rank** -- The effective rank of A, i.e.,
869
+ # the number of singular values which are greater than RCOND*S(1).
870
+ # - **s** -- The singular values of A in decreasing order.
871
+ # Returns nil if 'gelsy' is used.
872
+
873
+ def lstsq(a, b, driver:'lsd', rcond:-1)
874
+ a = NArray.asarray(a)
875
+ b = NArray.asarray(b)
876
+ b_orig = nil
877
+ if b.shape.size==1
878
+ b_orig = b
879
+ b = b_orig[true,:new]
880
+ end
881
+ m = a.shape[-2]
882
+ n = a.shape[-1]
883
+ #nrhs = b.shape[-1]
884
+ if m != b.shape[-2]
885
+ raise NArray::ShapeError, "size mismatch: A-row and B-row"
886
+ end
887
+ if m < n # need to extend b matrix
888
+ shp = b.shape
889
+ shp[-2] = n
890
+ b2 = b.class.zeros(*shp)
891
+ b2[false,0...m,true] = b
892
+ b = b2
893
+ end
894
+ case driver.to_s
895
+ when /^(ge)?lsd$/i
896
+ # x, s, rank, info
897
+ x, s, rank, = Lapack.call(:gelsd, a, b, rcond:rcond)
898
+ when /^(ge)?lss$/i
899
+ # v, x, s, rank, info
900
+ _, x, s, rank, = Lapack.call(:gelss, a, b, rcond:rcond)
901
+ when /^(ge)?lsy$/i
902
+ jpvt = Int32.zeros(*a[false,0,true].shape)
903
+ # v, x, jpvt, rank, info
904
+ _, x, _, rank, = Lapack.call(:gelsy, a, b, jpvt, rcond:rcond)
905
+ s = nil
906
+ else
907
+ raise ArgumentError, "invalid driver: #{driver}"
908
+ end
909
+ resids = nil
910
+ if m > n
911
+ if /ls(d|s)$/i =~ driver
912
+ case rank
913
+ when n
914
+ resids = (x[n..-1,true].abs**2).sum(axis:0)
915
+ when NArray
916
+ if true
917
+ resids = (x[false,n..-1,true].abs**2).sum(axis:-2)
918
+ else
919
+ resids = x[false,0,true].new_zeros
920
+ mask = rank.eq(n)
921
+ # NArray does not suppurt this yet.
922
+ resids[mask,true] = (x[mask,n..-1,true].abs**2).sum(axis:-2)
923
+ end
924
+ end
925
+ end
926
+ x = x[false,0...n,true]
927
+ end
928
+ if b_orig && b_orig.shape.size==1
929
+ x = x[true,0]
930
+ resids &&= resids[false,0]
931
+ end
932
+ [x, resids, rank, s]
933
+ end
934
+
935
+ # Compute the (Moore-Penrose) pseudo-inverse of a matrix
936
+ # using svd or lstsq.
937
+ #
938
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
939
+ # @param driver [String or Symbol] choose LAPACK driver from
940
+ # SVD ('svd', 'sdd') or Least square ('lsd','lss','lsy')
941
+ # (optional, default='svd')
942
+ # @param rcond [Float] (optional, default=-1)
943
+ # RCOND is used to determine the effective rank of A.
944
+ # Singular values `S(i) <= RCOND*S(1)` are treated as zero.
945
+ # If RCOND < 0, machine precision is used instead.
946
+ # @return [Numo::NArray]
947
+ # @example
948
+ # a = Numo::DFloat.new(5,3).rand_norm
949
+ # => Numo::DFloat#shape=[5,3]
950
+ # [[-0.581255, -0.168354, 0.586895],
951
+ # [-0.595142, -0.802802, -0.326106],
952
+ # [0.282922, 1.68427, 0.918499],
953
+ # [-0.0485384, -0.464453, -0.992194],
954
+ # [0.413794, -0.60717, -0.699695]]
955
+ # b = Numo::Linalg.pinv(a,driver:"svd")
956
+ # => Numo::DFloat(view)#shape=[3,5]
957
+ # [[-0.360863, -0.813125, -0.353367, -0.891963, 0.877253],
958
+ # [-0.227645, 0.162939, 0.696655, 0.787685, -0.469346],
959
+ # [0.408671, -0.308323, -0.337807, -1.13833, 0.228051]]
960
+ # (a-a.dot(b.dot(a))).abs.max
961
+ # => 5.551115123125783e-16
962
+
963
+ def pinv(a, driver:"svd", rcond:nil)
964
+ a = NArray.asarray(a)
965
+ if a.ndim < 2
966
+ raise NArray::ShapeError, "2-d array is required"
967
+ end
968
+ case driver
969
+ when /^(ge)?s[dv]d$/
970
+ s, u, vh = svd(a, driver:driver, job:'S')
971
+ if rcond.nil? || rcond < 0
972
+ rcond = ((SFloat===s) ? 1e3 : 1e6) * s.class::EPSILON
973
+ elsif ! Numeric === rcond
974
+ raise ArgumentError, "rcond must be Numeric"
975
+ end
976
+ cond = (s > rcond * s.max(axis:-1, keepdims:true))
977
+ if cond.all?
978
+ r = s.reciprocal
979
+ else
980
+ r = s.new_zeros
981
+ r[cond] = s[cond].reciprocal
982
+ end
983
+ u *= r[false,:new,true]
984
+ dot(u,vh).conj.swapaxes(-2,-1)
985
+ when /^(ge)?ls[dsy]$/
986
+ b = a.class.eye(a.shape[-2])
987
+ x, = lstsq(a, b, driver:driver, rcond:rcond)
988
+ x
989
+ else
990
+ raise ArgumentError, "#{driver.inspect} is not one of drivers: "+
991
+ "svd, sdd, lsd, lss, lsy"
992
+ end
993
+ end
994
+
995
+ private
996
+
997
+ # @!visibility private
998
+ def _make_complex_eigvecs(w, vin) # :nodoc:
999
+ v = w.class.cast(vin)
1000
+ # broadcast to vin.shape
1001
+ m = (w.imag > 0 | Bit.zeros(*vin.shape)).where
1002
+ v[m].imag = vin[m+1]
1003
+ v[m+1] = v[m].conj
1004
+ v
1005
+ end
1006
+
1007
+ end
1008
+ end