numo-tiny_linalg 0.0.4 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,7 +24,7 @@ struct CHeGv {
24
24
  }
25
25
  };
26
26
 
27
- template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
27
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
28
28
  class HeGv {
29
29
  public:
30
30
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -40,15 +40,15 @@ private:
40
40
  };
41
41
 
42
42
  static void iter_hegv(na_loop_t* const lp) {
43
- DType* a = (DType*)NDL_PTR(lp, 0);
44
- DType* b = (DType*)NDL_PTR(lp, 1);
45
- RType* w = (RType*)NDL_PTR(lp, 2);
43
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
44
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
46
46
  int* info = (int*)NDL_PTR(lp, 3);
47
47
  hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
48
48
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
49
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
50
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
- const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
51
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
52
  *info = static_cast<int>(i);
53
53
  }
54
54
 
@@ -63,10 +63,10 @@ private:
63
63
  ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
64
  VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
65
  rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
- const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
- const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
66
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
70
 
71
71
  if (CLASS_OF(a_vnary) != nary_dtype) {
72
72
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -116,52 +116,6 @@ private:
116
116
 
117
117
  return ret;
118
118
  }
119
-
120
- static lapack_int get_itype(VALUE val) {
121
- const lapack_int itype = NUM2INT(val);
122
-
123
- if (itype != 1 && itype != 2 && itype != 3) {
124
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
- }
126
-
127
- return itype;
128
- }
129
-
130
- static char get_jobz(VALUE val) {
131
- const char jobz = NUM2CHR(val);
132
-
133
- if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
- }
136
-
137
- return jobz;
138
- }
139
-
140
- static char get_uplo(VALUE val) {
141
- const char uplo = NUM2CHR(val);
142
-
143
- if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
- }
146
-
147
- return uplo;
148
- }
149
-
150
- static int get_matrix_layout(VALUE val) {
151
- const char option = NUM2CHR(val);
152
-
153
- switch (option) {
154
- case 'r':
155
- case 'R':
156
- break;
157
- case 'c':
158
- case 'C':
159
- rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
160
- break;
161
- }
162
-
163
- return LAPACK_ROW_MAJOR;
164
- }
165
119
  };
166
120
 
167
121
  } // namespace TinyLinalg
@@ -24,7 +24,7 @@ struct CHeGvd {
24
24
  }
25
25
  };
26
26
 
27
- template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
27
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
28
28
  class HeGvd {
29
29
  public:
30
30
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -40,15 +40,15 @@ private:
40
40
  };
41
41
 
42
42
  static void iter_hegvd(na_loop_t* const lp) {
43
- DType* a = (DType*)NDL_PTR(lp, 0);
44
- DType* b = (DType*)NDL_PTR(lp, 1);
45
- RType* w = (RType*)NDL_PTR(lp, 2);
43
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
44
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
46
46
  int* info = (int*)NDL_PTR(lp, 3);
47
47
  hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
48
48
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
49
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
50
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
- const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
51
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
52
  *info = static_cast<int>(i);
53
53
  }
54
54
 
@@ -63,10 +63,10 @@ private:
63
63
  ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
64
  VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
65
  rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
- const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
- const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
66
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
70
 
71
71
  if (CLASS_OF(a_vnary) != nary_dtype) {
72
72
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -116,52 +116,6 @@ private:
116
116
 
117
117
  return ret;
118
118
  }
119
-
120
- static lapack_int get_itype(VALUE val) {
121
- const lapack_int itype = NUM2INT(val);
122
-
123
- if (itype != 1 && itype != 2 && itype != 3) {
124
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
- }
126
-
127
- return itype;
128
- }
129
-
130
- static char get_jobz(VALUE val) {
131
- const char jobz = NUM2CHR(val);
132
-
133
- if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
- }
136
-
137
- return jobz;
138
- }
139
-
140
- static char get_uplo(VALUE val) {
141
- const char uplo = NUM2CHR(val);
142
-
143
- if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
- }
146
-
147
- return uplo;
148
- }
149
-
150
- static int get_matrix_layout(VALUE val) {
151
- const char option = NUM2CHR(val);
152
-
153
- switch (option) {
154
- case 'r':
155
- case 'R':
156
- break;
157
- case 'c':
158
- case 'C':
159
- rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
160
- break;
161
- }
162
-
163
- return LAPACK_ROW_MAJOR;
164
- }
165
119
  };
166
120
 
167
121
  } // namespace TinyLinalg
@@ -18,7 +18,7 @@ struct CHeGvx {
18
18
  }
19
19
  };
20
20
 
21
- template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
21
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
22
22
  class HeGvx {
23
23
  public:
24
24
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -32,18 +32,18 @@ private:
32
32
  char jobz;
33
33
  char range;
34
34
  char uplo;
35
- RType vl;
36
- RType vu;
35
+ rtype vl;
36
+ rtype vu;
37
37
  lapack_int il;
38
38
  lapack_int iu;
39
39
  };
