cumo 0.2.5 → 0.3.0.pre1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -142,6 +142,7 @@ Init_cumo_cuda_runtime()
142
142
  {
143
143
  VALUE mCumo = rb_define_module("Cumo");
144
144
  VALUE mCUDA = rb_define_module_under(mCumo, "CUDA");
145
+ rb_define_const(mCumo, "Cuda", mCUDA); // alias
145
146
  mRuntime = rb_define_module_under(mCUDA, "Runtime");
146
147
  eRuntimeError = rb_define_class_under(mCUDA, "RuntimeError", rb_eStandardError);
147
148
 
data/ext/cumo/cumo.c CHANGED
@@ -33,6 +33,8 @@ void Init_cumo_cuda_driver();
33
33
  void Init_cumo_cuda_memory_pool();
34
34
  void Init_cumo_cuda_runtime();
35
35
  void Init_cumo_cuda_nvrtc();
36
+ void Init_cumo_cuda_cublas();
37
+ void Init_cumo_cuda_cudnn();
36
38
 
37
39
  void
38
40
  cumo_debug_breakpoint(void)
@@ -167,4 +169,7 @@ Init_cumo()
167
169
  Init_cumo_cuda_memory_pool();
168
170
  Init_cumo_cuda_runtime();
169
171
  Init_cumo_cuda_nvrtc();
172
+
173
+ Init_cumo_cuda_cublas();
174
+ Init_cumo_cuda_cudnn();
170
175
  }
data/ext/cumo/extconf.rb CHANGED
@@ -47,9 +47,9 @@ rm_f 'include/cumo/extconf.h'
47
47
  MakeMakefileCuda.install!(cxx: true)
48
48
 
49
49
  if ENV['DEBUG']
50
- $CFLAGS="-g -O0 -Wall"
50
+ $CFLAGS << " -g -O0 -Wall"
51
51
  end
52
- $CXXFLAGS += " -std=c++14 "
52
+ $CXXFLAGS << " -std=c++14"
53
53
  #$CFLAGS=" $(cflags) -O3 -m64 -msse2 -funroll-loops"
54
54
  #$CFLAGS=" $(cflags) -O3"
55
55
  $INCFLAGS = "-Iinclude -Inarray -Icuda #{$INCFLAGS}"
@@ -109,6 +109,8 @@ cuda/memory_pool
109
109
  cuda/memory_pool_impl
110
110
  cuda/runtime
111
111
  cuda/nvrtc
112
+ cuda/cudnn
113
+ cuda/cudnn_impl
112
114
  )
113
115
 
114
116
  if RUBY_VERSION[0..3] == "2.1."
@@ -179,5 +181,9 @@ have_library('nvrtc')
179
181
  have_library('cublas')
180
182
  # have_library('cusolver')
181
183
  # have_library('curand')
184
+ if have_library('cudnn') # TODO(sonots): cuDNN version check
185
+ $CFLAGS << " -DCUDNN_FOUND"
186
+ $CXXFLAGS << " -DCUDNN_FOUND"
187
+ end
182
188
 
183
189
  create_makefile('cumo')
