cumo 0.2.5 → 0.3.0.pre1

Sign up to get free protection for your applications and to get access to all the features.
@@ -142,6 +142,7 @@ Init_cumo_cuda_runtime()
142
142
  {
143
143
  VALUE mCumo = rb_define_module("Cumo");
144
144
  VALUE mCUDA = rb_define_module_under(mCumo, "CUDA");
145
+ rb_define_const(mCumo, "Cuda", mCUDA); // alias
145
146
  mRuntime = rb_define_module_under(mCUDA, "Runtime");
146
147
  eRuntimeError = rb_define_class_under(mCUDA, "RuntimeError", rb_eStandardError);
147
148
 
data/ext/cumo/cumo.c CHANGED
@@ -33,6 +33,8 @@ void Init_cumo_cuda_driver();
33
33
  void Init_cumo_cuda_memory_pool();
34
34
  void Init_cumo_cuda_runtime();
35
35
  void Init_cumo_cuda_nvrtc();
36
+ void Init_cumo_cuda_cublas();
37
+ void Init_cumo_cuda_cudnn();
36
38
 
37
39
  void
38
40
  cumo_debug_breakpoint(void)
@@ -167,4 +169,7 @@ Init_cumo()
167
169
  Init_cumo_cuda_memory_pool();
168
170
  Init_cumo_cuda_runtime();
169
171
  Init_cumo_cuda_nvrtc();
172
+
173
+ Init_cumo_cuda_cublas();
174
+ Init_cumo_cuda_cudnn();
170
175
  }
data/ext/cumo/extconf.rb CHANGED
@@ -47,9 +47,9 @@ rm_f 'include/cumo/extconf.h'
47
47
  MakeMakefileCuda.install!(cxx: true)
48
48
 
49
49
  if ENV['DEBUG']
50
- $CFLAGS="-g -O0 -Wall"
50
+ $CFLAGS << " -g -O0 -Wall"
51
51
  end
52
- $CXXFLAGS += " -std=c++14 "
52
+ $CXXFLAGS << " -std=c++14"
53
53
  #$CFLAGS=" $(cflags) -O3 -m64 -msse2 -funroll-loops"
54
54
  #$CFLAGS=" $(cflags) -O3"
55
55
  $INCFLAGS = "-Iinclude -Inarray -Icuda #{$INCFLAGS}"
@@ -109,6 +109,8 @@ cuda/memory_pool
109
109
  cuda/memory_pool_impl
110
110
  cuda/runtime
111
111
  cuda/nvrtc
112
+ cuda/cudnn
113
+ cuda/cudnn_impl
112
114
  )
113
115
 
114
116
  if RUBY_VERSION[0..3] == "2.1."
@@ -179,5 +181,9 @@ have_library('nvrtc')
179
181
  have_library('cublas')
180
182
  # have_library('cusolver')
181
183
  # have_library('curand')
184
+ if have_library('cudnn') # TODO(sonots): cuDNN version check
185
+ $CFLAGS << " -DCUDNN_FOUND"
186
+ $CXXFLAGS << " -DCUDNN_FOUND"
187
+ end
182
188
 
183
189
  create_makefile('cumo')
