numo-tiny_linalg 0.0.3 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,121 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGv {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
28
+ class HeGv {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegv), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegv_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegv(na_loop_t* const lp) {
43
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
44
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegv(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegv_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+ };
120
+
121
+ } // namespace TinyLinalg
@@ -0,0 +1,121 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, lapack_complex_double* a,
6
+ lapack_int lda, lapack_complex_double* b,
7
+ lapack_int ldb, double* w) {
8
+ return LAPACKE_zhegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
9
+ }
10
+ };
11
+
12
+ struct CHeGvd {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
14
+ char uplo, lapack_int n, float* a, lapack_int lda,
15
+ float* b, lapack_int ldb, float* w) {
16
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
17
+ }
18
+
19
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
20
+ char uplo, lapack_int n, lapack_complex_float* a,
21
+ lapack_int lda, lapack_complex_float* b,
22
+ lapack_int ldb, float* w) {
23
+ return LAPACKE_chegvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
24
+ }
25
+ };
26
+
27
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
28
+ class HeGvd {
29
+ public:
30
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
31
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvd), -1);
32
+ }
33
+
34
+ private:
35
+ struct hegvd_opt {
36
+ int matrix_layout;
37
+ lapack_int itype;
38
+ char jobz;
39
+ char uplo;
40
+ };
41
+
42
+ static void iter_hegvd(na_loop_t* const lp) {
43
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
44
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
46
+ int* info = (int*)NDL_PTR(lp, 3);
47
+ hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
48
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
+ *info = static_cast<int>(i);
53
+ }
54
+
55
+ static VALUE tiny_linalg_hegvd(int argc, VALUE* argv, VALUE self) {
56
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
57
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
58
+
59
+ VALUE a_vnary = Qnil;
60
+ VALUE b_vnary = Qnil;
61
+ VALUE kw_args = Qnil;
62
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
63
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
+
71
+ if (CLASS_OF(a_vnary) != nary_dtype) {
72
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
73
+ }
74
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
75
+ a_vnary = nary_dup(a_vnary);
76
+ }
77
+ if (CLASS_OF(b_vnary) != nary_dtype) {
78
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
79
+ }
80
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
81
+ b_vnary = nary_dup(b_vnary);
82
+ }
83
+
84
+ narray_t* a_nary = nullptr;
85
+ GetNArray(a_vnary, a_nary);
86
+ if (NA_NDIM(a_nary) != 2) {
87
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
88
+ return Qnil;
89
+ }
90
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
91
+ rb_raise(rb_eArgError, "input array a must be square");
92
+ return Qnil;
93
+ }
94
+ narray_t* b_nary = nullptr;
95
+ GetNArray(a_vnary, b_nary);
96
+ if (NA_NDIM(b_nary) != 2) {
97
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
98
+ return Qnil;
99
+ }
100
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
101
+ rb_raise(rb_eArgError, "input array b must be square");
102
+ return Qnil;
103
+ }
104
+
105
+ const size_t n = NA_SHAPE(a_nary)[1];
106
+ size_t shape[1] = { n };
107
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
108
+ ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
109
+ ndfunc_t ndf = { iter_hegvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
110
+ hegvd_opt opt = { matrix_layout, itype, jobz, uplo };
111
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
112
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
113
+
114
+ RB_GC_GUARD(a_vnary);
115
+ RB_GC_GUARD(b_vnary);
116
+
117
+ return ret;
118
+ }
119
+ };
120
+
121
+ } // namespace TinyLinalg
@@ -0,0 +1,137 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct ZHeGvx {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
5
+ lapack_int n, lapack_complex_double* a, lapack_int lda, lapack_complex_double* b, lapack_int ldb,
6
+ double vl, double vu, lapack_int il, lapack_int iu,
7
+ double abstol, lapack_int* m, double* w, lapack_complex_double* z, lapack_int ldz, lapack_int* ifail) {
8
+ return LAPACKE_zhegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
9
+ }
10
+ };
11
+
12
+ struct CHeGvx {
13
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz, char range, char uplo,
14
+ lapack_int n, lapack_complex_float* a, lapack_int lda, lapack_complex_float* b, lapack_int ldb,
15
+ float vl, float vu, lapack_int il, lapack_int iu,
16
+ float abstol, lapack_int* m, float* w, lapack_complex_float* z, lapack_int ldz, lapack_int* ifail) {
17
+ return LAPACKE_chegvx(matrix_layout, itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, ifail);
18
+ }
19
+ };
20
+
21
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
22
+ class HeGvx {
23
+ public:
24
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
25
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_hegvx), -1);
26
+ }
27
+
28
+ private:
29
+ struct hegvx_opt {
30
+ int matrix_layout;
31
+ lapack_int itype;
32
+ char jobz;
33
+ char range;
34
+ char uplo;
35
+ rtype vl;
36
+ rtype vu;
37
+ lapack_int il;
38
+ lapack_int iu;
39
+ };
40
+
41
+ static void iter_hegvx(na_loop_t* const lp) {
42
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
43
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
44
+ int* m = (int*)NDL_PTR(lp, 2);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 3);
46
+ dtype* z = (dtype*)NDL_PTR(lp, 4);
47
+ int* ifail = (int*)NDL_PTR(lp, 5);
48
+ int* info = (int*)NDL_PTR(lp, 6);
49
+ hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
50
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
51
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
+ const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
+ const rtype abstol = 0.0;
55
+ const lapack_int i = LapackFn().call(
56
+ opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
+ opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
+ *info = static_cast<int>(i);
59
+ }
60
+
61
+ static VALUE tiny_linalg_hegvx(int argc, VALUE* argv, VALUE self) {
62
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
63
+ VALUE nary_rtype = NaryTypes[nary_rtype_id];
64
+
65
+ VALUE a_vnary = Qnil;
66
+ VALUE b_vnary = Qnil;
67
+ VALUE kw_args = Qnil;
68
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
69
+ ID kw_table[9] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("range"), rb_intern("uplo"),
70
+ rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
71
+ VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
72
+ rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
73
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
74
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
75
+ const char range = kw_values[2] != Qundef ? Util().get_range(kw_values[2]) : 'A';
76
+ const char uplo = kw_values[3] != Qundef ? Util().get_uplo(kw_values[3]) : 'U';
77
+ const rtype vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
+ const rtype vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
79
+ const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
80
+ const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
81
+ const int matrix_layout = kw_values[8] != Qundef ? Util().get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
82
+
83
+ if (CLASS_OF(a_vnary) != nary_dtype) {
84
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
85
+ }
86
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
87
+ a_vnary = nary_dup(a_vnary);
88
+ }
89
+ if (CLASS_OF(b_vnary) != nary_dtype) {
90
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
91
+ }
92
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
93
+ b_vnary = nary_dup(b_vnary);
94
+ }
95
+
96
+ narray_t* a_nary = nullptr;
97
+ GetNArray(a_vnary, a_nary);
98
+ if (NA_NDIM(a_nary) != 2) {
99
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
100
+ return Qnil;
101
+ }
102
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
103
+ rb_raise(rb_eArgError, "input array a must be square");
104
+ return Qnil;
105
+ }
106
+ narray_t* b_nary = nullptr;
107
+ GetNArray(a_vnary, b_nary);
108
+ if (NA_NDIM(b_nary) != 2) {
109
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
110
+ return Qnil;
111
+ }
112
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
113
+ rb_raise(rb_eArgError, "input array b must be square");
114
+ return Qnil;
115
+ }
116
+
117
+ const size_t n = NA_SHAPE(a_nary)[1];
118
+ size_t m = range != 'I' ? n : iu - il + 1;
119
+ size_t w_shape[1] = { m };
120
+ size_t z_shape[2] = { n, m };
121
+ size_t ifail_shape[1] = { n };
122
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
123
+ ndfunc_arg_out_t aout[5] = { { numo_cInt32, 0 }, { nary_rtype, 1, w_shape }, { nary_dtype, 2, z_shape }, { numo_cInt32, 1, ifail_shape }, { numo_cInt32, 0 } };
124
+ ndfunc_t ndf = { iter_hegvx, NO_LOOP | NDF_EXTRACT, 2, 5, ain, aout };
125
+ hegvx_opt opt = { matrix_layout, itype, jobz, range, uplo, vl, vu, il, iu };
126
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
127
+ VALUE ret = rb_ary_new3(7, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1), rb_ary_entry(res, 2),
128
+ rb_ary_entry(res, 3), rb_ary_entry(res, 4));
129
+
130
+ RB_GC_GUARD(a_vnary);
131
+ RB_GC_GUARD(b_vnary);
132
+
133
+ return ret;
134
+ }
135
+ };
136
+
137
+ } // namespace TinyLinalg
@@ -14,7 +14,7 @@ struct SOrgQr {
14
14
  }
