numo-tiny_linalg 0.0.4 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -24,7 +24,7 @@ struct CHeGv {
24
24
  }
25
25
  };
26
26
 
27
- template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
27
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
28
28
  class HeGv {
29
29
  public:
30
30
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -40,15 +40,15 @@ private:
40
40
  };
41
41
 
42
42
  static void iter_hegv(na_loop_t* const lp) {
43
- DType* a = (DType*)NDL_PTR(lp, 0);
44
- DType* b = (DType*)NDL_PTR(lp, 1);
45
- RType* w = (RType*)NDL_PTR(lp, 2);
43
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
44
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
46
46
  int* info = (int*)NDL_PTR(lp, 3);
47
47
  hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
48
48
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
49
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
50
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
- const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
51
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
52
  *info = static_cast<int>(i);
53
53
  }
54
54
 
@@ -63,10 +63,10 @@ private:
63
63
  ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
64
  VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
65
  rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
- const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
- const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
66
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
70
 
71
71
  if (CLASS_OF(a_vnary) != nary_dtype) {
72
72
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -116,52 +116,6 @@ private:
116
116
 
117
117
  return ret;
118
118
  }
119
-
120
- static lapack_int get_itype(VALUE val) {
121
- const lapack_int itype = NUM2INT(val);
122
-
123
- if (itype != 1 && itype != 2 && itype != 3) {
124
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
- }
126
-
127
- return itype;
128
- }
129
-
130
- static char get_jobz(VALUE val) {
131
- const char jobz = NUM2CHR(val);
132
-
133
- if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
- }
136
-
137
- return jobz;
138
- }
139
-
140
- static char get_uplo(VALUE val) {
141
- const char uplo = NUM2CHR(val);
142
-
143
- if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
- }
146
-
147
- return uplo;
148
- }
149
-
150
- static int get_matrix_layout(VALUE val) {
151
- const char option = NUM2CHR(val);
152
-
153
- switch (option) {
154
- case 'r':
155
- case 'R':
156
- break;
157
- case 'c':
158
- case 'C':
159
- rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
160
- break;
161
- }
162
-
163
- return LAPACK_ROW_MAJOR;
164
- }
165
119
  };
166
120
 
167
121
  } // namespace TinyLinalg
@@ -24,7 +24,7 @@ struct CHeGvd {
24
24
  }
25
25
  };
26
26
 
27
- template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
27
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
28
28
  class HeGvd {
29
29
  public:
30
30
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -40,15 +40,15 @@ private:
40
40
  };
41
41
 
42
42
  static void iter_hegvd(na_loop_t* const lp) {
43
- DType* a = (DType*)NDL_PTR(lp, 0);
44
- DType* b = (DType*)NDL_PTR(lp, 1);
45
- RType* w = (RType*)NDL_PTR(lp, 2);
43
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
44
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 2);
46
46
  int* info = (int*)NDL_PTR(lp, 3);
47
47
  hegvd_opt* opt = (hegvd_opt*)(lp->opt_ptr);
48
48
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
49
49
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
50
50
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
51
- const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
51
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
52
52
  *info = static_cast<int>(i);
53
53
  }
54
54
 
@@ -63,10 +63,10 @@ private:
63
63
  ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
64
64
  VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
65
65
  rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
66
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
67
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
68
- const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
69
- const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
66
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
67
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
68
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
69
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
70
70
 
71
71
  if (CLASS_OF(a_vnary) != nary_dtype) {
72
72
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -116,52 +116,6 @@ private:
116
116
 
117
117
  return ret;
118
118
  }
119
-
120
- static lapack_int get_itype(VALUE val) {
121
- const lapack_int itype = NUM2INT(val);
122
-
123
- if (itype != 1 && itype != 2 && itype != 3) {
124
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
125
- }
126
-
127
- return itype;
128
- }
129
-
130
- static char get_jobz(VALUE val) {
131
- const char jobz = NUM2CHR(val);
132
-
133
- if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
134
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
135
- }
136
-
137
- return jobz;
138
- }
139
-
140
- static char get_uplo(VALUE val) {
141
- const char uplo = NUM2CHR(val);
142
-
143
- if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
144
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
145
- }
146
-
147
- return uplo;
148
- }
149
-
150
- static int get_matrix_layout(VALUE val) {
151
- const char option = NUM2CHR(val);
152
-
153
- switch (option) {
154
- case 'r':
155
- case 'R':
156
- break;
157
- case 'c':
158
- case 'C':
159
- rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
160
- break;
161
- }
162
-
163
- return LAPACK_ROW_MAJOR;
164
- }
165
119
  };
