numo-linalg 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/Gemfile +4 -0
  3. data/README.md +80 -0
  4. data/Rakefile +18 -0
  5. data/ext/numo/linalg/blas/blas.c +352 -0
  6. data/ext/numo/linalg/blas/cblas.h +575 -0
  7. data/ext/numo/linalg/blas/cblas_t.h +563 -0
  8. data/ext/numo/linalg/blas/depend.erb +23 -0
  9. data/ext/numo/linalg/blas/extconf.rb +67 -0
  10. data/ext/numo/linalg/blas/gen/cogen.rb +72 -0
  11. data/ext/numo/linalg/blas/gen/decl.rb +203 -0
  12. data/ext/numo/linalg/blas/gen/desc.rb +8138 -0
  13. data/ext/numo/linalg/blas/gen/erbpp2.rb +339 -0
  14. data/ext/numo/linalg/blas/gen/replace_cblas_h.rb +27 -0
  15. data/ext/numo/linalg/blas/gen/spec.rb +93 -0
  16. data/ext/numo/linalg/blas/numo_blas.h +41 -0
  17. data/ext/numo/linalg/blas/tmpl/axpy.c +75 -0
  18. data/ext/numo/linalg/blas/tmpl/copy.c +57 -0
  19. data/ext/numo/linalg/blas/tmpl/def_c.c +3 -0
  20. data/ext/numo/linalg/blas/tmpl/def_d.c +3 -0
  21. data/ext/numo/linalg/blas/tmpl/def_s.c +3 -0
  22. data/ext/numo/linalg/blas/tmpl/def_z.c +3 -0
  23. data/ext/numo/linalg/blas/tmpl/dot.c +68 -0
  24. data/ext/numo/linalg/blas/tmpl/ger.c +114 -0
  25. data/ext/numo/linalg/blas/tmpl/init_class.c +20 -0
  26. data/ext/numo/linalg/blas/tmpl/init_module.c +12 -0
  27. data/ext/numo/linalg/blas/tmpl/lib.c +40 -0
  28. data/ext/numo/linalg/blas/tmpl/mm.c +214 -0
  29. data/ext/numo/linalg/blas/tmpl/module.c +9 -0
  30. data/ext/numo/linalg/blas/tmpl/mv.c +194 -0
  31. data/ext/numo/linalg/blas/tmpl/nrm2.c +79 -0
  32. data/ext/numo/linalg/blas/tmpl/rot.c +65 -0
  33. data/ext/numo/linalg/blas/tmpl/rotm.c +82 -0
  34. data/ext/numo/linalg/blas/tmpl/scal.c +69 -0
  35. data/ext/numo/linalg/blas/tmpl/sdsdot.c +77 -0
  36. data/ext/numo/linalg/blas/tmpl/set_prefix.c +16 -0
  37. data/ext/numo/linalg/blas/tmpl/swap.c +57 -0
  38. data/ext/numo/linalg/blas/tmpl/syr.c +102 -0
  39. data/ext/numo/linalg/blas/tmpl/syr2.c +110 -0
  40. data/ext/numo/linalg/blas/tmpl/syr2k.c +129 -0
  41. data/ext/numo/linalg/blas/tmpl/syrk.c +132 -0
  42. data/ext/numo/linalg/lapack/depend.erb +23 -0
  43. data/ext/numo/linalg/lapack/extconf.rb +45 -0
  44. data/ext/numo/linalg/lapack/gen/cogen.rb +74 -0
  45. data/ext/numo/linalg/lapack/gen/desc.rb +151278 -0
  46. data/ext/numo/linalg/lapack/gen/replace_lapacke_h.rb +32 -0
  47. data/ext/numo/linalg/lapack/gen/spec.rb +104 -0
  48. data/ext/numo/linalg/lapack/lapack.c +387 -0
  49. data/ext/numo/linalg/lapack/lapacke.h +16425 -0
  50. data/ext/numo/linalg/lapack/lapacke_config.h +119 -0
  51. data/ext/numo/linalg/lapack/lapacke_mangling.h +17 -0
  52. data/ext/numo/linalg/lapack/lapacke_t.h +10550 -0
  53. data/ext/numo/linalg/lapack/numo_lapack.h +42 -0
  54. data/ext/numo/linalg/lapack/tmpl/def_c.c +3 -0
  55. data/ext/numo/linalg/lapack/tmpl/def_d.c +7 -0
  56. data/ext/numo/linalg/lapack/tmpl/def_s.c +7 -0
  57. data/ext/numo/linalg/lapack/tmpl/def_z.c +3 -0
  58. data/ext/numo/linalg/lapack/tmpl/fact.c +179 -0
  59. data/ext/numo/linalg/lapack/tmpl/geev.c +123 -0
  60. data/ext/numo/linalg/lapack/tmpl/gels.c +232 -0
  61. data/ext/numo/linalg/lapack/tmpl/gesv.c +149 -0
  62. data/ext/numo/linalg/lapack/tmpl/gesvd.c +189 -0
  63. data/ext/numo/linalg/lapack/tmpl/ggev.c +138 -0
  64. data/ext/numo/linalg/lapack/tmpl/gqr.c +121 -0
  65. data/ext/numo/linalg/lapack/tmpl/init_class.c +20 -0
  66. data/ext/numo/linalg/lapack/tmpl/init_module.c +12 -0
  67. data/ext/numo/linalg/lapack/tmpl/lange.c +79 -0
  68. data/ext/numo/linalg/lapack/tmpl/lib.c +40 -0
  69. data/ext/numo/linalg/lapack/tmpl/module.c +9 -0
  70. data/ext/numo/linalg/lapack/tmpl/syev.c +91 -0
  71. data/ext/numo/linalg/lapack/tmpl/sygv.c +104 -0
  72. data/ext/numo/linalg/lapack/tmpl/trf.c +276 -0
  73. data/ext/numo/linalg/numo_linalg.h +115 -0
  74. data/lib/numo/linalg.rb +3 -0
  75. data/lib/numo/linalg/function.rb +1008 -0
  76. data/lib/numo/linalg/linalg.rb +7 -0
  77. data/lib/numo/linalg/loader.rb +174 -0
  78. data/lib/numo/linalg/use/atlas.rb +3 -0
  79. data/lib/numo/linalg/use/lapack.rb +3 -0
  80. data/lib/numo/linalg/use/mkl.rb +3 -0
  81. data/lib/numo/linalg/use/openblas.rb +3 -0
  82. data/lib/numo/linalg/version.rb +5 -0
  83. data/numo-linalg.gemspec +26 -0
  84. data/spec/lapack_spec.rb +13 -0
  85. metadata +172 -0
