red-chainer 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -2,6 +2,8 @@ module Chainer
2
2
  module Functions
3
3
  module Pooling
4
4
  class AveragePooling2D < Pooling2D
5
+ attr_reader :in_shape, :in_dtype
6
+
5
7
  # Spatial average pooling function.
6
8
  #
7
9
  # This function acts similarly to :class:`Convolution2D`,
@@ -14,31 +16,52 @@ module Chainer
14
16
  # @param [integer] pad Spatial padding width for the input array. `pad=p` and `pad=[p, p]` are equivalent.
15
17
  # @return [Chainer::Variable] Output variable
16
18
  def self.average_pooling_2d(x, ksize, stride: nil, pad: 0)
17
- self.new(ksize, stride: stride, pad: pad, cover_all: false).(x)
19
+ self.new(ksize, stride: stride, pad: pad, cover_all: false).apply([x])[0]
18
20
  end
19
21
 
20
22
  # Average pooling over a set of 2d planes.
21
- def forward_cpu(x)
22
- retain_inputs([])
23
+ def forward(x)
23
24
  @in_shape = x[0].shape
24
25
  @in_dtype = x[0].class
25
26
 
26
- col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw)
27
+ col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw)
27
28
  y = col.mean(axis: [2, 3])
28
29
 
29
30
  [y]
30
31
  end
31
32
 
32
- def backward_cpu(x, gy)
33
+ def backward(indexes, gy)
34
+ AveragePooling2DGrad.new(self).apply(gy)
35
+ end
36
+ end
37
+
38
+ class AveragePooling2DGrad < FunctionNode
39
+ def initialize(apool2d)
40
+ @kh = apool2d.kh
41
+ @kw = apool2d.kw
42
+ @sy = apool2d.sy
43
+ @sx = apool2d.sx
44
+ @ph = apool2d.ph
45
+ @pw = apool2d.pw
46
+ @in_shape = apool2d.in_shape
47
+ @in_dtype = apool2d.in_dtype
48
+ @apool2d = apool2d
49
+ end
50
+
51
+ def forward(gy)
33
52
  h, w = @in_shape[2..-1]
34
53
  shape = gy[0].shape
35
54
  shape.insert(2, 1, 1)
36
55
  gcol = gy[0].reshape(*shape).tile(1, 1, @kh, @kw, 1, 1)
37
56
 
38
- gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
57
+ gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
39
58
  gx /= @kh * @kw
40
59
  [gx]
41
60
  end
61
+
62
+ def backward(indexes, grad_outputs)
63
+ AveragePooling2D.new([@kh, @kw], stride: [@sy, @sx], pad: [@ph, @pw], cover_all: false).apply(grad_outputs)
64
+ end
42
65
  end
43
66
  end
44
67
  end
@@ -2,24 +2,24 @@ module Chainer
2
2
  module Functions
3
3
  module Pooling
4
4
  class MaxPooling2D < Pooling2D
5
+ attr_reader :in_shape, :in_dtype, :indexes
5
6
  # Spatial max pooling function
6
7
  #
7
8
  # @param [Chainer::Variable] x Input variable
8
- # @param [integer || 2D integer array] Size of pooling window
9
- # @param [integer || 2D integer array] Stride of pooling applications
10
- # @param [integer || 2D integer array] Spatial padding width for the input array
11
- # @param [boolean] If `true`, all spatial locations are pooled int some output pixels
9
+ # @param [integer || 2D integer array] ksize Size of pooling window
10
+ # @param [integer || 2D integer array] stride Stride of pooling applications
11
+ # @param [integer || 2D integer array] pad Spatial padding width for the input array
12
+ # @param [boolean] cover_all If `true`, all spatial locations are pooled int some output pixels
12
13
  # @return [Chainer::Variable] Output variable
13
14
  def self.max_pooling_2d(x, ksize, stride: nil, pad: 0, cover_all: true)
14
- self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).(x)
15
+ self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).apply([x]).first
15
16
  end
16
17
 
17
- def forward_cpu(x)
18
- retain_inputs([])
18
+ def forward(x)
19
19
  @in_shape = x[0].shape
20
20
  @in_dtype = x[0].class
21
21
 
22
- col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
22
+ col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
23
23
  n, c, kh, kw, out_h, out_w = col.shape