@@ -10,8 +10,8 @@ extern "C" {
10
10
  #endif
11
11
  #endif
12
12
 
13
- #define CUMO_VERSION "0.2.5"
14
- #define CUMO_VERSION_CODE 25
13
+ #define CUMO_VERSION "0.3.0.pre1"
14
+ #define CUMO_VERSION_CODE 301
15
15
 
16
16
  bool cumo_compatible_mode_enabled_p();
17
17
  bool cumo_show_warning_enabled_p();
@@ -0,0 +1,205 @@
1
+ #ifndef CUMO_CUDA_CUDNN_H
2
+ #define CUMO_CUDA_CUDNN_H
3
+
4
+ #include <ruby.h>
5
+ #ifdef CUDNN_FOUND
6
+ #include <cudnn.h>
7
+ #endif // CUDNN_FOUND
8
+
9
+ #if defined(__cplusplus)
10
+ extern "C" {
11
+ #if 0
12
+ } /* satisfy cc-mode */
13
+ #endif
14
+ #endif
15
+
16
+ #ifdef CUDNN_FOUND
17
+
18
+ VALUE cumo_na_eShapeError;
19
+
20
+ #define CUMO_CUDA_CUDNN_DEFAULT_MAX_WORKSPACE_SIZE 8 * 1024 * 1024
21
+
22
+ // TODO: Move to proper generic place
23
+ #define CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x,t) \
24
+ if (rb_obj_class(x)!=(t)) { \
25
+ rb_raise(rb_eTypeError,"invalid NArray type (class)"); \
26
+ }
27
+
28
+ // TODO: Move to proper generic place
29
+ #define CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(sz1,sz2) \
30
+ if ((sz1) != (sz2)) { \
31
+ rb_raise(cumo_na_eShapeError, \
32
+ "size mismatch: %d != %d", \
33
+ (int)(sz1), (int)(sz2)); \
34
+ }
35
+
36
+ // TODO: Move to proper generic place
37
+ #define CUMO_CUDA_CUDNN_CHECK_DIM_EQ(nd1,nd2) \
38
+ if ((nd1) != (nd2)) { \
39
+ rb_raise(cumo_na_eShapeError, \
40
+ "dimention mismatch: %d != %d", \
41
+ (int)(nd1), (int)(nd2)); \
42
+ }
43
+
44
+ void
45
+ cumo_cuda_cudnn_check_status(cudnnStatus_t status);
46
+
47
+ cudnnHandle_t
48
+ cumo_cuda_cudnn_handle();
49
+
50
+ // TODO: Move to more generic proper place
51
+ static inline VALUE
52
+ cumo_cuda_cudnn_option_value(VALUE value, VALUE default_value)
53
+ {
54
+ switch(TYPE(value)) {
55
+ case T_NIL:
56
+ case T_UNDEF:
57
+ return default_value;
58
+ }
59
+ return value;
60
+ }
61
+
62
+ // VALUE is Ruby Array
63
+ static inline void
64
+ cumo_cuda_cudnn_get_int_ary(int* int_ary, VALUE ary, size_t ndim, int default_value)
65
+ {
66
+ if (ary == Qnil) {
67
+ // default to 1
68
+ for (size_t idim = 0; idim < ndim; ++idim) {
69
+ int_ary[idim] = default_value;
70
+ }
71
+ } else if (TYPE(ary) == T_FIXNUM) {
72
+ for (size_t idim = 0; idim < ndim; ++idim) {
73
+ int_ary[idim] = NUM2INT(ary);
74
+ }
75
+ } else {
76
+ Check_Type(ary, T_ARRAY);
77
+ CUMO_CUDA_CUDNN_CHECK_DIM_EQ((size_t)(RARRAY_LEN(ary)), ndim);
78
+ for (size_t idim = 0; idim < ndim; ++idim) {
79
+ int_ary[idim] = NUM2INT(rb_ary_entry(ary, idim));
80
+ }
81
+ }
82
+ }
83
+
84
+ size_t
85
+ cumo_cuda_cudnn_GetConvOutDim(
86
+ size_t in_dim,
87
+ size_t kernel_size,
88
+ size_t stride,
89
+ size_t pad);
90
+
91
+ size_t
92
+ cumo_cuda_cudnn_GetConvTransposeOutDim(
93
+ size_t in_dim,
94
+ size_t kernel_size,
95
+ size_t stride,
96
+ size_t pad);
97
+
98
+ cudnnStatus_t
99
+ cumo_cuda_cudnn_CreateTensorDescriptor(
100
+ cudnnTensorDescriptor_t *desc,
101
+ VALUE a,
102
+ cudnnDataType_t cudnn_dtype);
103
+
104
+ cudnnStatus_t
105
+ cumo_cuda_cudnn_CreateFilterDescriptor(
106
+ cudnnFilterDescriptor_t *desc,
107
+ VALUE a,
108
+ cudnnDataType_t cudnn_dtype);
109
+
110
+ cudnnStatus_t
111
+ cumo_cuda_cudnn_CreateConvolutionDescriptor(
112
+ cudnnConvolutionDescriptor_t *desc,
113
+ size_t ndim,
114
+ int* int_stride,
115
+ int* int_pad,
116
+ cudnnDataType_t cudnn_dtype);
117
+
118
+ cudnnStatus_t
119
+ cumo_cuda_cudnn_CreatePoolingDescriptor(
120
+ cudnnPoolingDescriptor_t *desc,
121
+ cudnnPoolingMode_t mode,
122
+ size_t ndim,
123
+ int* int_kernel_size,
124
+ int* int_stride,
125
+ int* int_pad);
126
+
127
+ cudnnStatus_t
128
+ cumo_cuda_cudnn_FindConvolutionForwardAlgorithm(
129
+ cudnnConvolutionFwdAlgoPerf_t *perf_result,
130
+ cudnnHandle_t handle,
131
+ cudnnTensorDescriptor_t x_desc,
132
+ VALUE x,
133
+ cudnnFilterDescriptor_t w_desc,
134
+ VALUE w,
135
+ cudnnConvolutionDescriptor_t conv_desc,
136
+ cudnnTensorDescriptor_t y_sec,
137
+ VALUE y,
138
+ size_t max_workspace_size,
139
+ int* int_stride,
140
+ int* int_pad,
141
+ size_t ndim,
142
+ cudnnDataType_t cudnn_dtype);
143
+
144
+ cudnnStatus_t
145
+ cumo_cuda_cudnn_FindConvolutionBackwardDataAlgorithm(
146
+ cudnnConvolutionBwdDataAlgoPerf_t *perf_result,
147
+ cudnnHandle_t handle,
148
+ cudnnFilterDescriptor_t w_desc,
149
+ VALUE w,
150
+ cudnnTensorDescriptor_t x_desc,
151
+ VALUE x,
152
+ cudnnConvolutionDescriptor_t conv_desc,
153
+ cudnnTensorDescriptor_t y_desc,
154
+ VALUE y,
155
+ size_t max_workspace_size,
156
+ int* int_stride,
157
+ int* int_pad,
158
+ size_t ndim,
159
+ cudnnDataType_t cudnn_dtype);
160
+
161
+ cudnnStatus_t
162
+ cumo_cuda_cudnn_FindConvolutionBackwardFilterAlgorithm(
163
+ cudnnConvolutionBwdFilterAlgoPerf_t *perf_result,
164
+ cudnnHandle_t handle,
165
+ cudnnTensorDescriptor_t x_desc,
166
+ VALUE x,
167
+ cudnnTensorDescriptor_t dy_desc,
168
+ VALUE dy,
169
+ cudnnConvolutionDescriptor_t conv_desc,
170
+ cudnnFilterDescriptor_t dw_desc,
171
+ VALUE dw,
172
+ size_t max_workspace_size,
173
+ int* int_stride,
174
+ int* int_pad,
175
+ size_t ndim,
176
+ cudnnDataType_t cudnn_dtype);
177
+
178
+ cudnnBatchNormMode_t
179
+ cumo_cuda_cudnn_GetBatchNormMode(size_t ndim, int* int_axis);
180
+
181
+ cudnnStatus_t
182
+ cumo_cuda_cudnn_CreateBNTensorDescriptor(
183
+ cudnnTensorDescriptor_t *desc,
184
+ cudnnTensorDescriptor_t x_desc,
185
+ cudnnBatchNormMode_t mode);
186
+
187
+ size_t
188
+ cumo_cuda_cudnn_ReduceShape(
189
+ size_t *reduced_shape,
190
+ size_t shape_ndim,
191
+ size_t *shape,
192
+ size_t axes_ndim,
193
+ int *axes,
194
+ char keepdims);
195
+
196
+ #endif // CUDNN_FOUND
197
+
198
+ #if defined(__cplusplus)
199
+ #if 0
200
+ { /* satisfy cc-mode */
201
+ #endif
202
+ } /* extern "C" { */
203
+ #endif
204
+
205
+ #endif /* ifndef CUMO_CUDA_CUDNN_H */
@@ -0,0 +1,17 @@
1
+ #ifndef CUMO_HASH_COMBINE_H
2
+ #define CUMO_HASH_COMBINE_H
3
+
4
+ #include <cstddef>
5
+
6
+ namespace cumo {
7
+ namespace internal {
8
+
9
+ // Borrowed from boost::hash_combine
10
+ //
11
+ // TODO(sonots): hash combine in 64bit
12
+ inline void HashCombine(std::size_t& seed, std::size_t hash_value) { seed ^= hash_value + 0x9e3779b9 + (seed << 6) + (seed >> 2); }
13
+
14
+ } // namespace internal
15
+ } // namespace cumo
16
+
17
+ #endif /* ifndef CUMO_HASH_COMBINE_H */
@@ -26,11 +26,16 @@ char *cumo_na_get_pointer_for_write(VALUE);
26
26
  char *cumo_na_get_pointer_for_read(VALUE);
27
27
  char *cumo_na_get_pointer_for_read_write(VALUE);
28
28
  size_t cumo_na_get_offset(VALUE self);
29
+ char* cumo_na_get_offset_pointer(VALUE);
30
+ char* cumo_na_get_offset_pointer_for_write(VALUE);
31
+ char* cumo_na_get_offset_pointer_for_read(VALUE);
32
+ char* cumo_na_get_offset_pointer_for_read_write(VALUE);
29
33
 
30
34
  void cumo_na_copy_flags(VALUE src, VALUE dst);
31
35
 
32
36
  VALUE cumo_na_check_ladder(VALUE self, int start_dim);
33
37
  VALUE cumo_na_check_contiguous(VALUE self);
38
+ VALUE cumo_na_as_contiguous_array(VALUE a);
34
39
 
35
40
  VALUE cumo_na_flatten_dim(VALUE self, int sd);
36
41
 
@@ -7,6 +7,7 @@ typedef double rtype;
7
7
  #include "float_macro.h"
8
8
  #include "cublas_v2.h"
9
9
  #include "cumo/cuda/cublas.h"
10
+ #include "cumo/cuda/cudnn.h"
10
11
 
11
12
  #ifdef SFMT_H
12
13
  /* generates a random number on [0,1)-real-interval */
@@ -7,6 +7,7 @@ typedef float rtype;
7
7
  #include "float_macro.h"
8
8
  #include "cublas_v2.h"
9
9
  #include "cumo/cuda/cublas.h"
10
+ #include "cumo/cuda/cudnn.h"
10
11
 
11
12
  #ifdef SFMT_H
12
13
  /* generates a random number on [0,1)-real-interval */
@@ -53,6 +53,16 @@ end
53
53
  if (is_float || is_complex) && !is_object
54
54
  def_id "gemm"
55
55
  end
56
+ # cudnn
57
+ if is_float && !is_complex && !is_object
58
+ def_id "conv"
59
+ def_id "conv_transpose"
60
+ def_id "conv_grad_w"
61
+ def_id "batch_norm"
62
+ def_id "batch_norm_backward"
63
+ def_id "pooling_forward"
64
+ def_id "pooling_backward"
65
+ end
56
66
 
57
67
  if is_int && !is_object
58
68
  def_id "minlength" # for bincount
@@ -331,6 +341,17 @@ if (is_float || is_complex) && !is_object
331
341
  def_method "gemm"
332
342
  end
333
343
 
344
+ # cudnn
345
+ if is_float && !is_complex && !is_object
346
+ def_method "conv"
347
+ def_method "conv_transpose" # conv_backward_data
348
+ def_method "conv_grad_w" # conv_backward_filter
349
+ def_method "batch_norm"
350
+ def_method "batch_norm_backward"
351
+ def_method "pooling_forward" # max_pool, avg_pool
352
+ def_method "pooling_backward"
353
+ end
354
+
334
355
  # rmsdev
335
356
  # prod
336
357
 
@@ -0,0 +1,197 @@
1
+ #ifdef CUDNN_FOUND
2
+
3
+ <%
4
+ cudnn_dtype =
5
+ case type_name
6
+ when 'sfloat'
7
+ 'CUDNN_DATA_FLOAT'
8
+ when 'dfloat'
9
+ 'CUDNN_DATA_DOUBLE'
10
+ else
11
+ # CUDNN_DATA_HALF
12
+ raise 'not supported'
13
+ end
14
+ %>
15
+
16
+ // y = x.batch_norm(gamma, beta, running_mean:, running_var:, eps:, decay:, axis:, mean:, inv_std:)
17
+ static VALUE
18
+ <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
19
+ {
20
+ cudnnDataType_t cudnn_dtype = <%= cudnn_dtype %>;
21
+ cudnnStatus_t status = 0;
22
+ cudnnHandle_t handle = 0;
23
+ dtype coef_alpha = 1;
24
+ dtype coef_beta = 0;
25
+
26
+ VALUE x=self, gamma, beta, running_mean, running_var, eps, decay, axis, mean, inv_std, y;
27
+ VALUE kw_hash = Qnil;
28
+ ID kw_table[] = {
29
+ rb_intern("running_mean"),
30
+ rb_intern("running_var"),
31
+ rb_intern("mean"),
32
+ rb_intern("inv_std"),
33
+ rb_intern("eps"),
34
+ rb_intern("decay"),
35
+ rb_intern("axis"),
36
+ rb_intern("y")
37
+ };
38
+ VALUE opts[] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
39
+
40
+ cumo_narray_t *nx, *ngamma, *nbeta;
41
+ size_t *x_shape, *gamma_shape, *beta_shape, reduced_shape[CUMO_NA_MAX_DIMENSION];
42
+ size_t x_ndim, gamma_ndim, beta_ndim, reduced_ndim;
43
+
44
+ VALUE x_cont, gamma_cont, beta_cont;
45
+ cudnnTensorDescriptor_t x_desc = 0;
46
+ cudnnTensorDescriptor_t bn_desc = 0;
47
+ char *x_cont_ptr, *gamma_cont_ptr, *beta_cont_ptr, *y_ptr;
48
+
49
+ cudnnBatchNormMode_t mode;
50
+
51
+ // default values
52
+ char *running_mean_ptr=NULL;
53
+ char *running_var_ptr=NULL;
54
+ char *mean_ptr=NULL;
55
+ char *inv_std_ptr=NULL;
56
+ double double_eps = 2e-5;
57
+ double double_decay = 0.9;
58
+ int int_axis[CUMO_NA_MAX_DIMENSION] = {0};
59
+ size_t axis_ndim = 1;
60
+
61
+ rb_scan_args(argc, argv, "2:", &gamma, &beta, &kw_hash);
62
+ rb_get_kwargs(kw_hash, kw_table, 0, 8, opts);
63
+ running_mean = cumo_cuda_cudnn_option_value(opts[0], Qnil);
64
+ running_var = cumo_cuda_cudnn_option_value(opts[1], Qnil);
65
+ mean = cumo_cuda_cudnn_option_value(opts[2], Qnil);
66
+ inv_std = cumo_cuda_cudnn_option_value(opts[3], Qnil);
67
+ eps = cumo_cuda_cudnn_option_value(opts[4], Qnil);
68
+ decay = cumo_cuda_cudnn_option_value(opts[5], Qnil);
69
+ axis = cumo_cuda_cudnn_option_value(opts[6], Qnil);
70
+ y = cumo_cuda_cudnn_option_value(opts[7], Qnil);
71
+
72
+ if (running_mean != Qnil) {
73
+ running_mean_ptr = cumo_na_get_offset_pointer_for_write(running_mean);
74
+ }
75
+ if (running_var != Qnil) {
76
+ running_var_ptr = cumo_na_get_offset_pointer_for_write(running_var);
77
+ }
78
+ if (mean != Qnil) {
79
+ mean_ptr = cumo_na_get_offset_pointer_for_write(mean);
80
+ }
81
+ if (inv_std != Qnil) {
82
+ inv_std_ptr = cumo_na_get_offset_pointer_for_write(inv_std);
83
+ }
84
+ if (eps != Qnil) {
85
+ double_eps = NUM2DBL(eps);
86
+ }
87
+ if (decay != Qnil) {
88
+ double_decay = NUM2DBL(decay);
89
+ }
90
+ if (axis != Qnil) {
91
+ Check_Type(axis, T_ARRAY);
92
+ axis_ndim = (size_t)(RARRAY_LEN(axis));
93
+ for (size_t idim = 0; idim < axis_ndim; ++idim) {
94
+ int_axis[idim] = NUM2INT(rb_ary_entry(axis, (long)idim));
95
+ }
96
+ // TODO: check axis is sorted
97
+ }
98
+
99
+ CumoGetNArray(x, nx);
100
+ CumoGetNArray(gamma, ngamma);
101
+ CumoGetNArray(beta, nbeta);
102
+ x_ndim = nx->ndim;
103
+ x_shape = nx->shape;
104
+ gamma_ndim = ngamma->ndim;
105
+ gamma_shape = ngamma->shape;
106
+ beta_ndim = nbeta->ndim;
107
+ beta_shape = nbeta->shape;
108
+
109
+ // TODO: Size check of gammma, beta, running_mean, running_var, mean, inv_std
110
+ // are equivalent with either of reduced_shape(keepdims: false) or reduced_shape(keepdims: true)
111
+ reduced_ndim = cumo_cuda_cudnn_ReduceShape(reduced_shape, x_ndim, x_shape, axis_ndim, int_axis, 1);
112
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, gamma_ndim);
113
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, beta_ndim);
114
+ // for (size_t idim = 0; idim < reduced_ndim; ++idim) {
115
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], gamma_shape[idim]);
116
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], beta_shape[idim]);
117
+ // }
118
+
119
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x, cT);
120
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gamma, cT);
121
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(beta, cT);
122
+ if (running_mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(running_mean, cT);
123
+ if (running_var != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(running_var, cT);
124
+ if (mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(mean, cT);
125
+ if (inv_std != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(inv_std, cT);
126
+
127
+ x_cont = cumo_na_as_contiguous_array(x);
128
+ gamma_cont = cumo_na_as_contiguous_array(gamma);
129
+ beta_cont = cumo_na_as_contiguous_array(beta);
130
+ if (running_mean != Qnil && cumo_na_check_contiguous(running_mean) != Qtrue) {
131
+ rb_raise(rb_eRuntimeError, "running_mean must be contiguous");
132
+ }
133
+ if (running_var != Qnil && cumo_na_check_contiguous(running_var) != Qtrue) {
134
+ rb_raise(rb_eRuntimeError, "running_var must be contiguous");
135
+ }
136
+ if (mean != Qnil && cumo_na_check_contiguous(mean) != Qtrue) {
137
+ rb_raise(rb_eRuntimeError, "mean must be contiguous");
138
+ }
139
+ if (inv_std != Qnil && cumo_na_check_contiguous(inv_std) != Qtrue) {
140
+ rb_raise(rb_eRuntimeError, "inv_std must be contiguous");
141
+ }
142
+
143
+ x_cont_ptr = cumo_na_get_offset_pointer_for_read(x_cont);
144
+ gamma_cont_ptr = cumo_na_get_offset_pointer_for_read(gamma_cont);
145
+ beta_cont_ptr = cumo_na_get_offset_pointer_for_read(beta_cont);
146
+
147
+ // TODO: type and shape check
148
+ if (y == Qnil) y = cumo_na_new(cT, x_ndim, x_shape);
149
+ y_ptr = cumo_na_get_offset_pointer_for_write(y);
150
+
151
+ status = cumo_cuda_cudnn_CreateTensorDescriptor(&x_desc, x_cont, cudnn_dtype);
152
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
153
+
154
+ mode = cumo_cuda_cudnn_GetBatchNormMode(axis_ndim, int_axis);
155
+ status = cumo_cuda_cudnn_CreateBNTensorDescriptor(&bn_desc, x_desc, mode);
156
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
157
+ // TODO: bn_desc may return another type, and may need to cast gamma, beta, mean, var
158
+
159
+ handle = cumo_cuda_cudnn_handle();
160
+
161
+ status = cudnnBatchNormalizationForwardTraining(
162
+ handle,
163
+ mode,
164
+ (void*)&coef_alpha,
165
+ (void*)&coef_beta,
166
+ x_desc,
167
+ x_cont_ptr,
168
+ x_desc,
169
+ y_ptr,
170
+ bn_desc,
171
+ gamma_cont_ptr,
172
+ beta_cont_ptr,
173
+ 1.0 - double_decay,
174
+ running_mean_ptr,
175
+ running_var_ptr,
176
+ double_eps,
177
+ mean_ptr,
178
+ inv_std_ptr);
179
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
180
+
181
+ BATCH_NORM_ERROR:
182
+ if (x_desc) cudnnDestroyTensorDescriptor(x_desc);
183
+ if (bn_desc) cudnnDestroyTensorDescriptor(bn_desc);
184
+ cumo_cuda_cudnn_check_status(status);
185
+
186
+ return y;
187
+ }
188
+
189
+ #else // CUDNN_FOUND
190
+ VALUE cumo_cuda_eCUDNNError;
191
+
192
+ static VALUE
193
+ <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
194
+ {
195
+ rb_raise(cumo_cuda_eCUDNNError, "cuDNN is not available");
196
+ }
197
+ #endif // CUDNN_FOUND