166
120
 
167
121
  } // namespace TinyLinalg
@@ -18,7 +18,7 @@ struct CHeGvx {
18
18
  }
19
19
  };
20
20
 
21
- template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
21
+ template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
22
22
  class HeGvx {
23
23
  public:
24
24
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -32,18 +32,18 @@ private:
32
32
  char jobz;
33
33
  char range;
34
34
  char uplo;
35
- RType vl;
36
- RType vu;
35
+ rtype vl;
36
+ rtype vu;
37
37
  lapack_int il;
38
38
  lapack_int iu;
39
39
  };
40
40
 
41
41
  static void iter_hegvx(na_loop_t* const lp) {
42
- DType* a = (DType*)NDL_PTR(lp, 0);
43
- DType* b = (DType*)NDL_PTR(lp, 1);
42
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
43
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
44
44
  int* m = (int*)NDL_PTR(lp, 2);
45
- RType* w = (RType*)NDL_PTR(lp, 3);
46
- DType* z = (DType*)NDL_PTR(lp, 4);
45
+ rtype* w = (rtype*)NDL_PTR(lp, 3);
46
+ dtype* z = (dtype*)NDL_PTR(lp, 4);
47
47
  int* ifail = (int*)NDL_PTR(lp, 5);
48
48
  int* info = (int*)NDL_PTR(lp, 6);
49
49
  hegvx_opt* opt = (hegvx_opt*)(lp->opt_ptr);
@@ -51,8 +51,8 @@ private:
51
51
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
52
52
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
53
53
  const lapack_int ldz = opt->range != 'I' ? n : opt->iu - opt->il + 1;
54
- const RType abstol = 0.0;
55
- const lapack_int i = FncType().call(
54
+ const rtype abstol = 0.0;
55
+ const lapack_int i = LapackFn().call(
56
56
  opt->matrix_layout, opt->itype, opt->jobz, opt->range, opt->uplo, n, a, lda, b, ldb,
57
57
  opt->vl, opt->vu, opt->il, opt->iu, abstol, m, w, z, ldz, ifail);
58
58
  *info = static_cast<int>(i);
@@ -70,15 +70,15 @@ private:
70
70
  rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") };
71
71
  VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef };
72
72
  rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values);
73
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
74
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
75
- const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A';
76
- const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U';
77
- const RType vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
- const RType vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
73
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
74
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
75
+ const char range = kw_values[2] != Qundef ? Util().get_range(kw_values[2]) : 'A';
76
+ const char uplo = kw_values[3] != Qundef ? Util().get_uplo(kw_values[3]) : 'U';
77
+ const rtype vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0;
78
+ const rtype vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0;
79
79
  const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0;
80
80
  const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0;
81
- const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
81
+ const int matrix_layout = kw_values[8] != Qundef ? Util().get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR;
82
82
 
83
83
  if (CLASS_OF(a_vnary) != nary_dtype) {
84
84
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -132,62 +132,6 @@ private:
132
132
 
133
133
  return ret;
134
134
  }
135
-
136
- static lapack_int get_itype(VALUE val) {
137
- const lapack_int itype = NUM2INT(val);
138
-
139
- if (itype != 1 && itype != 2 && itype != 3) {
140
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
141
- }
142
-
143
- return itype;
144
- }
145
-
146
- static char get_jobz(VALUE val) {
147
- const char jobz = NUM2CHR(val);
148
-
149
- if (jobz != 'N' && jobz != 'V') {
150
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
151
- }
152
-
153
- return jobz;
154
- }
155
-
156
- static char get_range(VALUE val) {
157
- const char range = NUM2CHR(val);
158
-
159
- if (range != 'A' && range != 'V' && range != 'I') {
160
- rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'");
161
- }
162
-
163
- return range;
164
- }
165
-
166
- static char get_uplo(VALUE val) {
167
- const char uplo = NUM2CHR(val);
168
-
169
- if (uplo != 'U' && uplo != 'L') {
170
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
171
- }
172
-
173
- return uplo;
174
- }
175
-
176
- static int get_matrix_layout(VALUE val) {
177
- const char option = NUM2CHR(val);
178
-
179
- switch (option) {
180
- case 'r':
181
- case 'R':
182
- break;
183
- case 'c':
184
- case 'C':
185
- rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major.");
186
- break;
187
- }
188
-
189
- return LAPACK_ROW_MAJOR;
190
- }
191
135
  };
