cumo 0.3.3 → 0.3.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: febb3beb76c3e4994bee998b7d9911765a8f0279a7c0db550e45601a1c31c9da
4
- data.tar.gz: 60c87eb1810387de34667847b354eaa7e3d8d09ca31e09bd06bf4bbba1190578
3
+ metadata.gz: 143344d9aa774213541e1e44bbb1f2d65348a2e9a410df67435c83a890db6c14
4
+ data.tar.gz: c59b67b28e70ed1421ddfceb6a7928c373634278e3f57a23ad631375e5e3111d
5
5
  SHA512:
6
- metadata.gz: c7fe6ade2ebf7af02cae19cd403161f369cb65aaf13918597230f35f1ec7148f05b485c6d3db9063a02a60a9a197156de08b74a35abf51c65e4dd1e7e069243b
7
- data.tar.gz: d2e291b33102290ed4f22fd83c7ab3b69f9aae3f58e789577b41c9492d1fea761adbcd3efeaa46f7d05534e50d0a8ba033aaab65934a4627f8afa56c8efea305
6
+ metadata.gz: 20dc89053b605cbc3f272fffb28a185885ea00c9d4be8646397e9d51703411118810961e00546b3f83f8984ec18518e796328305437eb1f22c9d4880fd1b4b16
7
+ data.tar.gz: 06acfc5740b18994aa4ea1276157098729c57768ef8b7eb1ff65f998959dee7d0f160697127392a02bdcfa9090f4cefa91b53729ec0cdcf0b17c7212a9143da4
@@ -1,3 +1,9 @@
1
+ # 0.3.4 (2019-05-04)
2
+
3
+ Enhancements:
4
+
5
+ * Support cuDNN fixed\_batch\_norm (cudnnBatchNormalizationForwardInference)
6
+
1
7
  # 0.3.3 (2019-05-02)
2
8
 
3
9
  Fixes:
@@ -26,7 +32,7 @@ Enhancements:
26
32
  * conv (cudnnConvolution)
27
33
  * conv\_transpose (cudnnConvolutionBackwardData)
28
34
  * conv\_grad\_w (cudnnConvolutionBackwardFilter)
29
- * batch\_norm (cudnnBatchNormalization)
35
+ * batch\_norm (cudnnBatchNormalizationForwardTraining)
30
36
  * batch\_norm\_backward (cudnnBatchNormalizationBackward)
31
37
  * avg\_pool and max\_pool (cudnnPoolingForward)
32
38
  * avg\_pool\_backward and max\_pool\_backward (cudnnPoolingBackward)
