red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -2,6 +2,8 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Pooling
|
4
4
|
class AveragePooling2D < Pooling2D
|
5
|
+
attr_reader :in_shape, :in_dtype
|
6
|
+
|
5
7
|
# Spatial average pooling function.
|
6
8
|
#
|
7
9
|
# This function acts similarly to :class:`Convolution2D`,
|
@@ -14,31 +16,52 @@ module Chainer
|
|
14
16
|
# @param [integer] pad Spatial padding width for the input array. `pad=p` and `pad=[p, p]` are equivalent.
|
15
17
|
# @return [Chainer::Variable] Output variable
|
16
18
|
def self.average_pooling_2d(x, ksize, stride: nil, pad: 0)
|
17
|
-
self.new(ksize, stride: stride, pad: pad, cover_all: false).(x)
|
19
|
+
self.new(ksize, stride: stride, pad: pad, cover_all: false).apply([x])[0]
|
18
20
|
end
|
19
21
|
|
20
22
|
# Average pooling over a set of 2d planes.
|
21
|
-
def
|
22
|
-
retain_inputs([])
|
23
|
+
def forward(x)
|
23
24
|
@in_shape = x[0].shape
|
24
25
|
@in_dtype = x[0].class
|
25
26
|
|
26
|
-
col = Chainer::Utils::Conv.
|
27
|
+
col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw)
|
27
28
|
y = col.mean(axis: [2, 3])
|
28
29
|
|
29
30
|
[y]
|
30
31
|
end
|
31
32
|
|
32
|
-
def
|
33
|
+
def backward(indexes, gy)
|
34
|
+
AveragePooling2DGrad.new(self).apply(gy)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
class AveragePooling2DGrad < FunctionNode
|
39
|
+
def initialize(apool2d)
|
40
|
+
@kh = apool2d.kh
|
41
|
+
@kw = apool2d.kw
|
42
|
+
@sy = apool2d.sy
|
43
|
+
@sx = apool2d.sx
|
44
|
+
@ph = apool2d.ph
|
45
|
+
@pw = apool2d.pw
|
46
|
+
@in_shape = apool2d.in_shape
|
47
|
+
@in_dtype = apool2d.in_dtype
|
48
|
+
@apool2d = apool2d
|
49
|
+
end
|
50
|
+
|
51
|
+
def forward(gy)
|
33
52
|
h, w = @in_shape[2..-1]
|
34
53
|
shape = gy[0].shape
|
35
54
|
shape.insert(2, 1, 1)
|
36
55
|
gcol = gy[0].reshape(*shape).tile(1, 1, @kh, @kw, 1, 1)
|
37
56
|
|
38
|
-
gx = Chainer::Utils::Conv.
|
57
|
+
gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
|
39
58
|
gx /= @kh * @kw
|
40
59
|
[gx]
|
41
60
|
end
|
61
|
+
|
62
|
+
def backward(indexes, grad_outputs)
|
63
|
+
AveragePooling2D.new([@kh, @kw], stride: [@sy, @sx], pad: [@ph, @pw], cover_all: false).apply(grad_outputs)
|
64
|
+
end
|
42
65
|
end
|
43
66
|
end
|
44
67
|
end
|
@@ -2,24 +2,24 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Pooling
|
4
4
|
class MaxPooling2D < Pooling2D
|
5
|
+
attr_reader :in_shape, :in_dtype, :indexes
|
5
6
|
# Spatial max pooling function
|
6
7
|
#
|
7
8
|
# @param [Chainer::Variable] x Input variable
|
8
|
-
# @param [integer || 2D integer array] Size of pooling window
|
9
|
-
# @param [integer || 2D integer array] Stride of pooling applications
|
10
|
-
# @param [integer || 2D integer array] Spatial padding width for the input array
|
11
|
-
# @param [boolean] If `true`, all spatial locations are pooled int some output pixels
|
9
|
+
# @param [integer || 2D integer array] ksize Size of pooling window
|
10
|
+
# @param [integer || 2D integer array] stride Stride of pooling applications
|
11
|
+
# @param [integer || 2D integer array] pad Spatial padding width for the input array
|
12
|
+
# @param [boolean] cover_all If `true`, all spatial locations are pooled int some output pixels
|
12
13
|
# @return [Chainer::Variable] Output variable
|
13
14
|
def self.max_pooling_2d(x, ksize, stride: nil, pad: 0, cover_all: true)
|
14
|
-
self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).(x)
|
15
|
+
self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).apply([x]).first
|
15
16
|
end
|
16
17
|
|
17
|
-
def
|
18
|
-
retain_inputs([])
|
18
|
+
def forward(x)
|
19
19
|
@in_shape = x[0].shape
|
20
20
|
@in_dtype = x[0].class
|
21
21
|
|
22
|
-
col = Chainer::Utils::Conv.
|
22
|
+
col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
|
23
23
|
n, c, kh, kw, out_h, out_w = col.shape
|
24
24
|
col = col.reshape(n , c, kh * kw, out_h, out_w)
|
25
25
|
|
@@ -33,7 +33,27 @@ module Chainer
|
|
33
33
|
[y]
|
34
34
|
end
|
35
35
|
|
36
|
-
def
|
36
|
+
def backward(indexes, gy)
|
37
|
+
MaxPooling2DGrad.new(self).apply(gy)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
class MaxPooling2DGrad < FunctionNode
|
42
|
+
def initialize(mpool2d)
|
43
|
+
@kh = mpool2d.kh
|
44
|
+
@kw = mpool2d.kw
|
45
|
+
@sy = mpool2d.sy
|
46
|
+
@sx = mpool2d.sx
|
47
|
+
@ph = mpool2d.ph
|
48
|
+
@pw = mpool2d.pw
|
49
|
+
@cover_all = mpool2d.cover_all
|
50
|
+
@indexes = mpool2d.indexes
|
51
|
+
@in_shape = mpool2d.in_shape
|
52
|
+
@in_dtype = mpool2d.in_dtype
|
53
|
+
@mpool2d = mpool2d
|
54
|
+
end
|
55
|
+
|
56
|
+
def forward(gy)
|
37
57
|
n, c, out_h, out_w = gy[0].shape
|
38
58
|
h, w = @in_shape[2..-1]
|
39
59
|
kh, kw = @kh, @kw
|
@@ -41,16 +61,51 @@ module Chainer
|
|
41
61
|
gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
|
42
62
|
|
43
63
|
indexes = @indexes.flatten
|
44
|
-
indexes +=
|
45
|
-
|
64
|
+
indexes += indexes.class.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
|
65
|
+
|
46
66
|
gcol[indexes] = gy[0].flatten.dup
|
47
67
|
gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
|
48
68
|
gcol = gcol.swapaxes(2, 4)
|
49
69
|
gcol = gcol.swapaxes(3, 5)
|
50
70
|
|
51
|
-
gx = Chainer::Utils::Conv.
|
71
|
+
gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
|
52
72
|
[gx]
|
53
73
|
end
|
74
|
+
|
75
|
+
def backward(indexes, ggx)
|
76
|
+
MaxPooling2DWithIndexes.new(@mpool2d).apply(ggx)
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
class MaxPooling2DWithIndexes < FunctionNode
|
81
|
+
def initialize(mpool2d)
|
82
|
+
@kh = mpool2d.kh
|
83
|
+
@kw = mpool2d.kw
|
84
|
+
@sy = mpool2d.sy
|
85
|
+
@sx = mpool2d.sx
|
86
|
+
@ph = mpool2d.ph
|
87
|
+
@pw = mpool2d.pw
|
88
|
+
@cover_all = mpool2d.cover_all
|
89
|
+
@indexes = mpool2d.indexes
|
90
|
+
end
|
91
|
+
|
92
|
+
def forward(x)
|
93
|
+
col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
|
94
|
+
n, c, kh, kw, out_h, out_w = col.shape
|
95
|
+
col = col.reshape(n, c, kh * kw, out_h, out_w)
|
96
|
+
col = col.transpose(0, 1, 3, 4, 2).reshape(nil, kh * kw)
|
97
|
+
|
98
|
+
indexes = @indexes.flatten.dup
|
99
|
+
|
100
|
+
# TODO: col = col[numpy.arange(len(indexes)), indexes]
|
101
|
+
new_col = col.class.zeros(indexes.size)
|
102
|
+
x[0].class.new(indexes.size).seq.each_with_index do |v, i|
|
103
|
+
new_col[i] = col[v, indexes[i]]
|
104
|
+
end
|
105
|
+
col = new_col
|
106
|
+
|
107
|
+
[col.reshape(n, c, out_h, out_w)]
|
108
|
+
end
|
54
109
|
end
|
55
110
|
end
|
56
111
|
end
|
@@ -2,15 +2,17 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Pooling
|
4
4
|
# Base class of pooling function over a set of 2d planes
|
5
|
-
class Pooling2D < Chainer::
|
5
|
+
class Pooling2D < Chainer::FunctionNode
|
6
|
+
attr_reader :kh, :kw, :sy, :sx, :ph, :pw, :cover_all
|
7
|
+
|
6
8
|
def initialize(ksize, stride: nil, pad: 0, cover_all: true)
|
7
9
|
if stride.nil?
|
8
10
|
stride = ksize
|
9
11
|
end
|
10
12
|
|
11
|
-
@kh, @kw = ksize.is_a?(Array) ? ksize : [ksize, ksize]
|
12
|
-
@sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
|
13
|
-
@ph, @pw = pad.is_a?(Array) ? pad: [pad, pad]
|
13
|
+
@kh, @kw = ksize.is_a?(::Array) ? ksize : [ksize, ksize]
|
14
|
+
@sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
|
15
|
+
@ph, @pw = pad.is_a?(::Array) ? pad: [pad, pad]
|
14
16
|
|
15
17
|
@cover_all = cover_all
|
16
18
|
end
|
@@ -1,7 +1,6 @@
|
|
1
1
|
module Chainer
|
2
2
|
def _copy_arrays(xs)
|
3
|
-
|
4
|
-
xs.map{|x| (x.is_a? Numo::NArray) ? x.dup : x}
|
3
|
+
xs.map{|x| Chainer.array?(x) ? x.dup : x}
|
5
4
|
end
|
6
5
|
|
7
6
|
# Computes numerical gradient by finite differences.
|
@@ -19,37 +18,31 @@ module Chainer
|
|
19
18
|
# @param [Float] eps Epsilon value of finite differences.
|
20
19
|
# @return [Array] Numerical gradient arrays corresponding to +inputs+.
|
21
20
|
#
|
22
|
-
def numerical_grad(f, inputs, grad_outputs, eps=
|
21
|
+
def numerical_grad(f, inputs, grad_outputs, eps=1e-3)
|
23
22
|
raise unless eps > 0
|
24
23
|
inputs = inputs.to_a
|
25
24
|
grad_outputs = grad_outputs.to_a
|
26
|
-
xp = Numo::NArray
|
27
25
|
grads = inputs.map{|x| x.new_zeros()}
|
28
26
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
end
|
34
|
-
|
35
|
-
tmp.each do |x, gx|
|
36
|
-
x.each_with_index{|xx, *i|
|
37
|
-
orig = x[*i] # hold original value
|
27
|
+
inputs.zip(grads).each do |x, gx|
|
28
|
+
orig_x = x.dup # hold original value
|
29
|
+
x.each_with_index{|_, *i|
|
30
|
+
orig = orig_x[*i]
|
38
31
|
x[*i] = orig + eps
|
39
|
-
ys1 = _copy_arrays(f.
|
32
|
+
ys1 = _copy_arrays(f.())
|
40
33
|
x[*i] = orig - eps
|
41
|
-
ys2 = _copy_arrays(f.
|
34
|
+
ys2 = _copy_arrays(f.())
|
42
35
|
x[*i] = orig
|
43
36
|
|
44
37
|
ys1.zip(ys2, grad_outputs).each do |y1, y2, gy|
|
45
|
-
if
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
gx[*i] += dot / (2*eps).to_f
|
38
|
+
next if gy.nil?
|
39
|
+
diff = y1 - y2
|
40
|
+
if Chainer.array?(diff) && diff.empty?
|
41
|
+
dot = 0
|
42
|
+
else
|
43
|
+
dot = (diff * gy).sum
|
52
44
|
end
|
45
|
+
gx[*i] += dot / (2 * eps)
|
53
46
|
end
|
54
47
|
}
|
55
48
|
end
|
@@ -153,6 +146,7 @@ module Chainer
|
|
153
146
|
#
|
154
147
|
def check_backward(func, x_data, y_grad, params=[], eps: 0.001, atol: 1e-5, rtol: 1e-4, no_grads: nil, dtype: nil)
|
155
148
|
x_data = _as_tuple(x_data)
|
149
|
+
xm = Chainer.get_array_module(*x_data)
|
156
150
|
if !y_grad.nil?
|
157
151
|
y_grad = _as_tuple(y_grad)
|
158
152
|
end
|
@@ -161,80 +155,170 @@ module Chainer
|
|
161
155
|
xs = x_data.map{|x| Chainer::Variable.new(x)}
|
162
156
|
y = func.(*xs)
|
163
157
|
y = _as_tuple(y)
|
164
|
-
y = Chainer::Functions::Math::Identity.
|
165
|
-
y = _as_tuple(y)
|
158
|
+
y = Chainer::Functions::Math::Identity.new.apply(y)
|
166
159
|
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
160
|
+
y_grad = set_y_grad(y, y_grad)
|
161
|
+
|
162
|
+
# Clear gradients which may exist if func calls backward inside of itself.
|
163
|
+
clear_grads(xs)
|
164
|
+
clear_grads(params)
|
171
165
|
|
172
|
-
y.zip(y_grad).each do |iy, igy|
|
173
|
-
iy.grad = igy
|
174
|
-
end
|
175
|
-
else
|
176
|
-
if (y).size != 1
|
177
|
-
raise TypeError, "When `y_grad` is `nil`, the function must return azero-dimentional array"
|
178
|
-
end
|
179
|
-
y_grad = [1]
|
180
|
-
end
|
181
166
|
# We only need to call `backward` for one result `Chainer::Variable`.
|
182
167
|
# `Chainer::Variable.backward` method calls `Chainer::Function.backward` of its creator.
|
183
168
|
y[0].backward()
|
184
169
|
|
170
|
+
param_data = params.map { |p| p.data }
|
185
171
|
if dtype.nil?
|
186
|
-
casted_xs = x_data.map{|x| Chainer::Variable.new(x)}
|
172
|
+
casted_xs = x_data.map { |x| Chainer::Variable.new(x) }
|
187
173
|
else
|
188
|
-
if
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
174
|
+
raise '`dtype` is allowed only float type' if dtype != xm::DFloat && dtype != xm::SFloat
|
175
|
+
casted_xs = x_data.map { |x| x.is_a?(Numo::NArray) ? Chainer::Variable.new(x.cast_to(dtype)) : x }
|
176
|
+
end
|
177
|
+
|
178
|
+
if no_grads.nil?
|
179
|
+
no_grads = xs.map { |x| x.dtype != Numo::SFloat && x.dtype != Numo::DFloat }
|
180
|
+
else
|
181
|
+
raise "Length of no_grads param and xs should be same." if no_grads.size != xs.size
|
182
|
+
end
|
183
|
+
|
184
|
+
casted_data = casted_xs.map { |x| x.data.dup }
|
185
|
+
|
186
|
+
no_grads.zip(xs).each do |skip, x|
|
187
|
+
if skip
|
188
|
+
raise "x.grad is not nil" if x.grad != nil
|
189
|
+
else
|
190
|
+
raise 'gradients of some arguments are not calculated' if x.grad.nil?
|
193
191
|
end
|
194
|
-
casted_xs = x_data.map{|x|
|
195
|
-
if x.class == Numo::DFloat or x.class == Numo::SFloat
|
196
|
-
Chainer::Variable.new(dtype.cast(x))
|
197
|
-
else
|
198
|
-
Chainer::Variable.new(x)
|
199
|
-
end
|
200
|
-
}
|
201
192
|
end
|
202
193
|
|
203
|
-
|
194
|
+
# Keep the gradient arrays of params which may be overwritten by func
|
195
|
+
params_grad = params.map(&:grad)
|
196
|
+
|
197
|
+
if dtype.nil?
|
198
|
+
one = Numo::DFloat.new().fill(1.0)
|
199
|
+
else
|
200
|
+
one = dtype.new().fill(1.0)
|
201
|
+
end
|
202
|
+
|
203
|
+
g = lambda do
|
204
|
+
# This functions is called twice in `numerical_grad`.
|
205
|
+
# `one` is `1 + epsilon` or `1 - epsilon` in these calls.
|
206
|
+
# See the document of `numerical_grad`.
|
207
|
+
no_grads.zip(casted_xs, casted_data).each do |skip, cx, data|
|
208
|
+
next if skip || cx.data.empty?
|
209
|
+
# astype is require to store data with the given type
|
210
|
+
data = (one * data).cast_to(data.class)
|
211
|
+
cx.data = data
|
212
|
+
end
|
213
|
+
|
214
|
+
params.zip(param_data).each do |param, data|
|
215
|
+
if !dtype.nil?
|
216
|
+
param_dtype = dtype
|
217
|
+
else
|
218
|
+
param_dtype = param.dtype
|
219
|
+
end
|
220
|
+
# The inner astype is required to calculates __mul__ in
|
221
|
+
# `param_type` when data is low accuracy float.
|
222
|
+
# The outer one is require to store data with the given type.
|
223
|
+
param.data = (one * data.cast_to(param_dtype)).cast_to(param_dtype)
|
224
|
+
end
|
225
|
+
|
226
|
+
# Clear gradients to support func that calls backward inside of itself.
|
227
|
+
clear_grads(casted_xs)
|
228
|
+
clear_grads(params)
|
229
|
+
|
204
230
|
ys = func.(*casted_xs)
|
205
231
|
ys = _as_tuple(ys)
|
206
|
-
|
232
|
+
ys_data = ys.map { |y| y.data }
|
233
|
+
no_grads.zip(casted_xs, casted_data).each do |skip, cx, data|
|
234
|
+
next if skip
|
235
|
+
cx.data = data
|
236
|
+
end
|
237
|
+
params.zip(param_data).each do |param, data|
|
238
|
+
param.data = data
|
239
|
+
end
|
240
|
+
ys_data
|
207
241
|
end
|
208
242
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
243
|
+
gx, = numerical_grad(g, [one], y_grad, eps)
|
244
|
+
gx_accum = 0
|
245
|
+
|
246
|
+
no_grads.zip(xs, casted_xs).each do |skip, x, cx|
|
247
|
+
next if skip
|
248
|
+
gxi = x.grad.flatten.dup
|
249
|
+
cxi = cx.data.flatten.dup
|
250
|
+
unless dtype.nil?
|
251
|
+
gxi = gxi.cast_to(dtype)
|
252
|
+
cxi = cxi.cast_to(dtype)
|
214
253
|
end
|
254
|
+
gx_accum += gxi.empty? ? 0 : gxi.dot(cxi)
|
215
255
|
end
|
216
256
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
257
|
+
params.zip(params_grad).each do |p, gpi|
|
258
|
+
gpi =gpi.flatten.dup
|
259
|
+
pi = p.data.flatten.dup
|
260
|
+
unless dtype.nil?
|
261
|
+
gpi = gpi.cast_to(dtype)
|
262
|
+
pi = pi.cast_to(dtype)
|
221
263
|
end
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
264
|
+
gx_accum += gpi.dot(pi)
|
265
|
+
end
|
266
|
+
|
267
|
+
Chainer::Testing.assert_allclose(gx, gx_accum, atol: atol, rtol: rtol)
|
268
|
+
end
|
269
|
+
|
270
|
+
def check_double_backward(func, x_data, y_grad, x_grad_grad, params=[], params_grad_grad=[], eps: 1e-3, atol: 1e-4, rtol: 1e-3, no_grads: nil, dtype: nil)
|
271
|
+
x_data = _as_tuple(x_data)
|
272
|
+
params = _as_tuple(params)
|
273
|
+
n_x = x_data.size
|
274
|
+
|
275
|
+
first_order_grad = -> *inputs do
|
276
|
+
xs = inputs[0...n_x]
|
277
|
+
gys = inputs[n_x..-1]
|
278
|
+
|
279
|
+
y = _as_tuple(func.(*xs))
|
280
|
+
# Let all elements of y share the same creator.
|
281
|
+
# See the comment in check_backward.
|
282
|
+
y = Chainer::Functions::Math::Identity.new.apply(y)
|
283
|
+
set_y_grad(y, gys)
|
284
|
+
y[0].backward(enable_double_backprop: true)
|
285
|
+
|
286
|
+
xs.map(&:grad_var) + params.map(&:grad_var)
|
287
|
+
end
|
288
|
+
|
289
|
+
inputs = x_data + _as_tuple(y_grad)
|
290
|
+
grad_grad = _as_tuple(x_grad_grad) + _as_tuple(params_grad_grad)
|
291
|
+
check_backward(first_order_grad, inputs, grad_grad, params=params, eps: eps, atol: atol, rtol: rtol, no_grads: no_grads, dtype: dtype)
|
292
|
+
end
|
293
|
+
|
294
|
+
def set_y_grad(y, y_grad)
|
295
|
+
if y_grad.nil?
|
296
|
+
if y.size != 1
|
297
|
+
raise TypeError, 'When `y_grad` is `None`, the function must return a zero-dimentional array'
|
298
|
+
end
|
299
|
+
y_grad = [1]
|
300
|
+
else
|
301
|
+
if y.size != y_grad.size
|
302
|
+
raise TypeError, '`y_grad` must have the same length of output values'
|
303
|
+
end
|
304
|
+
y.zip(y_grad).each do |iy, igy|
|
305
|
+
if igy.is_a?(Chainer::Variable)
|
306
|
+
iy.grad_var = igy
|
307
|
+
else
|
308
|
+
iy.grad = igy
|
229
309
|
end
|
230
310
|
end
|
231
311
|
end
|
232
312
|
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
313
|
+
y_grad
|
314
|
+
end
|
315
|
+
|
316
|
+
def clear_grads(xs)
|
317
|
+
xs.each do |x|
|
318
|
+
x.grad_var = nil
|
237
319
|
end
|
238
320
|
end
|
239
|
-
|
321
|
+
|
322
|
+
module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward, :check_double_backward, :set_y_grad, :clear_grads
|
323
|
+
private_class_method :set_y_grad, :clear_grads
|
240
324
|
end
|