red-chainer 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -2,6 +2,8 @@ module Chainer
2
2
  module Functions
3
3
  module Pooling
4
4
  class AveragePooling2D < Pooling2D
5
+ attr_reader :in_shape, :in_dtype
6
+
5
7
  # Spatial average pooling function.
6
8
  #
7
9
  # This function acts similarly to :class:`Convolution2D`,
@@ -14,31 +16,52 @@ module Chainer
14
16
  # @param [integer] pad Spatial padding width for the input array. `pad=p` and `pad=[p, p]` are equivalent.
15
17
  # @return [Chainer::Variable] Output variable
16
18
  def self.average_pooling_2d(x, ksize, stride: nil, pad: 0)
17
- self.new(ksize, stride: stride, pad: pad, cover_all: false).(x)
19
+ self.new(ksize, stride: stride, pad: pad, cover_all: false).apply([x])[0]
18
20
  end
19
21
 
20
22
  # Average pooling over a set of 2d planes.
21
- def forward_cpu(x)
22
- retain_inputs([])
23
+ def forward(x)
23
24
  @in_shape = x[0].shape
24
25
  @in_dtype = x[0].class
25
26
 
26
- col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw)
27
+ col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw)
27
28
  y = col.mean(axis: [2, 3])
28
29
 
29
30
  [y]
30
31
  end
31
32
 
32
- def backward_cpu(x, gy)
33
+ def backward(indexes, gy)
34
+ AveragePooling2DGrad.new(self).apply(gy)
35
+ end
36
+ end
37
+
38
+ class AveragePooling2DGrad < FunctionNode
39
+ def initialize(apool2d)
40
+ @kh = apool2d.kh
41
+ @kw = apool2d.kw
42
+ @sy = apool2d.sy
43
+ @sx = apool2d.sx
44
+ @ph = apool2d.ph
45
+ @pw = apool2d.pw
46
+ @in_shape = apool2d.in_shape
47
+ @in_dtype = apool2d.in_dtype
48
+ @apool2d = apool2d
49
+ end
50
+
51
+ def forward(gy)
33
52
  h, w = @in_shape[2..-1]
34
53
  shape = gy[0].shape
35
54
  shape.insert(2, 1, 1)
36
55
  gcol = gy[0].reshape(*shape).tile(1, 1, @kh, @kw, 1, 1)
37
56
 
38
- gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
57
+ gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
39
58
  gx /= @kh * @kw
40
59
  [gx]
41
60
  end
61
+
62
+ def backward(indexes, grad_outputs)
63
+ AveragePooling2D.new([@kh, @kw], stride: [@sy, @sx], pad: [@ph, @pw], cover_all: false).apply(grad_outputs)
64
+ end
42
65
  end
43
66
  end
44
67
  end
@@ -2,24 +2,24 @@ module Chainer
2
2
  module Functions
3
3
  module Pooling
4
4
  class MaxPooling2D < Pooling2D
5
+ attr_reader :in_shape, :in_dtype, :indexes
5
6
  # Spatial max pooling function
6
7
  #
7
8
  # @param [Chainer::Variable] x Input variable
8
- # @param [integer || 2D integer array] Size of pooling window
9
- # @param [integer || 2D integer array] Stride of pooling applications
10
- # @param [integer || 2D integer array] Spatial padding width for the input array
11
- # @param [boolean] If `true`, all spatial locations are pooled int some output pixels
9
+ # @param [integer || 2D integer array] ksize Size of pooling window
10
+ # @param [integer || 2D integer array] stride Stride of pooling applications
11
+ # @param [integer || 2D integer array] pad Spatial padding width for the input array
12
+ # @param [boolean] cover_all If `true`, all spatial locations are pooled int some output pixels
12
13
  # @return [Chainer::Variable] Output variable
13
14
  def self.max_pooling_2d(x, ksize, stride: nil, pad: 0, cover_all: true)
14
- self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).(x)
15
+ self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).apply([x]).first
15
16
  end
16
17
 
17
- def forward_cpu(x)
18
- retain_inputs([])
18
+ def forward(x)
19
19
  @in_shape = x[0].shape
20
20
  @in_dtype = x[0].class
21
21
 
22
- col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
22
+ col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
23
23
  n, c, kh, kw, out_h, out_w = col.shape
