red-chainer 0.2.1 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (52) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -2
  3. data/examples/cifar/models/vgg.rb +84 -0
  4. data/examples/cifar/train_cifar.rb +70 -0
  5. data/examples/iris.rb +103 -0
  6. data/lib/chainer.rb +17 -0
  7. data/lib/chainer/configuration.rb +2 -1
  8. data/lib/chainer/cuda.rb +18 -0
  9. data/lib/chainer/dataset/convert.rb +30 -9
  10. data/lib/chainer/datasets/cifar.rb +56 -0
  11. data/lib/chainer/datasets/mnist.rb +3 -3
  12. data/lib/chainer/datasets/tuple_dataset.rb +3 -1
  13. data/lib/chainer/function.rb +1 -0
  14. data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
  15. data/lib/chainer/functions/activation/log_softmax.rb +4 -4
  16. data/lib/chainer/functions/activation/relu.rb +3 -4
  17. data/lib/chainer/functions/activation/sigmoid.rb +4 -4
  18. data/lib/chainer/functions/activation/tanh.rb +5 -5
  19. data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
  20. data/lib/chainer/functions/connection/linear.rb +1 -1
  21. data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
  22. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
  23. data/lib/chainer/functions/math/identity.rb +26 -0
  24. data/lib/chainer/functions/noise/dropout.rb +45 -0
  25. data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
  26. data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
  27. data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
  28. data/lib/chainer/gradient_check.rb +240 -0
  29. data/lib/chainer/initializer.rb +2 -0
  30. data/lib/chainer/initializers/constant.rb +1 -1
  31. data/lib/chainer/initializers/init.rb +5 -1
  32. data/lib/chainer/initializers/normal.rb +1 -1
  33. data/lib/chainer/iterators/serial_iterator.rb +1 -1
  34. data/lib/chainer/link.rb +11 -0
  35. data/lib/chainer/links/connection/convolution_2d.rb +98 -0
  36. data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
  37. data/lib/chainer/optimizer.rb +40 -1
  38. data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
  39. data/lib/chainer/parameter.rb +1 -1
  40. data/lib/chainer/serializers/marshal.rb +7 -3
  41. data/lib/chainer/testing/array.rb +32 -0
  42. data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
  43. data/lib/chainer/training/extensions/snapshot.rb +1 -1
  44. data/lib/chainer/training/standard_updater.rb +4 -0
  45. data/lib/chainer/training/trainer.rb +1 -1
  46. data/lib/chainer/utils/array.rb +13 -2
  47. data/lib/chainer/utils/conv.rb +59 -0
  48. data/lib/chainer/utils/math.rb +72 -0
  49. data/lib/chainer/utils/variable.rb +7 -3
  50. data/lib/chainer/version.rb +1 -1
  51. data/red-chainer.gemspec +1 -0
  52. metadata +37 -3