@@ -10,8 +10,8 @@ extern "C" {
10
10
  #endif
11
11
  #endif
12
12
 
13
- #define CUMO_VERSION "0.3.3"
14
- #define CUMO_VERSION_CODE 33
13
+ #define CUMO_VERSION "0.3.4"
14
+ #define CUMO_VERSION_CODE 34
15
15
 
16
16
  bool cumo_compatible_mode_enabled_p();
17
17
  bool cumo_show_warning_enabled_p();
@@ -60,6 +60,7 @@ if is_float && !is_complex && !is_object
60
60
  def_id "conv_grad_w"
61
61
  def_id "batch_norm"
62
62
  def_id "batch_norm_backward"
63
+ def_id "fixed_batch_norm"
63
64
  def_id "pooling_forward"
64
65
  def_id "pooling_backward"
65
66
  end
@@ -348,6 +349,7 @@ if is_float && !is_complex && !is_object
348
349
  def_method "conv_grad_w" # conv_backward_filter
349
350
  def_method "batch_norm"
350
351
  def_method "batch_norm_backward"
352
+ def_method "fixed_batch_norm"
351
353
  def_method "pooling_forward" # max_pool, avg_pool
352
354
  def_method "pooling_backward"
353
355
  end
@@ -0,0 +1,149 @@
1
+ #ifdef CUDNN_FOUND
2
+
3
+ <%
4
+ cudnn_dtype =
5
+ case type_name
6
+ when 'sfloat'
7
+ 'CUDNN_DATA_FLOAT'
8
+ when 'dfloat'
9
+ 'CUDNN_DATA_DOUBLE'
10
+ else
11
+ # CUDNN_DATA_HALF
12
+ raise 'not supported'
13
+ end
14
+ %>
15
+
16
+ // y = x.fixed_batch_norm(gamma, beta, mean, var, eps:, axis:)
17
+ static VALUE
18
+ <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
19
+ {
20
+ cudnnDataType_t cudnn_dtype = <%= cudnn_dtype %>;
21
+ cudnnStatus_t status = 0;
22
+ cudnnHandle_t handle = 0;
23
+ dtype coef_one = 1;
24
+ dtype coef_zero = 0;
25
+
26
+ VALUE x=self, gamma, beta, mean, var, eps, axis, y;
27
+ VALUE kw_hash = Qnil;
28
+ ID kw_table[] = {
29
+ rb_intern("eps"),
30
+ rb_intern("axis"),
31
+ rb_intern("y")
32
+ };
33
+ VALUE opts[] = {Qundef, Qundef, Qundef};
34
+
35
+ cumo_narray_t *nx; // , *ngamma, *nbeta;
36
+ size_t *x_shape; // *gamma_shape, *beta_shape, reduced_shape[CUMO_NA_MAX_DIMENSION];
37
+ size_t x_ndim;
38
+
39
+ VALUE x_cont, gamma_cont, beta_cont, mean_cont, var_cont;
40
+ cudnnTensorDescriptor_t x_desc = 0;
41
+ cudnnTensorDescriptor_t bn_desc = 0;
42
+ char *x_cont_ptr, *gamma_cont_ptr, *beta_cont_ptr, *mean_cont_ptr, *var_cont_ptr, *y_ptr;
43
+
44
+ cudnnBatchNormMode_t mode;
45
+
46
+ // default values
47
+ double double_eps = 2e-5;
48
+ int int_axis[CUMO_NA_MAX_DIMENSION] = {0};
49
+ size_t axis_ndim = 1;
50
+
51
+ rb_scan_args(argc, argv, "4:", &gamma, &beta, &mean, &var, &kw_hash);
52
+ rb_get_kwargs(kw_hash, kw_table, 0, 3, opts);
53
+ eps = cumo_cuda_cudnn_option_value(opts[0], Qnil);
54
+ axis = cumo_cuda_cudnn_option_value(opts[1], Qnil);
55
+ y = cumo_cuda_cudnn_option_value(opts[2], Qnil);
56
+
57
+ if (eps != Qnil) {
58
+ double_eps = NUM2DBL(eps);
59
+ }
60
+ if (axis != Qnil) {
61
+ axis_ndim = cumo_cuda_cudnn_get_int_axis(int_axis, axis);
62
+ }
63
+
64
+ CumoGetNArray(x, nx);
65
+ // CumoGetNArray(gamma, ngamma);
66
+ // CumoGetNArray(beta, nbeta);
67
+ x_ndim = nx->ndim;
68
+ x_shape = nx->shape;
69
+ // gamma_ndim = ngamma->ndim;
70
+ // gamma_shape = ngamma->shape;
71
+ // beta_ndim = nbeta->ndim;
72
+ // beta_shape = nbeta->shape;
73
+
74
+ // TODO: Size check of gammma, beta, running_mean, running_var, mean, inv_std
75
+ // are equivalent with either of reduced_shape(keepdims: false) or reduced_shape(keepdims: true)
76
+ // reduced_ndim = cumo_cuda_cudnn_ReduceShape(reduced_shape, x_ndim, x_shape, axis_ndim, int_axis, 1);
77
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, gamma_ndim);
78
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, beta_ndim);
79
+ // for (size_t idim = 0; idim < reduced_ndim; ++idim) {
80
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], gamma_shape[idim]);
81
+ // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], beta_shape[idim]);
82
+ // }
83
+
84
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x, cT);
85
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gamma, cT);
86
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(beta, cT);
87
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(mean, cT);
88
+ CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(var, cT);
89
+
90
+ x_cont = cumo_na_as_contiguous_array(x);
91
+ gamma_cont = cumo_na_as_contiguous_array(gamma);
92
+ beta_cont = cumo_na_as_contiguous_array(beta);
93
+ mean_cont = cumo_na_as_contiguous_array(mean);
94
+ var_cont = cumo_na_as_contiguous_array(var);
95
+
96
+ x_cont_ptr = cumo_na_get_offset_pointer_for_read(x_cont);
97
+ gamma_cont_ptr = cumo_na_get_offset_pointer_for_read(gamma_cont);
98
+ beta_cont_ptr = cumo_na_get_offset_pointer_for_read(beta_cont);
99
+ mean_cont_ptr = cumo_na_get_offset_pointer_for_read(mean_cont);
100
+ var_cont_ptr = cumo_na_get_offset_pointer_for_read(var_cont);
101
+
102
+ // TODO: type and shape check
103
+ if (y == Qnil) y = cumo_na_new(cT, x_ndim, x_shape);
104
+ y_ptr = cumo_na_get_offset_pointer_for_write(y);
105
+
106
+ status = cumo_cuda_cudnn_CreateTensorDescriptor(&x_desc, x_cont, cudnn_dtype);
107
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
108
+
109
+ mode = cumo_cuda_cudnn_GetBatchNormMode(axis_ndim, int_axis);
110
+ status = cumo_cuda_cudnn_CreateBNTensorDescriptor(&bn_desc, x_desc, mode);
111
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
112
+ // TODO: bn_desc may return another type, and may need to cast gamma, beta, mean, var
113
+
114
+ handle = cumo_cuda_cudnn_handle();
115
+
116
+ status = cudnnBatchNormalizationForwardInference(
117
+ handle,
118
+ mode,
119
+ (void*)&coef_one,
120
+ (void*)&coef_zero,
121
+ x_desc,
122
+ x_cont_ptr,
123
+ x_desc,
124
+ y_ptr,
125
+ bn_desc,
126
+ gamma_cont_ptr,
127
+ beta_cont_ptr,
128
+ mean_cont_ptr,
129
+ var_cont_ptr,
130
+ double_eps);
131
+ if (status != CUDNN_STATUS_SUCCESS) goto BATCH_NORM_ERROR;
132
+
133
+ BATCH_NORM_ERROR:
134
+ if (x_desc) cudnnDestroyTensorDescriptor(x_desc);
135
+ if (bn_desc) cudnnDestroyTensorDescriptor(bn_desc);
136
+ cumo_cuda_cudnn_check_status(status);
137
+
138
+ return y;
139
+ }
140
+
141
+ #else // CUDNN_FOUND
142
+ VALUE cumo_cuda_eCUDNNError;
143
+
144
+ static VALUE
145
+ <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
146
+ {
147
+ rb_raise(cumo_cuda_eCUDNNError, "cuDNN is not available");
148
+ }
149
+ #endif // CUDNN_FOUND
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: cumo
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.3
4
+ version: 0.3.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Naotoshi Seo
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-05-02 00:00:00.000000000 Z
11
+ date: 2019-05-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -240,6 +240,7 @@ files:
240
240
  - ext/cumo/narray/gen/tmpl/eye_kernel.cu
241
241
  - ext/cumo/narray/gen/tmpl/fill.c
242
242
  - ext/cumo/narray/gen/tmpl/fill_kernel.cu
243
+ - ext/cumo/narray/gen/tmpl/fixed_batch_norm.c
243
244
  - ext/cumo/narray/gen/tmpl/float_accum_kernel.cu
244
245
  - ext/cumo/narray/gen/tmpl/format.c
245
246
  - ext/cumo/narray/gen/tmpl/format_to_a.c