15
15
  };
16
16
 
17
- template <int nary_dtype_id, typename DType, typename FncType>
17
+ template <int nary_dtype_id, typename dtype, class LapackFn>
18
18
  class OrgQr {
19
19
  public:
20
20
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -27,15 +27,15 @@ private:
27
27
  };
28
28
 
29
29
  static void iter_orgqr(na_loop_t* const lp) {
30
- DType* a = (DType*)NDL_PTR(lp, 0);
31
- DType* tau = (DType*)NDL_PTR(lp, 1);
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* tau = (dtype*)NDL_PTR(lp, 1);
32
32
  int* info = (int*)NDL_PTR(lp, 2);
33
33
  orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
34
34
  const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
35
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
36
  const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
37
  const lapack_int lda = n;
38
- const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
38
+ const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
39
  *info = static_cast<int>(i);
40
40
  }
41
41
 
@@ -49,7 +49,7 @@ private:
49
49
  ID kw_table[1] = { rb_intern("order") };
50
50
  VALUE kw_values[1] = { Qundef };
51
51
  rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
- const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
52
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
53
 
54
54
  if (CLASS_OF(a_vnary) != nary_dtype) {
55
55
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -90,26 +90,6 @@ private:
90
90
 
91
91
  return ret;
92
92
  }