@@ -0,0 +1,26 @@
1
+ module Chainer
2
+ module Functions
3
+ module Math
4
+ # Identity function.
5
+ class Identity < Chainer::Function
6
+ def check_type_forward(in_types)
7
+ # pass
8
+ end
9
+
10
+ def forward(xs)
11
+ retain_inputs([])
12
+ return xs
13
+ end
14
+
15
+ def backward(xs, gys)
16
+ return gys
17
+ end
18
+
19
+ # Just returns input variables.
20
+ def self.identity(*inputs)
21
+ self.new.(*inputs)
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,45 @@
1
+ module Chainer
2
+ module Functions
3
+ module Noise
4
+ class Dropout < Chainer::Function
5
+ # Drops elements of input variable randomly.
6
+ #
7
+ # This function drops input elements randomly with probability `ratio` and
8
+ # scales the remaining elements by factor `1 / (1 - ratio)`.
9
+ # In testing mode, it does nothing and just returns `x`.
10
+ #
11
+ # @param [Chainer::Variable] x Input variable.
12
+ # @param [float] ratio Dropout ratio. The ``ratio`` must be `0.0 <= ratio < 1.0`.
13
+ # @return [Chainer::Variable] Output variable.
14
+ def self.dropout(x, ratio: 0.5)
15
+ Chainer.configuration.train ? self.new(ratio).(x) : x
16
+ end
17
+
18
+ def initialize(dropout_ratio)
19
+ if dropout_ratio < 0 || dropout_ratio >= 1.0
20
+ raise 'dropout_ratio must be in the range [0, 1)'
21
+ end
22
+ @dropout_ratio = dropout_ratio
23
+ end
24
+
25
+ def forward(x)
26
+ retain_inputs([])
27
+ unless self.instance_variable_defined?(:@mask)
28
+ scale = x[0].class[*[1.0 / (1 - @dropout_ratio)]][0]
29
+ flag = x[0].class.new(*x[0].shape).rand >= @dropout_ratio
30
+
31
+ @mask = x[0].class.zeros(*x[0].shape)
32
+ @mask[flag] = 1
33
+ @mask *= scale
34
+ end
35
+ [x[0] * @mask]
36
+ end
37
+
38
+ def backward(x, gy)
39
+ [gy[0] * @mask]
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
45
+
@@ -0,0 +1,136 @@
1
+ module Chainer
2
+ module Functions
3
+ module Normalization
4
+ class BatchNormalizationFunction < Chainer::Function
5
+ attr_reader :running_mean, :running_var
6
+ # Batch normalization function with fixed statistics.
7
+ # This is a variant of batch normalization, where the mean and variance
8
+ # statistics are given by the caller as fixed variables. This is
9
+ # used on testing mode of the batch normalization layer, where batch
10
+ # statistics cannot be used for prediction consistency.
11
+ #
12
+ # @param [Chainer::Variable] x Input variable.
13
+ # @param [Chainer::Variable] gamma Scaling parameter of normalized data.
14
+ # @param [Chainer::Variable] beta Shifting parameter of scaled normalized data.
15
+ # @param [Chainer::Variable] mean Shifting parameter of input.
16
+ # @param [Chainer::Variable] var Square of scaling parameter of input.
17
+ # @param [float] eps Epsilon value for numerical stability.
18
+ def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
19
+ old_train = Chainer.configuration.train
20
+ Chainer.configuration.train = false
21
+ norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
22
+ Chainer.configuration.train = old_train
23
+ norm
24
+ end
25
+
26
+ def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
27
+ @running_mean = mean
28
+ @running_var = var
29
+ @eps = eps
30
+ @mean_cache = nil
31
+ @decay = decay
32
+ end
33
+
34
+ def forward(inputs)
35
+ x, gamma, beta = inputs[0], inputs[1], inputs[2]
36
+ if Chainer.configuration.train
37
+ if @running_mean.nil?
38
+ @running_mean = Numo::NArray[*gamma].new_zeros
39
+ @running_var = Numo::NArray[*gamma].new_zeros
40
+ else
41
+ @running_mean = Numo::NArray[*@running_mean]
42
+ @running_var = Numo::NArray[*@running_var]
43
+ end
44
+ elsif inputs.size == 5
45
+ @fixed_mean = inputs[3]
46
+ @fixed_var = inputs[4]
47
+ end
48
+
49
+ head_ndim = gamma.ndim + 1
50
+ gamma_expander = [1] + gamma.shape + [1] * (x.ndim - head_ndim)
51
+ gamma = gamma.reshape(*gamma_expander)
52
+ beta_expander = [1] + beta.shape + [1] * (x.ndim - head_ndim)
53
+ beta = beta.reshape(*beta_expander)
54
+
55
+ if Chainer.configuration.train
56
+ axis = [0] + (head_ndim...(x.ndim)).to_a
57
+ mean = x.mean(axis: axis)
58
+ # FIXME: numpy.var
59
+ var = x.var(axis: axis)
60
+ var += @eps
61
+ else
62
+ mean = @fixed_mean
63
+ var = @fixed_var + @eps
64
+ end
65
+
66
+ @std = Numo::NMath.sqrt(var)
67
+
68
+ mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
69
+ x_mu = x - mean.reshape(*mean_expander)
70
+ std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
71
+ x_mu /= @std.reshape(*std_expander)
72
+ @x_hat = x_mu
73
+ y = gamma * @x_hat
74
+ y += beta
75
+
76
+ if Chainer.configuration.train
77
+ m = x.size.div(gamma.size)
78
+ adjust = m / [m - 1.0, 1.0].max
79
+ @running_mean *= @decay
80
+ temp_ar = Numo::NArray[*mean]
81
+ temp_ar *= (1 - @decay)
82
+ @running_mean += temp_ar
83
+
84
+ @running_var *= @decay
85
+ temp_ar = Numo::NArray[*var]
86
+ temp_ar *= ((1 - @decay) * adjust)
87
+ @running_var += temp_ar
88
+ end
89
+
90
+ [y,]
91
+ end
92
+
93
+ def backward(inputs, grad_outputs)
94
+ x, gamma = inputs[0], inputs[1]
95
+ gy = grad_outputs[0]
96
+ head_ndim = gamma.ndim + 1
97
+ m = gamma.class[x.size.div(gamma.size)][0]
98
+ axis = [0] + (head_ndim...(x.ndim)).to_a
99
+
100
+ if inputs.size == 5
101
+ mean = inputs[3]
102
+ var = inputs[4]
103
+ std = Numo::NMath.sqrt(var)
104
+ gs = gamma / std
105
+ gbeta = gy.sum(axis: axis)
106
+
107
+ mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
108
+ x_mu = x - mean.reshape(*mean_expander)
109
+ std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
110
+ x_mu /= std.reshape(*std_expander)
111
+ x_hat = x_mu
112
+ ggamma = (gy * x_hat).sum(axis: axis)
113
+ gmean = -gs * gbeta
114
+ gvar = -0.5 * gamma / var * ggamma
115
+ gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
116
+ gx = gs.reshape(*gs_expander)
117
+ return [gx, ggamma, gbeta, gmean, gvar]
118
+ end
119
+
120
+ gbeta = gy.sum(axis: axis)
121
+ ggamma = (gy * @x_hat).sum(axis: axis)
122
+ tmp = (gamma / @std)
123
+ tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
124
+ tmp = tmp.reshape(*tmp_expander)
125
+
126
+ ggamma_expander = [1] + ggamma.shape + [1] * (x.ndim - head_ndim)
127
+ gbeta_expander = [1] + gbeta.shape + [1] * (x.ndim - head_ndim)
128
+
129
+ gx = tmp * (gy - (@x_hat * ggamma.reshape(*ggamma_expander) + gbeta.reshape(*gbeta_expander)) / m )
130
+
131
+ [gx, ggamma, gbeta]
132
+ end
133
+ end
134
+ end
135
+ end
136
+ end
@@ -0,0 +1,57 @@
1
+ module Chainer
2
+ module Functions
3
+ module Pooling
4
+ class MaxPooling2D < Pooling2D
5
+ # Spatial max pooling function
6
+ #
7
+ # @param [Chainer::Variable] x Input variable
8
+ # @param [integer || 2D integer array] Size of pooling window
9
+ # @param [integer || 2D integer array] Stride of pooling applications
10
+ # @param [integer || 2D integer array] Spatial padding width for the input array
11
+ # @param [boolean] If `true`, all spatial locations are pooled int some output pixels
12
+ # @return [Chainer::Variable] Output variable
13
+ def self.max_pooling_2d(x, ksize, stride: nil, pad: 0, cover_all: true)
14
+ self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).(x)
15
+ end
16
+
17
+ def forward_cpu(x)
18
+ retain_inputs([])
19
+ @in_shape = x[0].shape
20
+ @in_dtype = x[0].class
21
+
22
+ col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
23
+ n, c, kh, kw, out_h, out_w = col.shape
24
+ col = col.reshape(n , c, kh * kw, out_h, out_w)
25
+
26
+ # TODO: numpy.argmax(axis=2)
27
+ d = col.shape[3..-1].reduce(:*) || 1
28
+ dx = col.shape[2..-1].reduce(:*) || 1
29
+ max_index = col.max_index(2)
30
+ @indexes = max_index.flatten.map_with_index { |val, idx| (val - (dx * (idx / d))) / d }.reshape(*max_index.shape)
31
+
32
+ y = col.max(axis: 2)
33
+ [y]
34
+ end
35
+
36
+ def backward_cpu(x, gy)
37
+ n, c, out_h, out_w = gy[0].shape
38
+ h, w = @in_shape[2..-1]
39
+ kh, kw = @kh, @kw
40
+
41
+ gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
42
+
43
+ indexes = @indexes.flatten
44
+ indexes += Numo::Int64.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
45
+
46
+ gcol[indexes] = gy[0].flatten.dup
47
+ gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
48
+ gcol = gcol.swapaxes(2, 4)
49
+ gcol = gcol.swapaxes(3, 5)
50
+
51
+ gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
52
+ [gx]
53
+ end
54
+ end
55
+ end
56
+ end
57
+ end
@@ -0,0 +1,20 @@
1
+ module Chainer
2
+ module Functions
3
+ module Pooling
4
+ # Base class of pooling function over a set of 2d planes
5
+ class Pooling2D < Chainer::Function
6
+ def initialize(ksize, stride: nil, pad: 0, cover_all: true)
7
+ if stride.nil?
8
+ stride = ksize
9
+ end
10
+
11
+ @kh, @kw = ksize.is_a?(Array) ? ksize : [ksize, ksize]
12
+ @sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
13
+ @ph, @pw = pad.is_a?(Array) ? pad: [pad, pad]
14
+
15
+ @cover_all = cover_all
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,240 @@
1
+ module Chainer
2
+ def _copy_arrays(xs)
3
+ xp = Chainer::get_array_module(*xs)
4
+ xs.map{|x| (x.is_a? Numo::NArray) ? x.dup : x}
5
+ end
6
+
7
+ # Computes numerical gradient by finite differences.
8
+ #
9
+ # This function is used to implement gradient check. For usage example, see
10
+ # unit tests of Chainer::Functions.
11
+ #
12
+ # @param [function] f Ruby function with no arguments that runs forward
13
+ # computation and returns the result.
14
+ # @param [Array<Arrays>] inputs Array of arrays that should be treated as
15
+ # inputs. Each element of them is slightly modified to realize numerical
16
+ # gradient by finite differences.
17
+ # @param [Array<Arrays>] grad_outputs Array of arrays that are treated as
18
+ # output gradients.
19
+ # @param [Float] eps Epsilon value of finite differences.
20
+ # @return [Array] Numerical gradient arrays corresponding to +inputs+.
21
+ #
22
+ def numerical_grad(f, inputs, grad_outputs, eps=0.001)
23
+ raise unless eps > 0
24
+ inputs = inputs.to_a
25
+ grad_outputs = grad_outputs.to_a
26
+ xp = Numo::NArray
27
+ grads = inputs.map{|x| x.new_zeros()}
28
+
29
+ if inputs[0].ndim < 2
30
+ tmp = [[inputs[0], grads[0]]]
31
+ else
32
+ tmp = (0...inputs[0].shape[0]).map{|i|[inputs[0][i, false], grads[0][i, false]]}
33
+ end
34
+
35
+ tmp.each do |x, gx|
36
+ x.each_with_index{|xx, *i|
37
+ orig = x[*i] # hold original value
38
+ x[*i] = orig + eps
39
+ ys1 = _copy_arrays(f.call(x))
40
+ x[*i] = orig - eps
41
+ ys2 = _copy_arrays(f.call(x))
42
+ x[*i] = orig
43
+
44
+ ys1.zip(ys2, grad_outputs).each do |y1, y2, gy|
45
+ if !gy.nil?
46
+ if ((y1 - y2) * gy).is_a? Numo::NArray
47
+ dot = ((y1 - y2) * gy).sum()
48
+ else
49
+ dot = ((y1 - y2) * gy).inject(:+)
50
+ end
51
+ gx[*i] += dot / (2*eps).to_f
52
+ end
53
+ end
54
+ }
55
+ end
56
+
57
+ return grads
58
+ end
59
+
60
+ def _as_tuple(x)
61
+ if x.is_a? Array
62
+ return x
63
+ else
64
+ return [x]
65
+ end
66
+ end
67
+
68
+ # Test backward procedure of a given function.
69
+ #
70
+ # This function automatically check backward-process of given function.
71
+ # For example, when you have a +Chainer::Function+ class +MyFunc+,
72
+ # that gets two arguments and returns one value, you can make its test like this:
73
+ #
74
+ # def test_my_func(self):
75
+ # func = MyFunc()
76
+ # x1_data = Numo::NArray[...]
77
+ # x2_data = Numo::NArray[...]
78
+ # gy_data = Numo::NArray[...]
79
+ # check_backward(func, [x1_data, x2_data], gy_data)
80
+ #
81
+ # This method creates +Chainer::Variable+ objects with +x_data+
82
+ # and calls +func+ with the +Chainer::Variable+ s to get its result
83
+ # as +Chainer::Variable+.
84
+ # Then, it sets +y_grad+ array to +grad+ attribute of the result and
85
+ # calls +backward+ method to get gradients of the inputs.
86
+ # To check correctness of the gradients, the function calls
87
+ # +numerical_grad+ to calculate numerically the gradients and compares
88
+ # the types of gradients with +Chainer::Testing.assert_allclose+.
89
+ # If input objects (+x1_data+ or/and +x2_data+ in this example) represent
90
+ # integer variables, their gradients are ignored.
91
+ #
92
+ # You can simplify a test when +MyFunc+ gets only one argument:
93
+ #
94
+ # check_backward(func, x1_data, gy_data)
95
+ #
96
+ # If +MyFunc+ is a loss function which returns a zero-dimensional
97
+ # array, pass +nil+ to +gy_data+. In this case, it sets +1+ to
98
+ # +grad+ attribute of the result:
99
+ #
100
+ # check_backward(my_loss_func, [x1_data, x2_data], nil)
101
+ #
102
+ # If +MyFunc+ returns multiple outputs, pass all gradients for outputs as a Array:
103
+ #
104
+ # gy1_data = Numo::NArray[...]
105
+ # gy2_data = Numo::NArray[...]
106
+ # check_backward(func, x1_data, [gy1_data, gy2_data])
107
+ #
108
+ # You can also test a +Chainer::Link+.
109
+ # To check gradients of parameters of the link, set a Array of the parameters
110
+ # to +params+ arguments:
111
+ #
112
+ # check_backward(my_link, [x1_data, x2_data], gy_data, [my_link.W, my_link.b])
113
+ #
114
+ # Note that +params+ are not +Numo::NArray+ s,
115
+ # but +Chainer::Variables+ s.
116
+ #
117
+ # Function objects are acceptable as +func+ argument:
118
+ #
119
+ # check_backward(lambda{|x1, x1| f(x1, x2)}, [x1_data, x2_data], gy_data)
120
+ #
121
+ # @note
122
+ # +func+ is called many times to get numerical gradients for all inputs.
123
+ # This function doesn't work correctly when +func+ behaves randomly as
124
+ # it gets different gradients.
125
+ # @param [Method, Proc] func A function which gets +Chainer::Variable+ s
126
+ # and returns +Chainer::Variable+ s. +func+ must returns
127
+ # a Array of +Chainer::Variable+ s or one
128
+ # +Chainer::Variable+. You can use +Chainer::Function+
129
+ # object, +Chainer::Link+ object or a function satisfying the
130
+ # condition.
131
+ # @param [Numo::NArray or Array<Numo::NArray>] x_data A set of +Numo::NArray+ s to be
132
+ # passed to +func+. If +x_data+ is one +Numo::NArray+ object, it is
133
+ # treated as +(x_data,)+.
134
+ # @param [Numo::NArray or Array<Numo::NArray> or nil] y_grad A set of +Numo::NArray+ s representing gradients of return-values of
135
+ # +func+. If +y_grad+ is one +Numo::NArray+ object, it is
136
+ # treated as +(y_grad,)+. If +func+ is a loss-function,
137
+ # +y_grad+ should be set to +nil+.
138
+ # @param [Chainer::Variable or Array<Chainder::Variable>] params A set of +Chainer::Variable+ s whose gradients are checked.
139
+ # When +func+ is a +Chainer::Link+ object,
140
+ # set its parameters as +params+.
141
+ # If +params+ is one +Chainer::Variable+ object,
142
+ # it is treated as +(params,)+.
143
+ # @param [Float] eps Epsilon value to be passed to +numerical_grad+.
144
+ # @param [Float] atol Absolute tolerance to be passed to +Chainer::Testing.assert_allclose+.
145
+ # @param [Float] rtol Relative tolerance to be passed to +Chainer::Testing.assert_allclose+.
146
+ # @param [Array<Boolean>] no_grads Flag to skip variable for gradient assertion.
147
+ # It should be same length as +x_data+.
148
+ # @param [Numo::NArray.class] dtype +x_data+ and +y_grad+ are casted to this
149
+ # dtype when calculating numerical gradients. Only float types and
150
+ # +nil+ are allowed.
151
+ # @see
152
+ # .numerical_grad
153
+ #
154
+ def check_backward(func, x_data, y_grad, params=[], eps: 0.001, atol: 1e-5, rtol: 1e-4, no_grads: nil, dtype: nil)
155
+ x_data = _as_tuple(x_data)
156
+ if !y_grad.nil?
157
+ y_grad = _as_tuple(y_grad)
158
+ end
159
+
160
+ params = _as_tuple(params)
161
+ xs = x_data.map{|x| Chainer::Variable.new(x)}
162
+ y = func.(*xs)
163
+ y = _as_tuple(y)
164
+ y = Chainer::Functions::Math::Identity.identity(*y)
165
+ y = _as_tuple(y)
166
+
167
+ if !y_grad.nil?
168
+ if (y).size != (y_grad).size
169
+ raise TypeError, "`y_grad` must have the same length of output values"
170
+ end
171
+
172
+ y.zip(y_grad).each do |iy, igy|
173
+ iy.grad = igy
174
+ end
175
+ else
176
+ if (y).size != 1
177
+ raise TypeError, "When `y_grad` is `nil`, the function must return azero-dimentional array"
178
+ end
179
+ y_grad = [1]
180
+ end
181
+ # We only need to call `backward` for one result `Chainer::Variable`.
182
+ # `Chainer::Variable.backward` method calls `Chainer::Function.backward` of its creator.
183
+ y[0].backward()
184
+
185
+ if dtype.nil?
186
+ casted_xs = x_data.map{|x| Chainer::Variable.new(x)}
187
+ else
188
+ if (dtype != Numo::DFloat) and (dtype != Numo::SFloat)
189
+ raise TypeError, "`dtype` is allowed only float type"
190
+ end
191
+ if (params).size > 0
192
+ raise TypeError, "`dtype` is available only if `params` is empty"
193
+ end
194
+ casted_xs = x_data.map{|x|
195
+ if x.class == Numo::DFloat or x.class == Numo::SFloat
196
+ Chainer::Variable.new(dtype.cast(x))
197
+ else
198
+ Chainer::Variable.new(x)
199
+ end
200
+ }
201
+ end
202
+
203
+ f = lambda do |_|
204
+ ys = func.(*casted_xs)
205
+ ys = _as_tuple(ys)
206
+ return ys.map{|y| y.data}.to_a
207
+ end
208
+
209
+ if no_grads.nil?
210
+ no_grads = xs.map{|x| (x.dtype != Numo::DFloat) and (x.dtype != Numo::SFloat)}
211
+ else
212
+ if no_grads.size != xs.size
213
+ raise TypeError, "Length of no_grads param and xs should be same."
214
+ end
215
+ end
216
+
217
+ no_grads.zip(xs, casted_xs).each do |skip, x, cx|
218
+ if skip
219
+ raise unless x.grad.nil?
220
+ next
221
+ end
222
+ gx, = numerical_grad(f, [cx.data], y_grad, eps)
223
+ Chainer::Testing.assert_allclose(x.grad, gx, atol: atol, rtol: rtol)
224
+ if dtype.nil?
225
+ raise unless gx.class == x.grad.class
226
+ else
227
+ if ((gx.class != Numo::DFloat) and (gx.class != Numo::SFloat)) and (gx.class != dtype)
228
+ raise
229
+ end
230
+ end
231
+ end
232
+
233
+ params.each do |p|
234
+ gp, = numerical_grad(f, [p.data], y_grad, eps)
235
+ Chainer::Testing.assert_allclose(p.grad, gp, atol: atol, rtol: rtol)
236
+ raise unless gp.dtype === p.grad.dtype
237
+ end
238
+ end
239
+ module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward
240
+ end