24
24
  col = col.reshape(n , c, kh * kw, out_h, out_w)
25
25
 
@@ -33,7 +33,27 @@ module Chainer
33
33
  [y]
34
34
  end
35
35
 
36
- def backward_cpu(x, gy)
36
+ def backward(indexes, gy)
37
+ MaxPooling2DGrad.new(self).apply(gy)
38
+ end
39
+ end
40
+
41
+ class MaxPooling2DGrad < FunctionNode
42
+ def initialize(mpool2d)
43
+ @kh = mpool2d.kh
44
+ @kw = mpool2d.kw
45
+ @sy = mpool2d.sy
46
+ @sx = mpool2d.sx
47
+ @ph = mpool2d.ph
48
+ @pw = mpool2d.pw
49
+ @cover_all = mpool2d.cover_all
50
+ @indexes = mpool2d.indexes
51
+ @in_shape = mpool2d.in_shape
52
+ @in_dtype = mpool2d.in_dtype
53
+ @mpool2d = mpool2d
54
+ end
55
+
56
+ def forward(gy)
37
57
  n, c, out_h, out_w = gy[0].shape
38
58
  h, w = @in_shape[2..-1]
39
59
  kh, kw = @kh, @kw
@@ -41,16 +61,51 @@ module Chainer
41
61
  gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
42
62
 
43
63
  indexes = @indexes.flatten
44
- indexes += Numo::Int64.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
45
-
64
+ indexes += indexes.class.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
65
+
46
66
  gcol[indexes] = gy[0].flatten.dup
47
67
  gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
48
68
  gcol = gcol.swapaxes(2, 4)
49
69
  gcol = gcol.swapaxes(3, 5)
50
70
 
51
- gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
71
+ gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
52
72
  [gx]
53
73
  end
74
+
75
+ def backward(indexes, ggx)
76
+ MaxPooling2DWithIndexes.new(@mpool2d).apply(ggx)
77
+ end
78
+ end
79
+
80
+ class MaxPooling2DWithIndexes < FunctionNode
81
+ def initialize(mpool2d)
82
+ @kh = mpool2d.kh
83
+ @kw = mpool2d.kw
84
+ @sy = mpool2d.sy
85
+ @sx = mpool2d.sx
86
+ @ph = mpool2d.ph
87
+ @pw = mpool2d.pw
88
+ @cover_all = mpool2d.cover_all
89
+ @indexes = mpool2d.indexes
90
+ end
91
+
92
+ def forward(x)
93
+ col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
94
+ n, c, kh, kw, out_h, out_w = col.shape
95
+ col = col.reshape(n, c, kh * kw, out_h, out_w)
96
+ col = col.transpose(0, 1, 3, 4, 2).reshape(nil, kh * kw)
97
+
98
+ indexes = @indexes.flatten.dup
99
+
100
+ # TODO: col = col[numpy.arange(len(indexes)), indexes]
101
+ new_col = col.class.zeros(indexes.size)
102
+ x[0].class.new(indexes.size).seq.each_with_index do |v, i|
103
+ new_col[i] = col[v, indexes[i]]
104
+ end
105
+ col = new_col
106
+
107
+ [col.reshape(n, c, out_h, out_w)]
108
+ end
54
109
  end
55
110
  end
56
111
  end
@@ -2,15 +2,17 @@ module Chainer
2
2
  module Functions
3
3
  module Pooling
4
4
  # Base class of pooling function over a set of 2d planes
5
- class Pooling2D < Chainer::Function
5
+ class Pooling2D < Chainer::FunctionNode
6
+ attr_reader :kh, :kw, :sy, :sx, :ph, :pw, :cover_all
7
+
6
8
  def initialize(ksize, stride: nil, pad: 0, cover_all: true)
7
9
  if stride.nil?
8
10
  stride = ksize
9
11
  end
10
12
 
11
- @kh, @kw = ksize.is_a?(Array) ? ksize : [ksize, ksize]
12
- @sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
13
- @ph, @pw = pad.is_a?(Array) ? pad: [pad, pad]
13
+ @kh, @kw = ksize.is_a?(::Array) ? ksize : [ksize, ksize]
14
+ @sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
15
+ @ph, @pw = pad.is_a?(::Array) ? pad: [pad, pad]
14
16
 
15
17
  @cover_all = cover_all
