numo-linalg-alt 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +5 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/LICENSE.txt +27 -0
- data/README.md +106 -0
- data/ext/numo/linalg/blas/dot.c +72 -0
- data/ext/numo/linalg/blas/dot.h +13 -0
- data/ext/numo/linalg/blas/dot_sub.c +71 -0
- data/ext/numo/linalg/blas/dot_sub.h +13 -0
- data/ext/numo/linalg/blas/gemm.c +184 -0
- data/ext/numo/linalg/blas/gemm.h +16 -0
- data/ext/numo/linalg/blas/gemv.c +161 -0
- data/ext/numo/linalg/blas/gemv.h +16 -0
- data/ext/numo/linalg/blas/nrm2.c +67 -0
- data/ext/numo/linalg/blas/nrm2.h +13 -0
- data/ext/numo/linalg/converter.c +67 -0
- data/ext/numo/linalg/converter.h +23 -0
- data/ext/numo/linalg/extconf.rb +99 -0
- data/ext/numo/linalg/lapack/geev.c +152 -0
- data/ext/numo/linalg/lapack/geev.h +15 -0
- data/ext/numo/linalg/lapack/gelsd.c +92 -0
- data/ext/numo/linalg/lapack/gelsd.h +15 -0
- data/ext/numo/linalg/lapack/geqrf.c +72 -0
- data/ext/numo/linalg/lapack/geqrf.h +15 -0
- data/ext/numo/linalg/lapack/gesdd.c +108 -0
- data/ext/numo/linalg/lapack/gesdd.h +15 -0
- data/ext/numo/linalg/lapack/gesv.c +99 -0
- data/ext/numo/linalg/lapack/gesv.h +15 -0
- data/ext/numo/linalg/lapack/gesvd.c +152 -0
- data/ext/numo/linalg/lapack/gesvd.h +15 -0
- data/ext/numo/linalg/lapack/getrf.c +71 -0
- data/ext/numo/linalg/lapack/getrf.h +15 -0
- data/ext/numo/linalg/lapack/getri.c +82 -0
- data/ext/numo/linalg/lapack/getri.h +15 -0
- data/ext/numo/linalg/lapack/getrs.c +110 -0
- data/ext/numo/linalg/lapack/getrs.h +15 -0
- data/ext/numo/linalg/lapack/heev.c +71 -0
- data/ext/numo/linalg/lapack/heev.h +15 -0
- data/ext/numo/linalg/lapack/heevd.c +71 -0
- data/ext/numo/linalg/lapack/heevd.h +15 -0
- data/ext/numo/linalg/lapack/heevr.c +111 -0
- data/ext/numo/linalg/lapack/heevr.h +15 -0
- data/ext/numo/linalg/lapack/hegv.c +94 -0
- data/ext/numo/linalg/lapack/hegv.h +15 -0
- data/ext/numo/linalg/lapack/hegvd.c +94 -0
- data/ext/numo/linalg/lapack/hegvd.h +15 -0
- data/ext/numo/linalg/lapack/hegvx.c +133 -0
- data/ext/numo/linalg/lapack/hegvx.h +15 -0
- data/ext/numo/linalg/lapack/hetrf.c +68 -0
- data/ext/numo/linalg/lapack/hetrf.h +15 -0
- data/ext/numo/linalg/lapack/lange.c +66 -0
- data/ext/numo/linalg/lapack/lange.h +15 -0
- data/ext/numo/linalg/lapack/orgqr.c +79 -0
- data/ext/numo/linalg/lapack/orgqr.h +15 -0
- data/ext/numo/linalg/lapack/potrf.c +70 -0
- data/ext/numo/linalg/lapack/potrf.h +15 -0
- data/ext/numo/linalg/lapack/potri.c +70 -0
- data/ext/numo/linalg/lapack/potri.h +15 -0
- data/ext/numo/linalg/lapack/potrs.c +94 -0
- data/ext/numo/linalg/lapack/potrs.h +15 -0
- data/ext/numo/linalg/lapack/syev.c +71 -0
- data/ext/numo/linalg/lapack/syev.h +15 -0
- data/ext/numo/linalg/lapack/syevd.c +71 -0
- data/ext/numo/linalg/lapack/syevd.h +15 -0
- data/ext/numo/linalg/lapack/syevr.c +111 -0
- data/ext/numo/linalg/lapack/syevr.h +15 -0
- data/ext/numo/linalg/lapack/sygv.c +93 -0
- data/ext/numo/linalg/lapack/sygv.h +15 -0
- data/ext/numo/linalg/lapack/sygvd.c +93 -0
- data/ext/numo/linalg/lapack/sygvd.h +15 -0
- data/ext/numo/linalg/lapack/sygvx.c +133 -0
- data/ext/numo/linalg/lapack/sygvx.h +15 -0
- data/ext/numo/linalg/lapack/sytrf.c +72 -0
- data/ext/numo/linalg/lapack/sytrf.h +15 -0
- data/ext/numo/linalg/lapack/trtrs.c +99 -0
- data/ext/numo/linalg/lapack/trtrs.h +15 -0
- data/ext/numo/linalg/lapack/ungqr.c +79 -0
- data/ext/numo/linalg/lapack/ungqr.h +15 -0
- data/ext/numo/linalg/linalg.c +290 -0
- data/ext/numo/linalg/linalg.h +85 -0
- data/ext/numo/linalg/util.c +95 -0
- data/ext/numo/linalg/util.h +17 -0
- data/lib/numo/linalg/version.rb +10 -0
- data/lib/numo/linalg.rb +1309 -0
- data/vendor/tmp/.gitkeep +0 -0
- metadata +146 -0
@@ -0,0 +1,15 @@
|
|
1
|
+
#ifndef NUMO_LINALG_ALT_LAPACK_SYGVX_H
|
2
|
+
#define NUMO_LINALG_ALT_LAPACK_SYGVX_H 1
|
3
|
+
|
4
|
+
#include <lapacke.h>
|
5
|
+
|
6
|
+
#include <ruby.h>
|
7
|
+
|
8
|
+
#include <numo/narray.h>
|
9
|
+
#include <numo/template.h>
|
10
|
+
|
11
|
+
#include "../util.h"
|
12
|
+
|
13
|
+
void define_linalg_lapack_sygvx(VALUE mLapack);
|
14
|
+
|
15
|
+
#endif /* NUMO_LINALG_ALT_LAPACK_SYGVX_H */
|
@@ -0,0 +1,72 @@
|
|
1
|
+
#include "sytrf.h"
|
2
|
+
|
3
|
+
struct _sytrf_option {
|
4
|
+
int matrix_layout;
|
5
|
+
char uplo;
|
6
|
+
};
|
7
|
+
|
8
|
+
#define DEF_LINALG_FUNC(tDType, tNAryClass, fLapackFunc) \
|
9
|
+
static void _iter_##fLapackFunc(na_loop_t* const lp) { \
|
10
|
+
tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
11
|
+
lapack_int* ipiv = (lapack_int*)NDL_PTR(lp, 1); \
|
12
|
+
int* info = (int*)NDL_PTR(lp, 2); \
|
13
|
+
struct _sytrf_option* opt = (struct _sytrf_option*)(lp->opt_ptr); \
|
14
|
+
const lapack_int n = (lapack_int)NDL_SHAPE(lp, 0)[0]; \
|
15
|
+
const lapack_int lda = n; \
|
16
|
+
const lapack_int i = LAPACKE_##fLapackFunc(opt->matrix_layout, opt->uplo, n, a, lda, ipiv); \
|
17
|
+
*info = (int)i; \
|
18
|
+
} \
|
19
|
+
\
|
20
|
+
static VALUE _linalg_lapack_##fLapackFunc(int argc, VALUE* argv, VALUE self) { \
|
21
|
+
VALUE a_vnary = Qnil; \
|
22
|
+
VALUE kw_args = Qnil; \
|
23
|
+
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args); \
|
24
|
+
ID kw_tables[2] = { rb_intern("matrix_layout"), rb_intern("uplo") }; \
|
25
|
+
VALUE kw_values[2] = { Qundef, Qundef }; \
|
26
|
+
rb_get_kwargs(kw_args, kw_tables, 0, 2, kw_values); \
|
27
|
+
const int matrix_layout = kw_values[0] != Qundef && kw_values[0] != Qnil ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; \
|
28
|
+
const char uplo = kw_values[1] != Qundef && kw_values[1] != Qnil ? get_uplo(kw_values[1]) : 'U'; \
|
29
|
+
\
|
30
|
+
if (CLASS_OF(a_vnary) != tNAryClass) { \
|
31
|
+
a_vnary = rb_funcall(tNAryClass, rb_intern("cast"), 1, a_vnary); \
|
32
|
+
} \
|
33
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) { \
|
34
|
+
a_vnary = nary_dup(a_vnary); \
|
35
|
+
} \
|
36
|
+
\
|
37
|
+
narray_t* a_nary = NULL; \
|
38
|
+
GetNArray(a_vnary, a_nary); \
|
39
|
+
if (NA_NDIM(a_nary) != 2) { \
|
40
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional"); \
|
41
|
+
return Qnil; \
|
42
|
+
} \
|
43
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) { \
|
44
|
+
rb_raise(rb_eArgError, "input array a must be square"); \
|
45
|
+
return Qnil; \
|
46
|
+
} \
|
47
|
+
\
|
48
|
+
const size_t n = NA_SHAPE(a_nary)[0]; \
|
49
|
+
size_t shape[1] = { n }; \
|
50
|
+
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } }; \
|
51
|
+
ndfunc_arg_out_t aout[2] = { { numo_cInt32, 1, shape }, { numo_cInt32, 0 } }; \
|
52
|
+
ndfunc_t ndf = { _iter_##fLapackFunc, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout }; \
|
53
|
+
struct _sytrf_option opt = { matrix_layout, uplo }; \
|
54
|
+
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary); \
|
55
|
+
\
|
56
|
+
RB_GC_GUARD(a_vnary); \
|
57
|
+
return res; \
|
58
|
+
}
|
59
|
+
|
60
|
+
DEF_LINALG_FUNC(double, numo_cDFloat, dsytrf)
|
61
|
+
DEF_LINALG_FUNC(float, numo_cSFloat, ssytrf)
|
62
|
+
DEF_LINALG_FUNC(lapack_complex_double, numo_cDComplex, zsytrf)
|
63
|
+
DEF_LINALG_FUNC(lapack_complex_float, numo_cSComplex, csytrf)
|
64
|
+
|
65
|
+
#undef DEF_LINALG_FUNC
|
66
|
+
|
67
|
+
void define_linalg_lapack_sytrf(VALUE mLapack) {
|
68
|
+
rb_define_module_function(mLapack, "dsytrf", RUBY_METHOD_FUNC(_linalg_lapack_dsytrf), -1);
|
69
|
+
rb_define_module_function(mLapack, "ssytrf", RUBY_METHOD_FUNC(_linalg_lapack_ssytrf), -1);
|
70
|
+
rb_define_module_function(mLapack, "zsytrf", RUBY_METHOD_FUNC(_linalg_lapack_zsytrf), -1);
|
71
|
+
rb_define_module_function(mLapack, "csytrf", RUBY_METHOD_FUNC(_linalg_lapack_csytrf), -1);
|
72
|
+
}
|
@@ -0,0 +1,15 @@
|
|
1
|
+
#ifndef NUMO_LINALG_ALT_LAPACK_SYTRF_H
|
2
|
+
#define NUMO_LINALG_ALT_LAPACK_SYTRF_H 1
|
3
|
+
|
4
|
+
#include <lapacke.h>
|
5
|
+
|
6
|
+
#include <ruby.h>
|
7
|
+
|
8
|
+
#include <numo/narray.h>
|
9
|
+
#include <numo/template.h>
|
10
|
+
|
11
|
+
#include "../util.h"
|
12
|
+
|
13
|
+
void define_linalg_lapack_sytrf(VALUE mLapack);
|
14
|
+
|
15
|
+
#endif /* NUMO_LINALG_ALT_LAPACK_SYTRF_H */
|
@@ -0,0 +1,99 @@
|
|
1
|
+
#include "trtrs.h"
|
2
|
+
|
3
|
+
struct _trtrs_option {
|
4
|
+
int matrix_layout;
|
5
|
+
char uplo;
|
6
|
+
char trans;
|
7
|
+
char diag;
|
8
|
+
};
|
9
|
+
|
10
|
+
#define DEF_LINALG_FUNC(tDType, tNAryClass, fLapackFunc) \
|
11
|
+
static void _iter_##fLapackFunc(na_loop_t* const lp) { \
|
12
|
+
tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
13
|
+
tDType* b = (tDType*)NDL_PTR(lp, 1); \
|
14
|
+
int* info = (int*)NDL_PTR(lp, 2); \
|
15
|
+
struct _trtrs_option* opt = (struct _trtrs_option*)(lp->opt_ptr); \
|
16
|
+
const lapack_int n = (lapack_int)NDL_SHAPE(lp, 0)[0]; \
|
17
|
+
const lapack_int nrhs = lp->args[1].ndim == 1 ? 1 : (lapack_int)NDL_SHAPE(lp, 1)[1]; \
|
18
|
+
const lapack_int lda = n; \
|
19
|
+
const lapack_int ldb = nrhs; \
|
20
|
+
const lapack_int i = LAPACKE_##fLapackFunc(opt->matrix_layout, opt->uplo, opt->trans, opt->diag, n, nrhs, a, lda, b, ldb); \
|
21
|
+
*info = (int)i; \
|
22
|
+
} \
|
23
|
+
\
|
24
|
+
static VALUE _linalg_lapack_##fLapackFunc(int argc, VALUE* argv, VALUE self) { \
|
25
|
+
VALUE a_vnary = Qnil; \
|
26
|
+
VALUE b_vnary = Qnil; \
|
27
|
+
VALUE kw_args = Qnil; \
|
28
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args); \
|
29
|
+
ID kw_table[4] = { rb_intern("order"), rb_intern("uplo"), rb_intern("trans"), rb_intern("diag") }; \
|
30
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef }; \
|
31
|
+
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values); \
|
32
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; \
|
33
|
+
const char uplo = kw_values[1] != Qundef ? get_uplo(kw_values[1]) : 'U'; \
|
34
|
+
const char trans = kw_values[2] != Qundef ? NUM2CHR(kw_values[2]) : 'N'; \
|
35
|
+
const char diag = kw_values[3] != Qundef ? NUM2CHR(kw_values[3]) : 'N'; \
|
36
|
+
\
|
37
|
+
if (CLASS_OF(a_vnary) != tNAryClass) { \
|
38
|
+
a_vnary = rb_funcall(tNAryClass, rb_intern("cast"), 1, a_vnary); \
|
39
|
+
} \
|
40
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) { \
|
41
|
+
a_vnary = nary_dup(a_vnary); \
|
42
|
+
} \
|
43
|
+
if (CLASS_OF(b_vnary) != tNAryClass) { \
|
44
|
+
b_vnary = rb_funcall(tNAryClass, rb_intern("cast"), 1, b_vnary); \
|
45
|
+
} \
|
46
|
+
if (!RTEST(nary_check_contiguous(b_vnary))) { \
|
47
|
+
b_vnary = nary_dup(b_vnary); \
|
48
|
+
} \
|
49
|
+
\
|
50
|
+
narray_t* a_nary = NULL; \
|
51
|
+
GetNArray(a_vnary, a_nary); \
|
52
|
+
if (NA_NDIM(a_nary) != 2) { \
|
53
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional"); \
|
54
|
+
return Qnil; \
|
55
|
+
} \
|
56
|
+
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) { \
|
57
|
+
rb_raise(rb_eArgError, "input array a must be square"); \
|
58
|
+
return Qnil; \
|
59
|
+
} \
|
60
|
+
\
|
61
|
+
narray_t* b_nary = NULL; \
|
62
|
+
GetNArray(b_vnary, b_nary); \
|
63
|
+
const int b_n_dims = NA_NDIM(b_nary); \
|
64
|
+
if (b_n_dims != 1 && b_n_dims != 2) { \
|
65
|
+
rb_raise(rb_eArgError, "input array b must be 1- or 2-dimensional"); \
|
66
|
+
return Qnil; \
|
67
|
+
} \
|
68
|
+
\
|
69
|
+
lapack_int n = (lapack_int)NA_SHAPE(a_nary)[0]; \
|
70
|
+
lapack_int nb = (lapack_int)NA_SHAPE(b_nary)[0]; \
|
71
|
+
if (n != nb) { \
|
72
|
+
rb_raise(nary_eShapeError, "shape1[0](=%d) != shape2[0](=%d)", n, nb); \
|
73
|
+
} \
|
74
|
+
\
|
75
|
+
ndfunc_arg_in_t ain[2] = { { tNAryClass, 2 }, { OVERWRITE, b_n_dims } }; \
|
76
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } }; \
|
77
|
+
ndfunc_t ndf = { _iter_##fLapackFunc, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout }; \
|
78
|
+
struct _trtrs_option opt = { matrix_layout, uplo, trans, diag }; \
|
79
|
+
VALUE info = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary); \
|
80
|
+
VALUE ret = rb_ary_new3(2, b_vnary, info); \
|
81
|
+
\
|
82
|
+
RB_GC_GUARD(a_vnary); \
|
83
|
+
RB_GC_GUARD(b_vnary); \
|
84
|
+
return ret; \
|
85
|
+
}
|
86
|
+
|
87
|
+
DEF_LINALG_FUNC(double, numo_cDFloat, dtrtrs)
|
88
|
+
DEF_LINALG_FUNC(float, numo_cSFloat, strtrs)
|
89
|
+
DEF_LINALG_FUNC(lapack_complex_double, numo_cDComplex, ztrtrs)
|
90
|
+
DEF_LINALG_FUNC(lapack_complex_float, numo_cSComplex, ctrtrs)
|
91
|
+
|
92
|
+
#undef DEF_LINALG_FUNC
|
93
|
+
|
94
|
+
void define_linalg_lapack_trtrs(VALUE mLapack) {
|
95
|
+
rb_define_module_function(mLapack, "dtrtrs", RUBY_METHOD_FUNC(_linalg_lapack_dtrtrs), -1);
|
96
|
+
rb_define_module_function(mLapack, "strtrs", RUBY_METHOD_FUNC(_linalg_lapack_strtrs), -1);
|
97
|
+
rb_define_module_function(mLapack, "ztrtrs", RUBY_METHOD_FUNC(_linalg_lapack_ztrtrs), -1);
|
98
|
+
rb_define_module_function(mLapack, "ctrtrs", RUBY_METHOD_FUNC(_linalg_lapack_ctrtrs), -1);
|
99
|
+
}
|
@@ -0,0 +1,15 @@
|
|
1
|
+
#ifndef NUMO_LINALG_ALT_LAPACK_TRTRS_H
|
2
|
+
#define NUMO_LINALG_ALT_LAPACK_TRTRS_H 1
|
3
|
+
|
4
|
+
#include <lapacke.h>
|
5
|
+
|
6
|
+
#include <ruby.h>
|
7
|
+
|
8
|
+
#include <numo/narray.h>
|
9
|
+
#include <numo/template.h>
|
10
|
+
|
11
|
+
#include "../util.h"
|
12
|
+
|
13
|
+
void define_linalg_lapack_trtrs(VALUE mLapack);
|
14
|
+
|
15
|
+
#endif /* NUMO_LINALG_ALT_LAPACK_TRTRS_H */
|
@@ -0,0 +1,79 @@
|
|
1
|
+
#include "ungqr.h"
|
2
|
+
|
3
|
+
struct _ungqr_option {
|
4
|
+
int matrix_layout;
|
5
|
+
};
|
6
|
+
|
7
|
+
#define DEF_LINALG_FUNC(tDType, tNAryClass, fLapackFunc) \
|
8
|
+
static void _iter_##fLapackFunc(na_loop_t* const lp) { \
|
9
|
+
tDType* a = (tDType*)NDL_PTR(lp, 0); \
|
10
|
+
tDType* tau = (tDType*)NDL_PTR(lp, 1); \
|
11
|
+
int* info = (int*)NDL_PTR(lp, 2); \
|
12
|
+
struct _ungqr_option* opt = (struct _ungqr_option*)(lp->opt_ptr); \
|
13
|
+
const lapack_int m = (lapack_int)NDL_SHAPE(lp, 0)[0]; \
|
14
|
+
const lapack_int n = (lapack_int)NDL_SHAPE(lp, 0)[1]; \
|
15
|
+
const lapack_int k = (lapack_int)NDL_SHAPE(lp, 1)[0]; \
|
16
|
+
const lapack_int lda = n; \
|
17
|
+
const lapack_int i = LAPACKE_##fLapackFunc(opt->matrix_layout, m, n, k, a, lda, tau); \
|
18
|
+
*info = (int)i; \
|
19
|
+
} \
|
20
|
+
\
|
21
|
+
static VALUE _linalg_lapack_##fLapackFunc(int argc, VALUE* argv, VALUE self) { \
|
22
|
+
VALUE a_vnary = Qnil; \
|
23
|
+
VALUE tau_vnary = Qnil; \
|
24
|
+
VALUE kw_args = Qnil; \
|
25
|
+
rb_scan_args(argc, argv, "2:", &a_vnary, &tau_vnary, &kw_args); \
|
26
|
+
ID kw_table[1] = { rb_intern("order") }; \
|
27
|
+
VALUE kw_values[1] = { Qundef }; \
|
28
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values); \
|
29
|
+
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; \
|
30
|
+
\
|
31
|
+
if (CLASS_OF(a_vnary) != tNAryClass) { \
|
32
|
+
a_vnary = rb_funcall(tNAryClass, rb_intern("cast"), 1, a_vnary); \
|
33
|
+
} \
|
34
|
+
if (!RTEST(nary_check_contiguous(a_vnary))) { \
|
35
|
+
a_vnary = nary_dup(a_vnary); \
|
36
|
+
} \
|
37
|
+
if (CLASS_OF(tau_vnary) != tNAryClass) { \
|
38
|
+
tau_vnary = rb_funcall(tNAryClass, rb_intern("cast"), 1, tau_vnary); \
|
39
|
+
} \
|
40
|
+
if (!RTEST(nary_check_contiguous(tau_vnary))) { \
|
41
|
+
tau_vnary = nary_dup(tau_vnary); \
|
42
|
+
} \
|
43
|
+
\
|
44
|
+
narray_t* a_nary = NULL; \
|
45
|
+
GetNArray(a_vnary, a_nary); \
|
46
|
+
if (NA_NDIM(a_nary) != 2) { \
|
47
|
+
rb_raise(rb_eArgError, "input array a must be 2-dimensional"); \
|
48
|
+
return Qnil; \
|
49
|
+
} \
|
50
|
+
narray_t* tau_nary = NULL; \
|
51
|
+
GetNArray(tau_vnary, tau_nary); \
|
52
|
+
if (NA_NDIM(tau_nary) != 1) { \
|
53
|
+
rb_raise(rb_eArgError, "input array tau must be 1-dimensional"); \
|
54
|
+
return Qnil; \
|
55
|
+
} \
|
56
|
+
\
|
57
|
+
ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { tNAryClass, 1 } }; \
|
58
|
+
ndfunc_arg_out_t aout[1] = { { numo_cInt32, 0 } }; \
|
59
|
+
ndfunc_t ndf = { _iter_##fLapackFunc, NO_LOOP | NDF_EXTRACT, 2, 1, ain, aout }; \
|
60
|
+
struct _ungqr_option opt = { matrix_layout }; \
|
61
|
+
VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, tau_vnary); \
|
62
|
+
\
|
63
|
+
VALUE ret = rb_ary_new3(2, a_vnary, res); \
|
64
|
+
\
|
65
|
+
RB_GC_GUARD(a_vnary); \
|
66
|
+
RB_GC_GUARD(tau_vnary); \
|
67
|
+
\
|
68
|
+
return ret; \
|
69
|
+
}
|
70
|
+
|
71
|
+
DEF_LINALG_FUNC(lapack_complex_double, numo_cDComplex, zungqr)
|
72
|
+
DEF_LINALG_FUNC(lapack_complex_float, numo_cSComplex, cungqr)
|
73
|
+
|
74
|
+
#undef DEF_LINALG_FUNC
|
75
|
+
|
76
|
+
void define_linalg_lapack_ungqr(VALUE mLapack) {
|
77
|
+
rb_define_module_function(mLapack, "zungqr", RUBY_METHOD_FUNC(_linalg_lapack_zungqr), -1);
|
78
|
+
rb_define_module_function(mLapack, "cungqr", RUBY_METHOD_FUNC(_linalg_lapack_cungqr), -1);
|
79
|
+
}
|
@@ -0,0 +1,15 @@
|
|
1
|
+
#ifndef NUMO_LINALG_ALT_LAPACK_UNGQR_H
|
2
|
+
#define NUMO_LINALG_ALT_LAPACK_UNGQR_H 1
|
3
|
+
|
4
|
+
#include <lapacke.h>
|
5
|
+
|
6
|
+
#include <ruby.h>
|
7
|
+
|
8
|
+
#include <numo/narray.h>
|
9
|
+
#include <numo/template.h>
|
10
|
+
|
11
|
+
#include "../util.h"
|
12
|
+
|
13
|
+
void define_linalg_lapack_ungqr(VALUE mLapack);
|
14
|
+
|
15
|
+
#endif /* NUMO_LINALG_ALT_LAPACK_UNGQR_H */
|
@@ -0,0 +1,290 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2025 Atsushi Tatsuma
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* Redistribution and use in source and binary forms, with or without
|
6
|
+
* modification, are permitted provided that the following conditions are met:
|
7
|
+
*
|
8
|
+
* * Redistributions of source code must retain the above copyright notice, this
|
9
|
+
* list of conditions and the following disclaimer.
|
10
|
+
*
|
11
|
+
* * Redistributions in binary form must reproduce the above copyright notice,
|
12
|
+
* this list of conditions and the following disclaimer in the documentation
|
13
|
+
* and/or other materials provided with the distribution.
|
14
|
+
*
|
15
|
+
* * Neither the name of the copyright holder nor the names of its
|
16
|
+
* contributors may be used to endorse or promote products derived from
|
17
|
+
* this software without specific prior written permission.
|
18
|
+
*
|
19
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29
|
+
*/
|
30
|
+
|
31
|
+
#include "linalg.h"
|
32
|
+
|
33
|
+
VALUE rb_mLinalg;
|
34
|
+
VALUE rb_mLinalgBlas;
|
35
|
+
VALUE rb_mLinalgLapack;
|
36
|
+
|
37
|
+
char blas_char(VALUE nary_arr) {
|
38
|
+
char type = 'n';
|
39
|
+
const size_t n = RARRAY_LEN(nary_arr);
|
40
|
+
for (size_t i = 0; i < n; i++) {
|
41
|
+
VALUE arg = rb_ary_entry(nary_arr, i);
|
42
|
+
if (RB_TYPE_P(arg, T_ARRAY)) {
|
43
|
+
arg = rb_funcall(numo_cNArray, rb_intern("asarray"), 1, arg);
|
44
|
+
}
|
45
|
+
if (CLASS_OF(arg) == numo_cBit || CLASS_OF(arg) == numo_cInt64 || CLASS_OF(arg) == numo_cInt32 ||
|
46
|
+
CLASS_OF(arg) == numo_cInt16 || CLASS_OF(arg) == numo_cInt8 || CLASS_OF(arg) == numo_cUInt64 ||
|
47
|
+
CLASS_OF(arg) == numo_cUInt32 || CLASS_OF(arg) == numo_cUInt16 || CLASS_OF(arg) == numo_cUInt8) {
|
48
|
+
if (type == 'n') {
|
49
|
+
type = 'd';
|
50
|
+
}
|
51
|
+
} else if (CLASS_OF(arg) == numo_cDFloat) {
|
52
|
+
if (type == 'c' || type == 'z') {
|
53
|
+
type = 'z';
|
54
|
+
} else {
|
55
|
+
type = 'd';
|
56
|
+
}
|
57
|
+
} else if (CLASS_OF(arg) == numo_cSFloat) {
|
58
|
+
if (type == 'n') {
|
59
|
+
type = 's';
|
60
|
+
}
|
61
|
+
} else if (CLASS_OF(arg) == numo_cDComplex) {
|
62
|
+
type = 'z';
|
63
|
+
} else if (CLASS_OF(arg) == numo_cSComplex) {
|
64
|
+
if (type == 'n' || type == 's') {
|
65
|
+
type = 'c';
|
66
|
+
} else if (type == 'd') {
|
67
|
+
type = 'z';
|
68
|
+
}
|
69
|
+
}
|
70
|
+
}
|
71
|
+
return type;
|
72
|
+
}
|
73
|
+
|
74
|
+
static VALUE linalg_blas_char(int argc, VALUE* argv, VALUE self) {
|
75
|
+
VALUE nary_arr = Qnil;
|
76
|
+
rb_scan_args(argc, argv, "*", &nary_arr);
|
77
|
+
|
78
|
+
const char type = blas_char(nary_arr);
|
79
|
+
if (type == 'n') {
|
80
|
+
rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
|
81
|
+
return Qnil;
|
82
|
+
}
|
83
|
+
|
84
|
+
return rb_str_new(&type, 1);
|
85
|
+
}
|
86
|
+
|
87
|
+
static VALUE linalg_blas_call(int argc, VALUE* argv, VALUE self) {
|
88
|
+
VALUE fn_name = Qnil;
|
89
|
+
VALUE nary_arr = Qnil;
|
90
|
+
VALUE kw_args = Qnil;
|
91
|
+
rb_scan_args(argc, argv, "1*:", &fn_name, &nary_arr, &kw_args);
|
92
|
+
|
93
|
+
const char type = blas_char(nary_arr);
|
94
|
+
if (type == 'n') {
|
95
|
+
rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
|
96
|
+
return Qnil;
|
97
|
+
}
|
98
|
+
|
99
|
+
char fn_str[256];
|
100
|
+
snprintf(fn_str, sizeof(fn_str), "%c%s",
|
101
|
+
type, rb_id2name(rb_to_id(rb_to_symbol(fn_name))));
|
102
|
+
ID fn_id = rb_intern(fn_str);
|
103
|
+
size_t n = RARRAY_LEN(nary_arr);
|
104
|
+
VALUE ret = Qnil;
|
105
|
+
|
106
|
+
if (NIL_P(kw_args)) {
|
107
|
+
VALUE* args = ALLOCA_N(VALUE, n);
|
108
|
+
for (size_t i = 0; i < n; i++) {
|
109
|
+
args[i] = rb_ary_entry(nary_arr, i);
|
110
|
+
}
|
111
|
+
ret = rb_funcallv(self, fn_id, (int)n, args);
|
112
|
+
} else {
|
113
|
+
VALUE* args = ALLOCA_N(VALUE, n + 1);
|
114
|
+
for (size_t i = 0; i < n; i++) {
|
115
|
+
args[i] = rb_ary_entry(nary_arr, i);
|
116
|
+
}
|
117
|
+
args[n] = kw_args;
|
118
|
+
ret = rb_funcallv_kw(self, fn_id, (int)(n + 1), args, RB_PASS_KEYWORDS);
|
119
|
+
}
|
120
|
+
|
121
|
+
return ret;
|
122
|
+
}
|
123
|
+
|
124
|
+
static VALUE linalg_dot(VALUE self, VALUE a_, VALUE b_) {
|
125
|
+
VALUE a = IsNArray(a_) ? a_ : rb_funcall(numo_cNArray, rb_intern("asarray"), 1, a_);
|
126
|
+
VALUE b = IsNArray(b_) ? b_ : rb_funcall(numo_cNArray, rb_intern("asarray"), 1, b_);
|
127
|
+
|
128
|
+
VALUE arg_arr = rb_ary_new3(2, a, b);
|
129
|
+
const char type = blas_char(arg_arr);
|
130
|
+
if (type == 'n') {
|
131
|
+
rb_raise(rb_eTypeError, "invalid data type for BLAS/LAPACK");
|
132
|
+
return Qnil;
|
133
|
+
}
|
134
|
+
|
135
|
+
VALUE ret = Qnil;
|
136
|
+
narray_t* a_nary = NULL;
|
137
|
+
narray_t* b_nary = NULL;
|
138
|
+
GetNArray(a, a_nary);
|
139
|
+
GetNArray(b, b_nary);
|
140
|
+
const int a_ndim = NA_NDIM(a_nary);
|
141
|
+
const int b_ndim = NA_NDIM(b_nary);
|
142
|
+
|
143
|
+
if (a_ndim == 1) {
|
144
|
+
if (b_ndim == 1) {
|
145
|
+
ID fn_id = type == 'c' || type == 'z' ? rb_intern("dotu") : rb_intern("dot");
|
146
|
+
ret = rb_funcall(rb_mLinalgBlas, rb_intern("call"), 3, ID2SYM(fn_id), a, b);
|
147
|
+
} else {
|
148
|
+
VALUE kw_args = rb_hash_new();
|
149
|
+
if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
|
150
|
+
b = rb_funcall(b, rb_intern("transpose"), 0);
|
151
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
|
152
|
+
} else {
|
153
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
|
154
|
+
}
|
155
|
+
char fn_name[] = "xgemv";
|
156
|
+
fn_name[0] = type;
|
157
|
+
VALUE argv[3] = { b, a, kw_args };
|
158
|
+
ret = rb_funcallv_kw(rb_mLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
|
159
|
+
}
|
160
|
+
} else {
|
161
|
+
if (b_ndim == 1) {
|
162
|
+
VALUE kw_args = rb_hash_new();
|
163
|
+
if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
|
164
|
+
a = rb_funcall(a, rb_intern("transpose"), 0);
|
165
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("T"));
|
166
|
+
} else {
|
167
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("trans")), rb_str_new_cstr("N"));
|
168
|
+
}
|
169
|
+
char fn_name[] = "xgemv";
|
170
|
+
fn_name[0] = type;
|
171
|
+
VALUE argv[3] = { a, b, kw_args };
|
172
|
+
ret = rb_funcallv_kw(rb_mLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
|
173
|
+
} else {
|
174
|
+
VALUE kw_args = rb_hash_new();
|
175
|
+
if (!RTEST(nary_check_contiguous(a)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
|
176
|
+
a = rb_funcall(a, rb_intern("transpose"), 0);
|
177
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("T"));
|
178
|
+
} else {
|
179
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("transa")), rb_str_new_cstr("N"));
|
180
|
+
}
|
181
|
+
if (!RTEST(nary_check_contiguous(b)) && RTEST(rb_funcall(b, rb_intern("fortran_contiguous?"), 0))) {
|
182
|
+
b = rb_funcall(b, rb_intern("transpose"), 0);
|
183
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("T"));
|
184
|
+
} else {
|
185
|
+
rb_hash_aset(kw_args, ID2SYM(rb_intern("transb")), rb_str_new_cstr("N"));
|
186
|
+
}
|
187
|
+
char fn_name[] = "xgemm";
|
188
|
+
fn_name[0] = type;
|
189
|
+
VALUE argv[3] = { a, b, kw_args };
|
190
|
+
ret = rb_funcallv_kw(rb_mLinalgBlas, rb_intern(fn_name), 3, argv, RB_PASS_KEYWORDS);
|
191
|
+
}
|
192
|
+
}
|
193
|
+
|
194
|
+
RB_GC_GUARD(a);
|
195
|
+
RB_GC_GUARD(b);
|
196
|
+
|
197
|
+
return ret;
|
198
|
+
}
|
199
|
+
|
200
|
+
void Init_linalg(void) {
|
201
|
+
rb_require("numo/narray");
|
202
|
+
|
203
|
+
/**
|
204
|
+
* Document-module: Numo::Linalg
|
205
|
+
* Numo::Linalg is a subset library from Numo::Linalg consisting only of methods used in Machine Learning algorithms.
|
206
|
+
*/
|
207
|
+
rb_mLinalg = rb_define_module_under(rb_mNumo, "Linalg");
|
208
|
+
/**
|
209
|
+
* Document-module: Numo::Linalg::Blas
|
210
|
+
* Numo::Linalg::Blas is wrapper module of BLAS functions.
|
211
|
+
* @!visibility private
|
212
|
+
*/
|
213
|
+
rb_mLinalgBlas = rb_define_module_under(rb_mLinalg, "Blas");
|
214
|
+
/**
|
215
|
+
* Document-module: Numo::Linalg::Lapack
|
216
|
+
* Numo::Linalg::Lapack is wrapper module of LAPACK functions.
|
217
|
+
* @!visibility private
|
218
|
+
*/
|
219
|
+
rb_mLinalgLapack = rb_define_module_under(rb_mLinalg, "Lapack");
|
220
|
+
|
221
|
+
/* The version of OpenBLAS used in background library. */
|
222
|
+
rb_define_const(rb_mLinalg, "OPENBLAS_VERSION", rb_str_new_cstr(OPENBLAS_VERSION));
|
223
|
+
|
224
|
+
/**
|
225
|
+
* Returns BLAS char ([sdcz]) defined by data-type of arguments.
|
226
|
+
*
|
227
|
+
* @overload blas_char(a, ...) -> String
|
228
|
+
* @param [Numo::NArray] a
|
229
|
+
* @return [String]
|
230
|
+
*/
|
231
|
+
rb_define_module_function(rb_mLinalg, "blas_char", RUBY_METHOD_FUNC(linalg_blas_char), -1);
|
232
|
+
/**
|
233
|
+
* Calculates dot product of two vectors / matrices.
|
234
|
+
*
|
235
|
+
* @overload dot(a, b) -> [Float|Complex|Numo::NArray]
|
236
|
+
* @param [Numo::NArray] a
|
237
|
+
* @param [Numo::NArray] b
|
238
|
+
* @return [Float|Complex|Numo::NArray]
|
239
|
+
*/
|
240
|
+
rb_define_module_function(rb_mLinalg, "dot", RUBY_METHOD_FUNC(linalg_dot), 2);
|
241
|
+
/**
|
242
|
+
* Calls BLAS function prefixed with BLAS char.
|
243
|
+
*
|
244
|
+
* @overload call(func, *args)
|
245
|
+
* @param func [Symbol] BLAS function name without BLAS char.
|
246
|
+
* @param args arguments of BLAS function.
|
247
|
+
* @example
|
248
|
+
* Numo::Linalg::Blas.call(:gemv, a, b)
|
249
|
+
*/
|
250
|
+
rb_define_singleton_method(rb_mLinalgBlas, "call", RUBY_METHOD_FUNC(linalg_blas_call), -1);
|
251
|
+
|
252
|
+
define_linalg_blas_dot(rb_mLinalgBlas);
|
253
|
+
define_linalg_blas_dot_sub(rb_mLinalgBlas);
|
254
|
+
define_linalg_blas_gemm(rb_mLinalgBlas);
|
255
|
+
define_linalg_blas_gemv(rb_mLinalgBlas);
|
256
|
+
define_linalg_blas_nrm2(rb_mLinalgBlas);
|
257
|
+
define_linalg_lapack_geqrf(rb_mLinalgLapack);
|
258
|
+
define_linalg_lapack_orgqr(rb_mLinalgLapack);
|
259
|
+
define_linalg_lapack_ungqr(rb_mLinalgLapack);
|
260
|
+
define_linalg_lapack_geev(rb_mLinalgLapack);
|
261
|
+
define_linalg_lapack_gesv(rb_mLinalgLapack);
|
262
|
+
define_linalg_lapack_gesvd(rb_mLinalgLapack);
|
263
|
+
define_linalg_lapack_gesdd(rb_mLinalgLapack);
|
264
|
+
define_linalg_lapack_getrf(rb_mLinalgLapack);
|
265
|
+
define_linalg_lapack_getri(rb_mLinalgLapack);
|
266
|
+
define_linalg_lapack_getrs(rb_mLinalgLapack);
|
267
|
+
define_linalg_lapack_trtrs(rb_mLinalgLapack);
|
268
|
+
define_linalg_lapack_potrf(rb_mLinalgLapack);
|
269
|
+
define_linalg_lapack_potri(rb_mLinalgLapack);
|
270
|
+
define_linalg_lapack_potrs(rb_mLinalgLapack);
|
271
|
+
define_linalg_lapack_syev(rb_mLinalgLapack);
|
272
|
+
define_linalg_lapack_heev(rb_mLinalgLapack);
|
273
|
+
define_linalg_lapack_syevd(rb_mLinalgLapack);
|
274
|
+
define_linalg_lapack_heevd(rb_mLinalgLapack);
|
275
|
+
define_linalg_lapack_syevr(rb_mLinalgLapack);
|
276
|
+
define_linalg_lapack_heevr(rb_mLinalgLapack);
|
277
|
+
define_linalg_lapack_sygv(rb_mLinalgLapack);
|
278
|
+
define_linalg_lapack_hegv(rb_mLinalgLapack);
|
279
|
+
define_linalg_lapack_sygvd(rb_mLinalgLapack);
|
280
|
+
define_linalg_lapack_hegvd(rb_mLinalgLapack);
|
281
|
+
define_linalg_lapack_sygvx(rb_mLinalgLapack);
|
282
|
+
define_linalg_lapack_hegvx(rb_mLinalgLapack);
|
283
|
+
define_linalg_lapack_lange(rb_mLinalgLapack);
|
284
|
+
define_linalg_lapack_gelsd(rb_mLinalgLapack);
|
285
|
+
define_linalg_lapack_sytrf(rb_mLinalgLapack);
|
286
|
+
define_linalg_lapack_hetrf(rb_mLinalgLapack);
|
287
|
+
|
288
|
+
rb_define_alias(rb_singleton_class(rb_mLinalgBlas), "znrm2", "dznrm2");
|
289
|
+
rb_define_alias(rb_singleton_class(rb_mLinalgBlas), "cnrm2", "scnrm2");
|
290
|
+
}
|