cumo 0.2.5 → 0.3.0.pre1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -1
- data/README.md +12 -1
- data/cumo.gemspec +1 -1
- data/ext/cumo/cuda/cudnn.c +80 -0
- data/ext/cumo/cuda/cudnn_impl.cpp +572 -0
- data/ext/cumo/cuda/runtime.c +1 -0
- data/ext/cumo/cumo.c +5 -0
- data/ext/cumo/extconf.rb +8 -2
- data/ext/cumo/include/cumo.h +2 -2
- data/ext/cumo/include/cumo/cuda/cudnn.h +205 -0
- data/ext/cumo/include/cumo/hash_combine.hpp +17 -0
- data/ext/cumo/include/cumo/intern.h +5 -0
- data/ext/cumo/include/cumo/types/dfloat.h +1 -0
- data/ext/cumo/include/cumo/types/sfloat.h +1 -0
- data/ext/cumo/narray/gen/spec.rb +21 -0
- data/ext/cumo/narray/gen/tmpl/batch_norm.c +197 -0
- data/ext/cumo/narray/gen/tmpl/batch_norm_backward.c +191 -0
- data/ext/cumo/narray/gen/tmpl/conv.c +216 -0
- data/ext/cumo/narray/gen/tmpl/conv_grad_w.c +183 -0
- data/ext/cumo/narray/gen/tmpl/conv_transpose.c +244 -0
- data/ext/cumo/narray/gen/tmpl/gemm.c +14 -0
- data/ext/cumo/narray/gen/tmpl/pooling_backward.c +136 -0
- data/ext/cumo/narray/gen/tmpl/pooling_forward.c +136 -0
- data/ext/cumo/narray/narray.c +29 -0
- data/lib/cumo/cuda.rb +1 -0
- data/lib/cumo/cuda/cudnn.rb +88 -0
- metadata +18 -5
data/ext/cumo/cuda/runtime.c
CHANGED
@@ -142,6 +142,7 @@ Init_cumo_cuda_runtime()
|
|
142
142
|
{
|
143
143
|
VALUE mCumo = rb_define_module("Cumo");
|
144
144
|
VALUE mCUDA = rb_define_module_under(mCumo, "CUDA");
|
145
|
+
rb_define_const(mCumo, "Cuda", mCUDA); // alias
|
145
146
|
mRuntime = rb_define_module_under(mCUDA, "Runtime");
|
146
147
|
eRuntimeError = rb_define_class_under(mCUDA, "RuntimeError", rb_eStandardError);
|
147
148
|
|
data/ext/cumo/cumo.c
CHANGED
@@ -33,6 +33,8 @@ void Init_cumo_cuda_driver();
|
|
33
33
|
void Init_cumo_cuda_memory_pool();
|
34
34
|
void Init_cumo_cuda_runtime();
|
35
35
|
void Init_cumo_cuda_nvrtc();
|
36
|
+
void Init_cumo_cuda_cublas();
|
37
|
+
void Init_cumo_cuda_cudnn();
|
36
38
|
|
37
39
|
void
|
38
40
|
cumo_debug_breakpoint(void)
|
@@ -167,4 +169,7 @@ Init_cumo()
|
|
167
169
|
Init_cumo_cuda_memory_pool();
|
168
170
|
Init_cumo_cuda_runtime();
|
169
171
|
Init_cumo_cuda_nvrtc();
|
172
|
+
|
173
|
+
Init_cumo_cuda_cublas();
|
174
|
+
Init_cumo_cuda_cudnn();
|
170
175
|
}
|
data/ext/cumo/extconf.rb
CHANGED
@@ -47,9 +47,9 @@ rm_f 'include/cumo/extconf.h'
|
|
47
47
|
MakeMakefileCuda.install!(cxx: true)
|
48
48
|
|
49
49
|
if ENV['DEBUG']
|
50
|
-
$CFLAGS
|
50
|
+
$CFLAGS << " -g -O0 -Wall"
|
51
51
|
end
|
52
|
-
$CXXFLAGS
|
52
|
+
$CXXFLAGS << " -std=c++14"
|
53
53
|
#$CFLAGS=" $(cflags) -O3 -m64 -msse2 -funroll-loops"
|
54
54
|
#$CFLAGS=" $(cflags) -O3"
|
55
55
|
$INCFLAGS = "-Iinclude -Inarray -Icuda #{$INCFLAGS}"
|
@@ -109,6 +109,8 @@ cuda/memory_pool
|
|
109
109
|
cuda/memory_pool_impl
|
110
110
|
cuda/runtime
|
111
111
|
cuda/nvrtc
|
112
|
+
cuda/cudnn
|
113
|
+
cuda/cudnn_impl
|
112
114
|
)
|
113
115
|
|
114
116
|
if RUBY_VERSION[0..3] == "2.1."
|
@@ -179,5 +181,9 @@ have_library('nvrtc')
|
|
179
181
|
have_library('cublas')
|
180
182
|
# have_library('cusolver')
|
181
183
|
# have_library('curand')
|
184
|
+
if have_library('cudnn') # TODO(sonots): cuDNN version check
|
185
|
+
$CFLAGS << " -DCUDNN_FOUND"
|
186
|
+
$CXXFLAGS << " -DCUDNN_FOUND"
|
187
|
+
end
|
182
188
|
|
183
189
|
create_makefile('cumo')
|
data/ext/cumo/include/cumo.h
CHANGED
@@ -0,0 +1,205 @@
|
|
1
|
+
#ifndef CUMO_CUDA_CUDNN_H
|
2
|
+
#define CUMO_CUDA_CUDNN_H
|
3
|
+
|
4
|
+
#include <ruby.h>
|
5
|
+
#ifdef CUDNN_FOUND
|
6
|
+
#include <cudnn.h>
|
7
|
+
#endif // CUDNN_FOUND
|
8
|
+
|
9
|
+
#if defined(__cplusplus)
|
10
|
+
extern "C" {
|
11
|
+
#if 0
|
12
|
+
} /* satisfy cc-mode */
|
13
|
+
#endif
|
14
|
+
#endif
|
15
|
+
|
16
|
+
#ifdef CUDNN_FOUND
|
17
|
+
|
18
|
+
VALUE cumo_na_eShapeError;
|
19
|
+
|
20
|
+
#define CUMO_CUDA_CUDNN_DEFAULT_MAX_WORKSPACE_SIZE 8 * 1024 * 1024
|
21
|
+
|
22
|
+
// TODO: Move to proper generic place
|
23
|
+
#define CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x,t) \
|
24
|
+
if (rb_obj_class(x)!=(t)) { \
|
25
|
+
rb_raise(rb_eTypeError,"invalid NArray type (class)"); \
|
26
|
+
}
|
27
|
+
|
28
|
+
// TODO: Move to proper generic place
|
29
|
+
#define CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(sz1,sz2) \
|
30
|
+
if ((sz1) != (sz2)) { \
|
31
|
+
rb_raise(cumo_na_eShapeError, \
|
32
|
+
"size mismatch: %d != %d", \
|
33
|
+
(int)(sz1), (int)(sz2)); \
|
34
|
+
}
|
35
|
+
|
36
|
+
// TODO: Move to proper generic place
|
37
|
+
#define CUMO_CUDA_CUDNN_CHECK_DIM_EQ(nd1,nd2) \
|
38
|
+
if ((nd1) != (nd2)) { \
|
39
|
+
rb_raise(cumo_na_eShapeError, \
|
40
|
+
"dimention mismatch: %d != %d", \
|
41
|
+
(int)(nd1), (int)(nd2)); \
|
42
|
+
}
|
43
|
+
|
44
|
+
void
|
45
|
+
cumo_cuda_cudnn_check_status(cudnnStatus_t status);
|
46
|
+
|
47
|
+
cudnnHandle_t
|
48
|
+
cumo_cuda_cudnn_handle();
|
49
|
+
|
50
|
+
// TODO: Move to more generic proper place
|
51
|
+
static inline VALUE
|
52
|
+
cumo_cuda_cudnn_option_value(VALUE value, VALUE default_value)
|
53
|
+
{
|
54
|
+
switch(TYPE(value)) {
|
55
|
+
case T_NIL:
|
56
|
+
case T_UNDEF:
|
57
|
+
return default_value;
|
58
|
+
}
|
59
|
+
return value;
|
60
|
+
}
|
61
|
+
|
62
|
+
// VALUE is Ruby Array
|
63
|
+
static inline void
|
64
|
+
cumo_cuda_cudnn_get_int_ary(int* int_ary, VALUE ary, size_t ndim, int default_value)
|
65
|
+
{
|
66
|
+
if (ary == Qnil) {
|
67
|
+
// default to 1
|
68
|
+
for (size_t idim = 0; idim < ndim; ++idim) {
|
69
|
+
int_ary[idim] = default_value;
|
70
|
+
}
|
71
|
+
} else if (TYPE(ary) == T_FIXNUM) {
|
72
|
+
for (size_t idim = 0; idim < ndim; ++idim) {
|
73
|
+
int_ary[idim] = NUM2INT(ary);
|
74
|
+
}
|
75
|
+
} else {
|
76
|
+
Check_Type(ary, T_ARRAY);
|
77
|
+
CUMO_CUDA_CUDNN_CHECK_DIM_EQ((size_t)(RARRAY_LEN(ary)), ndim);
|
78
|
+
for (size_t idim = 0; idim < ndim; ++idim) {
|
79
|
+
int_ary[idim] = NUM2INT(rb_ary_entry(ary, idim));
|
80
|
+
}
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
size_t
|
85
|
+
cumo_cuda_cudnn_GetConvOutDim(
|
86
|
+
size_t in_dim,
|
87
|
+
size_t kernel_size,
|
88
|
+
size_t stride,
|
89
|
+
size_t pad);
|
90
|
+
|
91
|
+
size_t
|
92
|
+
cumo_cuda_cudnn_GetConvTransposeOutDim(
|
93
|
+
size_t in_dim,
|
94
|
+
size_t kernel_size,
|
95
|
+
size_t stride,
|
96
|
+
size_t pad);
|
97
|
+
|
98
|
+
cudnnStatus_t
|
99
|
+
cumo_cuda_cudnn_CreateTensorDescriptor(
|
100
|
+
cudnnTensorDescriptor_t *desc,
|
101
|
+
VALUE a,
|
102
|
+
cudnnDataType_t cudnn_dtype);
|
103
|
+
|
104
|
+
cudnnStatus_t
|
105
|
+
cumo_cuda_cudnn_CreateFilterDescriptor(
|
106
|
+
cudnnFilterDescriptor_t *desc,
|
107
|
+
VALUE a,
|
108
|
+
cudnnDataType_t cudnn_dtype);
|
109
|
+
|
110
|
+
cudnnStatus_t
|
111
|
+
cumo_cuda_cudnn_CreateConvolutionDescriptor(
|
112
|
+
cudnnConvolutionDescriptor_t *desc,
|
113
|
+
size_t ndim,
|
114
|
+
int* int_stride,
|
115
|
+
int* int_pad,
|
116
|
+
cudnnDataType_t cudnn_dtype);
|
117
|
+
|
118
|
+
cudnnStatus_t
|
119
|
+
cumo_cuda_cudnn_CreatePoolingDescriptor(
|
120
|
+
cudnnPoolingDescriptor_t *desc,
|
121
|
+
cudnnPoolingMode_t mode,
|
122
|
+
size_t ndim,
|
123
|
+
int* int_kernel_size,
|
124
|
+
int* int_stride,
|
125
|
+
int* int_pad);
|
126
|
+
|
127
|
+
cudnnStatus_t
|
128
|
+
cumo_cuda_cudnn_FindConvolutionForwardAlgorithm(
|
129
|
+
cudnnConvolutionFwdAlgoPerf_t *perf_result,
|
130
|
+
cudnnHandle_t handle,
|
131
|
+
cudnnTensorDescriptor_t x_desc,
|
132
|
+
VALUE x,
|
133
|
+
cudnnFilterDescriptor_t w_desc,
|
134
|
+
VALUE w,
|
135
|
+
cudnnConvolutionDescriptor_t conv_desc,
|
136
|
+
cudnnTensorDescriptor_t y_sec,
|
137
|
+
VALUE y,
|
138
|
+
size_t max_workspace_size,
|
139
|
+
int* int_stride,
|
140
|
+
int* int_pad,
|
141
|
+
size_t ndim,
|
142
|
+
cudnnDataType_t cudnn_dtype);
|
143
|
+
|
144
|
+
cudnnStatus_t
|
145
|
+
cumo_cuda_cudnn_FindConvolutionBackwardDataAlgorithm(
|
146
|
+
cudnnConvolutionBwdDataAlgoPerf_t *perf_result,
|
147
|
+
cudnnHandle_t handle,
|
148
|
+
cudnnFilterDescriptor_t w_desc,
|
149
|
+
VALUE w,
|
150
|
+
cudnnTensorDescriptor_t x_desc,
|
151
|
+
VALUE x,
|
152
|
+
cudnnConvolutionDescriptor_t conv_desc,
|
153
|
+
cudnnTensorDescriptor_t y_desc,
|
154
|
+
VALUE y,
|
155
|
+
size_t max_workspace_size,
|
156
|
+
int* int_stride,
|
157
|
+
int* int_pad,
|
158
|
+
size_t ndim,
|
159
|
+
cudnnDataType_t cudnn_dtype);
|
160
|
+
|
161
|
+
cudnnStatus_t
|
162
|
+
cumo_cuda_cudnn_FindConvolutionBackwardFilterAlgorithm(
|
163
|
+
cudnnConvolutionBwdFilterAlgoPerf_t *perf_result,
|
164
|
+
cudnnHandle_t handle,
|
165
|
+
cudnnTensorDescriptor_t x_desc,
|
166
|
+
VALUE x,
|
167
|
+
cudnnTensorDescriptor_t dy_desc,
|
168
|
+
VALUE dy,
|
169
|
+
cudnnConvolutionDescriptor_t conv_desc,
|
170
|
+
cudnnFilterDescriptor_t dw_desc,
|
171
|
+
VALUE dw,
|
172
|
+
size_t max_workspace_size,
|
173
|
+
int* int_stride,
|
174
|
+
int* int_pad,
|
175
|
+
size_t ndim,
|
176
|
+
cudnnDataType_t cudnn_dtype);
|
177
|
+
|
178
|
+
cudnnBatchNormMode_t
|
179
|
+
cumo_cuda_cudnn_GetBatchNormMode(size_t ndim, int* int_axis);
|
180
|
+
|
181
|
+
cudnnStatus_t
|
182
|
+
cumo_cuda_cudnn_CreateBNTensorDescriptor(
|
183
|
+
cudnnTensorDescriptor_t *desc,
|
184
|
+
cudnnTensorDescriptor_t x_desc,
|
185
|
+
cudnnBatchNormMode_t mode);
|
186
|
+
|
187
|
+
size_t
|
188
|
+
cumo_cuda_cudnn_ReduceShape(
|
189
|
+
size_t *reduced_shape,
|
190
|
+
size_t shape_ndim,
|
191
|
+
size_t *shape,
|
192
|
+
size_t axes_ndim,
|
193
|
+
int *axes,
|
194
|
+
char keepdims);
|
195
|
+
|
196
|
+
#endif // CUDNN_FOUND
|
197
|
+
|
198
|
+
#if defined(__cplusplus)
|
199
|
+
#if 0
|
200
|
+
{ /* satisfy cc-mode */
|
201
|
+
#endif
|
202
|
+
} /* extern "C" { */
|
203
|
+
#endif
|
204
|
+
|
205
|
+
#endif /* ifndef CUMO_CUDA_CUDNN_H */
|
@@ -0,0 +1,17 @@
|
|
1
|
+
#ifndef CUMO_HASH_COMBINE_H
|
2
|
+
#define CUMO_HASH_COMBINE_H
|
3
|
+
|
4
|
+
#include <cstddef>
|
5
|
+
|
6
|
+
namespace cumo {
|
7
|
+
namespace internal {
|
8
|
+
|
9
|
+
// Borrowed from boost::hash_combine
|
10
|
+
//
|
11
|
+
// TODO(sonots): hash combine in 64bit
|
12
|
+
inline void HashCombine(std::size_t& seed, std::size_t hash_value) { seed ^= hash_value + 0x9e3779b9 + (seed << 6) + (seed >> 2); }
|
13
|
+
|
14
|
+
} // namespace internal
|
15
|
+
} // namespace cumo
|
16
|
+
|
17
|
+
#endif /* ifndef CUMO_HASH_COMBINE_H */
|
@@ -26,11 +26,16 @@ char *cumo_na_get_pointer_for_write(VALUE);
|
|
26
26
|
char *cumo_na_get_pointer_for_read(VALUE);
|
27
27
|
char *cumo_na_get_pointer_for_read_write(VALUE);
|
28
28
|
size_t cumo_na_get_offset(VALUE self);
|
29
|
+
char* cumo_na_get_offset_pointer(VALUE);
|
30
|
+
char* cumo_na_get_offset_pointer_for_write(VALUE);
|
31
|
+
char* cumo_na_get_offset_pointer_for_read(VALUE);
|
32
|
+
char* cumo_na_get_offset_pointer_for_read_write(VALUE);
|
29
33
|
|
30
34
|
void cumo_na_copy_flags(VALUE src, VALUE dst);
|
31
35
|
|
32
36
|
VALUE cumo_na_check_ladder(VALUE self, int start_dim);
|
33
37
|
VALUE cumo_na_check_contiguous(VALUE self);
|
38
|
+
VALUE cumo_na_as_contiguous_array(VALUE a);
|
34
39
|
|
35
40
|
VALUE cumo_na_flatten_dim(VALUE self, int sd);
|
36
41
|
|
data/ext/cumo/narray/gen/spec.rb
CHANGED
@@ -53,6 +53,16 @@ end
|
|
53
53
|
if (is_float || is_complex) && !is_object
|
54
54
|
def_id "gemm"
|
55
55
|
end
|
56
|
+
# cudnn
|
57
|
+
if is_float && !is_complex && !is_object
|
58
|
+
def_id "conv"
|
59
|
+
def_id "conv_transpose"
|
60
|
+
def_id "conv_grad_w"
|
61
|
+
def_id "batch_norm"
|
62
|
+
def_id "batch_norm_backward"
|
63
|
+
def_id "pooling_forward"
|
64
|
+
def_id "pooling_backward"
|
65
|
+
end
|
56
66
|
|
57
67
|
if is_int && !is_object
|
58
68
|
def_id "minlength" # for bincount
|
@@ -331,6 +341,17 @@ if (is_float || is_complex) && !is_object
|
|
331
341
|
def_method "gemm"
|
332
342
|
end
|
333
343
|
|
344
|
+
# cudnn
|
345
|
+
if is_float && !is_complex && !is_object
|
346
|
+
def_method "conv"
|
347
|
+
def_method "conv_transpose" # conv_backward_data
|
348
|
+
def_method "conv_grad_w" # conv_backward_filter
|
349
|
+
def_method "batch_norm"
|
350
|
+
def_method "batch_norm_backward"
|
351
|
+
def_method "pooling_forward" # max_pool, avg_pool
|
352
|
+
def_method "pooling_backward"
|
353
|
+
end
|
354
|
+
|
334
355
|
# rmsdev
|
335
356
|
# prod
|
336
357
|
|
@@ -0,0 +1,197 @@
|
|
1
|
+
#ifdef CUDNN_FOUND
|
2
|
+
|
3
|
+
<%
|
4
|
+
cudnn_dtype =
|
5
|
+
case type_name
|
6
|
+
when 'sfloat'
|
7
|
+
'CUDNN_DATA_FLOAT'
|
8
|
+
when 'dfloat'
|
9
|
+
'CUDNN_DATA_DOUBLE'
|
10
|
+
else
|
11
|
+
# CUDNN_DATA_HALF
|
12
|
+
raise 'not supported'
|
13
|
+
end
|
14
|
+
%>
|
15
|
+
|
16
|
+
// y = x.batch_norm(gamma, beta, running_mean:, running_var:, eps:, decay:, axis:, mean:, inv_std:)
|
17
|
+
static VALUE
|
18
|
+
<%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
|
19
|
+
{
|
20
|
+
cudnnDataType_t cudnn_dtype = <%= cudnn_dtype %>;
|
21
|
+
cudnnStatus_t status = 0;
|
22
|
+
cudnnHandle_t handle = 0;
|
23
|
+
dtype coef_alpha = 1;
|
24
|
+
dtype coef_beta = 0;
|
25
|
+
|
26
|
+
VALUE x=self, gamma, beta, running_mean, running_var, eps, decay, axis, mean, inv_std, y;
|
27
|
+
VALUE kw_hash = Qnil;
|
28
|
+
ID kw_table[] = {
|
29
|
+
rb_intern("running_mean"),
|
30
|
+
rb_intern("running_var"),
|
31
|
+
rb_intern("mean"),
|
32
|
+
rb_intern("inv_std"),
|
33
|
+
rb_intern("eps"),
|
34
|
+
rb_intern("decay"),
|
35
|
+
rb_intern("axis"),
|
36
|
+
rb_intern("y")
|
37
|
+
};
|
38
|
+
VALUE opts[] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
39
|
+
|
40
|
+
cumo_narray_t *nx, *ngamma, *nbeta;
|
41
|
+
size_t *x_shape, *gamma_shape, *beta_shape, reduced_shape[CUMO_NA_MAX_DIMENSION];
|
42
|
+
size_t x_ndim, gamma_ndim, beta_ndim, reduced_ndim;
|
43
|
+
|
44
|
+
VALUE x_cont, gamma_cont, beta_cont;
|
45
|
+
cudnnTensorDescriptor_t x_desc = 0;
|
46
|
+
cudnnTensorDescriptor_t bn_desc = 0;
|
47
|
+
char *x_cont_ptr, *gamma_cont_ptr, *beta_cont_ptr, *y_ptr;
|
48
|
+
|
49
|
+
cudnnBatchNormMode_t mode;
|
50
|
+
|
51
|
+
// default values
|
52
|
+
char *running_mean_ptr=NULL;
|
53
|
+
char *running_var_ptr=NULL;
|
54
|
+
char *mean_ptr=NULL;
|
55
|
+
char *inv_std_ptr=NULL;
|
56
|
+
double double_eps = 2e-5;
|
57
|
+
double double_decay = 0.9;
|
58
|
+
int int_axis[CUMO_NA_MAX_DIMENSION] = {0};
|
59
|
+
size_t axis_ndim = 1;
|
60
|
+
|
61
|
+
rb_scan_args(argc, argv, "2:", &gamma, &beta, &kw_hash);
|
62
|
+
rb_get_kwargs(kw_hash, kw_table, 0, 8, opts);
|
63
|
+
running_mean = cumo_cuda_cudnn_option_value(opts[0], Qnil);
|
64
|
+
running_var = cumo_cuda_cudnn_option_value(opts[1], Qnil);
|
65
|
+
mean = cumo_cuda_cudnn_option_value(opts[2], Qnil);
|
66
|
+
inv_std = cumo_cuda_cudnn_option_value(opts[3], Qnil);
|
67
|
+
eps = cumo_cuda_cudnn_option_value(opts[4], Qnil);
|
68
|
+
decay = cumo_cuda_cudnn_option_value(opts[5], Qnil);
|
69
|
+
axis = cumo_cuda_cudnn_option_value(opts[6], Qnil);
|
70
|
+
y = cumo_cuda_cudnn_option_value(opts[7], Qnil);
|
71
|
+
|
72
|
+
if (running_mean != Qnil) {
|
73
|
+
running_mean_ptr = cumo_na_get_offset_pointer_for_write(running_mean);
|
74
|
+
}
|
75
|
+
if (running_var != Qnil) {
|
76
|
+
running_var_ptr = cumo_na_get_offset_pointer_for_write(running_var);
|
77
|
+
}
|
78
|
+
if (mean != Qnil) {
|
79
|
+
mean_ptr = cumo_na_get_offset_pointer_for_write(mean);
|
80
|
+
}
|
81
|
+
if (inv_std != Qnil) {
|
82
|
+
inv_std_ptr = cumo_na_get_offset_pointer_for_write(inv_std);
|
83
|
+
}
|
84
|
+
if (eps != Qnil) {
|
85
|
+
double_eps = NUM2DBL(eps);
|
86
|
+
}
|
87
|
+
if (decay != Qnil) {
|
88
|
+
double_decay = NUM2DBL(decay);
|
89
|
+
}
|
90
|
+
if (axis != Qnil) {
|
91
|
+
Check_Type(axis, T_ARRAY);
|
92
|
+
axis_ndim = (size_t)(RARRAY_LEN(axis));
|
93
|
+
for (size_t idim = 0; idim < axis_ndim; ++idim) {
|
94
|
+
int_axis[idim] = NUM2INT(rb_ary_entry(axis, (long)idim));
|
95
|
+
}
|
96
|
+
// TODO: check axis is sorted
|
97
|
+
}
|
98
|
+
|
99
|
+
CumoGetNArray(x, nx);
|
100
|
+
CumoGetNArray(gamma, ngamma);
|
101
|
+
CumoGetNArray(beta, nbeta);
|
102
|
+
x_ndim = nx->ndim;
|
103
|
+
x_shape = nx->shape;
|
104
|
+
gamma_ndim = ngamma->ndim;
|
105
|
+
gamma_shape = ngamma->shape;
|
106
|
+
beta_ndim = nbeta->ndim;
|
107
|
+
beta_shape = nbeta->shape;
|
108
|
+
|
109
|
+
// TODO: Size check of gammma, beta, running_mean, running_var, mean, inv_std
|
110
|
+
// are equivalent with either of reduced_shape(keepdims: false) or reduced_shape(keepdims: true)
|
111
|
+
reduced_ndim = cumo_cuda_cudnn_ReduceShape(reduced_shape, x_ndim, x_shape, axis_ndim, int_axis, 1);
|
112
|
+
// CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, gamma_ndim);
|
113
|
+
// CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, beta_ndim);
|
114
|
+
// for (size_t idim = 0; idim < reduced_ndim; ++idim) {
|
115
|
+
// CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], gamma_shape[idim]);
|
116
|
+
// CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], beta_shape[idim]);
|
117
|
+
// }
|
118
|
+
|
119
|
+
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x, cT);
|
120
|
+
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gamma, cT);
|
121
|
+
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(beta, cT);
|
122
|
+
if (running_mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(running_mean, cT);
|
123
|
+
if (running_var != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(running_var, cT);
|
124
|
+
if (mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(mean, cT);
|
125
|
+
if (inv_std != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(inv_std, cT);
|
126
|
+
|
127
|
+
x_cont = cumo_na_as_contiguous_array(x);
|
128
|
+
gamma_cont = cumo_na_as_contiguous_array(gamma);
|
129
|
+
beta_cont = cumo_na_as_contiguous_array(beta);
|
130
|
+
if (running_mean != Qnil && cumo_na_check_contiguous(running_mean) != Qtrue) {
|
131
|
+
rb_raise(rb_eRuntimeError, "running_mean must be contiguous");
|
132
|
+
}
|
133
|
+
if (running_var != Qnil && cumo_na_check_contiguous(running_var) != Qtrue) {
|
134
|
+
rb_raise(rb_eRuntimeError, "running_var must be contiguous");
|
135
|
+
}
|
136
|
+
if (mean != Qnil && cumo_na_check_contiguous(mean) != Qtrue) {
|
137
|
+
rb_raise(rb_eRuntimeError, "mean must be contiguous");
|
138
|
+
}
|
139
|
+
if (inv_std != Qnil && cumo_na_check_contiguous(inv_std) != Qtrue) {
|
140
|
+
rb_raise(rb_eRuntimeError, "inv_std must be contiguous");
|
141
|
+
}
|
142
|
+
|
143
|
+
x_cont_ptr = cumo_na_get_offset_pointer_for_read(x_cont);
|
144
|
+
gamma_cont_ptr = cumo_na_get_offset_pointer_for_read(gamma_cont);
|
145
|
+
beta_cont_ptr = cumo_na_get_offset_pointer_for_read(beta_cont);
|
146
|
+
|
147
|
+
// TODO: type and shape check
|
148
|
+
if (y == Qnil) y = cumo_na_new(cT, x_ndim, x_shape);
|
149
|
+
y_ptr = cumo_na_get_offset_pointer_for_write(y);
|
150
|
+
|
151
|
+
status = cumo_cuda_cudnn_CreateTensorDescriptor(&x_desc, x_cont, cudnn_dtype);
|
152
|
+
if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
|
153
|
+
|
154
|
+
mode = cumo_cuda_cudnn_GetBatchNormMode(axis_ndim, int_axis);
|
155
|
+
status = cumo_cuda_cudnn_CreateBNTensorDescriptor(&bn_desc, x_desc, mode);
|
156
|
+
if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
|
157
|
+
// TODO: bn_desc may return another type, and may need to cast gamma, beta, mean, var
|
158
|
+
|
159
|
+
handle = cumo_cuda_cudnn_handle();
|
160
|
+
|
161
|
+
status = cudnnBatchNormalizationForwardTraining(
|
162
|
+
handle,
|
163
|
+
mode,
|
164
|
+
(void*)&coef_alpha,
|
165
|
+
(void*)&coef_beta,
|
166
|
+
x_desc,
|
167
|
+
x_cont_ptr,
|
168
|
+
x_desc,
|
169
|
+
y_ptr,
|
170
|
+
bn_desc,
|
171
|
+
gamma_cont_ptr,
|
172
|
+
beta_cont_ptr,
|
173
|
+
1.0 - double_decay,
|
174
|
+
running_mean_ptr,
|
175
|
+
running_var_ptr,
|
176
|
+
double_eps,
|
177
|
+
mean_ptr,
|
178
|
+
inv_std_ptr);
|
179
|
+
if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
|
180
|
+
|
181
|
+
BATCH_NORM_ERROR:
|
182
|
+
if (x_desc) cudnnDestroyTensorDescriptor(x_desc);
|
183
|
+
if (bn_desc) cudnnDestroyTensorDescriptor(bn_desc);
|
184
|
+
cumo_cuda_cudnn_check_status(status);
|
185
|
+
|
186
|
+
return y;
|
187
|
+
}
|
188
|
+
|
189
|
+
#else // CUDNN_FOUND
|
190
|
+
VALUE cumo_cuda_eCUDNNError;
|
191
|
+
|
192
|
+
static VALUE
|
193
|
+
<%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
|
194
|
+
{
|
195
|
+
rb_raise(cumo_cuda_eCUDNNError, "cuDNN is not available");
|
196
|
+
}
|
197
|
+
#endif // CUDNN_FOUND
|