@@ -10,8 +10,8 @@ extern "C" {
10
10
  #endif
11
11
  #endif
12
12
 
13
- #define CUMO_VERSION "0.2.5"
14
- #define CUMO_VERSION_CODE 25
13
+ #define CUMO_VERSION "0.3.0.pre1"
14
+ #define CUMO_VERSION_CODE 301
15
15
 
16
16
  bool cumo_compatible_mode_enabled_p();
17
17
  bool cumo_show_warning_enabled_p();
@@ -0,0 +1,205 @@
1
+ #ifndef CUMO_CUDA_CUDNN_H
2
+ #define CUMO_CUDA_CUDNN_H
3
+
4
+ #include <ruby.h>
5
+ #ifdef CUDNN_FOUND
6
+ #include <cudnn.h>
7
+ #endif // CUDNN_FOUND
8
+
9
+ #if defined(__cplusplus)
10
+ extern "C" {
11
+ #if 0
12
+ } /* satisfy cc-mode */
13
+ #endif
14
+ #endif
15
+
16
+ #ifdef CUDNN_FOUND
17
+
18
+ VALUE cumo_na_eShapeError;
19
+
20
+ #define CUMO_CUDA_CUDNN_DEFAULT_MAX_WORKSPACE_SIZE 8 * 1024 * 1024
21
+
22
+ // TODO: Move to proper generic place
23
+ #define CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x,t) \
24
+ if (rb_obj_class(x)!=(t)) { \
25
+ rb_raise(rb_eTypeError,"invalid NArray type (class)"); \
26
+ }
27
+
28
+ // TODO: Move to proper generic place
29
+ #define CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(sz1,sz2) \
30
+ if ((sz1) != (sz2)) { \
31
+ rb_raise(cumo_na_eShapeError, \
32
+ "size mismatch: %d != %d", \
33
+ (int)(sz1), (int)(sz2)); \
34
+ }
35
+
36
+ // TODO: Move to proper generic place
37
+ #define CUMO_CUDA_CUDNN_CHECK_DIM_EQ(nd1,nd2) \
38
+ if ((nd1) != (nd2)) { \
39
+ rb_raise(cumo_na_eShapeError, \
40
+ "dimention mismatch: %d != %d", \
41
+ (int)(nd1), (int)(nd2)); \
42
+ }
43
+
44
+ void
45
+ cumo_cuda_cudnn_check_status(cudnnStatus_t status);
46
+
47
+ cudnnHandle_t
48
+ cumo_cuda_cudnn_handle();
49
+
50
+ // TODO: Move to more generic proper place
51
+ static inline VALUE
52
+ cumo_cuda_cudnn_option_value(VALUE value, VALUE default_value)
53
+ {
54
+ switch(TYPE(value)) {
55
+ case T_NIL:
56
+ case T_UNDEF:
57
+ return default_value;
58
+ }
59
+ return value;
60
+ }
61
+
62
+ // VALUE is Ruby Array
63
+ static inline void
64
+ cumo_cuda_cudnn_get_int_ary(int* int_ary, VALUE ary, size_t ndim, int default_value)
65
+ {
66
+ if (ary == Qnil) {
67
+ // default to 1
68
+ for (size_t idim = 0; idim < ndim; ++idim) {
69
+ int_ary[idim] = default_value;
70
+ }
71
+ } else if (TYPE(ary) == T_FIXNUM) {
72
+ for (size_t idim = 0; idim < ndim; ++idim) {
73
+ int_ary[idim] = NUM2INT(ary);
74
+ }
75
+ } else {
76
+ Check_Type(ary, T_ARRAY);
77
+ CUMO_CUDA_CUDNN_CHECK_DIM_EQ((size_t)(RARRAY_LEN(ary)), ndim);
78
+ for (size_t idim = 0; idim < ndim; ++idim) {
79
+ int_ary[idim] = NUM2INT(rb_ary_entry(ary, idim));
80
+ }
81
+ }
82
+ }
83
+
84
+ size_t
85
+ cumo_cuda_cudnn_GetConvOutDim(
86
+ size_t in_dim,
87
+ size_t kernel_size,
88
+ size_t stride,
89
+ size_t pad);
90
+
91
+ size_t
92
+ cumo_cuda_cudnn_GetConvTransposeOutDim(
93
+ size_t in_dim,
94
+ size_t kernel_size,
95
+ size_t stride,
96
+ size_t pad);
97
+
98
+ cudnnStatus_t
99
+ cumo_cuda_cudnn_CreateTensorDescriptor(
100
+ cudnnTensorDescriptor_t *desc,
101
+ VALUE a,
102
+ cudnnDataType_t cudnn_dtype);
103
+
104
+ cudnnStatus_t
105
+ cumo_cuda_cudnn_CreateFilterDescriptor(
106
+ cudnnFilterDescriptor_t *desc,
107
+ VALUE a,
108
+ cudnnDataType_t cudnn_dtype);
109
+
110
+ cudnnStatus_t
111
+ cumo_cuda_cudnn_CreateConvolutionDescriptor(
112
+ cudnnConvolutionDescriptor_t *desc,
113
+ size_t ndim,
114
+ int* int_stride,
115
+ int* int_pad,
116
+ cudnnDataType_t cudnn_dtype);
117
+
118
+ cudnnStatus_t
119
+ cumo_cuda_cudnn_CreatePoolingDescriptor(
120
+ cudnnPoolingDescriptor_t *desc,
121
+ cudnnPoolingMode_t mode,
122
+ size_t ndim,
123
+ int* int_kernel_size,
124
+ int* int_stride,
125
+ int* int_pad);
126
+
127
+ cudnnStatus_t
128
+ cumo_cuda_cudnn_FindConvolutionForwardAlgorithm(
129
+ cudnnConvolutionFwdAlgoPerf_t *perf_result,
130
+ cudnnHandle_t handle,
131
+ cudnnTensorDescriptor_t x_desc,
132
+ VALUE x,
133
+ cudnnFilterDescriptor_t w_desc,
134
+ VALUE w,
135
+ cudnnConvolutionDescriptor_t conv_desc,
136
+ cudnnTensorDescriptor_t y_sec,
137
+ VALUE y,
138
+ size_t max_workspace_size,
139
+ int* int_stride,
140
+ int* int_pad,
141
+ size_t ndim,
142
+ cudnnDataType_t cudnn_dtype);
143
+
144
+ cudnnStatus_t
145
+ cumo_cuda_cudnn_FindConvolutionBackwardDataAlgorithm(
146
+ cudnnConvolutionBwdDataAlgoPerf_t *perf_result,
147
+ cudnnHandle_t handle,
148
+ cudnnFilterDescriptor_t w_desc,
149
+ VALUE w,
150
+ cudnnTensorDescriptor_t x_desc,
151
+ VALUE x,
152
+ cudnnConvolutionDescriptor_t conv_desc,
153
+ cudnnTensorDescriptor_t y_desc,
154
+ VALUE y,
155
+ size_t max_workspace_size,
156
+ int* int_stride,
157
+ int* int_pad,
158
+ size_t ndim,
159
+ cudnnDataType_t cudnn_dtype);
160
+
161
+ cudnnStatus_t
162
+ cumo_cuda_cudnn_FindConvolutionBackwardFilterAlgorithm(
163
+ cudnnConvolutionBwdFilterAlgoPerf_t *perf_result,
164
+ cudnnHandle_t handle,
165
+ cudnnTensorDescriptor_t x_desc,
166
+ VALUE x,
167
+ cudnnTensorDescriptor_t dy_desc,
168
+ VALUE dy,
169
+ cudnnConvolutionDescriptor_t conv_desc,
170
+ cudnnFilterDescriptor_t dw_desc,
171
+ VALUE dw,
172
+ size_t max_workspace_size,
173
+ int* int_stride,
174
+ int* int_pad,
175
+ size_t ndim,
176
+ cudnnDataType_t cudnn_dtype);
177
+
178
+ cudnnBatchNormMode_t
179
+ cumo_cuda_cudnn_GetBatchNormMode(size_t ndim, int* int_axis);
180
+
181
+ cudnnStatus_t
182
+ cumo_cuda_cudnn_CreateBNTensorDescriptor(
183
+ cudnnTensorDescriptor_t *desc,
184
+ cudnnTensorDescriptor_t x_desc,
185
+ cudnnBatchNormMode_t mode);
186
+
187
+ size_t
188
+ cumo_cuda_cudnn_ReduceShape(
189
+ size_t *reduced_shape,
190
+ size_t shape_ndim,
191
+ size_t *shape,
192
+ size_t axes_ndim,
193
+ int *axes,
194
+ char keepdims);
195
+
196
+ #endif // CUDNN_FOUND
197
+
198
+ #if defined(__cplusplus)
199
+ #if 0
200
+ { /* satisfy cc-mode */
201
+ #endif
202
+ } /* extern "C" { */
203
+ #endif
204
+
205
+ #endif /* ifndef CUMO_CUDA_CUDNN_H */
@@ -0,0 +1,17 @@
1
+ #ifndef CUMO_HASH_COMBINE_H
2
+ #define CUMO_HASH_COMBINE_H
3
+
4
+ #include <cstddef>
5
+
6
+ namespace cumo {
7
+ namespace internal {
8
+
9
+ // Borrowed from boost::hash_combine
10
+ //
11
+ // TODO(sonots): hash combine in 64bit
12
+ inline void HashCombine(std::size_t& seed, std::size_t hash_value) { seed ^= hash_value + 0x9e3779b9 + (seed << 6) + (seed >> 2); }
13
+
14
+ } // namespace internal
15
+ } // namespace cumo
16
+
17
+ #endif /* ifndef CUMO_HASH_COMBINE_H */
@@ -26,11 +26,16 @@ char *cumo_na_get_pointer_for_write(VALUE);
26
26
  char *cumo_na_get_pointer_for_read(VALUE);
27
27
  char *cumo_na_get_pointer_for_read_write(VALUE);
28
28
  size_t cumo_na_get_offset(VALUE self);
29
+ char* cumo_na_get_offset_pointer(VALUE);
30
+ char* cumo_na_get_offset_pointer_for_write(VALUE);
31
+ char* cumo_na_get_offset_pointer_for_read(VALUE);
32
+ char* cumo_na_get_offset_pointer_for_read_write(VALUE);
29
33
 
30
34
  void cumo_na_copy_flags(VALUE src, VALUE dst);
31
35
 
32
36
  VALUE cumo_na_check_ladder(VALUE self, int start_dim);
33
37
  VALUE cumo_na_check_contiguous(VALUE self);
38
+ VALUE cumo_na_as_contiguous_array(VALUE a);
34
39
 
35
40
  VALUE cumo_na_flatten_dim(VALUE self, int sd);
36
41
 
@@ -7,6 +7,7 @@ typedef double rtype;
7
7
  #include "float_macro.h"
8
8
  #include "cublas_v2.h"
9
9
  #include "cumo/cuda/cublas.h"
10
+ #include "cumo/cuda/cudnn.h"
10
11
 
11
12
  #ifdef SFMT_H
12
13
  /* generates a random number on [0,1)-real-interval */
@@ -7,6 +7,7 @@ typedef float rtype;
7
7
  #include "float_macro.h"
8
8
  #include "cublas_v2.h"
9
9
  #include "cumo/cuda/cublas.h"
10
+ #include "cumo/cuda/cudnn.h"
10
11
 
11
12
  #ifdef SFMT_H
12
13
  /* generates a random number on [0,1)-real-interval */
@@ -53,6 +53,16 @@ end
53
53
  if (is_float || is_complex) && !is_object
54
54
  def_id "gemm"
55
55
  end
56
+ # cudnn
57
+ if is_float && !is_complex && !is_object
58
+ def_id "conv"
59
+ def_id "conv_transpose"
60
+ def_id "conv_grad_w"
61
+ def_id "batch_norm"
62
+ def_id "batch_norm_backward"
63
+ def_id "pooling_forward"
64
+ def_id "pooling_backward"
65
+ end
56
66
 
57
67
  if is_int && !is_object
58
68
  def_id "minlength" # for bincount
@@ -331,6 +341,17 @@ if (is_float || is_complex) && !is_object
331
341
  def_method "gemm"
332
342
  end
333
343
 
344
+ # cudnn
345
+ if is_float && !is_complex && !is_object
346
+ def_method "conv"
347
+ def_method "conv_transpose" # conv_backward_data
348
+ def_method "conv_grad_w" # conv_backward_filter
349
+ def_method "batch_norm"
350
+ def_method "batch_norm_backward"
351
+ def_method "pooling_forward" # max_pool, avg_pool
352
+ def_method "pooling_backward"
353
+ end
354
+
334
355
  # rmsdev
335
356
  # prod
336
357
 
@@ -0,0 +1,197 @@
1
+ #ifdef CUDNN_FOUND
2
+
3
+ <%
4
+ cudnn_dtype =
5
+ case type_name
6
+ when 'sfloat'
7
+ 'CUDNN_DATA_FLOAT'
8
+ when 'dfloat'
9
+ 'CUDNN_DATA_DOUBLE'
10
+ else
11
+ # CUDNN_DATA_HALF
12
+ raise 'not supported'
13
+ end
14
+ %>
15
+
16
+ // y = x.batch_norm(gamma, beta, running_mean:, running_var:, eps:, decay:, axis:, mean:, inv_std:)
17
+ static VALUE
18
+ <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
19
+ {
20
+ cudnnDataType_t cudnn_dtype = <%= cudnn_dtype %>;
21
+ cudnnStatus_t status = 0;
22
+ cudnnHandle_t handle = 0;
23
+ dtype coef_alpha = 1;
24
+ dtype coef_beta = 0;
25
+
26
+ VALUE x=self, gamma, beta, running_mean, running_var, eps, decay, axis, mean, inv_std, y;
27
+ VALUE kw_hash = Qnil;
28
+ ID kw_table[] = {
29
+ rb_intern("running_mean"),
30
+ rb_intern("running_var"),
31
+ rb_intern("mean"),
32
+ rb_intern("inv_std"),
33
+ rb_intern("eps"),
34
+ rb_intern("decay"),
35
+ rb_intern("axis"),
36
+ rb_intern("y")
37
+ };
38
+ VALUE opts[] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
39
+
40
+ cumo_narray_t *nx, *ngamma, *nbeta;
41
+ size_t *x_shape, *gamma_shape, *beta_shape, reduced_shape[CUMO_NA_MAX_DIMENSION];
42
+ size_t x_ndim, gamma_ndim, beta_ndim, reduced_ndim;
43
+
44
+ VALUE x_cont, gamma_cont, beta_cont;
45
+ cudnnTensorDescriptor_t x_desc = 0;
46
+ cudnnTensorDescriptor_t bn_desc = 0;
47
+ char *x_cont_ptr, *gamma_cont_ptr, *beta_cont_ptr, *y_ptr;
48
+
49
+ cudnnBatchNormMode_t mode;
50
+
51
+ // default values
52
+ char *running_mean_ptr=NULL;
53
+ char *running_var_ptr=NULL;
54
+ char *mean_ptr=NULL;
55
+ char *inv_std_ptr=NULL;
56
+ double double_eps = 2e-5;
57
+ double double_decay = 0.9;
58
+ int int_axis[CUMO_NA_MAX_DIMENSION] = {0};
59
+ size_t axis_ndim = 1;
60
+
61
+ rb_scan_args(argc, argv, "2:", &gamma, &beta, &kw_hash);
62
+ rb_get_kwargs(kw_hash, kw_table, 0, 8, opts);
63
+ running_mean = cumo_cuda_cudnn_option_value(opts[0], Qnil);
64
+ running_var = cumo_cuda_cudnn_option_value(opts[1], Qnil);
65
+ mean = cumo_cuda_cudnn_option_value(opts[2], Qnil);
66
+ inv_std = cumo_cuda_cudnn_option_value(opts[3], Qnil);
67
+ eps = cumo_cuda_cudnn_option_value(opts[4], Qnil);
68
+ decay = cumo_cuda_cudnn_option_value(opts[5], Qnil);
69
+ axis = cumo_cuda_cudnn_option_value(opts[6], Qnil);
70
+ y = cumo_cuda_cudnn_option_value(opts[7], Qnil);
71
+
72
+ if (running_mean != Qnil) {
73
+ running_mean_ptr = cumo_na_get_offset_pointer_for_write(running_mean);
74
+ }
75
+ if (running_var != Qnil) {
76
+ running_var_ptr = cumo_na_get_offset_pointer_for_write(running_var);
77
+ }
78
+ if (mean != Qnil) {
79
+ mean_ptr = cumo_na_get_offset_pointer_for_write(mean);
80
+ }
81
+ if (inv_std != Qnil) {
82
+ inv_std_ptr = cumo_na_get_offset_pointer_for_write(inv_std);
83
+ }
84
+ if (eps != Qnil) {
85
+ double_eps = NUM2DBL(eps);
86
+ }
87
+ if (decay != Qnil) {
88
+ double_decay = NUM2DBL(decay);
89
+ }
90
+ if (axis != Qnil) {
91
+ Check_Type(axis, T_ARRAY);
92
+ axis_ndim = (size_t)(RARRAY_LEN(axis));
93
+ for (size_t idim = 0; idim < axis_ndim; ++idim) {
94
+ int_axis[idim] = NUM2INT(rb_ary_entry(axis, (long)idim));
95
+ }
96
+ // TODO: check axis is sorted
97
+ }
98
+
99
+ CumoGetNArray(x, nx);
100
+ CumoGetNArray(gamma, ngamma);
101
+ CumoGetNArray(beta, nbeta);
102
+ x_ndim = nx->ndim;
103
+ x_shape = nx->shape;
104
+ gamma_ndim = ngamma->ndim;
105
+ gamma_shape = ngamma->shape;
106
+ beta_ndim = nbeta->ndim;
107
+ beta_shape = nbeta->shape;
108
+
109
+ // TODO: Size check of gammma, beta, running_mean, running_var, mean, inv_std
110
+ // are equivalent with either of reduced_shape(keepdims: false) or reduced_shape(keepdims: true)
111
+ reduced_ndim = cumo_cuda_cudnn_ReduceShape(reduced_shape, x_ndim, x_shape, axis_ndim, int_axis, 1);
112
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, gamma_ndim);
113
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, beta_ndim);
114
+ // for (size_t idim = 0; idim < reduced_ndim; ++idim) {
115
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], gamma_shape[idim]);
116
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], beta_shape[idim]);
117
+ // }
118
+
119
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x, cT);
120
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gamma, cT);
121
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(beta, cT);
122
+ if (running_mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(running_mean, cT);
123
+ if (running_var != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(running_var, cT);
124
+ if (mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(mean, cT);
125
+ if (inv_std != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(inv_std, cT);
126
+
127
+ x_cont = cumo_na_as_contiguous_array(x);
128
+ gamma_cont = cumo_na_as_contiguous_array(gamma);
129
+ beta_cont = cumo_na_as_contiguous_array(beta);
130
+ if (running_mean != Qnil && cumo_na_check_contiguous(running_mean) != Qtrue) {
131
+ rb_raise(rb_eRuntimeError, "running_mean must be contiguous");
132
+ }
133
+ if (running_var != Qnil && cumo_na_check_contiguous(running_var) != Qtrue) {
134
+ rb_raise(rb_eRuntimeError, "running_var must be contiguous");
135
+ }
136
+ if (mean != Qnil && cumo_na_check_contiguous(mean) != Qtrue) {
137
+ rb_raise(rb_eRuntimeError, "mean must be contiguous");
138
+ }
139
+ if (inv_std != Qnil && cumo_na_check_contiguous(inv_std) != Qtrue) {
140
+ rb_raise(rb_eRuntimeError, "inv_std must be contiguous");
141
+ }
142
+
143
+ x_cont_ptr = cumo_na_get_offset_pointer_for_read(x_cont);
144
+ gamma_cont_ptr = cumo_na_get_offset_pointer_for_read(gamma_cont);
145
+ beta_cont_ptr = cumo_na_get_offset_pointer_for_read(beta_cont);
146
+
147
+ // TODO: type and shape check
148
+ if (y == Qnil) y = cumo_na_new(cT, x_ndim, x_shape);
149
+ y_ptr = cumo_na_get_offset_pointer_for_write(y);
150
+
151
+ status = cumo_cuda_cudnn_CreateTensorDescriptor(&x_desc, x_cont, cudnn_dtype);
152
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
153
+
154
+ mode = cumo_cuda_cudnn_GetBatchNormMode(axis_ndim, int_axis);
155
+ status = cumo_cuda_cudnn_CreateBNTensorDescriptor(&bn_desc, x_desc, mode);
156
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
157
+ // TODO: bn_desc may return another type, and may need to cast gamma, beta, mean, var
158
+
159
+ handle = cumo_cuda_cudnn_handle();
160
+
161
+ status = cudnnBatchNormalizationForwardTraining(
162
+ handle,
163
+ mode,
164
+ (void*)&coef_alpha,
165
+ (void*)&coef_beta,
166
+ x_desc,
167
+ x_cont_ptr,
168
+ x_desc,
169
+ y_ptr,
170
+ bn_desc,
171
+ gamma_cont_ptr,
172
+ beta_cont_ptr,
173
+ 1.0 - double_decay,
174
+ running_mean_ptr,
175
+ running_var_ptr,
176
+ double_eps,
177
+ mean_ptr,
178
+ inv_std_ptr);
179
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
180
+
181
+ BATCH_NORM_ERROR:
182
+ if (x_desc) cudnnDestroyTensorDescriptor(x_desc);
183
+ if (bn_desc) cudnnDestroyTensorDescriptor(bn_desc);
184
+ cumo_cuda_cudnn_check_status(status);
185
+
186
+ return y;
187
+ }
188
+
189
+ #else // CUDNN_FOUND
190
+ VALUE cumo_cuda_eCUDNNError;
191
+
192
+ static VALUE
193
+ <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
194
+ {
195
+ rb_raise(cumo_cuda_eCUDNNError, "cuDNN is not available");
196
+ }
197
+ #endif // CUDNN_FOUND