24
24
  col = col.reshape(n , c, kh * kw, out_h, out_w)
25
25
 
@@ -33,7 +33,27 @@ module Chainer
33
33
  [y]
34
34
  end
35
35
 
36
- def backward_cpu(x, gy)
36
+ def backward(indexes, gy)
37
+ MaxPooling2DGrad.new(self).apply(gy)
38
+ end
39
+ end
40
+
41
+ class MaxPooling2DGrad < FunctionNode
42
+ def initialize(mpool2d)
43
+ @kh = mpool2d.kh
44
+ @kw = mpool2d.kw
45
+ @sy = mpool2d.sy
46
+ @sx = mpool2d.sx
47
+ @ph = mpool2d.ph
48
+ @pw = mpool2d.pw
49
+ @cover_all = mpool2d.cover_all
50
+ @indexes = mpool2d.indexes
51
+ @in_shape = mpool2d.in_shape
52
+ @in_dtype = mpool2d.in_dtype
53
+ @mpool2d = mpool2d
54
+ end
55
+
56
+ def forward(gy)
37
57
  n, c, out_h, out_w = gy[0].shape
38
58
  h, w = @in_shape[2..-1]
39
59
  kh, kw = @kh, @kw
@@ -41,16 +61,51 @@ module Chainer
41
61
  gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
42
62
 
43
63
  indexes = @indexes.flatten
44
- indexes += Numo::Int64.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
45
-
64
+ indexes += indexes.class.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
65
+
46
66
  gcol[indexes] = gy[0].flatten.dup
47
67
  gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
48
68
  gcol = gcol.swapaxes(2, 4)
49
69
  gcol = gcol.swapaxes(3, 5)
50
70
 
51
- gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
71
+ gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
52
72
  [gx]
53
73
  end
74
+
75
+ def backward(indexes, ggx)
76
+ MaxPooling2DWithIndexes.new(@mpool2d).apply(ggx)
77
+ end
78
+ end
79
+
80
+ class MaxPooling2DWithIndexes < FunctionNode
81
+ def initialize(mpool2d)
82
+ @kh = mpool2d.kh
83
+ @kw = mpool2d.kw
84
+ @sy = mpool2d.sy
85
+ @sx = mpool2d.sx
86
+ @ph = mpool2d.ph
87
+ @pw = mpool2d.pw
88
+ @cover_all = mpool2d.cover_all
89
+ @indexes = mpool2d.indexes
90
+ end
91
+
92
+ def forward(x)
93
+ col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
94
+ n, c, kh, kw, out_h, out_w = col.shape
95
+ col = col.reshape(n, c, kh * kw, out_h, out_w)
96
+ col = col.transpose(0, 1, 3, 4, 2).reshape(nil, kh * kw)
97
+
98
+ indexes = @indexes.flatten.dup
99
+
100
+ # TODO: col = col[numpy.arange(len(indexes)), indexes]
101
+ new_col = col.class.zeros(indexes.size)
102
+ x[0].class.new(indexes.size).seq.each_with_index do |v, i|
103
+ new_col[i] = col[v, indexes[i]]
104
+ end
105
+ col = new_col
106
+
107
+ [col.reshape(n, c, out_h, out_w)]
108
+ end
54
109
  end
55
110
  end
56
111
  end
@@ -2,15 +2,17 @@ module Chainer
2
2
  module Functions
3
3
  module Pooling
4
4
  # Base class of pooling function over a set of 2d planes
5
- class Pooling2D < Chainer::Function
5
+ class Pooling2D < Chainer::FunctionNode
6
+ attr_reader :kh, :kw, :sy, :sx, :ph, :pw, :cover_all
7
+
6
8
  def initialize(ksize, stride: nil, pad: 0, cover_all: true)
7
9
  if stride.nil?
8
10
  stride = ksize
9
11
  end
10
12
 
11
- @kh, @kw = ksize.is_a?(Array) ? ksize : [ksize, ksize]
12
- @sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
13
- @ph, @pw = pad.is_a?(Array) ? pad: [pad, pad]
13
+ @kh, @kw = ksize.is_a?(::Array) ? ksize : [ksize, ksize]
14
+ @sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
15
+ @ph, @pw = pad.is_a?(::Array) ? pad: [pad, pad]
14
16
 
15
17
  @cover_all = cover_all
