red-chainer 0.2.1 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -2
  3. data/examples/cifar/models/vgg.rb +84 -0
  4. data/examples/cifar/train_cifar.rb +70 -0
  5. data/examples/iris.rb +103 -0
  6. data/lib/chainer.rb +17 -0
  7. data/lib/chainer/configuration.rb +2 -1
  8. data/lib/chainer/cuda.rb +18 -0
  9. data/lib/chainer/dataset/convert.rb +30 -9
  10. data/lib/chainer/datasets/cifar.rb +56 -0
  11. data/lib/chainer/datasets/mnist.rb +3 -3
  12. data/lib/chainer/datasets/tuple_dataset.rb +3 -1
  13. data/lib/chainer/function.rb +1 -0
  14. data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
  15. data/lib/chainer/functions/activation/log_softmax.rb +4 -4
  16. data/lib/chainer/functions/activation/relu.rb +3 -4
  17. data/lib/chainer/functions/activation/sigmoid.rb +4 -4
  18. data/lib/chainer/functions/activation/tanh.rb +5 -5
  19. data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
  20. data/lib/chainer/functions/connection/linear.rb +1 -1
  21. data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
  22. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
  23. data/lib/chainer/functions/math/identity.rb +26 -0
  24. data/lib/chainer/functions/noise/dropout.rb +45 -0
  25. data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
  26. data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
  27. data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
  28. data/lib/chainer/gradient_check.rb +240 -0
  29. data/lib/chainer/initializer.rb +2 -0
  30. data/lib/chainer/initializers/constant.rb +1 -1
  31. data/lib/chainer/initializers/init.rb +5 -1
  32. data/lib/chainer/initializers/normal.rb +1 -1
  33. data/lib/chainer/iterators/serial_iterator.rb +1 -1
  34. data/lib/chainer/link.rb +11 -0
  35. data/lib/chainer/links/connection/convolution_2d.rb +98 -0
  36. data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
  37. data/lib/chainer/optimizer.rb +40 -1
  38. data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
  39. data/lib/chainer/parameter.rb +1 -1
  40. data/lib/chainer/serializers/marshal.rb +7 -3
  41. data/lib/chainer/testing/array.rb +32 -0
  42. data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
  43. data/lib/chainer/training/extensions/snapshot.rb +1 -1
  44. data/lib/chainer/training/standard_updater.rb +4 -0
  45. data/lib/chainer/training/trainer.rb +1 -1
  46. data/lib/chainer/utils/array.rb +13 -2
  47. data/lib/chainer/utils/conv.rb +59 -0
  48. data/lib/chainer/utils/math.rb +72 -0
  49. data/lib/chainer/utils/variable.rb +7 -3
  50. data/lib/chainer/version.rb +1 -1
  51. data/red-chainer.gemspec +1 -0
  52. metadata +37 -3