93
-
94
- static int get_matrix_layout(VALUE val) {
95
- const char* option_str = StringValueCStr(val);
96
-
97
- if (std::strlen(option_str) > 0) {
98
- switch (option_str[0]) {
99
- case 'r':
100
- case 'R':
101
- break;
102
- case 'c':
103
- case 'C':
104
- rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
- break;
106
- }
107
- }
108
-
109
- RB_GC_GUARD(val);
110
-
111
- return LAPACK_ROW_MAJOR;
112
- }
113
93
  };
114
94
 
115
95
  } // namespace TinyLinalg
@@ -0,0 +1,112 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGv {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGv {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygv(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
+ class SyGv {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygv), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygv_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygv(na_loop_t* const lp) {
35
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
36
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
37
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygv(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygv, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygv_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+ };
111
+
112
+ } // namespace TinyLinalg
@@ -0,0 +1,112 @@
1
+ namespace TinyLinalg {
2
+
3
+ struct DSyGvd {
4
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
5
+ char uplo, lapack_int n, double* a, lapack_int lda,
6
+ double* b, lapack_int ldb, double* w) {
7
+ return LAPACKE_dsygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
8
+ }
9
+ };
10
+
11
+ struct SSyGvd {
12
+ lapack_int call(int matrix_layout, lapack_int itype, char jobz,
13
+ char uplo, lapack_int n, float* a, lapack_int lda,
14
+ float* b, lapack_int ldb, float* w) {
15
+ return LAPACKE_ssygvd(matrix_layout, itype, jobz, uplo, n, a, lda, b, ldb, w);
16
+ }
17
+ };
18
+
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
+ class SyGvd {
21
+ public:
22
+ static void define_module_function(VALUE mLapack, const char* fnc_name) {
23
+ rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_sygvd), -1);
24
+ }
25
+
26
+ private:
27
+ struct sygvd_opt {
28
+ int matrix_layout;
29
+ lapack_int itype;
30
+ char jobz;
31
+ char uplo;
32
+ };
33
+
34
+ static void iter_sygvd(na_loop_t* const lp) {
35
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
36
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
37
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
38
+ int* info = (int*)NDL_PTR(lp, 3);
39
+ sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
40
+ const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
+ const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
+ const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
+ *info = static_cast<int>(i);
45
+ }
46
+
47
+ static VALUE tiny_linalg_sygvd(int argc, VALUE* argv, VALUE self) {
48
+ VALUE nary_dtype = NaryTypes[nary_dtype_id];
49
+
50
+ VALUE a_vnary = Qnil;
51
+ VALUE b_vnary = Qnil;
52
+ VALUE kw_args = Qnil;
53
+ rb_scan_args(argc, argv, "2:", &a_vnary, &b_vnary, &kw_args);
54
+ ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
+ rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
+
62
+ if (CLASS_OF(a_vnary) != nary_dtype) {
63
+ a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
64
+ }
65
+ if (!RTEST(nary_check_contiguous(a_vnary))) {
66
+ a_vnary = nary_dup(a_vnary);
67
+ }
68
+ if (CLASS_OF(b_vnary) != nary_dtype) {
69
+ b_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, b_vnary);
70
+ }
71
+ if (!RTEST(nary_check_contiguous(b_vnary))) {
72
+ b_vnary = nary_dup(b_vnary);
73
+ }
74
+
75
+ narray_t* a_nary = nullptr;
76
+ GetNArray(a_vnary, a_nary);
77
+ if (NA_NDIM(a_nary) != 2) {
78
+ rb_raise(rb_eArgError, "input array a must be 2-dimensional");
79
+ return Qnil;
80
+ }
81
+ if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
82
+ rb_raise(rb_eArgError, "input array a must be square");
83
+ return Qnil;
84
+ }
85
+ narray_t* b_nary = nullptr;
86
+ GetNArray(a_vnary, b_nary);
87
+ if (NA_NDIM(b_nary) != 2) {
88
+ rb_raise(rb_eArgError, "input array b must be 2-dimensional");
89
+ return Qnil;
90
+ }
91
+ if (NA_SHAPE(b_nary)[0] != NA_SHAPE(b_nary)[1]) {
92
+ rb_raise(rb_eArgError, "input array b must be square");
93
+ return Qnil;
94
+ }
95
+
96
+ const size_t n = NA_SHAPE(a_nary)[1];
97
+ size_t shape[1] = { n };
98
+ ndfunc_arg_in_t ain[2] = { { OVERWRITE, 2 }, { OVERWRITE, 2 } };
99
+ ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
100
+ ndfunc_t ndf = { iter_sygvd, NO_LOOP | NDF_EXTRACT, 2, 2, ain, aout };
101
+ sygvd_opt opt = { matrix_layout, itype, jobz, uplo };
102
+ VALUE res = na_ndloop3(&ndf, &opt, 2, a_vnary, b_vnary);
103
+ VALUE ret = rb_ary_new3(4, a_vnary, b_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));
104
+
105
+ RB_GC_GUARD(a_vnary);
106
+ RB_GC_GUARD(b_vnary);
107
+
108
+ return ret;
109
+ }
110
+ };
111
+
112
+ } // namespace TinyLinalg