red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -2,6 +2,8 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Pooling
|
4
4
|
class AveragePooling2D < Pooling2D
|
5
|
+
attr_reader :in_shape, :in_dtype
|
6
|
+
|
5
7
|
# Spatial average pooling function.
|
6
8
|
#
|
7
9
|
# This function acts similarly to :class:`Convolution2D`,
|
@@ -14,31 +16,52 @@ module Chainer
|
|
14
16
|
# @param [integer] pad Spatial padding width for the input array. `pad=p` and `pad=[p, p]` are equivalent.
|
15
17
|
# @return [Chainer::Variable] Output variable
|
16
18
|
def self.average_pooling_2d(x, ksize, stride: nil, pad: 0)
|
17
|
-
self.new(ksize, stride: stride, pad: pad, cover_all: false).(x)
|
19
|
+
self.new(ksize, stride: stride, pad: pad, cover_all: false).apply([x])[0]
|
18
20
|
end
|
19
21
|
|
20
22
|
# Average pooling over a set of 2d planes.
|
21
|
-
def
|
22
|
-
retain_inputs([])
|
23
|
+
def forward(x)
|
23
24
|
@in_shape = x[0].shape
|
24
25
|
@in_dtype = x[0].class
|
25
26
|
|
26
|
-
col = Chainer::Utils::Conv.
|
27
|
+
col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw)
|
27
28
|
y = col.mean(axis: [2, 3])
|
28
29
|
|
29
30
|
[y]
|
30
31
|
end
|
31
32
|
|
32
|
-
def
|
33
|
+
def backward(indexes, gy)
|
34
|
+
AveragePooling2DGrad.new(self).apply(gy)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
class AveragePooling2DGrad < FunctionNode
|
39
|
+
def initialize(apool2d)
|
40
|
+
@kh = apool2d.kh
|
41
|
+
@kw = apool2d.kw
|
42
|
+
@sy = apool2d.sy
|
43
|
+
@sx = apool2d.sx
|
44
|
+
@ph = apool2d.ph
|
45
|
+
@pw = apool2d.pw
|
46
|
+
@in_shape = apool2d.in_shape
|
47
|
+
@in_dtype = apool2d.in_dtype
|
48
|
+
@apool2d = apool2d
|
49
|
+
end
|
50
|
+
|
51
|
+
def forward(gy)
|
33
52
|
h, w = @in_shape[2..-1]
|
34
53
|
shape = gy[0].shape
|
35
54
|
shape.insert(2, 1, 1)
|
36
55
|
gcol = gy[0].reshape(*shape).tile(1, 1, @kh, @kw, 1, 1)
|
37
56
|
|
38
|
-
gx = Chainer::Utils::Conv.
|
57
|
+
gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
|
39
58
|
gx /= @kh * @kw
|
40
59
|
[gx]
|
41
60
|
end
|
61
|
+
|
62
|
+
def backward(indexes, grad_outputs)
|
63
|
+
AveragePooling2D.new([@kh, @kw], stride: [@sy, @sx], pad: [@ph, @pw], cover_all: false).apply(grad_outputs)
|
64
|
+
end
|
42
65
|
end
|
43
66
|
end
|
44
67
|
end
|
@@ -2,24 +2,24 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Pooling
|
4
4
|
class MaxPooling2D < Pooling2D
|
5
|
+
attr_reader :in_shape, :in_dtype, :indexes
|
5
6
|
# Spatial max pooling function
|
6
7
|
#
|
7
8
|
# @param [Chainer::Variable] x Input variable
|
8
|
-
# @param [integer || 2D integer array] Size of pooling window
|
9
|
-
# @param [integer || 2D integer array] Stride of pooling applications
|
10
|
-
# @param [integer || 2D integer array] Spatial padding width for the input array
|
11
|
-
# @param [boolean] If `true`, all spatial locations are pooled int some output pixels
|
9
|
+
# @param [integer || 2D integer array] ksize Size of pooling window
|
10
|
+
# @param [integer || 2D integer array] stride Stride of pooling applications
|
11
|
+
# @param [integer || 2D integer array] pad Spatial padding width for the input array
|
12
|
+
# @param [boolean] cover_all If `true`, all spatial locations are pooled int some output pixels
|
12
13
|
# @return [Chainer::Variable] Output variable
|
13
14
|
def self.max_pooling_2d(x, ksize, stride: nil, pad: 0, cover_all: true)
|
14
|
-
self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).(x)
|
15
|
+
self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).apply([x]).first
|
15
16
|
end
|
16
17
|
|
17
|
-
def
|
18
|
-
retain_inputs([])
|
18
|
+
def forward(x)
|
19
19
|
@in_shape = x[0].shape
|
20
20
|
@in_dtype = x[0].class
|
21
21
|
|
22
|
-
col = Chainer::Utils::Conv.
|
22
|
+
col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
|
23
23
|
n, c, kh, kw, out_h, out_w = col.shape
|
24
24
|
col = col.reshape(n , c, kh * kw, out_h, out_w)
|
25
25
|
|
@@ -33,7 +33,27 @@ module Chainer
|
|
33
33
|
[y]
|
34
34
|
end
|
35
35
|
|
36
|
-
def
|
36
|
+
def backward(indexes, gy)
|
37
|
+
MaxPooling2DGrad.new(self).apply(gy)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
class MaxPooling2DGrad < FunctionNode
|
42
|
+
def initialize(mpool2d)
|
43
|
+
@kh = mpool2d.kh
|
44
|
+
@kw = mpool2d.kw
|
45
|
+
@sy = mpool2d.sy
|
46
|
+
@sx = mpool2d.sx
|
47
|
+
@ph = mpool2d.ph
|
48
|
+
@pw = mpool2d.pw
|
49
|
+
@cover_all = mpool2d.cover_all
|
50
|
+
@indexes = mpool2d.indexes
|
51
|
+
@in_shape = mpool2d.in_shape
|
52
|
+
@in_dtype = mpool2d.in_dtype
|
53
|
+
@mpool2d = mpool2d
|
54
|
+
end
|
55
|
+
|
56
|
+
def forward(gy)
|
37
57
|
n, c, out_h, out_w = gy[0].shape
|
38
58
|
h, w = @in_shape[2..-1]
|
39
59
|
kh, kw = @kh, @kw
|
@@ -41,16 +61,51 @@ module Chainer
|
|
41
61
|
gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
|
42
62
|
|
43
63
|
indexes = @indexes.flatten
|
44
|
-
indexes +=
|
45
|
-
|
64
|
+
indexes += indexes.class.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
|
65
|
+
|
46
66
|
gcol[indexes] = gy[0].flatten.dup
|
47
67
|
gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
|
48
68
|
gcol = gcol.swapaxes(2, 4)
|
49
69
|
gcol = gcol.swapaxes(3, 5)
|
50
70
|
|
51
|
-
gx = Chainer::Utils::Conv.
|
71
|
+
gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
|
52
72
|
[gx]
|
53
73
|
end
|
74
|
+
|
75
|
+
def backward(indexes, ggx)
|
76
|
+
MaxPooling2DWithIndexes.new(@mpool2d).apply(ggx)
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
class MaxPooling2DWithIndexes < FunctionNode
|
81
|
+
def initialize(mpool2d)
|
82
|
+
@kh = mpool2d.kh
|
83
|
+
@kw = mpool2d.kw
|
84
|
+
@sy = mpool2d.sy
|
85
|
+
@sx = mpool2d.sx
|
86
|
+
@ph = mpool2d.ph
|
87
|
+
@pw = mpool2d.pw
|
88
|
+
@cover_all = mpool2d.cover_all
|
89
|
+
@indexes = mpool2d.indexes
|
90
|
+
end
|
91
|
+
|
92
|
+
def forward(x)
|
93
|
+
col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
|
94
|
+
n, c, kh, kw, out_h, out_w = col.shape
|
95
|
+
col = col.reshape(n, c, kh * kw, out_h, out_w)
|
96
|
+
col = col.transpose(0, 1, 3, 4, 2).reshape(nil, kh * kw)
|
97
|
+
|
98
|
+
indexes = @indexes.flatten.dup
|
99
|
+
|
100
|
+
# TODO: col = col[numpy.arange(len(indexes)), indexes]
|
101
|
+
new_col = col.class.zeros(indexes.size)
|
102
|
+
x[0].class.new(indexes.size).seq.each_with_index do |v, i|
|
103
|
+
new_col[i] = col[v, indexes[i]]
|
104
|
+
end
|
105
|
+
col = new_col
|
106
|
+
|
107
|
+
[col.reshape(n, c, out_h, out_w)]
|
108
|
+
end
|
54
109
|
end
|
55
110
|
end
|
56
111
|
end
|
@@ -2,15 +2,17 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Pooling
|
4
4
|
# Base class of pooling function over a set of 2d planes
|
5
|
-
class Pooling2D < Chainer::
|
5
|
+
class Pooling2D < Chainer::FunctionNode
|
6
|
+
attr_reader :kh, :kw, :sy, :sx, :ph, :pw, :cover_all
|
7
|
+
|
6
8
|
def initialize(ksize, stride: nil, pad: 0, cover_all: true)
|
7
9
|
if stride.nil?
|
8
10
|
stride = ksize
|
9
11
|
end
|
10
12
|
|
11
|
-
@kh, @kw = ksize.is_a?(Array) ? ksize : [ksize, ksize]
|
12
|
-
@sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
|
13
|
-
@ph, @pw = pad.is_a?(Array) ? pad: [pad, pad]
|
13
|
+
@kh, @kw = ksize.is_a?(::Array) ? ksize : [ksize, ksize]
|
14
|
+
@sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
|
15
|
+
@ph, @pw = pad.is_a?(::Array) ? pad: [pad, pad]
|
14
16
|
|
15
17
|
@cover_all = cover_all
|
16
18
|
end
|
@@ -1,7 +1,6 @@
|
|
1
1
|
module Chainer
|
2
2
|
def _copy_arrays(xs)
|
3
|
-
|
4
|
-
xs.map{|x| (x.is_a? Numo::NArray) ? x.dup : x}
|
3
|
+
xs.map{|x| Chainer.array?(x) ? x.dup : x}
|
5
4
|
end
|
6
5
|
|
7
6
|
# Computes numerical gradient by finite differences.
|
@@ -19,37 +18,31 @@ module Chainer
|
|
19
18
|
# @param [Float] eps Epsilon value of finite differences.
|
20
19
|
# @return [Array] Numerical gradient arrays corresponding to +inputs+.
|
21
20
|
#
|
22
|
-
def numerical_grad(f, inputs, grad_outputs, eps=
|
21
|
+
def numerical_grad(f, inputs, grad_outputs, eps=1e-3)
|
23
22
|
raise unless eps > 0
|
24
23
|
inputs = inputs.to_a
|
25
24
|
grad_outputs = grad_outputs.to_a
|
26
|
-
xp = Numo::NArray
|
27
25
|
grads = inputs.map{|x| x.new_zeros()}
|
28
26
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
end
|
34
|
-
|
35
|
-
tmp.each do |x, gx|
|
36
|
-
x.each_with_index{|xx, *i|
|
37
|
-
orig = x[*i] # hold original value
|
27
|
+
inputs.zip(grads).each do |x, gx|
|
28
|
+
orig_x = x.dup # hold original value
|
29
|
+
x.each_with_index{|_, *i|
|
30
|
+
orig = orig_x[*i]
|
38
31
|
x[*i] = orig + eps
|
39
|
-
ys1 = _copy_arrays(f.
|
32
|
+
ys1 = _copy_arrays(f.())
|
40
33
|
x[*i] = orig - eps
|
41
|
-
ys2 = _copy_arrays(f.
|
34
|
+
ys2 = _copy_arrays(f.())
|
42
35
|
x[*i] = orig
|
43
36
|
|
44
37
|
ys1.zip(ys2, grad_outputs).each do |y1, y2, gy|
|
45
|
-
if
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
gx[*i] += dot / (2*eps).to_f
|
38
|
+
next if gy.nil?
|
39
|
+
diff = y1 - y2
|
40
|
+
if Chainer.array?(diff) && diff.empty?
|
41
|
+
dot = 0
|
42
|
+
else
|
43
|
+
dot = (diff * gy).sum
|
52
44
|
end
|
45
|
+
gx[*i] += dot / (2 * eps)
|
53
46
|
end
|
54
47
|
}
|
55
48
|
end
|
@@ -153,6 +146,7 @@ module Chainer
|
|
153
146
|
#
|
154
147
|
def check_backward(func, x_data, y_grad, params=[], eps: 0.001, atol: 1e-5, rtol: 1e-4, no_grads: nil, dtype: nil)
|
155
148
|
x_data = _as_tuple(x_data)
|
149
|
+
xm = Chainer.get_array_module(*x_data)
|
156
150
|
if !y_grad.nil?
|
157
151
|
y_grad = _as_tuple(y_grad)
|
158
152
|
end
|
@@ -161,80 +155,170 @@ module Chainer
|
|
161
155
|
xs = x_data.map{|x| Chainer::Variable.new(x)}
|
162
156
|
y = func.(*xs)
|
163
157
|
y = _as_tuple(y)
|
164
|
-
y = Chainer::Functions::Math::Identity.
|
165
|
-
y = _as_tuple(y)
|
158
|
+
y = Chainer::Functions::Math::Identity.new.apply(y)
|
166
159
|
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
160
|
+
y_grad = set_y_grad(y, y_grad)
|
161
|
+
|
162
|
+
# Clear gradients which may exist if func calls backward inside of itself.
|
163
|
+
clear_grads(xs)
|
164
|
+
clear_grads(params)
|
171
165
|
|
172
|
-
y.zip(y_grad).each do |iy, igy|
|
173
|
-
iy.grad = igy
|
174
|
-
end
|
175
|
-
else
|
176
|
-
if (y).size != 1
|
177
|
-
raise TypeError, "When `y_grad` is `nil`, the function must return azero-dimentional array"
|
178
|
-
end
|
179
|
-
y_grad = [1]
|
180
|
-
end
|
181
166
|
# We only need to call `backward` for one result `Chainer::Variable`.
|
182
167
|
# `Chainer::Variable.backward` method calls `Chainer::Function.backward` of its creator.
|
183
168
|
y[0].backward()
|
184
169
|
|
170
|
+
param_data = params.map { |p| p.data }
|
185
171
|
if dtype.nil?
|
186
|
-
casted_xs = x_data.map{|x| Chainer::Variable.new(x)}
|
172
|
+
casted_xs = x_data.map { |x| Chainer::Variable.new(x) }
|
187
173
|
else
|
188
|
-
if
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
174
|
+
raise '`dtype` is allowed only float type' if dtype != xm::DFloat && dtype != xm::SFloat
|
175
|
+
casted_xs = x_data.map { |x| x.is_a?(Numo::NArray) ? Chainer::Variable.new(x.cast_to(dtype)) : x }
|
176
|
+
end
|
177
|
+
|
178
|
+
if no_grads.nil?
|
179
|
+
no_grads = xs.map { |x| x.dtype != Numo::SFloat && x.dtype != Numo::DFloat }
|
180
|
+
else
|
181
|
+
raise "Length of no_grads param and xs should be same." if no_grads.size != xs.size
|
182
|
+
end
|
183
|
+
|
184
|
+
casted_data = casted_xs.map { |x| x.data.dup }
|
185
|
+
|
186
|
+
no_grads.zip(xs).each do |skip, x|
|
187
|
+
if skip
|
188
|
+
raise "x.grad is not nil" if x.grad != nil
|
189
|
+
else
|
190
|
+
raise 'gradients of some arguments are not calculated' if x.grad.nil?
|
193
191
|
end
|
194
|
-
casted_xs = x_data.map{|x|
|
195
|
-
if x.class == Numo::DFloat or x.class == Numo::SFloat
|
196
|
-
Chainer::Variable.new(dtype.cast(x))
|
197
|
-
else
|
198
|
-
Chainer::Variable.new(x)
|
199
|
-
end
|
200
|
-
}
|
201
192
|
end
|
202
193
|
|
203
|
-
|
194
|
+
# Keep the gradient arrays of params which may be overwritten by func
|
195
|
+
params_grad = params.map(&:grad)
|
196
|
+
|
197
|
+
if dtype.nil?
|
198
|
+
one = Numo::DFloat.new().fill(1.0)
|
199
|
+
else
|
200
|
+
one = dtype.new().fill(1.0)
|
201
|
+
end
|
202
|
+
|
203
|
+
g = lambda do
|
204
|
+
# This functions is called twice in `numerical_grad`.
|
205
|
+
# `one` is `1 + epsilon` or `1 - epsilon` in these calls.
|
206
|
+
# See the document of `numerical_grad`.
|
207
|
+
no_grads.zip(casted_xs, casted_data).each do |skip, cx, data|
|
208
|
+
next if skip || cx.data.empty?
|
209
|
+
# astype is require to store data with the given type
|
210
|
+
data = (one * data).cast_to(data.class)
|
211
|
+
cx.data = data
|
212
|
+
end
|
213
|
+
|
214
|
+
params.zip(param_data).each do |param, data|
|
215
|
+
if !dtype.nil?
|
216
|
+
param_dtype = dtype
|
217
|
+
else
|
218
|
+
param_dtype = param.dtype
|
219
|
+
end
|
220
|
+
# The inner astype is required to calculates __mul__ in
|
221
|
+
# `param_type` when data is low accuracy float.
|
222
|
+
# The outer one is require to store data with the given type.
|
223
|
+
param.data = (one * data.cast_to(param_dtype)).cast_to(param_dtype)
|
224
|
+
end
|
225
|
+
|
226
|
+
# Clear gradients to support func that calls backward inside of itself.
|
227
|
+
clear_grads(casted_xs)
|
228
|
+
clear_grads(params)
|
229
|
+
|
204
230
|
ys = func.(*casted_xs)
|
205
231
|
ys = _as_tuple(ys)
|
206
|
-
|
232
|
+
ys_data = ys.map { |y| y.data }
|
233
|
+
no_grads.zip(casted_xs, casted_data).each do |skip, cx, data|
|
234
|
+
next if skip
|
235
|
+
cx.data = data
|
236
|
+
end
|
237
|
+
params.zip(param_data).each do |param, data|
|
238
|
+
param.data = data
|
239
|
+
end
|
240
|
+
ys_data
|
207
241
|
end
|
208
242
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
243
|
+
gx, = numerical_grad(g, [one], y_grad, eps)
|
244
|
+
gx_accum = 0
|
245
|
+
|
246
|
+
no_grads.zip(xs, casted_xs).each do |skip, x, cx|
|
247
|
+
next if skip
|
248
|
+
gxi = x.grad.flatten.dup
|
249
|
+
cxi = cx.data.flatten.dup
|
250
|
+
unless dtype.nil?
|
251
|
+
gxi = gxi.cast_to(dtype)
|
252
|
+
cxi = cxi.cast_to(dtype)
|
214
253
|
end
|
254
|
+
gx_accum += gxi.empty? ? 0 : gxi.dot(cxi)
|
215
255
|
end
|
216
256
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
257
|
+
params.zip(params_grad).each do |p, gpi|
|
258
|
+
gpi =gpi.flatten.dup
|
259
|
+
pi = p.data.flatten.dup
|
260
|
+
unless dtype.nil?
|
261
|
+
gpi = gpi.cast_to(dtype)
|
262
|
+
pi = pi.cast_to(dtype)
|
221
263
|
end
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
264
|
+
gx_accum += gpi.dot(pi)
|
265
|
+
end
|
266
|
+
|
267
|
+
Chainer::Testing.assert_allclose(gx, gx_accum, atol: atol, rtol: rtol)
|
268
|
+
end
|
269
|
+
|
270
|
+
def check_double_backward(func, x_data, y_grad, x_grad_grad, params=[], params_grad_grad=[], eps: 1e-3, atol: 1e-4, rtol: 1e-3, no_grads: nil, dtype: nil)
|
271
|
+
x_data = _as_tuple(x_data)
|
272
|
+
params = _as_tuple(params)
|
273
|
+
n_x = x_data.size
|
274
|
+
|
275
|
+
first_order_grad = -> *inputs do
|
276
|
+
xs = inputs[0...n_x]
|
277
|
+
gys = inputs[n_x..-1]
|
278
|
+
|
279
|
+
y = _as_tuple(func.(*xs))
|
280
|
+
# Let all elements of y share the same creator.
|
281
|
+
# See the comment in check_backward.
|
282
|
+
y = Chainer::Functions::Math::Identity.new.apply(y)
|
283
|
+
set_y_grad(y, gys)
|
284
|
+
y[0].backward(enable_double_backprop: true)
|
285
|
+
|
286
|
+
xs.map(&:grad_var) + params.map(&:grad_var)
|
287
|
+
end
|
288
|
+
|
289
|
+
inputs = x_data + _as_tuple(y_grad)
|
290
|
+
grad_grad = _as_tuple(x_grad_grad) + _as_tuple(params_grad_grad)
|
291
|
+
check_backward(first_order_grad, inputs, grad_grad, params=params, eps: eps, atol: atol, rtol: rtol, no_grads: no_grads, dtype: dtype)
|
292
|
+
end
|
293
|
+
|
294
|
+
def set_y_grad(y, y_grad)
|
295
|
+
if y_grad.nil?
|
296
|
+
if y.size != 1
|
297
|
+
raise TypeError, 'When `y_grad` is `None`, the function must return a zero-dimentional array'
|
298
|
+
end
|
299
|
+
y_grad = [1]
|
300
|
+
else
|
301
|
+
if y.size != y_grad.size
|
302
|
+
raise TypeError, '`y_grad` must have the same length of output values'
|
303
|
+
end
|
304
|
+
y.zip(y_grad).each do |iy, igy|
|
305
|
+
if igy.is_a?(Chainer::Variable)
|
306
|
+
iy.grad_var = igy
|
307
|
+
else
|
308
|
+
iy.grad = igy
|
229
309
|
end
|
230
310
|
end
|
231
311
|
end
|
232
312
|
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
313
|
+
y_grad
|
314
|
+
end
|
315
|
+
|
316
|
+
def clear_grads(xs)
|
317
|
+
xs.each do |x|
|
318
|
+
x.grad_var = nil
|
237
319
|
end
|
238
320
|
end
|
239
|
-
|
321
|
+
|
322
|
+
module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward, :check_double_backward, :set_y_grad, :clear_grads
|
323
|
+
private_class_method :set_y_grad, :clear_grads
|
240
324
|
end
|