40
40
 
41
41
  static void iter_hegvx(na_loop_t* const lp) {
42
- DType* a = (DType*)NDL_PTR(lp, 0);
43
- DType* b = (DType*)NDL_PTR(lp, 1);
42
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
43
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
44
44
  int* m = (int*)NDL_PTR(lp, 2);
45
- RType* w = (RType*)NDL_PTR(lp, 3);
46
- DType* z = (DType*)NDL_PTR(lp, 4);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 3);
46
+ dtype* z = (dtype*)NDL_PTR(lp, 4);
47
47
  int* ifail = (int*)NDL_PTR(lp, 5);
48
48
  int* info = (int*)NDL_PTR(lp, 6);
49
49
  hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
@@ -51,8 +51,8 @@ private:
51
51
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
52
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
53
  const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
- const RType abstol = 0.0;
55
- const lapack_int i = FncType().call(
54
+ const rtype abstol = 0.0;
55
+ const lapack_int i = LapackFn().call(
56
56
  opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
57
  opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
58
  *info = static_cast<int>(i);
@@ -70,15 +70,15 @@ private:
70
70
  rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
71
71
  VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
72
72
  rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
73
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
74
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
75
- const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
76
- const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
77
- const RType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
- const RType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
73
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
74
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
75
+ const char range = kw_values[2] != Qundef ? Util().get_range(kw_values[2]) : 'A';
76
+ const char uplo = kw_values[3] != Qundef ? Util().get_uplo(kw_values[3]) : 'U';
77
+ const rtype vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
+ const rtype vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
79
79
  const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
80
80
  const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
81
- const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
81
+ const int matrix_layout = kw_values[8] != Qundef ? Util().get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
82
82
 
83
83
  if (CLASS_OF(a_vnary) != nary_dtype) {
84
84
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -132,62 +132,6 @@ private:
132
132
 
133
133
  return ret;
134
134
  }
135
-
136
- static lapack_int get_itype(VALUE val) {
137
- const lapack_int itype = NUM2INT(val);
138
-
139
- if (itype != 1 && itype != 2 && itype != 3) {
140
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
141
- }
142
-
143
- return itype;
144
- }
145
-
146
- static char get_jobz(VALUE val) {
147
- const char jobz = NUM2CHR(val);
148
-
149
- if (jobz != 'N' && jobz != 'V') {
150
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
151
- }
152
-
153
- return jobz;
154
- }
155
-
156
- static char get_range(VALUE val) {
157
- const char range = NUM2CHR(val);
158
-
159
- if (range != 'A' && range != 'V' && range != 'I') {
160
- rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
161
- }
162
-
163
- return range;
164
- }
165
-
166
- static char get_uplo(VALUE val) {
167
- const char uplo = NUM2CHR(val);
168
-
169
- if (uplo != 'U' && uplo != 'L') {
170
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
171
- }
172
-
173
- return uplo;
174
- }
175
-
176
- static int get_matrix_layout(VALUE val) {
177
- const char option = NUM2CHR(val);
178
-
179
- switch (option) {
180
- case 'r':
181
- case 'R':
182
- break;
183
- case 'c':
184
- case 'C':
185
- rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
186
- break;
187
- }
188
-
189
- return LAPACK_ROW_MAJOR;
190
- }
191
135
  };
192
136
 
193
137
  } // namespace TinyLinalg
@@ -14,7 +14,7 @@ struct SOrgQr {
14
14
  }
15
15
  };
16
16
 
17
- template <int nary_dtype_id, typename DType, typename FncType>
17
+ template <int nary_dtype_id, typename dtype, class LapackFn>
18
18
  class OrgQr {
19
19
  public:
20
20
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -27,15 +27,15 @@ private:
27
27
  };
28
28
 
29
29
  static void iter_orgqr(na_loop_t* const lp) {
30
- DType* a = (DType*)NDL_PTR(lp, 0);
31
- DType* tau = (DType*)NDL_PTR(lp, 1);
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* tau = (dtype*)NDL_PTR(lp, 1);
32
32
  int* info = (int*)NDL_PTR(lp, 2);
33
33
  orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
34
34
  const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
35
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
36
  const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
37
  const lapack_int lda = n;
38
- const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
38
+ const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
39
  *info = static_cast<int>(i);
40
40
  }
41
41
 
@@ -49,7 +49,7 @@ private:
49
49
  ID kw_table[1] = { rb_intern("order") };
50
50
  VALUE kw_values[1] = { Qundef };
51
51
  rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
- const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
52
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
53
 
54
54
  if (CLASS_OF(a_vnary) != nary_dtype) {
55
55
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -90,26 +90,6 @@ private:
90
90
 
91
91
  return ret;
92
92
  }