@@ -0,0 +1,26 @@
1
+ module Chainer
2
+ module Functions
3
+ module Math
4
+ # Identity function.
5
+ class Identity < Chainer::Function
6
+ def check_type_forward(in_types)
7
+ # pass
8
+ end
9
+
10
+ def forward(xs)
11
+ retain_inputs([])
12
+ return xs
13
+ end
14
+
15
+ def backward(xs, gys)
16
+ return gys
17
+ end
18
+
19
+ # Just returns input variables.
20
+ def self.identity(*inputs)
21
+ self.new.(*inputs)
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,45 @@
1
+ module Chainer
2
+ module Functions
3
+ module Noise
4
+ class Dropout < Chainer::Function
5
+ # Drops elements of input variable randomly.
6
+ #
7
+ # This function drops input elements randomly with probability `ratio` and
8
+ # scales the remaining elements by factor `1 / (1 - ratio)`.
9
+ # In testing mode, it does nothing and just returns `x`.
10
+ #
11
+ # @param [Chainer::Variable] x Input variable.
12
+ # @param [float] ratio Dropout ratio. The ``ratio`` must be `0.0 <= ratio < 1.0`.
13
+ # @return [Chainer::Variable] Output variable.
14
+ def self.dropout(x, ratio: 0.5)
15
+ Chainer.configuration.train ? self.new(ratio).(x) : x
16
+ end
17
+
18
+ def initialize(dropout_ratio)
19
+ if dropout_ratio < 0 || dropout_ratio >= 1.0
20
+ raise 'dropout_ratio must be in the range [0, 1)'
21
+ end
22
+ @dropout_ratio = dropout_ratio
23
+ end
24
+
25
+ def forward(x)
26
+ retain_inputs([])
27
+ unless self.instance_variable_defined?(:@mask)
28
+ scale = x[0].class[*[1.0 / (1 - @dropout_ratio)]][0]
29
+ flag = x[0].class.new(*x[0].shape).rand >= @dropout_ratio
30
+
31
+ @mask = x[0].class.zeros(*x[0].shape)
32
+ @mask[flag] = 1
33
+ @mask *= scale
34
+ end
35
+ [x[0] * @mask]
36
+ end
37
+
38
+ def backward(x, gy)
39
+ [gy[0] * @mask]
40
+ end
41
+ end
42
+ end
43
+ end
44
+ end
45
+
@@ -0,0 +1,136 @@
1
+ module Chainer
2
+ module Functions
3
+ module Normalization
4
+ class BatchNormalizationFunction < Chainer::Function
5
+ attr_reader :running_mean, :running_var
6
+ # Batch normalization function with fixed statistics.
7
+ # This is a variant of batch normalization, where the mean and variance
8
+ # statistics are given by the caller as fixed variables. This is
9
+ # used on testing mode of the batch normalization layer, where batch
10
+ # statistics cannot be used for prediction consistency.
11
+ #
12
+ # @param [Chainer::Variable] x Input variable.
13
+ # @param [Chainer::Variable] gamma Scaling parameter of normalized data.
14
+ # @param [Chainer::Variable] beta Shifting parameter of scaled normalized data.
15
+ # @param [Chainer::Variable] mean Shifting parameter of input.
16
+ # @param [Chainer::Variable] var Square of scaling parameter of input.
17
+ # @param [float] eps Epsilon value for numerical stability.
18
+ def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
19
+ old_train = Chainer.configuration.train
20
+ Chainer.configuration.train = false
21
+ norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
22
+ Chainer.configuration.train = old_train
23
+ norm
24
+ end
25
+
26
+ def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
27
+ @running_mean = mean
28
+ @running_var = var
29
+ @eps = eps
30
+ @mean_cache = nil
31
+ @decay = decay
32
+ end
33
+
34
+ def forward(inputs)
35
+ x, gamma, beta = inputs[0], inputs[1], inputs[2]
36
+ if Chainer.configuration.train
37
+ if @running_mean.nil?
38
+ @running_mean = Numo::NArray[*gamma].new_zeros
39
+ @running_var = Numo::NArray[*gamma].new_zeros
40
+ else
41
+ @running_mean = Numo::NArray[*@running_mean]
42
+ @running_var = Numo::NArray[*@running_var]
43
+ end
44
+ elsif inputs.size == 5
45
+ @fixed_mean = inputs[3]
46
+ @fixed_var = inputs[4]
47
+ end
48
+
49
+ head_ndim = gamma.ndim + 1
50
+ gamma_expander = [1] + gamma.shape + [1] * (x.ndim - head_ndim)
51
+ gamma = gamma.reshape(*gamma_expander)
52
+ beta_expander = [1] + beta.shape + [1] * (x.ndim - head_ndim)
53
+ beta = beta.reshape(*beta_expander)
54
+
55
+ if Chainer.configuration.train
56
+ axis = [0] + (head_ndim...(x.ndim)).to_a
57
+ mean = x.mean(axis: axis)
58
+ # FIXME: numpy.var
59
+ var = x.var(axis: axis)
60
+ var += @eps
61
+ else
62
+ mean = @fixed_mean
63
+ var = @fixed_var + @eps
64
+ end
65
+
66
+ @std = Numo::NMath.sqrt(var)
67
+
68
+ mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
69
+ x_mu = x - mean.reshape(*mean_expander)
70
+ std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
71
+ x_mu /= @std.reshape(*std_expander)
72
+ @x_hat = x_mu
73
+ y = gamma * @x_hat
74
+ y += beta
75
+
76
+ if Chainer.configuration.train
77
+ m = x.size.div(gamma.size)
78
+ adjust = m / [m - 1.0, 1.0].max
79
+ @running_mean *= @decay
80
+ temp_ar = Numo::NArray[*mean]
81
+ temp_ar *= (1 - @decay)
82
+ @running_mean += temp_ar
83
+
84
+ @running_var *= @decay
85
+ temp_ar = Numo::NArray[*var]
86
+ temp_ar *= ((1 - @decay) * adjust)
87
+ @running_var += temp_ar
88
+ end
89
+
90
+ [y,]
91
+ end
92
+
93
+ def backward(inputs, grad_outputs)
94
+ x, gamma = inputs[0], inputs[1]
95
+ gy = grad_outputs[0]
96
+ head_ndim = gamma.ndim + 1
97
+ m = gamma.class[x.size.div(gamma.size)][0]
98
+ axis = [0] + (head_ndim...(x.ndim)).to_a
99
+
100
+ if inputs.size == 5
101
+ mean = inputs[3]
102
+ var = inputs[4]
103
+ std = Numo::NMath.sqrt(var)
104
+ gs = gamma / std
105
+ gbeta = gy.sum(axis: axis)
106
+
107
+ mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
108
+ x_mu = x - mean.reshape(*mean_expander)
109
+ std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
110
+ x_mu /= std.reshape(*std_expander)
111
+ x_hat = x_mu
112
+ ggamma = (gy * x_hat).sum(axis: axis)
113
+ gmean = -gs * gbeta
114
+ gvar = -0.5 * gamma / var * ggamma
115
+ gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
116
+ gx = gs.reshape(*gs_expander)
117
+ return [gx, ggamma, gbeta, gmean, gvar]
118
+ end
119
+
120
+ gbeta = gy.sum(axis: axis)
121
+ ggamma = (gy * @x_hat).sum(axis: axis)
122
+ tmp = (gamma / @std)
123
+ tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
124
+ tmp = tmp.reshape(*tmp_expander)
125
+
126
+ ggamma_expander = [1] + ggamma.shape + [1] * (x.ndim - head_ndim)
127
+ gbeta_expander = [1] + gbeta.shape + [1] * (x.ndim - head_ndim)
128
+
129
+ gx = tmp * (gy - (@x_hat * ggamma.reshape(*ggamma_expander) + gbeta.reshape(*gbeta_expander)) / m )
130
+
131
+ [gx, ggamma, gbeta]
132
+ end
133
+ end
134
+ end
135
+ end
136
+ end
@@ -0,0 +1,57 @@
1
+ module Chainer
2
+ module Functions
3
+ module Pooling
4
+ class MaxPooling2D < Pooling2D
5
+ # Spatial max pooling function
6
+ #
7
+ # @param [Chainer::Variable] x Input variable
8
+ # @param [integer || 2D integer array] Size of pooling window
9
+ # @param [integer || 2D integer array] Stride of pooling applications
10
+ # @param [integer || 2D integer array] Spatial padding width for the input array
11
+ # @param [boolean] If `true`, all spatial locations are pooled int some output pixels
12
+ # @return [Chainer::Variable] Output variable
13
+ def self.max_pooling_2d(x, ksize, stride: nil, pad: 0, cover_all: true)
14
+ self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).(x)
15
+ end
16
+
17
+ def forward_cpu(x)
18
+ retain_inputs([])
19
+ @in_shape = x[0].shape
20
+ @in_dtype = x[0].class
21
+
22
+ col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
23
+ n, c, kh, kw, out_h, out_w = col.shape
24
+ col = col.reshape(n , c, kh * kw, out_h, out_w)
25
+
26
+ # TODO: numpy.argmax(axis=2)
27
+ d = col.shape[3..-1].reduce(:*) || 1
28
+ dx = col.shape[2..-1].reduce(:*) || 1
29
+ max_index = col.max_index(2)
30
+ @indexes = max_index.flatten.map_with_index { |val, idx| (val - (dx * (idx / d))) / d }.reshape(*max_index.shape)
31
+
32
+ y = col.max(axis: 2)
33
+ [y]
34
+ end
35
+
36
+ def backward_cpu(x, gy)
37
+ n, c, out_h, out_w = gy[0].shape
38
+ h, w = @in_shape[2..-1]
39
+ kh, kw = @kh, @kw
40
+
41
+ gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
42
+
43
+ indexes = @indexes.flatten
44
+ indexes += Numo::Int64.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
45
+
46
+ gcol[indexes] = gy[0].flatten.dup
47
+ gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
48
+ gcol = gcol.swapaxes(2, 4)
49
+ gcol = gcol.swapaxes(3, 5)
50
+
51
+ gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
52
+ [gx]
53
+ end
54
+ end
55
+ end
56
+ end
57
+ end
@@ -0,0 +1,20 @@
1
+ module Chainer
2
+ module Functions
3
+ module Pooling
4
+ # Base class of pooling function over a set of 2d planes
5
+ class Pooling2D < Chainer::Function
6
+ def initialize(ksize, stride: nil, pad: 0, cover_all: true)
7
+ if stride.nil?
8
+ stride = ksize
9
+ end
10
+
11
+ @kh, @kw = ksize.is_a?(Array) ? ksize : [ksize, ksize]
12
+ @sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
13
+ @ph, @pw = pad.is_a?(Array) ? pad: [pad, pad]
14
+
15
+ @cover_all = cover_all
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,240 @@
1
+ module Chainer
2
+ def _copy_arrays(xs)
3
+ xp = Chainer::get_array_module(*xs)
4
+ xs.map{|x| (x.is_a? Numo::NArray) ? x.dup : x}
5
+ end
6
+
7
+ # Computes numerical gradient by finite differences.
8
+ #
9
+ # This function is used to implement gradient check. For usage example, see
10
+ # unit tests of Chainer::Functions.
11
+ #
12
+ # @param [function] f Ruby function with no arguments that runs forward
13
+ # computation and returns the result.
14
+ # @param [Array<Arrays>] inputs Array of arrays that should be treated as
15
+ # inputs. Each element of them is slightly modified to realize numerical
16
+ # gradient by finite differences.
17
+ # @param [Array<Arrays>] grad_outputs Array of arrays that are treated as
18
+ # output gradients.
19
+ # @param [Float] eps Epsilon value of finite differences.
20
+ # @return [Array] Numerical gradient arrays corresponding to +inputs+.
21
+ #
22
+ def numerical_grad(f, inputs, grad_outputs, eps=0.001)
23
+ raise unless eps > 0
24
+ inputs = inputs.to_a
25
+ grad_outputs = grad_outputs.to_a
26
+ xp = Numo::NArray
27
+ grads = inputs.map{|x| x.new_zeros()}
28
+
29
+ if inputs[0].ndim < 2
30
+ tmp = [[inputs[0], grads[0]]]
31
+ else
32
+ tmp = (0...inputs[0].shape[0]).map{|i|[inputs[0][i, false], grads[0][i, false]]}
33
+ end
34
+
35
+ tmp.each do |x, gx|
36
+ x.each_with_index{|xx, *i|
37
+ orig = x[*i] # hold original value
38
+ x[*i] = orig + eps
39
+ ys1 = _copy_arrays(f.call(x))
40
+ x[*i] = orig - eps
41
+ ys2 = _copy_arrays(f.call(x))
42
+ x[*i] = orig
43
+
44
+ ys1.zip(ys2, grad_outputs).each do |y1, y2, gy|
45
+ if !gy.nil?
46
+ if ((y1 - y2) * gy).is_a? Numo::NArray
47
+ dot = ((y1 - y2) * gy).sum()
48
+ else
49
+ dot = ((y1 - y2) * gy).inject(:+)
50
+ end
51
+ gx[*i] += dot / (2*eps).to_f
52
+ end
53
+ end
54
+ }
55
+ end
56
+
57
+ return grads
58
+ end
59
+
60
+ def _as_tuple(x)
61
+ if x.is_a? Array
62
+ return x
63
+ else
64
+ return [x]
65
+ end
66
+ end
67
+
68
+ # Test backward procedure of a given function.
69
+ #
70
+ # This function automatically check backward-process of given function.
71
+ # For example, when you have a +Chainer::Function+ class +MyFunc+,
72
+ # that gets two arguments and returns one value, you can make its test like this:
73
+ #
74
+ # def test_my_func(self):
75
+ # func = MyFunc()
76
+ # x1_data = Numo::NArray[...]
77
+ # x2_data = Numo::NArray[...]
78
+ # gy_data = Numo::NArray[...]
79
+ # check_backward(func, [x1_data, x2_data], gy_data)
80
+ #
81
+ # This method creates +Chainer::Variable+ objects with +x_data+
82
+ # and calls +func+ with the +Chainer::Variable+ s to get its result
83
+ # as +Chainer::Variable+.
84
+ # Then, it sets +y_grad+ array to +grad+ attribute of the result and
85
+ # calls +backward+ method to get gradients of the inputs.
86
+ # To check correctness of the gradients, the function calls
87
+ # +numerical_grad+ to calculate numerically the gradients and compares
88
+ # the types of gradients with +Chainer::Testing.assert_allclose+.
89
+ # If input objects (+x1_data+ or/and +x2_data+ in this example) represent
90
+ # integer variables, their gradients are ignored.
91
+ #
92
+ # You can simplify a test when +MyFunc+ gets only one argument:
93
+ #
94
+ # check_backward(func, x1_data, gy_data)
95
+ #
96
+ # If +MyFunc+ is a loss function which returns a zero-dimensional
97
+ # array, pass +nil+ to +gy_data+. In this case, it sets +1+ to
98
+ # +grad+ attribute of the result:
99
+ #
100
+ # check_backward(my_loss_func, [x1_data, x2_data], nil)
101
+ #
102
+ # If +MyFunc+ returns multiple outputs, pass all gradients for outputs as a Array:
103
+ #
104
+ # gy1_data = Numo::NArray[...]
105
+ # gy2_data = Numo::NArray[...]
106
+ # check_backward(func, x1_data, [gy1_data, gy2_data])
107
+ #
108
+ # You can also test a +Chainer::Link+.
109
+ # To check gradients of parameters of the link, set a Array of the parameters
110
+ # to +params+ arguments:
111
+ #
112
+ # check_backward(my_link, [x1_data, x2_data], gy_data, [my_link.W, my_link.b])
113
+ #
114
+ # Note that +params+ are not +Numo::NArray+ s,
115
+ # but +Chainer::Variables+ s.
116
+ #
117
+ # Function objects are acceptable as +func+ argument:
118
+ #
119
+ # check_backward(lambda{|x1, x1| f(x1, x2)}, [x1_data, x2_data], gy_data)
120
+ #
121
+ # @note
122
+ # +func+ is called many times to get numerical gradients for all inputs.
123
+ # This function doesn't work correctly when +func+ behaves randomly as
124
+ # it gets different gradients.
125
+ # @param [Method, Proc] func A function which gets +Chainer::Variable+ s
126
+ # and returns +Chainer::Variable+ s. +func+ must returns
127
+ # a Array of +Chainer::Variable+ s or one
128
+ # +Chainer::Variable+. You can use +Chainer::Function+
129
+ # object, +Chainer::Link+ object or a function satisfying the
130
+ # condition.
131
+ # @param [Numo::NArray or Array<Numo::NArray>] x_data A set of +Numo::NArray+ s to be
132
+ # passed to +func+. If +x_data+ is one +Numo::NArray+ object, it is
133
+ # treated as +(x_data,)+.
134
+ # @param [Numo::NArray or Array<Numo::NArray> or nil] y_grad A set of +Numo::NArray+ s representing gradients of return-values of
135
+ # +func+. If +y_grad+ is one +Numo::NArray+ object, it is
136
+ # treated as +(y_grad,)+. If +func+ is a loss-function,
137
+ # +y_grad+ should be set to +nil+.
138
+ # @param [Chainer::Variable or Array<Chainder::Variable>] params A set of +Chainer::Variable+ s whose gradients are checked.
139
+ # When +func+ is a +Chainer::Link+ object,
140
+ # set its parameters as +params+.
141
+ # If +params+ is one +Chainer::Variable+ object,
142
+ # it is treated as +(params,)+.
143
+ # @param [Float] eps Epsilon value to be passed to +numerical_grad+.
144
+ # @param [Float] atol Absolute tolerance to be passed to +Chainer::Testing.assert_allclose+.
145
+ # @param [Float] rtol Relative tolerance to be passed to +Chainer::Testing.assert_allclose+.
146
+ # @param [Array<Boolean>] no_grads Flag to skip variable for gradient assertion.
147
+ # It should be same length as +x_data+.
148
+ # @param [Numo::NArray.class] dtype +x_data+ and +y_grad+ are casted to this
149
+ # dtype when calculating numerical gradients. Only float types and
150
+ # +nil+ are allowed.
151
+ # @see
152
+ # .numerical_grad
153
+ #
154
+ def check_backward(func, x_data, y_grad, params=[], eps: 0.001, atol: 1e-5, rtol: 1e-4, no_grads: nil, dtype: nil)
155
+ x_data = _as_tuple(x_data)
156
+ if !y_grad.nil?
157
+ y_grad = _as_tuple(y_grad)
158
+ end
159
+
160
+ params = _as_tuple(params)
161
+ xs = x_data.map{|x| Chainer::Variable.new(x)}
162
+ y = func.(*xs)
163
+ y = _as_tuple(y)
164
+ y = Chainer::Functions::Math::Identity.identity(*y)
165
+ y = _as_tuple(y)
166
+
167
+ if !y_grad.nil?
168
+ if (y).size != (y_grad).size
169
+ raise TypeError, "`y_grad` must have the same length of output values"
170
+ end
171
+
172
+ y.zip(y_grad).each do |iy, igy|
173
+ iy.grad = igy
174
+ end
175
+ else
176
+ if (y).size != 1
177
+ raise TypeError, "When `y_grad` is `nil`, the function must return azero-dimentional array"
178
+ end
179
+ y_grad = [1]
180
+ end
181
+ # We only need to call `backward` for one result `Chainer::Variable`.
182
+ # `Chainer::Variable.backward` method calls `Chainer::Function.backward` of its creator.
183
+ y[0].backward()
184
+
185
+ if dtype.nil?
186
+ casted_xs = x_data.map{|x| Chainer::Variable.new(x)}
187
+ else
188
+ if (dtype != Numo::DFloat) and (dtype != Numo::SFloat)
189
+ raise TypeError, "`dtype` is allowed only float type"
190
+ end
191
+ if (params).size > 0
192
+ raise TypeError, "`dtype` is available only if `params` is empty"
193
+ end
194
+ casted_xs = x_data.map{|x|
195
+ if x.class == Numo::DFloat or x.class == Numo::SFloat
196
+ Chainer::Variable.new(dtype.cast(x))
197
+ else
198
+ Chainer::Variable.new(x)
199
+ end
200
+ }
201
+ end
202
+
203
+ f = lambda do |_|
204
+ ys = func.(*casted_xs)
205
+ ys = _as_tuple(ys)
206
+ return ys.map{|y| y.data}.to_a
207
+ end
208
+
209
+ if no_grads.nil?
210
+ no_grads = xs.map{|x| (x.dtype != Numo::DFloat) and (x.dtype != Numo::SFloat)}
211
+ else
212
+ if no_grads.size != xs.size
213
+ raise TypeError, "Length of no_grads param and xs should be same."
214
+ end
215
+ end
216
+
217
+ no_grads.zip(xs, casted_xs).each do |skip, x, cx|
218
+ if skip
219
+ raise unless x.grad.nil?
220
+ next
221
+ end
222
+ gx, = numerical_grad(f, [cx.data], y_grad, eps)
223
+ Chainer::Testing.assert_allclose(x.grad, gx, atol: atol, rtol: rtol)
224
+ if dtype.nil?
225
+ raise unless gx.class == x.grad.class
226
+ else
227
+ if ((gx.class != Numo::DFloat) and (gx.class != Numo::SFloat)) and (gx.class != dtype)
228
+ raise
229
+ end
230
+ end
231
+ end
232
+
233
+ params.each do |p|
234
+ gp, = numerical_grad(f, [p.data], y_grad, eps)
235
+ Chainer::Testing.assert_allclose(p.grad, gp, atol: atol, rtol: rtol)
236
+ raise unless gp.dtype === p.grad.dtype
237
+ end
238
+ end
239
+ module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward
240
+ end