cumo 0.2.5 → 0.3.0.pre1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -1
- data/README.md +12 -1
- data/cumo.gemspec +1 -1
- data/ext/cumo/cuda/cudnn.c +80 -0
- data/ext/cumo/cuda/cudnn_impl.cpp +572 -0
- data/ext/cumo/cuda/runtime.c +1 -0
- data/ext/cumo/cumo.c +5 -0
- data/ext/cumo/extconf.rb +8 -2
- data/ext/cumo/include/cumo.h +2 -2
- data/ext/cumo/include/cumo/cuda/cudnn.h +205 -0
- data/ext/cumo/include/cumo/hash_combine.hpp +17 -0
- data/ext/cumo/include/cumo/intern.h +5 -0
- data/ext/cumo/include/cumo/types/dfloat.h +1 -0
- data/ext/cumo/include/cumo/types/sfloat.h +1 -0
- data/ext/cumo/narray/gen/spec.rb +21 -0
- data/ext/cumo/narray/gen/tmpl/batch_norm.c +197 -0
- data/ext/cumo/narray/gen/tmpl/batch_norm_backward.c +191 -0
- data/ext/cumo/narray/gen/tmpl/conv.c +216 -0
- data/ext/cumo/narray/gen/tmpl/conv_grad_w.c +183 -0
- data/ext/cumo/narray/gen/tmpl/conv_transpose.c +244 -0
- data/ext/cumo/narray/gen/tmpl/gemm.c +14 -0
- data/ext/cumo/narray/gen/tmpl/pooling_backward.c +136 -0
- data/ext/cumo/narray/gen/tmpl/pooling_forward.c +136 -0
- data/ext/cumo/narray/narray.c +29 -0
- data/lib/cumo/cuda.rb +1 -0
- data/lib/cumo/cuda/cudnn.rb +88 -0
- metadata +18 -5
data/ext/cumo/cuda/runtime.c
CHANGED
@@ -142,6 +142,7 @@ Init_cumo_cuda_runtime()
|
|
142
142
|
{
|
143
143
|
VALUE mCumo = rb_define_module("Cumo");
|
144
144
|
VALUE mCUDA = rb_define_module_under(mCumo, "CUDA");
|
145
|
+
rb_define_const(mCumo, "Cuda", mCUDA); // alias
|
145
146
|
mRuntime = rb_define_module_under(mCUDA, "Runtime");
|
146
147
|
eRuntimeError = rb_define_class_under(mCUDA, "RuntimeError", rb_eStandardError);
|
147
148
|
|
data/ext/cumo/cumo.c
CHANGED
@@ -33,6 +33,8 @@ void Init_cumo_cuda_driver();
|
|
33
33
|
void Init_cumo_cuda_memory_pool();
|
34
34
|
void Init_cumo_cuda_runtime();
|
35
35
|
void Init_cumo_cuda_nvrtc();
|
36
|
+
void Init_cumo_cuda_cublas();
|
37
|
+
void Init_cumo_cuda_cudnn();
|
36
38
|
|
37
39
|
void
|
38
40
|
cumo_debug_breakpoint(void)
|
@@ -167,4 +169,7 @@ Init_cumo()
|
|
167
169
|
Init_cumo_cuda_memory_pool();
|
168
170
|
Init_cumo_cuda_runtime();
|
169
171
|
Init_cumo_cuda_nvrtc();
|
172
|
+
|
173
|
+
Init_cumo_cuda_cublas();
|
174
|
+
Init_cumo_cuda_cudnn();
|
170
175
|
}
|
data/ext/cumo/extconf.rb
CHANGED
@@ -47,9 +47,9 @@ rm_f 'include/cumo/extconf.h'
|
|
47
47
|
MakeMakefileCuda.install!(cxx: true)
|
48
48
|
|
49
49
|
if ENV['DEBUG']
|
50
|
-
$CFLAGS
|
50
|
+
$CFLAGS << " -g -O0 -Wall"
|
51
51
|
end
|
52
|
-
$CXXFLAGS
|
52
|
+
$CXXFLAGS << " -std=c++14"
|
53
53
|
#$CFLAGS=" $(cflags) -O3 -m64 -msse2 -funroll-loops"
|
54
54
|
#$CFLAGS=" $(cflags) -O3"
|
55
55
|
$INCFLAGS = "-Iinclude -Inarray -Icuda #{$INCFLAGS}"
|
@@ -109,6 +109,8 @@ cuda/memory_pool
|
|
109
109
|
cuda/memory_pool_impl
|
110
110
|
cuda/runtime
|
111
111
|
cuda/nvrtc
|
112
|
+
cuda/cudnn
|
113
|
+
cuda/cudnn_impl
|
112
114
|
)
|
113
115
|
|
114
116
|
if RUBY_VERSION[0..3] == "2.1."
|
@@ -179,5 +181,9 @@ have_library('nvrtc')
|
|
179
181
|
have_library('cublas')
|
180
182
|
# have_library('cusolver')
|
181
183
|
# have_library('curand')
|
184
|
+
if have_library('cudnn') # TODO(sonots): cuDNN version check
|
185
|
+
$CFLAGS << " -DCUDNN_FOUND"
|
186
|
+
$CXXFLAGS << " -DCUDNN_FOUND"
|
187
|
+
end
|
182
188
|
|
183
189
|
create_makefile('cumo')
|
data/ext/cumo/include/cumo.h
CHANGED
@@ -0,0 +1,205 @@
|
|
1
|
+
#ifndef CUMO_CUDA_CUDNN_H
|
2
|
+
#define CUMO_CUDA_CUDNN_H
|
3
|
+
|
4
|
+
#include <ruby.h>
|
5
|
+
#ifdef CUDNN_FOUND
|
6
|
+
#include <cudnn.h>
|
7
|
+
#endif // CUDNN_FOUND
|
8
|
+
|
9
|
+
#if defined(__cplusplus)
|
10
|
+
extern "C" {
|
11
|
+
#if 0
|
12
|
+
} /* satisfy cc-mode */
|
13
|
+
#endif
|
14
|
+
#endif
|
15
|
+
|
16
|
+
#ifdef CUDNN_FOUND
|
17
|
+
|
18
|
+
VALUE cumo_na_eShapeError;
|
19
|
+
|
20
|
+
#define CUMO_CUDA_CUDNN_DEFAULT_MAX_WORKSPACE_SIZE 8 * 1024 * 1024
|
21
|
+
|
22
|
+
// TODO: Move to proper generic place
|
23
|
+
#define CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x,t) \
|
24
|
+
if (rb_obj_class(x)!=(t)) { \
|
25
|
+
rb_raise(rb_eTypeError,"invalid NArray type (class)"); \
|
26
|
+
}
|
27
|
+
|
28
|
+
// TODO: Move to proper generic place
|
29
|
+
#define CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(sz1,sz2) \
|
30
|
+
if ((sz1) != (sz2)) { \
|
31
|
+
rb_raise(cumo_na_eShapeError, \
|
32
|
+
"size mismatch: %d != %d", \
|
33
|
+
(int)(sz1), (int)(sz2)); \
|
34
|
+
}
|
35
|
+
|
36
|
+
// TODO: Move to proper generic place
|
37
|
+
#define CUMO_CUDA_CUDNN_CHECK_DIM_EQ(nd1,nd2) \
|
38
|
+
if ((nd1) != (nd2)) { \
|
39
|
+
rb_raise(cumo_na_eShapeError, \
|
40
|
+
"dimention mismatch: %d != %d", \
|
41
|
+
(int)(nd1), (int)(nd2)); \
|
42
|
+
}
|
43
|
+
|
44
|
+
void
|
45
|
+
cumo_cuda_cudnn_check_status(cudnnStatus_t status);
|
46
|
+
|
47
|
+
cudnnHandle_t
|
48
|
+
cumo_cuda_cudnn_handle();
|
49
|
+
|
50
|
+
// TODO: Move to more generic proper place
|
51
|
+
static inline VALUE
|
52
|
+
cumo_cuda_cudnn_option_value(VALUE value, VALUE default_value)
|
53
|
+
{
|
54
|
+
switch(TYPE(value)) {
|
55
|
+
case T_NIL:
|
56
|
+
case T_UNDEF:
|
57
|
+
return default_value;
|
58
|
+
}
|
59
|
+
return value;
|
60
|
+
}
|
61
|
+
|
62
|
+
// VALUE is Ruby Array
|
63
|
+
static inline void
|
64
|
+
cumo_cuda_cudnn_get_int_ary(int* int_ary, VALUE ary, size_t ndim, int default_value)
|
65
|
+
{
|
66
|
+
if (ary == Qnil) {
|
67
|
+
// default to 1
|
68
|
+
for (size_t idim = 0; idim < ndim; ++idim) {
|
69
|
+
int_ary[idim] = default_value;
|
70
|
+
}
|
71
|
+
} else if (TYPE(ary) == T_FIXNUM) {
|
72
|
+
for (size_t idim = 0; idim < ndim; ++idim) {
|
73
|
+
int_ary[idim] = NUM2INT(ary);
|
74
|
+
}
|
75
|
+
} else {
|
76
|
+
Check_Type(ary, T_ARRAY);
|
77
|
+
CUMO_CUDA_CUDNN_CHECK_DIM_EQ((size_t)(RARRAY_LEN(ary)), ndim);
|
78
|
+
for (size_t idim = 0; idim < ndim; ++idim) {
|
79
|
+
int_ary[idim] = NUM2INT(rb_ary_entry(ary, idim));
|
80
|
+
}
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
size_t
|
85
|
+
cumo_cuda_cudnn_GetConvOutDim(
|
86
|
+
size_t in_dim,
|
87
|
+
size_t kernel_size,
|
88
|
+
size_t stride,
|
89
|
+
size_t pad);
|
90
|
+
|
91
|
+
size_t
|
92
|
+
cumo_cuda_cudnn_GetConvTransposeOutDim(
|
93
|
+
size_t in_dim,
|
94
|
+
size_t kernel_size,
|
95
|
+
size_t stride,
|
96
|
+
size_t pad);
|
97
|
+
|
98
|
+
cudnnStatus_t
|
99
|
+
cumo_cuda_cudnn_CreateTensorDescriptor(
|
100
|
+
cudnnTensorDescriptor_t *desc,
|
101
|
+
VALUE a,
|
102
|
+
cudnnDataType_t cudnn_dtype);
|
103
|
+
|
104
|
+
cudnnStatus_t
|
105
|
+
cumo_cuda_cudnn_CreateFilterDescriptor(
|
106
|
+
cudnnFilterDescriptor_t *desc,
|
107
|
+
VALUE a,
|
108
|
+
cudnnDataType_t cudnn_dtype);
|
109
|
+
|
110
|
+
cudnnStatus_t
|
111
|
+
cumo_cuda_cudnn_CreateConvolutionDescriptor(
|
112
|
+
cudnnConvolutionDescriptor_t *desc,
|
113
|
+
size_t ndim,
|
114
|
+
int* int_stride,
|
115
|
+
int* int_pad,
|
116
|
+
cudnnDataType_t cudnn_dtype);
|
117
|
+
|
118
|
+
cudnnStatus_t
|
119
|
+
cumo_cuda_cudnn_CreatePoolingDescriptor(
|
120
|
+
cudnnPoolingDescriptor_t *desc,
|
121
|
+
cudnnPoolingMode_t mode,
|
122
|
+
size_t ndim,
|
123
|
+
int* int_kernel_size,
|
124
|
+
int* int_stride,
|
125
|
+
int* int_pad);
|
126
|
+
|
127
|
+
cudnnStatus_t
|
128
|
+
cumo_cuda_cudnn_FindConvolutionForwardAlgorithm(
|
129
|
+
cudnnConvolutionFwdAlgoPerf_t *perf_result,
|
130
|
+
cudnnHandle_t handle,
|
131
|
+
cudnnTensorDescriptor_t x_desc,
|
132
|
+
VALUE x,
|
133
|
+
cudnnFilterDescriptor_t w_desc,
|
134
|
+
VALUE w,
|
135
|
+
cudnnConvolutionDescriptor_t conv_desc,
|
136
|
+
cudnnTensorDescriptor_t y_sec,
|
137
|
+
VALUE y,
|
138
|
+
size_t max_workspace_size,
|
139
|
+
int* int_stride,
|
140
|
+
int* int_pad,
|
141
|
+
size_t ndim,
|
142
|
+
cudnnDataType_t cudnn_dtype);
|
143
|
+
|
144
|
+
cudnnStatus_t
|
145
|
+
cumo_cuda_cudnn_FindConvolutionBackwardDataAlgorithm(
|
146
|
+
cudnnConvolutionBwdDataAlgoPerf_t *perf_result,
|
147
|
+
cudnnHandle_t handle,
|
148
|
+
cudnnFilterDescriptor_t w_desc,
|
149
|
+
VALUE w,
|
150
|
+
cudnnTensorDescriptor_t x_desc,
|
151
|
+
VALUE x,
|
152
|
+
cudnnConvolutionDescriptor_t conv_desc,
|
153
|
+
cudnnTensorDescriptor_t y_desc,
|
154
|
+
VALUE y,
|
155
|
+
size_t max_workspace_size,
|
156
|
+
int* int_stride,
|
157
|
+
int* int_pad,
|
158
|
+
size_t ndim,
|
159
|
+
cudnnDataType_t cudnn_dtype);
|
160
|
+
|
161
|
+
cudnnStatus_t
|
162
|
+
cumo_cuda_cudnn_FindConvolutionBackwardFilterAlgorithm(
|
163
|
+
cudnnConvolutionBwdFilterAlgoPerf_t *perf_result,
|
164
|
+
cudnnHandle_t handle,
|
165
|
+
cudnnTensorDescriptor_t x_desc,
|
166
|
+
VALUE x,
|
167
|
+
cudnnTensorDescriptor_t dy_desc,
|
168
|
+
VALUE dy,
|
169
|
+
cudnnConvolutionDescriptor_t conv_desc,
|
170
|
+
cudnnFilterDescriptor_t dw_desc,
|
171
|
+
VALUE dw,
|
172
|
+
size_t max_workspace_size,
|
173
|
+
int* int_stride,
|
174
|
+
int* int_pad,
|
175
|
+
size_t ndim,
|
176
|
+
cudnnDataType_t cudnn_dtype);
|
177
|
+
|
178
|
+
cudnnBatchNormMode_t
|
179
|
+
cumo_cuda_cudnn_GetBatchNormMode(size_t ndim, int* int_axis);
|
180
|
+
|
181
|
+
cudnnStatus_t
|
182
|
+
cumo_cuda_cudnn_CreateBNTensorDescriptor(
|
183
|
+
cudnnTensorDescriptor_t *desc,
|
184
|
+
cudnnTensorDescriptor_t x_desc,
|
185
|
+
cudnnBatchNormMode_t mode);
|
186
|
+
|
187
|
+
size_t
|
188
|
+
cumo_cuda_cudnn_ReduceShape(
|
189
|
+
size_t *reduced_shape,
|
190
|
+
size_t shape_ndim,
|
191
|
+
size_t *shape,
|
192
|
+
size_t axes_ndim,
|
193
|
+
int *axes,
|
194
|
+
char keepdims);
|
195
|
+
|
196
|
+
#endif // CUDNN_FOUND
|
197
|
+
|
198
|
+
#if defined(__cplusplus)
|
199
|
+
#if 0
|
200
|
+
{ /* satisfy cc-mode */
|
201
|
+
#endif
|
202
|
+
} /* extern "C" { */
|
203
|
+
#endif
|
204
|
+
|
205
|
+
#endif /* ifndef CUMO_CUDA_CUDNN_H */
|
@@ -0,0 +1,17 @@
|
|
1
|
+
#ifndef CUMO_HASH_COMBINE_H
|
2
|
+
#define CUMO_HASH_COMBINE_H
|
3
|
+
|
4
|
+
#include <cstddef>
|
5
|
+
|
6
|
+
namespace cumo {
|
7
|
+
namespace internal {
|
8
|
+
|
9
|
+
// Borrowed from boost::hash_combine
|
10
|
+
//
|
11
|
+
// TODO(sonots): hash combine in 64bit
|
12
|
+
inline void HashCombine(std::size_t& seed, std::size_t hash_value) { seed ^= hash_value + 0x9e3779b9 + (seed << 6) + (seed >> 2); }
|
13
|
+
|
14
|
+
} // namespace internal
|
15
|
+
} // namespace cumo
|
16
|
+
|
17
|
+
#endif /* ifndef CUMO_HASH_COMBINE_H */
|
@@ -26,11 +26,16 @@ char *cumo_na_get_pointer_for_write(VALUE);
|
|
26
26
|
char *cumo_na_get_pointer_for_read(VALUE);
|
27
27
|
char *cumo_na_get_pointer_for_read_write(VALUE);
|
28
28
|
size_t cumo_na_get_offset(VALUE self);
|
29
|
+
char* cumo_na_get_offset_pointer(VALUE);
|
30
|
+
char* cumo_na_get_offset_pointer_for_write(VALUE);
|
31
|
+
char* cumo_na_get_offset_pointer_for_read(VALUE);
|
32
|
+
char* cumo_na_get_offset_pointer_for_read_write(VALUE);
|
29
33
|
|
30
34
|
void cumo_na_copy_flags(VALUE src, VALUE dst);
|
31
35
|
|
32
36
|
VALUE cumo_na_check_ladder(VALUE self, int start_dim);
|
33
37
|
VALUE cumo_na_check_contiguous(VALUE self);
|
38
|
+
VALUE cumo_na_as_contiguous_array(VALUE a);
|
34
39
|
|
35
40
|
VALUE cumo_na_flatten_dim(VALUE self, int sd);
|
36
41
|
|
data/ext/cumo/narray/gen/spec.rb
CHANGED
@@ -53,6 +53,16 @@ end
|
|
53
53
|
if (is_float || is_complex) && !is_object
|
54
54
|
def_id "gemm"
|
55
55
|
end
|
56
|
+
# cudnn
|
57
|
+
if is_float && !is_complex && !is_object
|
58
|
+
def_id "conv"
|
59
|
+
def_id "conv_transpose"
|
60
|
+
def_id "conv_grad_w"
|
61
|
+
def_id "batch_norm"
|
62
|
+
def_id "batch_norm_backward"
|
63
|
+
def_id "pooling_forward"
|
64
|
+
def_id "pooling_backward"
|
65
|
+
end
|
56
66
|
|
57
67
|
if is_int && !is_object
|
58
68
|
def_id "minlength" # for bincount
|
@@ -331,6 +341,17 @@ if (is_float || is_complex) && !is_object
|
|
331
341
|
def_method "gemm"
|
332
342
|
end
|
333
343
|
|
344
|
+
# cudnn
|
345
|
+
if is_float && !is_complex && !is_object
|
346
|
+
def_method "conv"
|
347
|
+
def_method "conv_transpose" # conv_backward_data
|
348
|
+
def_method "conv_grad_w" # conv_backward_filter
|
349
|
+
def_method "batch_norm"
|
350
|
+
def_method "batch_norm_backward"
|
351
|
+
def_method "pooling_forward" # max_pool, avg_pool
|
352
|
+
def_method "pooling_backward"
|
353
|
+
end
|
354
|
+
|
334
355
|
# rmsdev
|
335
356
|
# prod
|
336
357
|
|
@@ -0,0 +1,197 @@
|
|
1
|
+
#ifdef CUDNN_FOUND
|
2
|
+
|
3
|
+
<%
|
4
|
+
cudnn_dtype =
|
5
|
+
case type_name
|
6
|
+
when 'sfloat'
|
7
|
+
'CUDNN_DATA_FLOAT'
|
8
|
+
when 'dfloat'
|
9
|
+
'CUDNN_DATA_DOUBLE'
|
10
|
+
else
|
11
|
+
# CUDNN_DATA_HALF
|
12
|
+
raise 'not supported'
|
13
|
+
end
|
14
|
+
%>
|
15
|
+
|
16
|
+
// y = x.batch_norm(gamma, beta, running_mean:, running_var:, eps:, decay:, axis:, mean:, inv_std:)
|
17
|
+
static VALUE
|
18
|
+
<%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
|
19
|
+
{
|
20
|
+
cudnnDataType_t cudnn_dtype = <%= cudnn_dtype %>;
|
21
|
+
cudnnStatus_t status = 0;
|
22
|
+
cudnnHandle_t handle = 0;
|
23
|
+
dtype coef_alpha = 1;
|
24
|
+
dtype coef_beta = 0;
|
25
|
+
|
26
|
+
VALUE x=self, gamma, beta, running_mean, running_var, eps, decay, axis, mean, inv_std, y;
|
27
|
+
VALUE kw_hash = Qnil;
|
28
|
+
ID kw_table[] = {
|
29
|
+
rb_intern("running_mean"),
|
30
|
+
rb_intern("running_var"),
|
31
|
+
rb_intern("mean"),
|
32
|
+
rb_intern("inv_std"),
|
33
|
+
rb_intern("eps"),
|
34
|
+
rb_intern("decay"),
|
35
|
+
rb_intern("axis"),
|
36
|
+
rb_intern("y")
|
37
|
+
};
|
38
|
+
VALUE opts[] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
39
|
+
|
40
|
+
cumo_narray_t *nx, *ngamma, *nbeta;
|
41
|
+
size_t *x_shape, *gamma_shape, *beta_shape, reduced_shape[CUMO_NA_MAX_DIMENSION];
|
42
|
+
size_t x_ndim, gamma_ndim, beta_ndim, reduced_ndim;
|
43
|
+
|
44
|
+
VALUE x_cont, gamma_cont, beta_cont;
|
45
|
+
cudnnTensorDescriptor_t x_desc = 0;
|
46
|
+
cudnnTensorDescriptor_t bn_desc = 0;
|
47
|
+
char *x_cont_ptr, *gamma_cont_ptr, *beta_cont_ptr, *y_ptr;
|
48
|
+
|
49
|
+
cudnnBatchNormMode_t mode;
|
50
|
+
|
51
|
+
// default values
|
52
|
+
char *running_mean_ptr=NULL;
|
53
|
+
char *running_var_ptr=NULL;
|
54
|
+
char *mean_ptr=NULL;
|
55
|
+
char *inv_std_ptr=NULL;
|
56
|
+
double double_eps = 2e-5;
|
57
|
+
double double_decay = 0.9;
|
58
|
+
int int_axis[CUMO_NA_MAX_DIMENSION] = {0};
|
59
|
+
size_t axis_ndim = 1;
|
60
|
+
|
61
|
+
rb_scan_args(argc, argv, "2:", &gamma, &beta, &kw_hash);
|
62
|
+
rb_get_kwargs(kw_hash, kw_table, 0, 8, opts);
|
63
|
+
running_mean = cumo_cuda_cudnn_option_value(opts[0], Qnil);
|
64
|
+
running_var = cumo_cuda_cudnn_option_value(opts[1], Qnil);
|
65
|
+
mean = cumo_cuda_cudnn_option_value(opts[2], Qnil);
|
66
|
+
inv_std = cumo_cuda_cudnn_option_value(opts[3], Qnil);
|
67
|
+
eps = cumo_cuda_cudnn_option_value(opts[4], Qnil);
|
68
|
+
decay = cumo_cuda_cudnn_option_value(opts[5], Qnil);
|
69
|
+
axis = cumo_cuda_cudnn_option_value(opts[6], Qnil);
|
70
|
+
y = cumo_cuda_cudnn_option_value(opts[7], Qnil);
|
71
|
+
|
72
|
+
if (running_mean != Qnil) {
|
73
|
+
running_mean_ptr = cumo_na_get_offset_pointer_for_write(running_mean);
|
74
|
+
}
|
75
|
+
if (running_var != Qnil) {
|
76
|
+
running_var_ptr = cumo_na_get_offset_pointer_for_write(running_var);
|
77
|
+
}
|
78
|
+
if (mean != Qnil) {
|
79
|
+
mean_ptr = cumo_na_get_offset_pointer_for_write(mean);
|
80
|
+
}
|
81
|
+
if (inv_std != Qnil) {
|
82
|
+
inv_std_ptr = cumo_na_get_offset_pointer_for_write(inv_std);
|
83
|
+
}
|
84
|
+
if (eps != Qnil) {
|
85
|
+
double_eps = NUM2DBL(eps);
|
86
|
+
}
|
87
|
+
if (decay != Qnil) {
|
88
|
+
double_decay = NUM2DBL(decay);
|
89
|
+
}
|
90
|
+
if (axis != Qnil) {
|
91
|
+
Check_Type(axis, T_ARRAY);
|
92
|
+
axis_ndim = (size_t)(RARRAY_LEN(axis));
|
93
|
+
for (size_t idim = 0; idim < axis_ndim; ++idim) {
|
94
|
+
int_axis[idim] = NUM2INT(rb_ary_entry(axis, (long)idim));
|
95
|
+
}
|
96
|
+
// TODO: check axis is sorted
|
97
|
+
}
|
98
|
+
|
99
|
+
CumoGetNArray(x, nx);
|
100
|
+
CumoGetNArray(gamma, ngamma);
|
101
|
+
CumoGetNArray(beta, nbeta);
|
102
|
+
x_ndim = nx->ndim;
|
103
|
+
x_shape = nx->shape;
|
104
|
+
gamma_ndim = ngamma->ndim;
|
105
|
+
gamma_shape = ngamma->shape;
|
106
|
+
beta_ndim = nbeta->ndim;
|
107
|
+
beta_shape = nbeta->shape;
|
108
|
+
|
109
|
+
// TODO: Size check of gammma, beta, running_mean, running_var, mean, inv_std
|
110
|
+
// are equivalent with either of reduced_shape(keepdims: false) or reduced_shape(keepdims: true)
|
111
|
+
reduced_ndim = cumo_cuda_cudnn_ReduceShape(reduced_shape, x_ndim, x_shape, axis_ndim, int_axis, 1);
|
112
|
+
// CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, gamma_ndim);
|
113
|
+
// CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, beta_ndim);
|
114
|
+
// for (size_t idim = 0; idim < reduced_ndim; ++idim) {
|
115
|
+
// CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], gamma_shape[idim]);
|
116
|
+
// CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], beta_shape[idim]);
|
117
|
+
// }
|
118
|
+
|
119
|
+
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x, cT);
|
120
|
+
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gamma, cT);
|
121
|
+
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(beta, cT);
|
122
|
+
if (running_mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(running_mean, cT);
|
123
|
+
if (running_var != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(running_var, cT);
|
124
|
+
if (mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(mean, cT);
|
125
|
+
if (inv_std != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(inv_std, cT);
|
126
|
+
|
127
|
+
x_cont = cumo_na_as_contiguous_array(x);
|
128
|
+
gamma_cont = cumo_na_as_contiguous_array(gamma);
|
129
|
+
beta_cont = cumo_na_as_contiguous_array(beta);
|
130
|
+
if (running_mean != Qnil && cumo_na_check_contiguous(running_mean) != Qtrue) {
|
131
|
+
rb_raise(rb_eRuntimeError, "running_mean must be contiguous");
|
132
|
+
}
|
133
|
+
if (running_var != Qnil && cumo_na_check_contiguous(running_var) != Qtrue) {
|
134
|
+
rb_raise(rb_eRuntimeError, "running_var must be contiguous");
|
135
|
+
}
|
136
|
+
if (mean != Qnil && cumo_na_check_contiguous(mean) != Qtrue) {
|
137
|
+
rb_raise(rb_eRuntimeError, "mean must be contiguous");
|
138
|
+
}
|
139
|
+
if (inv_std != Qnil && cumo_na_check_contiguous(inv_std) != Qtrue) {
|
140
|
+
rb_raise(rb_eRuntimeError, "inv_std must be contiguous");
|
141
|
+
}
|
142
|
+
|
143
|
+
x_cont_ptr = cumo_na_get_offset_pointer_for_read(x_cont);
|
144
|
+
gamma_cont_ptr = cumo_na_get_offset_pointer_for_read(gamma_cont);
|
145
|
+
beta_cont_ptr = cumo_na_get_offset_pointer_for_read(beta_cont);
|
146
|
+
|
147
|
+
// TODO: type and shape check
|
148
|
+
if (y == Qnil) y = cumo_na_new(cT, x_ndim, x_shape);
|
149
|
+
y_ptr = cumo_na_get_offset_pointer_for_write(y);
|
150
|
+
|
151
|
+
status = cumo_cuda_cudnn_CreateTensorDescriptor(&x_desc, x_cont, cudnn_dtype);
|
152
|
+
if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
|
153
|
+
|
154
|
+
mode = cumo_cuda_cudnn_GetBatchNormMode(axis_ndim, int_axis);
|
155
|
+
status = cumo_cuda_cudnn_CreateBNTensorDescriptor(&bn_desc, x_desc, mode);
|
156
|
+
if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
|
157
|
+
// TODO: bn_desc may return another type, and may need to cast gamma, beta, mean, var
|
158
|
+
|
159
|
+
handle = cumo_cuda_cudnn_handle();
|
160
|
+
|
161
|
+
status = cudnnBatchNormalizationForwardTraining(
|
162
|
+
handle,
|
163
|
+
mode,
|
164
|
+
(void*)&coef_alpha,
|
165
|
+
(void*)&coef_beta,
|
166
|
+
x_desc,
|
167
|
+
x_cont_ptr,
|
168
|
+
x_desc,
|
169
|
+
y_ptr,
|
170
|
+
bn_desc,
|
171
|
+
gamma_cont_ptr,
|
172
|
+
beta_cont_ptr,
|
173
|
+
1.0 - double_decay,
|
174
|
+
running_mean_ptr,
|
175
|
+
running_var_ptr,
|
176
|
+
double_eps,
|
177
|
+
mean_ptr,
|
178
|
+
inv_std_ptr);
|
179
|
+
if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
|
180
|
+
|
181
|
+
BATCH_NORM_ERROR:
|
182
|
+
if (x_desc) cudnnDestroyTensorDescriptor(x_desc);
|
183
|
+
if (bn_desc) cudnnDestroyTensorDescriptor(bn_desc);
|
184
|
+
cumo_cuda_cudnn_check_status(status);
|
185
|
+
|
186
|
+
return y;
|
187
|
+
}
|
188
|
+
|
189
|
+
#else // CUDNN_FOUND
|
190
|
+
VALUE cumo_cuda_eCUDNNError;
|
191
|
+
|
192
|
+
static VALUE
|
193
|
+
<%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
|
194
|
+
{
|
195
|
+
rb_raise(cumo_cuda_eCUDNNError, "cuDNN is not available");
|
196
|
+
}
|
197
|
+
#endif // CUDNN_FOUND
|