93
-
94
- static int get_matrix_layout(VALUE val) {
95
- const char* option_str = StringValueCStr(val);
96
-
97
- if (std::strlen(option_str) > 0) {
98
- switch (option_str[0]) {
99
- case 'r':
100
- case 'R':
101
- break;
102
- case 'c':
103
- case 'C':
104
- rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
- break;
106
- }
107
- }
108
-
109
- RB_GC_GUARD(val);
110
-
111
- return LAPACK_ROW_MAJOR;
112
- }
113
93
  };
114
94
 
115
95
  } // namespace TinyLinalg
@@ -16,7 +16,7 @@ struct SSyGv {
16
16
  }
17
17
  };
18
18
 
19
- template <int nary_dtype_id, typename DType, typename FncType>
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
20
  class SyGv {
21
21
  public:
22
22
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -32,15 +32,15 @@ private:
32
32
  };
33
33
 
34
34
  static void iter_sygv(na_loop_t* const lp) {
35
- DType* a = (DType*)NDL_PTR(lp, 0);
36
- DType* b = (DType*)NDL_PTR(lp, 1);
37
- DType* w = (DType*)NDL_PTR(lp, 2);
35
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
36
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
37
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
38
38
  int* info = (int*)NDL_PTR(lp, 3);
39
39
  sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
40
40
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
41
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
42
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
- const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
43
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
44
  *info = static_cast<int>(i);
45
45
  }
46
46
 
@@ -54,10 +54,10 @@ private:
54
54
  ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
55
  VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
56
  rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
- const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
- const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
57
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
61
 
62
62
  if (CLASS_OF(a_vnary) != nary_dtype) {
63
63
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -107,52 +107,6 @@ private:
107
107
 
108
108
  return ret;
109
109
  }
110
-
111
- static lapack_int get_itype(VALUE val) {
112
- const lapack_int itype = NUM2INT(val);
113
-
114
- if (itype != 1 && itype != 2 && itype != 3) {
115
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
- }
117
-
118
- return itype;
119
- }
120
-
121
- static char get_jobz(VALUE val) {
122
- const char jobz = NUM2CHR(val);
123
-
124
- if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
- }
127
-
128
- return jobz;
129
- }
130
-
131
- static char get_uplo(VALUE val) {
132
- const char uplo = NUM2CHR(val);
133
-
134
- if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
- }
137
-
138
- return uplo;
139
- }
140
-
141
- static int get_matrix_layout(VALUE val) {
142
- const char option = NUM2CHR(val);
143
-
144
- switch (option) {
145
- case 'r':
146
- case 'R':
147
- break;
148
- case 'c':
149
- case 'C':
150
- rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
151
- break;
152
- }
153
-
154
- return LAPACK_ROW_MAJOR;
155
- }
156
110
  };
157
111
 
158
112
  } // namespace TinyLinalg
@@ -16,7 +16,7 @@ struct SSyGvd {
16
16
  }
17
17
  };
18
18
 
19
- template <int nary_dtype_id, typename DType, typename FncType>
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
20
  class SyGvd {
21
21
  public:
22
22
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -32,15 +32,15 @@ private:
32
32
  };
33
33
 
34
34
  static void iter_sygvd(na_loop_t* const lp) {
35
- DType* a = (DType*)NDL_PTR(lp, 0);
36
- DType* b = (DType*)NDL_PTR(lp, 1);
37
- DType* w = (DType*)NDL_PTR(lp, 2);
35
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
36
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
37
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
38
38
  int* info = (int*)NDL_PTR(lp, 3);
39
39
  sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
40
40
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
41
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
42
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
- const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
43
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
44
  *info = static_cast<int>(i);
45
45
  }
46
46
 
@@ -54,10 +54,10 @@ private:
54
54
  ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
55
  VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
56
  rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
- const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
- const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
57
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
61
 
62
62
  if (CLASS_OF(a_vnary) != nary_dtype) {
63
63
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -107,52 +107,6 @@ private:
107
107
 
108
108
  return ret;
109
109
  }
110
-
111
- static lapack_int get_itype(VALUE val) {
112
- const lapack_int itype = NUM2INT(val);
113
-
114
- if (itype != 1 && itype != 2 && itype != 3) {
115
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
- }
117
-
118
- return itype;
119
- }
120
-
121
- static char get_jobz(VALUE val) {
122
- const char jobz = NUM2CHR(val);
123
-
124
- if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
- }
127
-
128
- return jobz;
129
- }
130
-
131
- static char get_uplo(VALUE val) {
132
- const char uplo = NUM2CHR(val);
133
-
134
- if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
- }
137
-
138
- return uplo;
139
- }
140
-
141
- static int get_matrix_layout(VALUE val) {
142
- const char option = NUM2CHR(val);
143
-
144
- switch (option) {
145
- case 'r':
146
- case 'R':
147
- break;
148
- case 'c':
149
- case 'C':
150
- rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
151
- break;
152
- }
153
-
154
- return LAPACK_ROW_MAJOR;
155
- }
156
110
  };
157
111
 
158
112
  } // namespace TinyLinalg