16
18
  end
@@ -1,7 +1,6 @@
1
1
  module Chainer
2
2
  def _copy_arrays(xs)
3
- xp = Chainer::get_array_module(*xs)
4
- xs.map{|x| (x.is_a? Numo::NArray) ? x.dup : x}
3
+ xs.map{|x| Chainer.array?(x) ? x.dup : x}
5
4
  end
6
5
 
7
6
  # Computes numerical gradient by finite differences.
@@ -19,37 +18,31 @@ module Chainer
19
18
  # @param [Float] eps Epsilon value of finite differences.
20
19
  # @return [Array] Numerical gradient arrays corresponding to +inputs+.
21
20
  #
22
- def numerical_grad(f, inputs, grad_outputs, eps=0.001)
21
+ def numerical_grad(f, inputs, grad_outputs, eps=1e-3)
23
22
  raise unless eps > 0
24
23
  inputs = inputs.to_a
25
24
  grad_outputs = grad_outputs.to_a
26
- xp = Numo::NArray
27
25
  grads = inputs.map{|x| x.new_zeros()}
28
26
 
29
- if inputs[0].ndim < 2
30
- tmp = [[inputs[0], grads[0]]]
31
- else
32
- tmp = (0...inputs[0].shape[0]).map{|i|[inputs[0][i, false], grads[0][i, false]]}
33
- end
34
-
35
- tmp.each do |x, gx|
36
- x.each_with_index{|xx, *i|
37
- orig = x[*i] # hold original value
27
+ inputs.zip(grads).each do |x, gx|
28
+ orig_x = x.dup # hold original value
29
+ x.each_with_index{|_, *i|
30
+ orig = orig_x[*i]
38
31
  x[*i] = orig + eps
39
- ys1 = _copy_arrays(f.call(x))
32
+ ys1 = _copy_arrays(f.())
40
33
  x[*i] = orig - eps
41
- ys2 = _copy_arrays(f.call(x))
34
+ ys2 = _copy_arrays(f.())
42
35
  x[*i] = orig
43
36
 
44
37
  ys1.zip(ys2, grad_outputs).each do |y1, y2, gy|
45
- if !gy.nil?
46
- if ((y1 - y2) * gy).is_a? Numo::NArray
47
- dot = ((y1 - y2) * gy).sum()
48
- else
49
- dot = ((y1 - y2) * gy).inject(:+)
50
- end
51
- gx[*i] += dot / (2*eps).to_f
38
+ next if gy.nil?
39
+ diff = y1 - y2
40
+ if Chainer.array?(diff) && diff.empty?
41
+ dot = 0
42
+ else
43
+ dot = (diff * gy).sum
52
44
  end
45
+ gx[*i] += dot / (2 * eps)
53
46
  end
54
47
  }
55
48
  end
@@ -153,6 +146,7 @@ module Chainer
153
146
  #
154
147
  def check_backward(func, x_data, y_grad, params=[], eps: 0.001, atol: 1e-5, rtol: 1e-4, no_grads: nil, dtype: nil)
155
148
  x_data = _as_tuple(x_data)
149
+ xm = Chainer.get_array_module(*x_data)
156
150
  if !y_grad.nil?
157
151
  y_grad = _as_tuple(y_grad)
158
152
  end
@@ -161,80 +155,170 @@ module Chainer
161
155
  xs = x_data.map{|x| Chainer::Variable.new(x)}
162
156
  y = func.(*xs)
163
157
  y = _as_tuple(y)
164
- y = Chainer::Functions::Math::Identity.identity(*y)
165
- y = _as_tuple(y)
158
+ y = Chainer::Functions::Math::Identity.new.apply(y)
166
159
 
167
- if !y_grad.nil?
168
- if (y).size != (y_grad).size
169
- raise TypeError, "`y_grad` must have the same length of output values"
170
- end
160
+ y_grad = set_y_grad(y, y_grad)
161
+
162
+ # Clear gradients which may exist if func calls backward inside of itself.
163
+ clear_grads(xs)
164
+ clear_grads(params)
171
165
 
172
- y.zip(y_grad).each do |iy, igy|
173
- iy.grad = igy
174
- end
175
- else
176
- if (y).size != 1
177
- raise TypeError, "When `y_grad` is `nil`, the function must return azero-dimentional array"
178
- end
179
- y_grad = [1]
180
- end
181
166
  # We only need to call `backward` for one result `Chainer::Variable`.
182
167
  # `Chainer::Variable.backward` method calls `Chainer::Function.backward` of its creator.
183
168
  y[0].backward()
184
169
 
170
+ param_data = params.map { |p| p.data }
185
171
  if dtype.nil?
186
- casted_xs = x_data.map{|x| Chainer::Variable.new(x)}
172
+ casted_xs = x_data.map { |x| Chainer::Variable.new(x) }
187
173
  else
188
- if (dtype != Numo::DFloat) and (dtype != Numo::SFloat)
189
- raise TypeError, "`dtype` is allowed only float type"
190
- end
191
- if (params).size > 0
192
- raise TypeError, "`dtype` is available only if `params` is empty"
174
+ raise '`dtype` is allowed only float type' if dtype != xm::DFloat && dtype != xm::SFloat
175
+ casted_xs = x_data.map { |x| x.is_a?(Numo::NArray) ? Chainer::Variable.new(x.cast_to(dtype)) : x }
176
+ end
177
+
178
+ if no_grads.nil?
179
+ no_grads = xs.map { |x| x.dtype != Numo::SFloat && x.dtype != Numo::DFloat }
180
+ else
181
+ raise "Length of no_grads param and xs should be same." if no_grads.size != xs.size
182
+ end
183
+
184
+ casted_data = casted_xs.map { |x| x.data.dup }
185
+
186
+ no_grads.zip(xs).each do |skip, x|
187
+ if skip
188
+ raise "x.grad is not nil" if x.grad != nil
189
+ else
190
+ raise 'gradients of some arguments are not calculated' if x.grad.nil?
193
191
  end
194
- casted_xs = x_data.map{|x|
195
- if x.class == Numo::DFloat or x.class == Numo::SFloat
196
- Chainer::Variable.new(dtype.cast(x))
197
- else
198
- Chainer::Variable.new(x)
199
- end
200
- }
201
192
  end
202
193
 
203
- f = lambda do |_|
194
+ # Keep the gradient arrays of params which may be overwritten by func
195
+ params_grad = params.map(&:grad)
196
+
197
+ if dtype.nil?
198
+ one = Numo::DFloat.new().fill(1.0)
199
+ else
200
+ one = dtype.new().fill(1.0)
201
+ end
202
+
203
+ g = lambda do
204
+ # This functions is called twice in `numerical_grad`.
205
+ # `one` is `1 + epsilon` or `1 - epsilon` in these calls.
206
+ # See the document of `numerical_grad`.
207
+ no_grads.zip(casted_xs, casted_data).each do |skip, cx, data|
208
+ next if skip || cx.data.empty?
209
+ # astype is require to store data with the given type
210
+ data = (one * data).cast_to(data.class)
211
+ cx.data = data
212
+ end
213
+
214
+ params.zip(param_data).each do |param, data|
215
+ if !dtype.nil?
216
+ param_dtype = dtype
217
+ else
218
+ param_dtype = param.dtype
219
+ end
220
+ # The inner astype is required to calculates __mul__ in
221
+ # `param_type` when data is low accuracy float.
222
+ # The outer one is require to store data with the given type.
223
+ param.data = (one * data.cast_to(param_dtype)).cast_to(param_dtype)
224
+ end
225
+
226
+ # Clear gradients to support func that calls backward inside of itself.
227
+ clear_grads(casted_xs)
228
+ clear_grads(params)
229
+
204
230
  ys = func.(*casted_xs)
205
231
  ys = _as_tuple(ys)
206
- return ys.map{|y| y.data}.to_a
232
+ ys_data = ys.map { |y| y.data }
233
+ no_grads.zip(casted_xs, casted_data).each do |skip, cx, data|
234
+ next if skip
235
+ cx.data = data
236
+ end
237
+ params.zip(param_data).each do |param, data|
238
+ param.data = data
239
+ end
240
+ ys_data
207
241
  end
208
242
 
209
- if no_grads.nil?
210
- no_grads = xs.map{|x| (x.dtype != Numo::DFloat) and (x.dtype != Numo::SFloat)}
211
- else
212
- if no_grads.size != xs.size
213
- raise TypeError, "Length of no_grads param and xs should be same."
243
+ gx, = numerical_grad(g, [one], y_grad, eps)
244
+ gx_accum = 0
245
+
246
+ no_grads.zip(xs, casted_xs).each do |skip, x, cx|
247
+ next if skip
248
+ gxi = x.grad.flatten.dup
249
+ cxi = cx.data.flatten.dup
250
+ unless dtype.nil?
251
+ gxi = gxi.cast_to(dtype)
252
+ cxi = cxi.cast_to(dtype)
214
253
  end
254
+ gx_accum += gxi.empty? ? 0 : gxi.dot(cxi)
215
255
  end
216
256
 
217
- no_grads.zip(xs, casted_xs).each do |skip, x, cx|
218
- if skip
219
- raise unless x.grad.nil?
220
- next
257
+ params.zip(params_grad).each do |p, gpi|
258
+ gpi =gpi.flatten.dup
259
+ pi = p.data.flatten.dup
260
+ unless dtype.nil?
261
+ gpi = gpi.cast_to(dtype)
262
+ pi = pi.cast_to(dtype)
221
263
  end
222
- gx, = numerical_grad(f, [cx.data], y_grad, eps)
223
- Chainer::Testing.assert_allclose(x.grad, gx, atol: atol, rtol: rtol)
224
- if dtype.nil?
225
- raise unless gx.class == x.grad.class
226
- else
227
- if ((gx.class != Numo::DFloat) and (gx.class != Numo::SFloat)) and (gx.class != dtype)
228
- raise
264
+ gx_accum += gpi.dot(pi)
265
+ end
266
+
267
+ Chainer::Testing.assert_allclose(gx, gx_accum, atol: atol, rtol: rtol)
268
+ end
269
+
270
+ def check_double_backward(func, x_data, y_grad, x_grad_grad, params=[], params_grad_grad=[], eps: 1e-3, atol: 1e-4, rtol: 1e-3, no_grads: nil, dtype: nil)
271
+ x_data = _as_tuple(x_data)
272
+ params = _as_tuple(params)
273
+ n_x = x_data.size
274
+
275
+ first_order_grad = -> *inputs do
276
+ xs = inputs[0...n_x]
277
+ gys = inputs[n_x..-1]
278
+
279
+ y = _as_tuple(func.(*xs))
280
+ # Let all elements of y share the same creator.
281
+ # See the comment in check_backward.
282
+ y = Chainer::Functions::Math::Identity.new.apply(y)
283
+ set_y_grad(y, gys)
284
+ y[0].backward(enable_double_backprop: true)
285
+
286
+ xs.map(&:grad_var) + params.map(&:grad_var)
287
+ end
288
+
289
+ inputs = x_data + _as_tuple(y_grad)
290
+ grad_grad = _as_tuple(x_grad_grad) + _as_tuple(params_grad_grad)
291
+ check_backward(first_order_grad, inputs, grad_grad, params=params, eps: eps, atol: atol, rtol: rtol, no_grads: no_grads, dtype: dtype)
292
+ end
293
+
294
+ def set_y_grad(y, y_grad)
295
+ if y_grad.nil?
296
+ if y.size != 1
297
+ raise TypeError, 'When `y_grad` is `None`, the function must return a zero-dimentional array'
298
+ end
299
+ y_grad = [1]
300
+ else
301
+ if y.size != y_grad.size
302
+ raise TypeError, '`y_grad` must have the same length of output values'
303
+ end
304
+ y.zip(y_grad).each do |iy, igy|
305
+ if igy.is_a?(Chainer::Variable)
306
+ iy.grad_var = igy
307
+ else
308
+ iy.grad = igy
229
309
  end
230
310
  end
231
311
  end
232
312
 
233
- params.each do |p|
234
- gp, = numerical_grad(f, [p.data], y_grad, eps)
235
- Chainer::Testing.assert_allclose(p.grad, gp, atol: atol, rtol: rtol)
236
- raise unless gp.dtype === p.grad.dtype
313
+ y_grad
314
+ end
315
+
316
+ def clear_grads(xs)
317
+ xs.each do |x|
318
+ x.grad_var = nil
237
319
  end
238
320
  end
239
- module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward
321
+
322
+ module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward, :check_double_backward, :set_y_grad, :clear_grads
323
+ private_class_method :set_y_grad, :clear_grads
240
324
  end