192
136
 
193
137
  } // namespace TinyLinalg
@@ -14,7 +14,7 @@ struct SOrgQr {
14
14
  }
15
15
  };
16
16
 
17
- template <int nary_dtype_id, typename DType, typename FncType>
17
+ template <int nary_dtype_id, typename dtype, class LapackFn>
18
18
  class OrgQr {
19
19
  public:
20
20
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -27,15 +27,15 @@ private:
27
27
  };
28
28
 
29
29
  static void iter_orgqr(na_loop_t* const lp) {
30
- DType* a = (DType*)NDL_PTR(lp, 0);
31
- DType* tau = (DType*)NDL_PTR(lp, 1);
30
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
31
+ dtype* tau = (dtype*)NDL_PTR(lp, 1);
32
32
  int* info = (int*)NDL_PTR(lp, 2);
33
33
  orgqr_opt* opt = (orgqr_opt*)(lp->opt_ptr);
34
34
  const lapack_int m = NDL_SHAPE(lp, 0)[0];
35
35
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
36
36
  const lapack_int k = NDL_SHAPE(lp, 1)[0];
37
37
  const lapack_int lda = n;
38
- const lapack_int i = FncType().call(opt->matrix_layout, m, n, k, a, lda, tau);
38
+ const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, k, a, lda, tau);
39
39
  *info = static_cast<int>(i);
40
40
  }
41
41
 
@@ -49,7 +49,7 @@ private:
49
49
  ID kw_table[1] = { rb_intern("order") };
50
50
  VALUE kw_values[1] = { Qundef };
51
51
  rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
52
- const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
52
+ const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
53
53
 
54
54
  if (CLASS_OF(a_vnary) != nary_dtype) {
55
55
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -90,26 +90,6 @@ private:
90
90
 
91
91
  return ret;
92
92
  }
93
-
94
- static int get_matrix_layout(VALUE val) {
95
- const char* option_str = StringValueCStr(val);
96
-
97
- if (std::strlen(option_str) > 0) {
98
- switch (option_str[0]) {
99
- case 'r':
100
- case 'R':
101
- break;
102
- case 'c':
103
- case 'C':
104
- rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
105
- break;
106
- }
107
- }
108
-
109
- RB_GC_GUARD(val);
110
-
111
- return LAPACK_ROW_MAJOR;
112
- }
113
93
  };
114
94
 
115
95
  } // namespace TinyLinalg
@@ -16,7 +16,7 @@ struct SSyGv {
16
16
  }
17
17
  };
18
18
 
19
- template <int nary_dtype_id, typename DType, typename FncType>
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
20
  class SyGv {
21
21
  public:
22
22
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -32,15 +32,15 @@ private:
32
32
  };
33
33
 
34
34
  static void iter_sygv(na_loop_t* const lp) {
35
- DType* a = (DType*)NDL_PTR(lp, 0);
36
- DType* b = (DType*)NDL_PTR(lp, 1);
37
- DType* w = (DType*)NDL_PTR(lp, 2);
35
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
36
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
37
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
38
38
  int* info = (int*)NDL_PTR(lp, 3);
39
39
  sygv_opt* opt = (sygv_opt*)(lp->opt_ptr);
40
40
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
41
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
42
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
- const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
43
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
44
  *info = static_cast<int>(i);
45
45
  }
46
46
 
@@ -54,10 +54,10 @@ private:
54
54
  ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
55
  VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
56
  rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
- const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
- const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
57
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
61
 
62
62
  if (CLASS_OF(a_vnary) != nary_dtype) {
63
63
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -107,52 +107,6 @@ private:
107
107
 
108
108
  return ret;
109
109
  }