@@ -0,0 +1,276 @@
1
+ <%
2
+ has_rhs = (/trs$/ =~ name)
3
+ has_trans = (/^.(g|l|t).trs$/ =~ name)
4
+ has_uplo = (/^.(g|pt)/ !~ name)
5
+ has_ipiv = (/p[bfopt]tr.$/ !~ name)
6
+ ipiv_out = (has_ipiv && /trf$/ =~ name)
7
+ ipiv_in = (has_ipiv && /tr[is]$/ =~ name)
8
+ is_sym = (has_uplo || /getr[is]/=~name)
9
+ %>
10
+ #define RHS <%= has_rhs ? "1":"0" %>
11
+ #define TRANS <%= has_trans ? "1":"0" %>
12
+ #define UPLO <%= has_uplo ? "1":"0" %>
13
+ #define IPIV <%= has_ipiv ? "1":"0" %>
14
+ #define IPIV_OUT <%= ipiv_out ? "1":"0" %>
15
+ #define IPIV_IN <%= ipiv_in ? "1":"0" %>
16
+ #define SYM <%= is_sym ? "1":"0" %>
17
+ #define args_t <%=func_name%>_args_t
18
+ #define func_p <%=func_name%>_p
19
+
20
+ typedef struct {
21
+ int order;
22
+ char uplo;
23
+ char trans;
24
+ } args_t;
25
+
26
+ static <%=func_name%>_t func_p = 0;
27
+
28
+ static void
29
+ <%=c_iter%>(na_loop_t * const lp)
30
+ {
31
+ dtype *a;
32
+ #if RHS
33
+ dtype *b;
34
+ int nb, nrhs, ldb;
35
+ #endif
36
+ #if IPIV
37
+ int *pv;
38
+ #endif
39
+ int *info;
40
+ int m, n, lda;
41
+ args_t *g;
42
+
43
+ a = (dtype*)NDL_PTR(lp,0);
44
+ #if IPIV
45
+ pv = (int*)NDL_PTR(lp,1);
46
+ #endif
47
+ #if RHS
48
+ b = (dtype*)NDL_PTR(lp,1+IPIV);
49
+ #endif
50
+ info = (int*)NDL_PTR(lp,1+IPIV+RHS);
51
+ g = (args_t*)(lp->opt_ptr);
52
+
53
+ n = NDL_SHAPE(lp,0)[0];
54
+ m = NDL_SHAPE(lp,0)[1];
55
+ lda = NDL_STEP(lp,0) / sizeof(dtype);
56
+
57
+ #if RHS
58
+ // same as gels.c
59
+ if (lp->args[1+IPIV].ndim == 1) {
60
+ nrhs = 1;
61
+ nb = NDL_SHAPE(lp,1+IPIV)[0];
62
+ ldb = (g->order==LAPACK_COL_MAJOR) ? nb : 1;
63
+ } else {
64
+ nb = NDL_SHAPE(lp,1+IPIV)[0];
65
+ nrhs = NDL_SHAPE(lp,1+IPIV)[1];
66
+ ldb = nrhs;
67
+ { int tmp; SWAP_IFCOL(g->order,nb,nrhs); }
68
+ }
69
+ //printf("order=%d m=%d n=%d nb=%d nrhs=%d lda=%d ldb=%d\n",g->order,m,n,nb,nrhs,lda,ldb);
70
+ #else
71
+ //printf("order=%d m=%d n=%d lda=%d \n",g->order,m,n,lda);
72
+ #endif
73
+
74
+ #if SYM
75
+ n = min_(m,n);
76
+ #else
77
+ { int tmp; SWAP_IFCOL(g->order,m,n); }
78
+ #endif
79
+
80
+ <%
81
+ func_args = [ "g->order",
82
+ has_uplo && "g->uplo",
83
+ has_trans && "g->trans",
84
+ "n",
85
+ has_rhs ? "nrhs" : (!is_sym && "m"),
86
+ "a, lda",
87
+ has_ipiv && "pv",
88
+ has_rhs && "b, ldb",
89
+ ].select{|x| x}.join(", ")
90
+ %>
91
+ *info = (*func_p)(<%=func_args%>);
92
+ CHECK_ERROR(*info);
93
+ }
94
+
95
+ /*<%
96
+ args_v = [
97
+ "a",
98
+ ipiv_in && "ipiv",
99
+ has_rhs && "b",
100
+ ].select{|x| x}.join(", ")
101
+
102
+ args_opt = [
103
+ has_uplo && "uplo:'U'",
104
+ has_trans && "trans:'N'",
105
+ "order:'R'",
106
+ ].select{|x| x}.join(", ")
107
+
108
+ trf = name.sub(/.$/,"f")
109
+
110
+ params = [
111
+ has_rhs ? "@param a [#{class_name}] LU matrix computed by "+trf :
112
+ mat("a",:inplace),
113
+ ipiv_in && "@param ipiv [Numo::Int] pivot computed by "+trf,
114
+ has_rhs && mat("b",:inplace),
115
+ has_uplo && opt("uplo"),
116
+ has_trans && opt("trans"),
117
+ opt("order"),
118
+ ].select{|x| x}.join("\n ")
119
+
120
+ return_type = [
121
+ class_name,
122
+ ipiv_out && "Numo::Int",
123
+ "Integer"
124
+ ].select{|x| x}.join(", ")
125
+
126
+ return_name = [
127
+ has_rhs ? "b" : "a",
128
+ ipiv_out && "ipiv",
129
+ "info"
130
+ ].select{|x| x}.join(", ")
131
+ %>
132
+ @overload <%=name%>(<%=args_v%>, [<%=args_opt%>])
133
+ <%=params%>
134
+ @return [[<%=return_name%>]] Array<<%=return_type%>>
135
+ <%=outparam(return_name)%>
136
+
137
+ <%=description%>
138
+
139
+ */
140
+ static VALUE
141
+ <%=c_func(-1)%>(int argc, VALUE const argv[], VALUE UNUSED(mod))
142
+ {
143
+ <% %>
144
+ VALUE a, ans;
145
+ #if IPIV_IN
146
+ VALUE ipiv;
147
+ #endif
148
+ #if RHS
149
+ VALUE b;
150
+ size_t n, nb, nrhs;
151
+ narray_t *na2;
152
+ #endif
153
+ narray_t *na1;
154
+ <%
155
+ aout = [
156
+ ipiv_out && "{cInt,1,shape_piv}",
157
+ "{cInt,0}",
158
+ ].select{|x| x}.join(",")
159
+ %>
160
+ #if IPIV_OUT
161
+ size_t shape_piv[1];
162
+ #endif
163
+ #if IPIV_IN
164
+ # if RHS
165
+ ndfunc_arg_in_t ain[3] = {{cT,2},{cInt,1},{OVERWRITE,2}};
166
+ # else
167
+ ndfunc_arg_in_t ain[2] = {{OVERWRITE,2},{cInt,1}};
168
+ # endif
169
+ #else
170
+ # if RHS
171
+ ndfunc_arg_in_t ain[2] = {{cT,2},{OVERWRITE,2}};
172
+ # else
173
+ ndfunc_arg_in_t ain[1] = {{OVERWRITE,2}};
174
+ # endif
175
+ #endif
176
+ ndfunc_arg_out_t aout[1+IPIV_OUT] = {<%=aout%>};
177
+ ndfunc_t ndf = {&<%=c_iter%>, NO_LOOP|NDF_EXTRACT,
178
+ 1+IPIV_IN+RHS, IPIV_OUT+1, ain,aout};
179
+
180
+ args_t g = {0,0};
181
+ VALUE opts[2] = {Qundef,Qundef};
182
+ VALUE kw_hash = Qnil;
183
+ ID kw_table[2] = {id_order,id_uplo};
184
+
185
+ CHECK_FUNC(func_p,"<%=func_name%>");
186
+
187
+ #if IPIV_IN
188
+ # if RHS
189
+ rb_scan_args(argc, argv, "3:", &a, &ipiv, &b, &kw_hash);
190
+ # else
191
+ rb_scan_args(argc, argv, "2:", &a, &ipiv, &kw_hash);
192
+ # endif
193
+ #else
194
+ # if RHS
195
+ rb_scan_args(argc, argv, "2:", &a, &b, &kw_hash);
196
+ # else
197
+ rb_scan_args(argc, argv, "1:", &a, &kw_hash);
198
+ # endif
199
+ #endif
200
+ #if TRANS
201
+ kw_table[1] = id_trans;
202
+ rb_get_kwargs(kw_hash, kw_table, 0, 2, opts);
203
+ g.trans = option_trans(opts[1]);
204
+ #elif UPLO
205
+ rb_get_kwargs(kw_hash, kw_table, 0, 2, opts);
206
+ g.uplo = option_uplo(opts[1]);
207
+ #else
208
+ rb_get_kwargs(kw_hash, kw_table, 0, 1, opts);
209
+ #endif
210
+ g.order = option_order(opts[0]);
211
+
212
+ #if !RHS
213
+ COPY_OR_CAST_TO(a,cT);
214
+ #endif
215
+ GetNArray(a, na1);
216
+ CHECK_DIM_GE(na1, 2);
217
+ #if IPIV_OUT
218
+ shape_piv[0] = min_(ROW_SIZE(na1),COL_SIZE(na1));
219
+ #endif
220
+
221
+ #if RHS
222
+ COPY_OR_CAST_TO(b,cT);
223
+ GetNArray(b, na2);
224
+ CHECK_DIM_GE(na2, 1);
225
+ n = COL_SIZE(na1);
226
+ #if SYM
227
+ n = min_(n,ROW_SIZE(na1));
228
+ #endif
229
+ // same as gesv.c
230
+ if (NA_NDIM(na2) == 1) {
231
+ ain[1+IPIV_IN].dim = 1;
232
+ nb = COL_SIZE(na2);
233
+ nrhs = 1;
234
+ } else {
235
+ nb = ROW_SIZE(na2);
236
+ nrhs = COL_SIZE(na2);
237
+ { int tmp; SWAP_IFCOL(g.order,nb,nrhs); }
238
+ }
239
+ if (n != nb) {
240
+ rb_raise(nary_eShapeError, "matrix dimension mismatch: "
241
+ "a.col(or a.row)=%"SZF"u b.row=%"SZF"u", n, nb);
242
+ }
243
+ #endif
244
+
245
+ #if IPIV_IN
246
+ # if RHS
247
+ ans = na_ndloop3(&ndf, &g, 3, a, ipiv, b);
248
+ return rb_assoc_new(b, ans);
249
+ # else
250
+ ans = na_ndloop3(&ndf, &g, 2, a, ipiv);
251
+ return rb_assoc_new(a, ans);
252
+ # endif
253
+ #else
254
+ # if RHS
255
+ ans = na_ndloop3(&ndf, &g, 2, a, b);
256
+ return rb_assoc_new(b, ans);
257
+ # else
258
+ ans = na_ndloop3(&ndf, &g, 1, a);
259
+ # if IPIV_OUT
260
+ return rb_ary_unshift(ans, a);
261
+ # else
262
+ return rb_assoc_new(a, ans);
263
+ # endif
264
+ # endif
265
+ #endif
266
+ }
267
+
268
+ #undef args_t
269
+ #undef func_p
270
+ #undef RHS
271
+ #undef TRANS
272
+ #undef UPLO
273
+ #undef IPIV
274
+ #undef IPIV_OUT
275
+ #undef IPIV_IN
276
+ #undef SYM
@@ -0,0 +1,115 @@
1
+ #if defined __clang__
2
+ # define UNUSED(name) __unused name
3
+ #else
4
+ # define UNUSED(name) name
5
+ #endif
6
+
7
+ #if SIZEOF_INT == 4
8
+ #define cI numo_cInt32
9
+ #define cUI numo_cUInt32
10
+ #elif SIZEOF_INT==8
11
+ #define cI numo_cInt64
12
+ #define cUI numo_cUInt64
13
+ #endif
14
+
15
+ #if SIZEOF_SIZE_T == 4
16
+ #define cSZ numo_cUInt32
17
+ #define cSSZ numo_cInt32
18
+ #elif SIZEOF_SIZE_T == 8
19
+ #define cSZ numo_cUInt64
20
+ #define cSSZ numo_cInt64
21
+ #endif
22
+
23
+ #define cDF numo_cDFloat
24
+ #define cDC numo_cDComplex
25
+ #define cSF numo_cSFloat
26
+ #define cSC numo_cSComplex
27
+ #define cInt cI
28
+ #define cUInt cUI
29
+
30
+ extern VALUE na_expand_dims(VALUE self, VALUE vdim);
31
+
32
+ #define max_(m,n) (((m)>(n)) ? (m):(n))
33
+ #define min_(m,n) (((m)<(n)) ? (m):(n))
34
+
35
+ #define ROW_SIZE(na) ((na)->shape[(na)->ndim-2])
36
+ #define COL_SIZE(na) ((na)->shape[(na)->ndim-1])
37
+
38
+ #define CHECK_NARRAY_TYPE(x,t) \
39
+ if (CLASS_OF(x)!=(t)) { \
40
+ rb_raise(rb_eTypeError,"invalid NArray type (class)"); \
41
+ }
42
+
43
+ // Error Class ??
44
+ #define CHECK_DIM_GE(na,nd) \
45
+ if ((na)->ndim<(nd)) { \
46
+ rb_raise(nary_eShapeError, \
47
+ "n-dimension=%d, but >=%d is expected", \
48
+ (na)->ndim, (nd)); \
49
+ }
50
+
51
+ #define CHECK_DIM_EQ(na1,nd) \
52
+ if ((na1)->ndim != (nd)) { \
53
+ rb_raise(nary_eShapeError, \
54
+ "dimention mismatch: %d != %d", \
55
+ (na1)->ndim, (nd)); \
56
+ }
57
+
58
+ #define CHECK_SQUARE(name,na) \
59
+ if ((na)->shape[(na)->ndim-1] != (na)->shape[(na)->ndim-2]) { \
60
+ rb_raise(nary_eShapeError,"%s is not square matrix",name); \
61
+ }
62
+
63
+ #define CHECK_SIZE_GE(na,sz) \
64
+ if ((na)->size < (size_t)(sz)) { \
65
+ rb_raise(nary_eShapeError, \
66
+ "NArray size must be >= %"SZF"u",(size_t)(sz));\
67
+ }
68
+
69
+ #define CHECK_NON_EMPTY(na) \
70
+ if ((na)->size==0) { \
71
+ rb_raise(nary_eShapeError,"empty NArray"); \
72
+ }
73
+
74
+ #define CHECK_SIZE_EQ(n,m) \
75
+ if ((n)!=(m)) { \
76
+ rb_raise(nary_eShapeError, \
77
+ "size mismatch: %"SZF"d != %"SZF"d", \
78
+ (size_t)(n),(size_t)(m)); \
79
+ }
80
+
81
+ #define CHECK_SAME_SHAPE(na1,na2) \
82
+ { int i; \
83
+ CHECK_DIM_EQ(na1,na2->ndim); \
84
+ for (i=0; i<na1->ndim; i++) { \
85
+ CHECK_SIZE_EQ(na1->shape[i],na2->shape[i]); \
86
+ } \
87
+ }
88
+
89
+ #define CHECK_INT_EQ(sm,m,sn,n) \
90
+ if ((m) != (n)) { \
91
+ rb_raise(nary_eShapeError, \
92
+ "%s must be == %s: %s=%d %s=%d", \
93
+ sm,sn,sm,m,sn,n); \
94
+ }
95
+
96
+ // Error Class ??
97
+ #define CHECK_LEADING_GE(sld,ld,sn,n) \
98
+ if ((ld) < (n)) { \
99
+ rb_raise(nary_eShapeError, \
100
+ "%s must be >= max(%s,1): %s=%d %s=%d", \
101
+ sld,sn,sld,ld,sn,n); \
102
+ }
103
+
104
+ #define COPY_OR_CAST_TO(a,T) \
105
+ { \
106
+ if (CLASS_OF(a) == (T)) { \
107
+ if (!TEST_INPLACE(a)) { \
108
+ a = na_copy(a); \
109
+ } \
110
+ } else { \
111
+ a = rb_funcall(T,rb_intern("cast"),1,a); \
112
+ } \
113
+ }
114
+
115
+ #define swap(a,b) {tmp=a;a=b;b=tmp;}
@@ -0,0 +1,3 @@
1
+ require "numo/linalg/linalg"
2
+
3
+ Numo::Linalg::Loader.load_library
@@ -0,0 +1,1008 @@
1
+ module Numo; module Linalg
2
+
3
+ module Blas
4
+
5
+ FIXNAME =
6
+ {
7
+ cnrm2: :csnrm2,
8
+ znrm2: :dznrm2,
9
+ }
10
+
11
+ # Call BLAS function prefixed with BLAS char ([sdcz])
12
+ # defined from data-types of arguments.
13
+ # @param [Symbol] func function name without BLAS char.
14
+ # @param args arguments passed to Blas function.
15
+ # @example
16
+ # c = Numo::Linalg::Blas.call(:gemm, a, b)
17
+ def self.call(func,*args)
18
+ fn = (Linalg.blas_char(*args) + func.to_s).to_sym
19
+ fn = FIXNAME[fn] || fn
20
+ send(fn,*args)
21
+ end
22
+
23
+ end
24
+
25
+ module Lapack
26
+
27
+ FIXNAME =
28
+ {
29
+ corgqr: :cungqr,
30
+ zorgqr: :zungqr,
31
+ }
32
+
33
+ # Call LAPACK function prefixed with BLAS char ([sdcz])
34
+ # defined from data-types of arguments.
35
+ # @param [Symbol,String] func function name without BLAS char.
36
+ # @param args arguments passed to Lapack function.
37
+ # @example
38
+ # s = Numo::Linalg::Lapack.call(:gesv, a)
39
+ def self.call(func,*args)
40
+ fn = (Linalg.blas_char(*args) + func.to_s).to_sym
41
+ fn = FIXNAME[fn] || fn
42
+ send(fn,*args)
43
+ end
44
+
45
+ end
46
+
47
+ BLAS_CHAR =
48
+ {
49
+ SFloat => "s",
50
+ DFloat => "d",
51
+ SComplex => "c",
52
+ DComplex => "z",
53
+ }
54
+
55
+ module_function
56
+
57
+ def blas_char(*args)
58
+ t = Float
59
+ args.each do |a|
60
+ k =
61
+ case a
62
+ when NArray
63
+ a.class
64
+ when Array
65
+ NArray.array_type(a)
66
+ end
67
+ if k && k < NArray
68
+ t = k::UPCAST[t]
69
+ end
70
+ end
71
+ BLAS_CHAR[t] || raise(TypeError,"invalid data type for BLAS/LAPACK")
72
+ end
73
+
74
+ # module methods
75
+
76
+ ## Matrix and vector products
77
+
78
+ # Dot product.
79
+ # @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
80
+ # @param b [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
81
+ # @return [Numo::NArray] result of dot product
82
+ def dot(a, b)
83
+ a = NArray.asarray(a)
84
+ b = NArray.asarray(b)
85
+ case a.ndim
86
+ when 1
87
+ case b.ndim
88
+ when 1
89
+ Blas.call(:dot, a, b)
90
+ else
91
+ Blas.call(:gemv, b, a, trans:'t')
92
+ end
93
+ else
94
+ case b.ndim
95
+ when 1
96
+ Blas.call(:gemv, a, b)
97
+ else
98
+ Blas.call(:gemm, a, b)
99
+ end
100
+ end
101
+ end
102
+
103
+ # Matrix product.
104
+ # @param a [Numo::NArray] matrix (>= 2-dimensinal NArray)
105
+ # @param b [Numo::NArray] matrix (>= 2-dimensinal NArray)
106
+ # @return [Numo::NArray] result of matrix product
107
+ def matmul(a, b)
108
+ Blas.call(:gemm, a, b)
109
+ end
110
+
111
+ # Compute a square matrix `a` to the power `n`.
112
+ #
113
+ # * If n > 0: return `a**n`.
114
+ # * If n == 0: return identity matrix.
115
+ # * If n < 0: return `(a*\*-1)*\*n.abs`.
116
+ #
117
+ # @param a [Numo::NArray] square matrix (>= 2-dimensinal NArray).
118
+ # @param n [Integer] the exponent.
119
+ # @example
120
+ # i = Numo::DFloat[[0, 1], [-1, 0]]
121
+ # => Numo::DFloat#shape=[2,2]
122
+ # [[0, 1],
123
+ # [-1, 0]]
124
+ # Numo::Linalg.matrix_power(i,3)
125
+ # => Numo::DFloat#shape=[2,2]
126
+ # [[0, -1],
127
+ # [1, 0]]
128
+ # Numo::Linalg.matrix_power(i,0)
129
+ # => Numo::DFloat#shape=[2,2]
130
+ # [[1, 0],
131
+ # [0, 1]]
132
+ # Numo::Linalg.matrix_power(i,-3)
133
+ # => Numo::DFloat#shape=[2,2]
134
+ # [[0, 1],
135
+ # [-1, 0]]
136
+ #
137
+ # q = Numo::DFloat.zeros(4,4)
138
+ # q[0..1,0..1] = -i
139
+ # q[2..3,2..3] = i
140
+ # q
141
+ # => Numo::DFloat#shape=[4,4]
142
+ # [[-0, -1, 0, 0],
143
+ # [1, -0, 0, 0],
144
+ # [0, 0, 0, 1],
145
+ # [0, 0, -1, 0]]
146
+ # Numo::Linalg.matrix_power(q,2)
147
+ # => Numo::DFloat#shape=[4,4]
148
+ # [[-1, 0, 0, 0],
149
+ # [0, -1, 0, 0],
150
+ # [0, 0, -1, 0],
151
+ # [0, 0, 0, -1]]
152
+
153
+ def matrix_power(a, n)
154
+ a = NArray.asarray(a)
155
+ m,k = a.shape[-2..-1]
156
+ unless m==k
157
+ raise NArray::ShapeError, "input must be a square array"
158
+ end
159
+ unless Integer===n
160
+ raise ArgumentError, "exponent must be an integer"
161
+ end
162
+ if n == 0
163
+ return a.class.eye(m)
164
+ elsif n < 0
165
+ a = inv(a)
166
+ n = n.abs
167
+ end
168
+ if n <= 3
169
+ r = a
170
+ (n-1).times do
171
+ r = matmul(r,a)
172
+ end
173
+ else
174
+ while (n & 1) == 0
175
+ a = matmul(a,a)
176
+ n >>= 1
177
+ end
178
+ r = a
179
+ while n != 0
180
+ a = matmul(a,a)
181
+ n >>= 1
182
+ if (n & 1) != 0
183
+ r = matmul(r,a)
184
+ end
185
+ end
186
+ end
187
+ r
188
+ end
189
+
190
+
191
+ ## factorization
192
+
193
+ # Computes a QR factorization of a complex M-by-N matrix A: A = Q \* R.
194
+ #
195
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
196
+ # @param mode [String]
197
+ # - "reduce" -- returns both Q and R,
198
+ # - "r" -- returns only R,
199
+ # - "economy" -- returns both Q and R but computed in economy-size,
200
+ # - "raw" -- returns QR and TAU used in LAPACK.
201
+ # @return [r] if mode:"r"
202
+ # @return [[q,r]] if mode:"reduce" or "economic"
203
+ # @return [[qr,tau]] if mode:"raw" (LAPACK geqrf result)
204
+
205
+ def qr(a, mode:"reduce")
206
+ qr,tau, = Lapack.call(:geqrf, a)
207
+ *shp,m,n = qr.shape
208
+ r = (m >= n && %w[economic raw].include?(mode)) ?
209
+ qr[false, 0...n, true].triu : qr.triu
210
+ mode = mode.to_s.downcase
211
+ case mode
212
+ when "r"
213
+ return r
214
+ when "raw"
215
+ return [qr,tau]
216
+ when "reduce","economic"
217
+ # skip
218
+ else
219
+ raise ArgumentError, "invalid mode:#{mode}"
220
+ end
221
+ if m < n
222
+ q, = Lapack.call(:orgqr, qr[false, 0...m], tau)
223
+ elsif mode == "economic"
224
+ q, = Lapack.call(:orgqr, qr, tau)
225
+ else
226
+ qqr = qr.class.zeros(*(shp+[m,m]))
227
+ qqr[false,0...n] = qr
228
+ q, = Lapack.call(:orgqr, qqr, tau)
229
+ end
230
+ return [q,r]
231
+ end
232
+
233
+
234
+ # Computes the Singular Value Decomposition (SVD) of a M-by-N matrix A,
235
+ # and the left and/or right singular vectors. The SVD is written
236
+ #
237
+ # A = U * SIGMA * transpose(V)
238
+ #
239
+ # where SIGMA is an M-by-N matrix which is zero except for its
240
+ # min(m,n) diagonal elements, U is an M-by-M orthogonal matrix, and
241
+ # V is an N-by-N orthogonal matrix. The diagonal elements of SIGMA
242
+ # are the singular values of A; they are real and non-negative, and
243
+ # are returned in descending order. The first min(m,n) columns of U
244
+ # and V are the left and right singular vectors of A. Note that the
245
+ # routine returns V**T, not V.
246
+ #
247
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
248
+ # @param driver [String or Symbol] choose LAPACK solver from 'svd',
249
+ # 'sdd'. (optional, default='svd')
250
+ # @param job [String or Symbol]
251
+ # - 'A': all M columns of U and all N rows of V\*\*T are returned in
252
+ # the arrays U and VT.
253
+ # - 'S': the first min(M,N) columns of U and the first min(M,N)
254
+ # rows of V\*\*T are returned in the arrays U and VT.
255
+ # - 'N': no columns of U or rows of V\*\*T are computed.
256
+ # @return [[sigma,u,vt]] SVD result. Array<Numo::NArray>
257
+
258
+ def svd(a, driver:'svd', job:'A')
259
+ unless /^[ASN]/i =~ job
260
+ raise ArgumentError, "invalid job: #{job.inspect}"
261
+ end
262
+ case driver.to_s
263
+ when /^(ge)?sdd$/i, "turbo"
264
+ Lapack.call(:gesdd, a, jobz:job)[0..2]
265
+ when /^(ge)?svd$/i
266
+ Lapack.call(:gesvd, a, jobu:job, jobvt:job)[0..2]
267
+ else
268
+ raise ArgumentError, "invalid driver: #{driver}"
269
+ end
270
+ end
271
+
272
+ # Computes the Singular Values of a M-by-N matrix A.
273
+ # The SVD is written
274
+ #
275
+ # A = U * SIGMA * transpose(V)
276
+ #
277
+ # where SIGMA is an M-by-N matrix which is zero except for its
278
+ # min(m,n) diagonal elements. The diagonal elements of SIGMA
279
+ # are the singular values of A; they are real and non-negative, and
280
+ # are returned in descending order.
281
+ #
282
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
283
+ # @param driver [String or Symbol] choose LAPACK solver from 'svd',
284
+ # 'sdd'. (optional, default='svd')
285
+ # @return [Numo::NArray] returns SIGMA (singular values).
286
+
287
+ def svdvals(a, driver:'svd')
288
+ case driver.to_s
289
+ when /^(ge)?sdd$/i, "turbo"
290
+ Lapack.call(:gesdd, a, jobz:'N')[0]
291
+ when /^(ge)?svd$/i
292
+ Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')[0]
293
+ else
294
+ raise ArgumentError, "invalid driver: #{driver}"
295
+ end
296
+ end
297
+
298
+
299
+ # Computes an LU factorization of a M-by-N matrix A
300
+ # using partial pivoting with row interchanges.
301
+ #
302
+ # The factorization has the form
303
+ #
304
+ # A = P * L * U
305
+ #
306
+ # where P is a permutation matrix, L is lower triangular with unit
307
+ # diagonal elements (lower trapezoidal if m > n), and U is upper
308
+ # triangular (upper trapezoidal if m < n).
309
+ #
310
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
311
+ # @param driver [String or Symbol] choose LAPACK diriver from
312
+ # 'gen','sym','her'. (optional, default='gen')
313
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
314
+ # or ('U') lower ('L') triangle. (omitted when driver:"gen")
315
+ # @return [[lu, ipiv]]
316
+ # - **lu** [Numo::NArray] -- The factors L and U from the factorization
317
+ # `A = P*L*U`; the unit diagonal elements of L are not stored.
318
+ # - **ipiv** [Numo::NArray] -- The pivot indices; for 1 <= i <= min(M,N),
319
+ # row i of the matrix was interchanged with row IPIV(i).
320
+
321
+ def lu_fact(a, driver:"gen", uplo:"U")
322
+ case driver.to_s
323
+ when /^gen?(trf)?$/i
324
+ Lapack.call(:getrf, a)[0..1]
325
+ when /^(sym?|her?)(trf)?$/i
326
+ func = driver[0..2].downcase+"trf"
327
+ Lapack.call(func, a, uplo:uplo)[0..1]
328
+ else
329
+ raise ArgumentError, "invalid driver: #{driver}"
330
+ end
331
+ end
332
+
333
+ # Computes the inverse of a matrix using the LU factorization
334
+ # computed by Numo::Linalg.lu_fact.
335
+ #
336
+ # This method inverts U and then computes inv(A) by solving the system
337
+ #
338
+ # inv(A)*L = inv(U)
339
+ #
340
+ # for inv(A).
341
+ #
342
+ # @param lu [Numo::NArray] matrix containing the factors L and U
343
+ # from the factorization `A = P*L*U` as computed by
344
+ # Numo::Linalg.lu_fact.
345
+ # @param ipiv [Numo::NArray] The pivot indices from
346
+ # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
347
+ # interchanged with row IPIV(i).
348
+ # @param driver [String or Symbol] choose LAPACK diriver from
349
+ # 'gen','sym','her'. (optional, default='gen')
350
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
351
+ # or ('U') lower ('L') triangle. (omitted when driver:"gen")
352
+ # @return [Numo::NArray] the inverse of the original matrix A.
353
+
354
+ def lu_inv(lu, ipiv, driver:"gen", uplo:"U")
355
+ case driver.to_s
356
+ when /^gen?(tri)?$/i
357
+ Lapack.call(:getri, lu, ipiv)[0]
358
+ when /^(sym?|her?)(tri)?$/i
359
+ func = driver[0..2].downcase+"tri"
360
+ Lapack.call(func, lu, ipiv, uplo:uplo)[0]
361
+ else
362
+ raise ArgumentError, "invalid driver: #{driver}"
363
+ end
364
+ end
365
+
366
+ # Solves a system of linear equations
367
+ #
368
+ # A * X = B or A**T * X = B
369
+ #
370
+ # with a N-by-N matrix A using the LU factorization computed by
371
+ # Numo::Linalg.lu_fact
372
+ #
373
+ # @param lu [Numo::NArray] matrix containing the factors L and U
374
+ # from the factorization `A = P*L*U` as computed by
375
+ # Numo::Linalg.lu_fact.
376
+ # @param ipiv [Numo::NArray] The pivot indices from
377
+ # Numo::Linalg.lu_fact; for 1<=i<=N, row i of the matrix was
378
+ # interchanged with row IPIV(i).
379
+ # @param b [Numo::NArray] the right hand side matrix B.
380
+ # @param driver [String or Symbol] choose LAPACK diriver from
381
+ # 'gen','sym','her'. (optional, default='gen')
382
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
383
+ # or ('U') lower ('L') triangle. (omitted when driver:"gen")
384
+ # @param trans [String or Symbol]
385
+ # Specifies the form of the system of equations
386
+ # (omitted if not driver:"gen"):
387
+ #
388
+ # - If 'N': `A * X = B` (No transpose).
389
+ # - If 'T': `A*\*T* X = B` (Transpose).
390
+ # - If 'C': `A*\*T* X = B` (Conjugate transpose = Transpose).
391
+ # @return [Numo::NArray] the solution matrix X.
392
+
393
+ def lu_solve(lu, ipiv, b, driver:"gen", uplo:"U", trans:"N")
394
+ case driver.to_s
395
+ when /^gen?(trs)?$/i
396
+ Lapack.call(:getrs, lu, ipiv, b, trans:trans)[0]
397
+ when /^(sym?|her?)(trs)?$/i
398
+ func = driver[0..2].downcase+"trs"
399
+ Lapack.call(func, lu, ipiv, b, uplo:uplo)[0]
400
+ else
401
+ raise ArgumentError, "invalid driver: #{driver}"
402
+ end
403
+ end
404
+
405
+
406
+ # Computes the Cholesky factorization of a symmetric/Hermitian
407
+ # positive definite matrix A. The factorization has the form
408
+ #
409
+ # A = U**H * U, if UPLO = 'U', or
410
+ # A = L * L**H, if UPLO = 'L',
411
+ #
412
+ # where U is an upper triangular matrix and L is lower triangular
413
+ # @param a [Numo::NArray] n-by-n symmetric matrix A (>= 2-dimensinal NArray)
414
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
415
+ # or ('U') lower ('L') triangle.
416
+ # @return [Numo::NArray] the factor U or L.
417
+
418
+ def cho_fact(a, uplo:'U')
419
+ Lapack.call(:potrf, a, uplo:uplo)[0]
420
+ end
421
+ #alias cholesky cho_fact
422
+
423
+ # Computes the inverse of a symmetric/Hermitian
424
+ # positive definite matrix A using the Cholesky factorization
425
+ # `A = U**T*U` or `A = L*L**T` computed by Linalg.cho_fact.
426
+ #
427
+ # @param a [Numo::NArray] the triangular factor U or L from the
428
+ # Cholesky factorization, as computed by Linalg.cho_fact.
429
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
430
+ # or ('U') lower ('L') triangle.
431
+ # @return [Numo::NArray] the upper or lower triangle of the
432
+ # (symmetric) inverse of A.
433
+
434
+ def cho_inv(a, uplo:'U')
435
+ Lapack.call(:potri, a, uplo:uplo)[0]
436
+ end
437
+
438
+ # Solves a system of linear equations
439
+ # A*X = B
440
+ # with a symmetric/Hermitian positive definite matrix A
441
+ # using the Cholesky factorization
442
+ # `A = U**T*U` or `A = L*L**T` computed by Linalg.cho_fact.
443
+ # @param a [Numo::NArray] the triangular factor U or L from the
444
+ # Cholesky factorization, as computed by Linalg.cho_fact.
445
+ # @param b [Numo::NArray] the right hand side matrix B.
446
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
447
+ # or ('U') lower ('L') triangle.
448
+ # @return [Numo::NArray] the solution matrix X.
449
+
450
+ def cho_solve(a, b, uplo:'U')
451
+ Lapack.call(:potrs, a, b, uplo:uplo)[0]
452
+ end
453
+
454
+
455
+ ## Matrix eigenvalues
456
+
457
+ # Computes the eigenvalues and, optionally, the left and/or right
458
+ # eigenvectors for a square nonsymmetric matrix A.
459
+ #
460
+ # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
461
+ # @param left [Bool] (optional) If true, left eigenvectors are computed.
462
+ # @param right [Bool] (optional) If true, right eigenvectors are computed.
463
+ # @return [[w,vl,vr]]
464
+ # - **w** [Numo::NArray] -- The eigenvalues.
465
+ # - **vl** [Numo::NArray] -- The left eigenvectors if left is true, otherwise nil.
466
+ # - **vr** [Numo::NArray] -- The right eigenvectors if right is true, otherwise nil.
467
+
468
+ def eig(a, left:false, right:true)
469
+ jobvl, jobvr = left, right
470
+ case blas_char(a)
471
+ when /c|z/
472
+ w, vl, vr, info = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
473
+ else
474
+ wr, wi, vl, vr, info = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
475
+ w = wr + wi * Complex::I
476
+ vl = _make_complex_eigvecs(w,vl) if left
477
+ vr = _make_complex_eigvecs(w,vr) if right
478
+ end
479
+ [w,vl,vr] #.compact
480
+ end
481
+
482
+ # Computes the eigenvalues and, optionally, the left and/or right
483
+ # eigenvectors for a square symmetric/hermitian matrix A.
484
+ #
485
+ # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
486
+ # @param values_only [Bool] (optional) If false, eigenvectors are computed.
487
+ # @param uplo [String or Symbol] (optional, default='U')
488
+ # Access upper ('U') or lower ('L') triangle.
489
+ # @return [[w,v]]
490
+ # - **w** [Numo::NArray] -- The eigenvalues.
491
+ # - **v** [Numo::NArray] -- The eigenvectors if vals_only is false, otherwise nil.
492
+
493
+ def eigh(a, vals_only:false, uplo:false, turbo:false)
494
+ jobz = vals_only ? 'N' : 'V' # jobz: Compute eigenvalues and eigenvectors.
495
+ case blas_char(a)
496
+ when /c|z/
497
+ func = turbo ? :hegv : :heev
498
+ else
499
+ func = turbo ? :sygv : :syev
500
+ end
501
+ w, v, = Lapack.call(func, a, uplo:uplo, jobz:jobz)
502
+ [w,v] #.compact
503
+ end
504
+
505
+ # Computes the eigenvalues only for a square nonsymmetric matrix A.
506
+ #
507
+ # @param a [Numo::NArray] square nonsymmetric matrix (>= 2-dimensinal NArray)
508
+ # @return [Numo::NArray] eigenvalues
509
+
510
+ def eigvals(a)
511
+ jobvl, jobvr = 'N','N'
512
+ case blas_char(a)
513
+ when /c|z/
514
+ w, = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
515
+ else
516
+ wr, wi, = Lapack.call(:geev, a, jobvl:jobvl, jobvr:jobvr)
517
+ w = wr + wi * Complex::I
518
+ end
519
+ w
520
+ end
521
+
522
+ # Computes the eigenvalues for a square symmetric/hermitian matrix A.
523
+ #
524
+ # @param a [Numo::NArray] square symmetric/hermitian matrix
525
+ # (>= 2-dimensinal NArray)
526
+ # @param uplo [String or Symbol] (optional, default='U')
527
+ # Access upper ('U') or lower ('L') triangle.
528
+ # @return [Numo::NArray] eigenvalues
529
+
530
+ def eigvalsh(a, uplo:false, turbo:false)
531
+ jobz = 'N' # jobz: Compute eigenvalues and eigenvectors.
532
+ case blas_char(a)
533
+ when /c|z/
534
+ func = turbo ? :hegv : :heev
535
+ else
536
+ func = turbo ? :sygv : :syev
537
+ end
538
+ Lapack.call(func, a, uplo:uplo, jobz:jobz)[0]
539
+ end
540
+
541
+
542
+ ## Norms and other numbers
543
+
544
+ # Compute matrix or vector norm.
545
+ #
546
+ # | ord | matrix norm | vector norm |
547
+ # | ----- | ---------------------- | --------------------------- |
548
+ # | nil | Frobenius norm | 2-norm |
549
+ # | 'fro' | Frobenius norm | - |
550
+ # | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
551
+ # | 0 | - | (x.ne 0).sum |
552
+ # | 1 | x.abs.sum(axis:-2).max | same as below |
553
+ # | 2 | 2-norm (max sing_vals) | same as below |
554
+ # | other | - | (x.abs**ord).sum**(1.0/ord) |
555
+ #
556
+ # @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
557
+ # @param ord [String or Symbol] Order of the norm .
558
+ # @param axis [Integer or Array] Applied axes (optional).
559
+ # @param keepdims [Bool] If true, the applied axes are left in
560
+ # result with size one (optional).
561
+ # @return [Numo::NArray] norm result
562
+
563
+ def norm(a, ord=nil, axis:nil, keepdims:false)
564
+ a = Numo::NArray.asarray(a)
565
+
566
+ # check axis
567
+ if axis
568
+ case axis
569
+ when Integer
570
+ axis = [axis]
571
+ when Array
572
+ if axis.size < 1 || axis.size > 2
573
+ raise ArgmentError, "axis option should be 1- or 2-element array"
574
+ end
575
+ else
576
+ raise ArgumentError, "invalid option for axis: #{axis}"
577
+ end
578
+ # swap axes
579
+ if a.ndim > 1
580
+ idx = (0...a.ndim).to_a
581
+ tmp = []
582
+ axis.each do |i|
583
+ x = idx[i]
584
+ if x.nil?
585
+ raise ArgmentError, "axis contains same dimension"
586
+ end
587
+ tmp << x
588
+ idx[i] = nil
589
+ end
590
+ idx.compact!
591
+ idx.concat(tmp)
592
+ a = a.transpose(*idx)
593
+ end
594
+ else
595
+ case a.ndim
596
+ when 0
597
+ raise ArgumentError, "zero-dimensional array"
598
+ when 1
599
+ axis = [-1]
600
+ else
601
+ axis = [-2,-1]
602
+ end
603
+ end
604
+
605
+ # calculate norm
606
+ case axis.size
607
+
608
+ when 1 # vector
609
+ k = keepdims
610
+ ord ||= 2 # default
611
+ case ord.to_s
612
+ when "0"
613
+ r = a.class.cast(a.ne(0)).sum(axis:-1, keepdims:k)
614
+ when "1"
615
+ r = a.abs.sum(axis:-1, keepdims:k)
616
+ when "2"
617
+ r = Blas.call(:nrm2, a, keepdims:k)
618
+ when /^-?\d+$/
619
+ o = ord.to_i
620
+ r = (a.abs**o).sum(axis:-1, keepdims:k)**(1.0/o)
621
+ when /^inf(inity)?$/i
622
+ r = a.abs.max(axis:-1, keepdims:k)
623
+ when /^-inf(inity)?$/i
624
+ r = a.abs.min(axis:-1, keepdims:k)
625
+ else
626
+ raise ArgumentError, "ord (#{ord}) is invalid for vector norm"
627
+ end
628
+
629
+ when 2 # matrix
630
+ if keepdims
631
+ fixdims = [true] * a.ndim
632
+ axis.each do |i|
633
+ if i < -a.ndim || i >= a.ndim
634
+ raise ArgmentError, "axis (%d) is out of range", i
635
+ end
636
+ fixdims[i] = :new
637
+ end
638
+ end
639
+ ord ||= "fro" # default
640
+ case ord.to_s
641
+ when "1"
642
+ r, = Lapack.call(:lange, a, '1')
643
+ when "-1"
644
+ r = a.abs.sum(axis:-2).min(axis:-1)
645
+ when "2"
646
+ svd, = Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')
647
+ r = svd.max(axis:-1)
648
+ when "-2"
649
+ svd, = Lapack.call(:gesvd, a, jobu:'N', jobvt:'N')
650
+ r = svd.min(axis:-1)
651
+ when /^f(ro)?$/i
652
+ r, = Lapack.call(:lange, a, 'F')
653
+ when /^inf(inity)?$/i
654
+ r, = Lapack.call(:lange, a, 'I')
655
+ when /^-inf(inity)?$/i
656
+ r = a.abs.sum(axis:-1).min(axis:-1)
657
+ else
658
+ raise ArgumentError, "ord (#{ord}) is invalid for matrix norm"
659
+ end
660
+ if keepdims
661
+ if NArray===r
662
+ r = r[*fixdims]
663
+ else
664
+ r = a.class.new(1,1).store(r)
665
+ end
666
+ end
667
+ end
668
+ return r
669
+ end
670
+
671
+ # Compute the condition number of a matrix
672
+ # using the norm with one of the following order.
673
+ #
674
+ # | ord | matrix norm |
675
+ # | ----- | ---------------------- |
676
+ # | nil | 2-norm using SVD |
677
+ # | 'fro' | Frobenius norm |
678
+ # | 'inf' | x.abs.sum(axis:-1).max |
679
+ # | 1 | x.abs.sum(axis:-2).max |
680
+ # | 2 | 2-norm (max sing_vals) |
681
+ #
682
+ # @param a [Numo::NArray] matrix or vector (>= 1-dimensinal NArray)
683
+ # @param ord [String or Symbol] Order of the norm.
684
+ # @return [Numo::NArray] cond result
685
+ # @example
686
+ # a = Numo::DFloat[[1, 0, -1], [0, 1, 0], [1, 0, 1]]
687
+ # => Numo::DFloat#shape=[3,3]
688
+ # [[1, 0, -1],
689
+ # [0, 1, 0],
690
+ # [1, 0, 1]]
691
+ # LA = Numo::Linalg
692
+ # LA.cond(a)
693
+ # => 1.4142135623730951
694
+ # LA.cond(a, 'fro')
695
+ # => 3.1622776601683795
696
+ # LA.cond(a, 'inf')
697
+ # => 2.0
698
+ # LA.cond(a, '-inf')
699
+ # => 1.0
700
+ # LA.cond(a, 1)
701
+ # => 2.0
702
+ # LA.cond(a, -1)
703
+ # => 1.0
704
+ # LA.cond(a, 2)
705
+ # => 1.4142135623730951
706
+ # LA.cond(a, -2)
707
+ # => 0.7071067811865475
708
+ # (LA.svdvals(a)).min*(LA.svdvals(LA.inv(a))).min
709
+ # => 0.7071067811865475
710
+
711
+ def cond(a,ord=nil)
712
+ if ord.nil?
713
+ s = svdvals(a)
714
+ s[false, 0]/s[false, -1]
715
+ else
716
+ norm(a, ord, axis:[-2,-1]) * norm(inv(a), ord, axis:[-2,-1])
717
+ end
718
+ end
719
+
720
+ # Determinant of a matrix
721
+ #
722
+ # @param a [Numo::NArray] matrix (>= 2-dimensional NArray)
723
+ # @return [Float or Complex or Numo::NArray]
724
+
725
+ def det(a)
726
+ lu, piv, = Lapack.call(:getrf, a)
727
+ idx = piv.new_narray.store(piv.class.new(piv.shape[-1]).seq(1))
728
+ m = piv.eq(idx).count_false(axis:-1) % 2
729
+ sign = m * -2 + 1
730
+ lu.diagonal.prod(axis:-1) * sign
731
+ end
732
+
733
+ # Natural logarithm of the determinant of a matrix
734
+ #
735
+ # @param a [Numo::NArray] matrix (>= 2-dimensional NArray)
736
+ # @return [[sign,logdet]]
737
+ # - **sign** -- A number representing the sign of the determinant.
738
+ # - **logdet** -- The natural log of the absolute value of the determinant.
739
+
740
+ def slogdet(a)
741
+ lu, piv, = Lapack.call(:getrf, a)
742
+ idx = piv.new_narray.store(piv.class.new(piv.shape[-1]).seq(1))
743
+ m = piv.eq(idx).count_false(axis:-1) % 2
744
+ sign = m * -2 + 1
745
+
746
+ lud = lu.diagonal
747
+ if (lud.eq 0).any?
748
+ return 0, (-Float::INFINITY)
749
+ end
750
+ lud_abs = lud.abs
751
+ sign *= (lud/lud_abs).prod
752
+ [sign, NMath.log(lud_abs).sum(axis:-1)]
753
+ end
754
+
755
+ # Compute matrix rank of array using SVD
756
+ # *Rank* is the number of singular values greater than *tol*.
757
+ #
758
+ # @param m [Numo::NArray] matrix (>= 2-dimensional NArray)
759
+ # @param tol [Float] threshold below which singular values are
760
+ # considered to be zero. If *tol* is nil,
761
+ # `tol = sing_vals.max() * m.shape.max * EPSILON`.
762
+ # @param driver [String or Symbol] choose LAPACK solver from 'svd',
763
+ # 'sdd'. (optional, default='svd')
764
+
765
+ def matrix_rank(m, tol:nil, driver:'svd')
766
+ m = Numo::NArray.asarray(m)
767
+ if m.ndim < 2
768
+ m.ne(0).any? ? 1 : 0
769
+ else
770
+ case driver.to_s
771
+ when /^(ge)?sdd$/, "turbo"
772
+ s = Lapack.call(:gesdd, m, jobz:'N')[0]
773
+ when /^(ge)?svd$/
774
+ s = Lapack.call(:gesvd, m, jobu:'N', jobvt:'N')[0]
775
+ else
776
+ raise ArgumentError, "invalid driver: #{driver}"
777
+ end
778
+ tol ||= s.max(axis:-1, keepdims:true) *
779
+ (m.shape[-2..-1].max * s.class::EPSILON)
780
+ (s > tol).count(axis:-1)
781
+ end
782
+ end
783
+
784
+
785
+ ## Solving equations and inverting matrices
786
+
787
+ # Solves linear equation `a * x = b` for `x`
788
+ # from square matrix `a`
789
+ # @param a [Numo::NArray] n-by-n square matrix (>= 2-dimensinal NArray)
790
+ # @param b [Numo::NArray] n-by-nrhs right-hand-side matrix (>=
791
+ # 1-dimensinal NArray)
792
+ # @param driver [String or Symbol] choose LAPACK diriver from
793
+ # 'gen','sym','her' or 'pos'. (optional, default='gen')
794
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
795
+ # or ('U') lower ('L') triangle. (omitted when driver:"gen")
796
+ # @return [Numo::NArray] The solusion matrix/vector X.
797
+
798
+ def solve(a, b, driver:"gen", uplo:'U')
799
+ case driver.to_s
800
+ when /^gen?(sv)?$/i
801
+ # returns lu, x, ipiv, info
802
+ Lapack.call(:gesv, a, b)[1]
803
+ when /^(sym?|her?|pos?)(sv)?$/i
804
+ func = driver[0..2].downcase+"sv"
805
+ Lapack.call(func, a, b, uplo:uplo)[1]
806
+ else
807
+ raise ArgumentError, "invalid driver: #{driver}"
808
+ end
809
+ end
810
+
811
+ # Inverse matrix from square matrix `a`
812
+ # @param a [Numo::NArray] n-by-n square matrix (>= 2-dimensinal NArray)
813
+ # @param driver [String or Symbol] choose LAPACK diriver
814
+ # ('ge'|'sy'|'he'|'po') + ("sv"|"trf")
815
+ # (optional, default='getrf')
816
+ # @param uplo [String or Symbol] optional, default='U'. Access upper
817
+ # or ('U') lower ('L') triangle. (omitted when driver:"ge")
818
+ # @return [Numo::NArray] The inverse matrix.
819
+ # @example
820
+ # Numo::Linalg.inv(a,driver:'getrf')
821
+ # => Numo::DFloat#shape=[2,2]
822
+ # [[-2, 1],
823
+ # [1.5, -0.5]]
824
+ # a.dot(Numo::Linalg.inv(a,driver:'getrf'))
825
+ # => Numo::DFloat#shape=[2,2]
826
+ # [[1, 0],
827
+ # [8.88178e-16, 1]]
828
+
829
+ def inv(a, driver:"getrf", uplo:'U')
830
+ case driver
831
+ when /(ge|sy|he|po)sv$/
832
+ d = $1
833
+ b = a.new_zeros.eye
834
+ solve(a, b, driver:d, uplo:uplo)
835
+ when /(ge|sy|he)tr[fi]$/
836
+ d = $1
837
+ lu, piv = lu_fact(a, driver:d, uplo:uplo)
838
+ lu_inv(lu, piv, driver:d, uplo:uplo)
839
+ when /potr[fi]$/
840
+ lu = cho_fact(a, uplo:uplo)
841
+ cho_inv(lu, uplo:uplo)
842
+ else
843
+ raise ArgumentError, "invalid driver: #{driver}"
844
+ end
845
+ end
846
+
847
+ # Computes the minimum-norm solution to a linear least squares
848
+ # problem:
849
+ #
850
+ # minimize 2-norm(| b - A*x |)
851
+ #
852
+ # using the singular value decomposition (SVD) of A.
853
+ # A is an M-by-N matrix which may be rank-deficient.
854
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
855
+ # @param b [Numo::NArray] m-by-nrhs right-hand-side matrix b
856
+ # (>= 1-dimensinal NArray)
857
+ # @param driver [String or Symbol] choose LAPACK driver from
858
+ # 'lsd','lss','lsy' (optional, default='lsd')
859
+ # @param rcond [Float] (optional, default=-1)
860
+ # RCOND is used to determine the effective rank of A.
861
+ # Singular values `S(i) <= RCOND*S(1)` are treated as zero.
862
+ # If RCOND < 0, machine precision is used instead.
863
+ # @return [[x, resids, rank, s]]
864
+ # - **x** -- The solution matrix/vector X.
865
+ # - **resids** -- Sums of residues, squared 2-norm for each column in
866
+ # `b - a x`. If matrix_rank(a) < N or > M, or 'gelsy' is used,
867
+ # this is an empty array.
868
+ # - **rank** -- The effective rank of A, i.e.,
869
+ # the number of singular values which are greater than RCOND*S(1).
870
+ # - **s** -- The singular values of A in decreasing order.
871
+ # Returns nil if 'gelsy' is used.
872
+
873
+ def lstsq(a, b, driver:'lsd', rcond:-1)
874
+ a = NArray.asarray(a)
875
+ b = NArray.asarray(b)
876
+ b_orig = nil
877
+ if b.shape.size==1
878
+ b_orig = b
879
+ b = b_orig[true,:new]
880
+ end
881
+ m = a.shape[-2]
882
+ n = a.shape[-1]
883
+ #nrhs = b.shape[-1]
884
+ if m != b.shape[-2]
885
+ raise NArray::ShapeError, "size mismatch: A-row and B-row"
886
+ end
887
+ if m < n # need to extend b matrix
888
+ shp = b.shape
889
+ shp[-2] = n
890
+ b2 = b.class.zeros(*shp)
891
+ b2[false,0...m,true] = b
892
+ b = b2
893
+ end
894
+ case driver.to_s
895
+ when /^(ge)?lsd$/i
896
+ # x, s, rank, info
897
+ x, s, rank, = Lapack.call(:gelsd, a, b, rcond:rcond)
898
+ when /^(ge)?lss$/i
899
+ # v, x, s, rank, info
900
+ _, x, s, rank, = Lapack.call(:gelss, a, b, rcond:rcond)
901
+ when /^(ge)?lsy$/i
902
+ jpvt = Int32.zeros(*a[false,0,true].shape)
903
+ # v, x, jpvt, rank, info
904
+ _, x, _, rank, = Lapack.call(:gelsy, a, b, jpvt, rcond:rcond)
905
+ s = nil
906
+ else
907
+ raise ArgumentError, "invalid driver: #{driver}"
908
+ end
909
+ resids = nil
910
+ if m > n
911
+ if /ls(d|s)$/i =~ driver
912
+ case rank
913
+ when n
914
+ resids = (x[n..-1,true].abs**2).sum(axis:0)
915
+ when NArray
916
+ if true
917
+ resids = (x[false,n..-1,true].abs**2).sum(axis:-2)
918
+ else
919
+ resids = x[false,0,true].new_zeros
920
+ mask = rank.eq(n)
921
+ # NArray does not suppurt this yet.
922
+ resids[mask,true] = (x[mask,n..-1,true].abs**2).sum(axis:-2)
923
+ end
924
+ end
925
+ end
926
+ x = x[false,0...n,true]
927
+ end
928
+ if b_orig && b_orig.shape.size==1
929
+ x = x[true,0]
930
+ resids &&= resids[false,0]
931
+ end
932
+ [x, resids, rank, s]
933
+ end
934
+
935
+ # Compute the (Moore-Penrose) pseudo-inverse of a matrix
936
+ # using svd or lstsq.
937
+ #
938
+ # @param a [Numo::NArray] m-by-n matrix A (>= 2-dimensinal NArray)
939
+ # @param driver [String or Symbol] choose LAPACK driver from
940
+ # SVD ('svd', 'sdd') or Least square ('lsd','lss','lsy')
941
+ # (optional, default='svd')
942
+ # @param rcond [Float] (optional, default=-1)
943
+ # RCOND is used to determine the effective rank of A.
944
+ # Singular values `S(i) <= RCOND*S(1)` are treated as zero.
945
+ # If RCOND < 0, machine precision is used instead.
946
+ # @return [Numo::NArray]
947
+ # @example
948
+ # a = Numo::DFloat.new(5,3).rand_norm
949
+ # => Numo::DFloat#shape=[5,3]
950
+ # [[-0.581255, -0.168354, 0.586895],
951
+ # [-0.595142, -0.802802, -0.326106],
952
+ # [0.282922, 1.68427, 0.918499],
953
+ # [-0.0485384, -0.464453, -0.992194],
954
+ # [0.413794, -0.60717, -0.699695]]
955
+ # b = Numo::Linalg.pinv(a,driver:"svd")
956
+ # => Numo::DFloat(view)#shape=[3,5]
957
+ # [[-0.360863, -0.813125, -0.353367, -0.891963, 0.877253],
958
+ # [-0.227645, 0.162939, 0.696655, 0.787685, -0.469346],
959
+ # [0.408671, -0.308323, -0.337807, -1.13833, 0.228051]]
960
+ # (a-a.dot(b.dot(a))).abs.max
961
+ # => 5.551115123125783e-16
962
+
963
+ def pinv(a, driver:"svd", rcond:nil)
964
+ a = NArray.asarray(a)
965
+ if a.ndim < 2
966
+ raise NArray::ShapeError, "2-d array is required"
967
+ end
968
+ case driver
969
+ when /^(ge)?s[dv]d$/
970
+ s, u, vh = svd(a, driver:driver, job:'S')
971
+ if rcond.nil? || rcond < 0
972
+ rcond = ((SFloat===s) ? 1e3 : 1e6) * s.class::EPSILON
973
+ elsif ! Numeric === rcond
974
+ raise ArgumentError, "rcond must be Numeric"
975
+ end
976
+ cond = (s > rcond * s.max(axis:-1, keepdims:true))
977
+ if cond.all?
978
+ r = s.reciprocal
979
+ else
980
+ r = s.new_zeros
981
+ r[cond] = s[cond].reciprocal
982
+ end
983
+ u *= r[false,:new,true]
984
+ dot(u,vh).conj.swapaxes(-2,-1)
985
+ when /^(ge)?ls[dsy]$/
986
+ b = a.class.eye(a.shape[-2])
987
+ x, = lstsq(a, b, driver:driver, rcond:rcond)
988
+ x
989
+ else
990
+ raise ArgumentError, "#{driver.inspect} is not one of drivers: "+
991
+ "svd, sdd, lsd, lss, lsy"
992
+ end
993
+ end
994
+
995
+ private
996
+
997
+ # @!visibility private
998
+ def _make_complex_eigvecs(w, vin) # :nodoc:
999
+ v = w.class.cast(vin)
1000
+ # broadcast to vin.shape
1001
+ m = (w.imag > 0 | Bit.zeros(*vin.shape)).where
1002
+ v[m].imag = vin[m+1]
1003
+ v[m+1] = v[m].conj
1004
+ v
1005
+ end
1006
+
1007
+ end
1008
+ end