16
18
  end
@@ -1,7 +1,6 @@
1
1
  module Chainer
2
2
  def _copy_arrays(xs)
3
- xp = Chainer::get_array_module(*xs)
4
- xs.map{|x| (x.is_a? Numo::NArray) ? x.dup : x}
3
+ xs.map{|x| Chainer.array?(x) ? x.dup : x}
5
4
  end
6
5
 
7
6
  # Computes numerical gradient by finite differences.
@@ -19,37 +18,31 @@ module Chainer
19
18
  # @param [Float] eps Epsilon value of finite differences.
20
19
  # @return [Array] Numerical gradient arrays corresponding to +inputs+.
21
20
  #
22
- def numerical_grad(f, inputs, grad_outputs, eps=0.001)
21
+ def numerical_grad(f, inputs, grad_outputs, eps=1e-3)
23
22
  raise unless eps > 0
24
23
  inputs = inputs.to_a
25
24
  grad_outputs = grad_outputs.to_a
26
- xp = Numo::NArray
27
25
  grads = inputs.map{|x| x.new_zeros()}
28
26
 
29
- if inputs[0].ndim < 2
30
- tmp = [[inputs[0], grads[0]]]
31
- else
32
- tmp = (0...inputs[0].shape[0]).map{|i|[inputs[0][i, false], grads[0][i, false]]}
33
- end
34
-
35
- tmp.each do |x, gx|
36
- x.each_with_index{|xx, *i|
37
- orig = x[*i] # hold original value
27
+ inputs.zip(grads).each do |x, gx|
28
+ orig_x = x.dup # hold original value
29
+ x.each_with_index{|_, *i|
30
+ orig = orig_x[*i]
38
31
  x[*i] = orig + eps
39
- ys1 = _copy_arrays(f.call(x))
32
+ ys1 = _copy_arrays(f.())
40
33
  x[*i] = orig - eps
41
- ys2 = _copy_arrays(f.call(x))
34
+ ys2 = _copy_arrays(f.())
42
35
  x[*i] = orig
43
36
 
44
37
  ys1.zip(ys2, grad_outputs).each do |y1, y2, gy|
45
- if !gy.nil?
46
- if ((y1 - y2) * gy).is_a? Numo::NArray
47
- dot = ((y1 - y2) * gy).sum()
48
- else
49
- dot = ((y1 - y2) * gy).inject(:+)
50
- end
51
- gx[*i] += dot / (2*eps).to_f
38
+ next if gy.nil?
39
+ diff = y1 - y2
40
+ if Chainer.array?(diff) && diff.empty?
41
+ dot = 0
42
+ else
43
+ dot = (diff * gy).sum
52
44
  end
45
+ gx[*i] += dot / (2 * eps)
53
46
  end
54
47
  }
55
48
  end
@@ -153,6 +146,7 @@ module Chainer
153
146
  #
154
147
  def check_backward(func, x_data, y_grad, params=[], eps: 0.001, atol: 1e-5, rtol: 1e-4, no_grads: nil, dtype: nil)
155
148
  x_data = _as_tuple(x_data)
149
+ xm = Chainer.get_array_module(*x_data)
156
150
  if !y_grad.nil?
157
151
  y_grad = _as_tuple(y_grad)
158
152
  end
@@ -161,80 +155,170 @@ module Chainer
161
155
  xs = x_data.map{|x| Chainer::Variable.new(x)}
162
156
  y = func.(*xs)
163
157
  y = _as_tuple(y)
164
- y = Chainer::Functions::Math::Identity.identity(*y)
165
- y = _as_tuple(y)
158
+ y = Chainer::Functions::Math::Identity.new.apply(y)
166
159
 
167
- if !y_grad.nil?
168
- if (y).size != (y_grad).size
169
- raise TypeError, "`y_grad` must have the same length of output values"
170
- end
160
+ y_grad = set_y_grad(y, y_grad)
161
+
162
+ # Clear gradients which may exist if func calls backward inside of itself.
163
+ clear_grads(xs)
164
+ clear_grads(params)
171
165
 
172
- y.zip(y_grad).each do |iy, igy|
173
- iy.grad = igy
174
- end
175
- else
176
- if (y).size != 1
177
- raise TypeError, "When `y_grad` is `nil`, the function must return azero-dimentional array"
178
- end
179
- y_grad = [1]
180
- end
181
166
  # We only need to call `backward` for one result `Chainer::Variable`.
182
167
  # `Chainer::Variable.backward` method calls `Chainer::Function.backward` of its creator.
183
168
  y[0].backward()
184
169
 
170
+ param_data = params.map { |p| p.data }
185
171
  if dtype.nil?
186
- casted_xs = x_data.map{|x| Chainer::Variable.new(x)}
172
+ casted_xs = x_data.map { |x| Chainer::Variable.new(x) }
187
173
  else
188
- if (dtype != Numo::DFloat) and (dtype != Numo::SFloat)
189
- raise TypeError, "`dtype` is allowed only float type"
190
- end
191
- if (params).size > 0
192
- raise TypeError, "`dtype` is available only if `params` is empty"
174
+ raise '`dtype` is allowed only float type' if dtype != xm::DFloat && dtype != xm::SFloat
175
+ casted_xs = x_data.map { |x| x.is_a?(Numo::NArray) ? Chainer::Variable.new(x.cast_to(dtype)) : x }
176
+ end
177
+
178
+ if no_grads.nil?
179
+ no_grads = xs.map { |x| x.dtype != Numo::SFloat && x.dtype != Numo::DFloat }
180
+ else
181
+ raise "Length of no_grads param and xs should be same." if no_grads.size != xs.size
182
+ end
183
+
184
+ casted_data = casted_xs.map { |x| x.data.dup }
185
+
186
+ no_grads.zip(xs).each do |skip, x|
187
+ if skip
188
+ raise "x.grad is not nil" if x.grad != nil
189
+ else
190
+ raise 'gradients of some arguments are not calculated' if x.grad.nil?
193
191
  end
194
- casted_xs = x_data.map{|x|
195
- if x.class == Numo::DFloat or x.class == Numo::SFloat
196
- Chainer::Variable.new(dtype.cast(x))
197
- else
198
- Chainer::Variable.new(x)
199
- end
200
- }
201
192
  end
202
193
 
203
- f = lambda do |_|
194
+ # Keep the gradient arrays of params which may be overwritten by func
195
+ params_grad = params.map(&:grad)
196
+
197
+ if dtype.nil?
198
+ one = Numo::DFloat.new().fill(1.0)
199
+ else
200
+ one = dtype.new().fill(1.0)
201
+ end
202
+
203
+ g = lambda do
204
+ # This functions is called twice in `numerical_grad`.
205
+ # `one` is `1 + epsilon` or `1 - epsilon` in these calls.
206
+ # See the document of `numerical_grad`.
207
+ no_grads.zip(casted_xs, casted_data).each do |skip, cx, data|
208
+ next if skip || cx.data.empty?
209
+ # astype is require to store data with the given type
210
+ data = (one * data).cast_to(data.class)
211
+ cx.data = data
212
+ end
213
+
214
+ params.zip(param_data).each do |param, data|
215
+ if !dtype.nil?
216
+ param_dtype = dtype
217
+ else
218
+ param_dtype = param.dtype
219
+ end
220
+ # The inner astype is required to calculates __mul__ in
221
+ # `param_type` when data is low accuracy float.
222
+ # The outer one is require to store data with the given type.
223
+ param.data = (one * data.cast_to(param_dtype)).cast_to(param_dtype)
224
+ end
225
+
226
+ # Clear gradients to support func that calls backward inside of itself.
227
+ clear_grads(casted_xs)
228
+ clear_grads(params)
229
+
204
230
  ys = func.(*casted_xs)
205
231
  ys = _as_tuple(ys)
206
- return ys.map{|y| y.data}.to_a
232
+ ys_data = ys.map { |y| y.data }
233
+ no_grads.zip(casted_xs, casted_data).each do |skip, cx, data|
234
+ next if skip
235
+ cx.data = data
236
+ end
237
+ params.zip(param_data).each do |param, data|
238
+ param.data = data
239
+ end
240
+ ys_data
207
241
  end
208
242
 
209
- if no_grads.nil?
210
- no_grads = xs.map{|x| (x.dtype != Numo::DFloat) and (x.dtype != Numo::SFloat)}
211
- else
212
- if no_grads.size != xs.size
213
- raise TypeError, "Length of no_grads param and xs should be same."
243
+ gx, = numerical_grad(g, [one], y_grad, eps)
244
+ gx_accum = 0
245
+
246
+ no_grads.zip(xs, casted_xs).each do |skip, x, cx|
247
+ next if skip
248
+ gxi = x.grad.flatten.dup
249
+ cxi = cx.data.flatten.dup
250
+ unless dtype.nil?
251
+ gxi = gxi.cast_to(dtype)
252
+ cxi = cxi.cast_to(dtype)
214
253
  end
254
+ gx_accum += gxi.empty? ? 0 : gxi.dot(cxi)
215
255
  end
216
256
 
217
- no_grads.zip(xs, casted_xs).each do |skip, x, cx|
218
- if skip
219
- raise unless x.grad.nil?
220
- next
257
+ params.zip(params_grad).each do |p, gpi|
258
+ gpi =gpi.flatten.dup
259
+ pi = p.data.flatten.dup
260
+ unless dtype.nil?
261
+ gpi = gpi.cast_to(dtype)
262
+ pi = pi.cast_to(dtype)
221
263
  end
222
- gx, = numerical_grad(f, [cx.data], y_grad, eps)
223
- Chainer::Testing.assert_allclose(x.grad, gx, atol: atol, rtol: rtol)
224
- if dtype.nil?
225
- raise unless gx.class == x.grad.class
226
- else
227
- if ((gx.class != Numo::DFloat) and (gx.class != Numo::SFloat)) and (gx.class != dtype)
228
- raise
264
+ gx_accum += gpi.dot(pi)
265
+ end
266
+
267
+ Chainer::Testing.assert_allclose(gx, gx_accum, atol: atol, rtol: rtol)
268
+ end
269
+
270
+ def check_double_backward(func, x_data, y_grad, x_grad_grad, params=[], params_grad_grad=[], eps: 1e-3, atol: 1e-4, rtol: 1e-3, no_grads: nil, dtype: nil)
271
+ x_data = _as_tuple(x_data)
272
+ params = _as_tuple(params)
273
+ n_x = x_data.size
274
+
275
+ first_order_grad = -> *inputs do
276
+ xs = inputs[0...n_x]
277
+ gys = inputs[n_x..-1]
278
+
279
+ y = _as_tuple(func.(*xs))
280
+ # Let all elements of y share the same creator.
281
+ # See the comment in check_backward.
282
+ y = Chainer::Functions::Math::Identity.new.apply(y)
283
+ set_y_grad(y, gys)
284
+ y[0].backward(enable_double_backprop: true)
285
+
286
+ xs.map(&:grad_var) + params.map(&:grad_var)
287
+ end
288
+
289
+ inputs = x_data + _as_tuple(y_grad)
290
+ grad_grad = _as_tuple(x_grad_grad) + _as_tuple(params_grad_grad)
291
+ check_backward(first_order_grad, inputs, grad_grad, params=params, eps: eps, atol: atol, rtol: rtol, no_grads: no_grads, dtype: dtype)
292
+ end
293
+
294
+ def set_y_grad(y, y_grad)
295
+ if y_grad.nil?
296
+ if y.size != 1
297
+ raise TypeError, 'When `y_grad` is `None`, the function must return a zero-dimentional array'
298
+ end
299
+ y_grad = [1]
300
+ else
301
+ if y.size != y_grad.size
302
+ raise TypeError, '`y_grad` must have the same length of output values'
303
+ end
304
+ y.zip(y_grad).each do |iy, igy|
305
+ if igy.is_a?(Chainer::Variable)
306
+ iy.grad_var = igy
307
+ else
308
+ iy.grad = igy
229
309
  end
230
310
  end
231
311
  end
232
312
 
233
- params.each do |p|
234
- gp, = numerical_grad(f, [p.data], y_grad, eps)
235
- Chainer::Testing.assert_allclose(p.grad, gp, atol: atol, rtol: rtol)
236
- raise unless gp.dtype === p.grad.dtype
313
+ y_grad
314
+ end
315
+
316
+ def clear_grads(xs)
317
+ xs.each do |x|
318
+ x.grad_var = nil
237
319
  end
238
320
  end
239
- module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward
321
+
322
+ module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward, :check_double_backward, :set_y_grad, :clear_grads
323
+ private_class_method :set_y_grad, :clear_grads
240
324
  end