110
-
111
- static lapack_int get_itype(VALUE val) {
112
- const lapack_int itype = NUM2INT(val);
113
-
114
- if (itype != 1 && itype != 2 && itype != 3) {
115
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
- }
117
-
118
- return itype;
119
- }
120
-
121
- static char get_jobz(VALUE val) {
122
- const char jobz = NUM2CHR(val);
123
-
124
- if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
- }
127
-
128
- return jobz;
129
- }
130
-
131
- static char get_uplo(VALUE val) {
132
- const char uplo = NUM2CHR(val);
133
-
134
- if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
- }
137
-
138
- return uplo;
139
- }
140
-
141
- static int get_matrix_layout(VALUE val) {
142
- const char option = NUM2CHR(val);
143
-
144
- switch (option) {
145
- case 'r':
146
- case 'R':
147
- break;
148
- case 'c':
149
- case 'C':
150
- rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
151
- break;
152
- }
153
-
154
- return LAPACK_ROW_MAJOR;
155
- }
156
110
  };
157
111
 
158
112
  } // namespace TinyLinalg
@@ -16,7 +16,7 @@ struct SSyGvd {
16
16
  }
17
17
  };
18
18
 
19
- template <int nary_dtype_id, typename DType, typename FncType>
19
+ template <int nary_dtype_id, typename dtype, class LapackFn>
20
20
  class SyGvd {
21
21
  public:
22
22
  static void define_module_function(VALUE mLapack, const char* fnc_name) {
@@ -32,15 +32,15 @@ private:
32
32
  };
33
33
 
34
34
  static void iter_sygvd(na_loop_t* const lp) {
35
- DType* a = (DType*)NDL_PTR(lp, 0);
36
- DType* b = (DType*)NDL_PTR(lp, 1);
37
- DType* w = (DType*)NDL_PTR(lp, 2);
35
+ dtype* a = (dtype*)NDL_PTR(lp, 0);
36
+ dtype* b = (dtype*)NDL_PTR(lp, 1);
37
+ dtype* w = (dtype*)NDL_PTR(lp, 2);
38
38
  int* info = (int*)NDL_PTR(lp, 3);
39
39
  sygvd_opt* opt = (sygvd_opt*)(lp->opt_ptr);
40
40
  const lapack_int n = NDL_SHAPE(lp, 0)[1];
41
41
  const lapack_int lda = NDL_SHAPE(lp, 0)[0];
42
42
  const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
43
- const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
43
+ const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
44
44
  *info = static_cast<int>(i);
45
45
  }
46
46
 
@@ -54,10 +54,10 @@ private:
54
54
  ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
55
55
  VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
56
56
  rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
57
- const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
58
- const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
59
- const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
60
- const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
57
+ const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
58
+ const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
59
+ const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
60
+ const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
61
61
 
62
62
  if (CLASS_OF(a_vnary) != nary_dtype) {
63
63
  a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
@@ -107,52 +107,6 @@ private:
107
107
 
108
108
  return ret;
109
109
  }
110
-
111
- static lapack_int get_itype(VALUE val) {
112
- const lapack_int itype = NUM2INT(val);
113
-
114
- if (itype != 1 && itype != 2 && itype != 3) {
115
- rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
116
- }
117
-
118
- return itype;
119
- }
120
-
121
- static char get_jobz(VALUE val) {
122
- const char jobz = NUM2CHR(val);
123
-
124
- if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
125
- rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
126
- }
127
-
128
- return jobz;
129
- }
130
-
131
- static char get_uplo(VALUE val) {
132
- const char uplo = NUM2CHR(val);
133
-
134
- if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
135
- rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
136
- }
137
-
138
- return uplo;
139
- }
140
-
141
- static int get_matrix_layout(VALUE val) {
142
- const char option = NUM2CHR(val);
143
-
144
- switch (option) {
145
- case 'r':
146
- case 'R':
147
- break;
148
- case 'c':
149
- case 'C':
150
- rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major.");
151
- break;
152
- }
153
-
154
- return LAPACK_ROW_MAJOR;
155
- }
156
110
  };
157
111
 
158
112
  } // namespace TinyLinalg