red-chainer 0.2.1 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar/models/vgg.rb +84 -0
- data/examples/cifar/train_cifar.rb +70 -0
- data/examples/iris.rb +103 -0
- data/lib/chainer.rb +17 -0
- data/lib/chainer/configuration.rb +2 -1
- data/lib/chainer/cuda.rb +18 -0
- data/lib/chainer/dataset/convert.rb +30 -9
- data/lib/chainer/datasets/cifar.rb +56 -0
- data/lib/chainer/datasets/mnist.rb +3 -3
- data/lib/chainer/datasets/tuple_dataset.rb +3 -1
- data/lib/chainer/function.rb +1 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
- data/lib/chainer/functions/activation/log_softmax.rb +4 -4
- data/lib/chainer/functions/activation/relu.rb +3 -4
- data/lib/chainer/functions/activation/sigmoid.rb +4 -4
- data/lib/chainer/functions/activation/tanh.rb +5 -5
- data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
- data/lib/chainer/functions/connection/linear.rb +1 -1
- data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
- data/lib/chainer/functions/math/identity.rb +26 -0
- data/lib/chainer/functions/noise/dropout.rb +45 -0
- data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
- data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
- data/lib/chainer/gradient_check.rb +240 -0
- data/lib/chainer/initializer.rb +2 -0
- data/lib/chainer/initializers/constant.rb +1 -1
- data/lib/chainer/initializers/init.rb +5 -1
- data/lib/chainer/initializers/normal.rb +1 -1
- data/lib/chainer/iterators/serial_iterator.rb +1 -1
- data/lib/chainer/link.rb +11 -0
- data/lib/chainer/links/connection/convolution_2d.rb +98 -0
- data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
- data/lib/chainer/optimizer.rb +40 -1
- data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
- data/lib/chainer/parameter.rb +1 -1
- data/lib/chainer/serializers/marshal.rb +7 -3
- data/lib/chainer/testing/array.rb +32 -0
- data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
- data/lib/chainer/training/extensions/snapshot.rb +1 -1
- data/lib/chainer/training/standard_updater.rb +4 -0
- data/lib/chainer/training/trainer.rb +1 -1
- data/lib/chainer/utils/array.rb +13 -2
- data/lib/chainer/utils/conv.rb +59 -0
- data/lib/chainer/utils/math.rb +72 -0
- data/lib/chainer/utils/variable.rb +7 -3
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +1 -0
- metadata +37 -3
@@ -0,0 +1,26 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Math
|
4
|
+
# Identity function.
|
5
|
+
class Identity < Chainer::Function
|
6
|
+
def check_type_forward(in_types)
|
7
|
+
# pass
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(xs)
|
11
|
+
retain_inputs([])
|
12
|
+
return xs
|
13
|
+
end
|
14
|
+
|
15
|
+
def backward(xs, gys)
|
16
|
+
return gys
|
17
|
+
end
|
18
|
+
|
19
|
+
# Just returns input variables.
|
20
|
+
def self.identity(*inputs)
|
21
|
+
self.new.(*inputs)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,45 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Noise
|
4
|
+
class Dropout < Chainer::Function
|
5
|
+
# Drops elements of input variable randomly.
|
6
|
+
#
|
7
|
+
# This function drops input elements randomly with probability `ratio` and
|
8
|
+
# scales the remaining elements by factor `1 / (1 - ratio)`.
|
9
|
+
# In testing mode, it does nothing and just returns `x`.
|
10
|
+
#
|
11
|
+
# @param [Chainer::Variable] x Input variable.
|
12
|
+
# @param [float] ratio Dropout ratio. The ``ratio`` must be `0.0 <= ratio < 1.0`.
|
13
|
+
# @return [Chainer::Variable] Output variable.
|
14
|
+
def self.dropout(x, ratio: 0.5)
|
15
|
+
Chainer.configuration.train ? self.new(ratio).(x) : x
|
16
|
+
end
|
17
|
+
|
18
|
+
def initialize(dropout_ratio)
|
19
|
+
if dropout_ratio < 0 || dropout_ratio >= 1.0
|
20
|
+
raise 'dropout_ratio must be in the range [0, 1)'
|
21
|
+
end
|
22
|
+
@dropout_ratio = dropout_ratio
|
23
|
+
end
|
24
|
+
|
25
|
+
def forward(x)
|
26
|
+
retain_inputs([])
|
27
|
+
unless self.instance_variable_defined?(:@mask)
|
28
|
+
scale = x[0].class[*[1.0 / (1 - @dropout_ratio)]][0]
|
29
|
+
flag = x[0].class.new(*x[0].shape).rand >= @dropout_ratio
|
30
|
+
|
31
|
+
@mask = x[0].class.zeros(*x[0].shape)
|
32
|
+
@mask[flag] = 1
|
33
|
+
@mask *= scale
|
34
|
+
end
|
35
|
+
[x[0] * @mask]
|
36
|
+
end
|
37
|
+
|
38
|
+
def backward(x, gy)
|
39
|
+
[gy[0] * @mask]
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
@@ -0,0 +1,136 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Normalization
|
4
|
+
class BatchNormalizationFunction < Chainer::Function
|
5
|
+
attr_reader :running_mean, :running_var
|
6
|
+
# Batch normalization function with fixed statistics.
|
7
|
+
# This is a variant of batch normalization, where the mean and variance
|
8
|
+
# statistics are given by the caller as fixed variables. This is
|
9
|
+
# used on testing mode of the batch normalization layer, where batch
|
10
|
+
# statistics cannot be used for prediction consistency.
|
11
|
+
#
|
12
|
+
# @param [Chainer::Variable] x Input variable.
|
13
|
+
# @param [Chainer::Variable] gamma Scaling parameter of normalized data.
|
14
|
+
# @param [Chainer::Variable] beta Shifting parameter of scaled normalized data.
|
15
|
+
# @param [Chainer::Variable] mean Shifting parameter of input.
|
16
|
+
# @param [Chainer::Variable] var Square of scaling parameter of input.
|
17
|
+
# @param [float] eps Epsilon value for numerical stability.
|
18
|
+
def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
|
19
|
+
old_train = Chainer.configuration.train
|
20
|
+
Chainer.configuration.train = false
|
21
|
+
norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
|
22
|
+
Chainer.configuration.train = old_train
|
23
|
+
norm
|
24
|
+
end
|
25
|
+
|
26
|
+
def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
|
27
|
+
@running_mean = mean
|
28
|
+
@running_var = var
|
29
|
+
@eps = eps
|
30
|
+
@mean_cache = nil
|
31
|
+
@decay = decay
|
32
|
+
end
|
33
|
+
|
34
|
+
def forward(inputs)
|
35
|
+
x, gamma, beta = inputs[0], inputs[1], inputs[2]
|
36
|
+
if Chainer.configuration.train
|
37
|
+
if @running_mean.nil?
|
38
|
+
@running_mean = Numo::NArray[*gamma].new_zeros
|
39
|
+
@running_var = Numo::NArray[*gamma].new_zeros
|
40
|
+
else
|
41
|
+
@running_mean = Numo::NArray[*@running_mean]
|
42
|
+
@running_var = Numo::NArray[*@running_var]
|
43
|
+
end
|
44
|
+
elsif inputs.size == 5
|
45
|
+
@fixed_mean = inputs[3]
|
46
|
+
@fixed_var = inputs[4]
|
47
|
+
end
|
48
|
+
|
49
|
+
head_ndim = gamma.ndim + 1
|
50
|
+
gamma_expander = [1] + gamma.shape + [1] * (x.ndim - head_ndim)
|
51
|
+
gamma = gamma.reshape(*gamma_expander)
|
52
|
+
beta_expander = [1] + beta.shape + [1] * (x.ndim - head_ndim)
|
53
|
+
beta = beta.reshape(*beta_expander)
|
54
|
+
|
55
|
+
if Chainer.configuration.train
|
56
|
+
axis = [0] + (head_ndim...(x.ndim)).to_a
|
57
|
+
mean = x.mean(axis: axis)
|
58
|
+
# FIXME: numpy.var
|
59
|
+
var = x.var(axis: axis)
|
60
|
+
var += @eps
|
61
|
+
else
|
62
|
+
mean = @fixed_mean
|
63
|
+
var = @fixed_var + @eps
|
64
|
+
end
|
65
|
+
|
66
|
+
@std = Numo::NMath.sqrt(var)
|
67
|
+
|
68
|
+
mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
|
69
|
+
x_mu = x - mean.reshape(*mean_expander)
|
70
|
+
std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
|
71
|
+
x_mu /= @std.reshape(*std_expander)
|
72
|
+
@x_hat = x_mu
|
73
|
+
y = gamma * @x_hat
|
74
|
+
y += beta
|
75
|
+
|
76
|
+
if Chainer.configuration.train
|
77
|
+
m = x.size.div(gamma.size)
|
78
|
+
adjust = m / [m - 1.0, 1.0].max
|
79
|
+
@running_mean *= @decay
|
80
|
+
temp_ar = Numo::NArray[*mean]
|
81
|
+
temp_ar *= (1 - @decay)
|
82
|
+
@running_mean += temp_ar
|
83
|
+
|
84
|
+
@running_var *= @decay
|
85
|
+
temp_ar = Numo::NArray[*var]
|
86
|
+
temp_ar *= ((1 - @decay) * adjust)
|
87
|
+
@running_var += temp_ar
|
88
|
+
end
|
89
|
+
|
90
|
+
[y,]
|
91
|
+
end
|
92
|
+
|
93
|
+
def backward(inputs, grad_outputs)
|
94
|
+
x, gamma = inputs[0], inputs[1]
|
95
|
+
gy = grad_outputs[0]
|
96
|
+
head_ndim = gamma.ndim + 1
|
97
|
+
m = gamma.class[x.size.div(gamma.size)][0]
|
98
|
+
axis = [0] + (head_ndim...(x.ndim)).to_a
|
99
|
+
|
100
|
+
if inputs.size == 5
|
101
|
+
mean = inputs[3]
|
102
|
+
var = inputs[4]
|
103
|
+
std = Numo::NMath.sqrt(var)
|
104
|
+
gs = gamma / std
|
105
|
+
gbeta = gy.sum(axis: axis)
|
106
|
+
|
107
|
+
mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
|
108
|
+
x_mu = x - mean.reshape(*mean_expander)
|
109
|
+
std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
|
110
|
+
x_mu /= std.reshape(*std_expander)
|
111
|
+
x_hat = x_mu
|
112
|
+
ggamma = (gy * x_hat).sum(axis: axis)
|
113
|
+
gmean = -gs * gbeta
|
114
|
+
gvar = -0.5 * gamma / var * ggamma
|
115
|
+
gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
|
116
|
+
gx = gs.reshape(*gs_expander)
|
117
|
+
return [gx, ggamma, gbeta, gmean, gvar]
|
118
|
+
end
|
119
|
+
|
120
|
+
gbeta = gy.sum(axis: axis)
|
121
|
+
ggamma = (gy * @x_hat).sum(axis: axis)
|
122
|
+
tmp = (gamma / @std)
|
123
|
+
tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
|
124
|
+
tmp = tmp.reshape(*tmp_expander)
|
125
|
+
|
126
|
+
ggamma_expander = [1] + ggamma.shape + [1] * (x.ndim - head_ndim)
|
127
|
+
gbeta_expander = [1] + gbeta.shape + [1] * (x.ndim - head_ndim)
|
128
|
+
|
129
|
+
gx = tmp * (gy - (@x_hat * ggamma.reshape(*ggamma_expander) + gbeta.reshape(*gbeta_expander)) / m )
|
130
|
+
|
131
|
+
[gx, ggamma, gbeta]
|
132
|
+
end
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
@@ -0,0 +1,57 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Pooling
|
4
|
+
class MaxPooling2D < Pooling2D
|
5
|
+
# Spatial max pooling function
|
6
|
+
#
|
7
|
+
# @param [Chainer::Variable] x Input variable
|
8
|
+
# @param [integer || 2D integer array] Size of pooling window
|
9
|
+
# @param [integer || 2D integer array] Stride of pooling applications
|
10
|
+
# @param [integer || 2D integer array] Spatial padding width for the input array
|
11
|
+
# @param [boolean] If `true`, all spatial locations are pooled int some output pixels
|
12
|
+
# @return [Chainer::Variable] Output variable
|
13
|
+
def self.max_pooling_2d(x, ksize, stride: nil, pad: 0, cover_all: true)
|
14
|
+
self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).(x)
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward_cpu(x)
|
18
|
+
retain_inputs([])
|
19
|
+
@in_shape = x[0].shape
|
20
|
+
@in_dtype = x[0].class
|
21
|
+
|
22
|
+
col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
|
23
|
+
n, c, kh, kw, out_h, out_w = col.shape
|
24
|
+
col = col.reshape(n , c, kh * kw, out_h, out_w)
|
25
|
+
|
26
|
+
# TODO: numpy.argmax(axis=2)
|
27
|
+
d = col.shape[3..-1].reduce(:*) || 1
|
28
|
+
dx = col.shape[2..-1].reduce(:*) || 1
|
29
|
+
max_index = col.max_index(2)
|
30
|
+
@indexes = max_index.flatten.map_with_index { |val, idx| (val - (dx * (idx / d))) / d }.reshape(*max_index.shape)
|
31
|
+
|
32
|
+
y = col.max(axis: 2)
|
33
|
+
[y]
|
34
|
+
end
|
35
|
+
|
36
|
+
def backward_cpu(x, gy)
|
37
|
+
n, c, out_h, out_w = gy[0].shape
|
38
|
+
h, w = @in_shape[2..-1]
|
39
|
+
kh, kw = @kh, @kw
|
40
|
+
|
41
|
+
gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
|
42
|
+
|
43
|
+
indexes = @indexes.flatten
|
44
|
+
indexes += Numo::Int64.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
|
45
|
+
|
46
|
+
gcol[indexes] = gy[0].flatten.dup
|
47
|
+
gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
|
48
|
+
gcol = gcol.swapaxes(2, 4)
|
49
|
+
gcol = gcol.swapaxes(3, 5)
|
50
|
+
|
51
|
+
gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
|
52
|
+
[gx]
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Pooling
|
4
|
+
# Base class of pooling function over a set of 2d planes
|
5
|
+
class Pooling2D < Chainer::Function
|
6
|
+
def initialize(ksize, stride: nil, pad: 0, cover_all: true)
|
7
|
+
if stride.nil?
|
8
|
+
stride = ksize
|
9
|
+
end
|
10
|
+
|
11
|
+
@kh, @kw = ksize.is_a?(Array) ? ksize : [ksize, ksize]
|
12
|
+
@sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
|
13
|
+
@ph, @pw = pad.is_a?(Array) ? pad: [pad, pad]
|
14
|
+
|
15
|
+
@cover_all = cover_all
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,240 @@
|
|
1
|
+
module Chainer
|
2
|
+
def _copy_arrays(xs)
|
3
|
+
xp = Chainer::get_array_module(*xs)
|
4
|
+
xs.map{|x| (x.is_a? Numo::NArray) ? x.dup : x}
|
5
|
+
end
|
6
|
+
|
7
|
+
# Computes numerical gradient by finite differences.
|
8
|
+
#
|
9
|
+
# This function is used to implement gradient check. For usage example, see
|
10
|
+
# unit tests of Chainer::Functions.
|
11
|
+
#
|
12
|
+
# @param [function] f Ruby function with no arguments that runs forward
|
13
|
+
# computation and returns the result.
|
14
|
+
# @param [Array<Arrays>] inputs Array of arrays that should be treated as
|
15
|
+
# inputs. Each element of them is slightly modified to realize numerical
|
16
|
+
# gradient by finite differences.
|
17
|
+
# @param [Array<Arrays>] grad_outputs Array of arrays that are treated as
|
18
|
+
# output gradients.
|
19
|
+
# @param [Float] eps Epsilon value of finite differences.
|
20
|
+
# @return [Array] Numerical gradient arrays corresponding to +inputs+.
|
21
|
+
#
|
22
|
+
def numerical_grad(f, inputs, grad_outputs, eps=0.001)
|
23
|
+
raise unless eps > 0
|
24
|
+
inputs = inputs.to_a
|
25
|
+
grad_outputs = grad_outputs.to_a
|
26
|
+
xp = Numo::NArray
|
27
|
+
grads = inputs.map{|x| x.new_zeros()}
|
28
|
+
|
29
|
+
if inputs[0].ndim < 2
|
30
|
+
tmp = [[inputs[0], grads[0]]]
|
31
|
+
else
|
32
|
+
tmp = (0...inputs[0].shape[0]).map{|i|[inputs[0][i, false], grads[0][i, false]]}
|
33
|
+
end
|
34
|
+
|
35
|
+
tmp.each do |x, gx|
|
36
|
+
x.each_with_index{|xx, *i|
|
37
|
+
orig = x[*i] # hold original value
|
38
|
+
x[*i] = orig + eps
|
39
|
+
ys1 = _copy_arrays(f.call(x))
|
40
|
+
x[*i] = orig - eps
|
41
|
+
ys2 = _copy_arrays(f.call(x))
|
42
|
+
x[*i] = orig
|
43
|
+
|
44
|
+
ys1.zip(ys2, grad_outputs).each do |y1, y2, gy|
|
45
|
+
if !gy.nil?
|
46
|
+
if ((y1 - y2) * gy).is_a? Numo::NArray
|
47
|
+
dot = ((y1 - y2) * gy).sum()
|
48
|
+
else
|
49
|
+
dot = ((y1 - y2) * gy).inject(:+)
|
50
|
+
end
|
51
|
+
gx[*i] += dot / (2*eps).to_f
|
52
|
+
end
|
53
|
+
end
|
54
|
+
}
|
55
|
+
end
|
56
|
+
|
57
|
+
return grads
|
58
|
+
end
|
59
|
+
|
60
|
+
def _as_tuple(x)
|
61
|
+
if x.is_a? Array
|
62
|
+
return x
|
63
|
+
else
|
64
|
+
return [x]
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
# Test backward procedure of a given function.
|
69
|
+
#
|
70
|
+
# This function automatically check backward-process of given function.
|
71
|
+
# For example, when you have a +Chainer::Function+ class +MyFunc+,
|
72
|
+
# that gets two arguments and returns one value, you can make its test like this:
|
73
|
+
#
|
74
|
+
# def test_my_func(self):
|
75
|
+
# func = MyFunc()
|
76
|
+
# x1_data = Numo::NArray[...]
|
77
|
+
# x2_data = Numo::NArray[...]
|
78
|
+
# gy_data = Numo::NArray[...]
|
79
|
+
# check_backward(func, [x1_data, x2_data], gy_data)
|
80
|
+
#
|
81
|
+
# This method creates +Chainer::Variable+ objects with +x_data+
|
82
|
+
# and calls +func+ with the +Chainer::Variable+ s to get its result
|
83
|
+
# as +Chainer::Variable+.
|
84
|
+
# Then, it sets +y_grad+ array to +grad+ attribute of the result and
|
85
|
+
# calls +backward+ method to get gradients of the inputs.
|
86
|
+
# To check correctness of the gradients, the function calls
|
87
|
+
# +numerical_grad+ to calculate numerically the gradients and compares
|
88
|
+
# the types of gradients with +Chainer::Testing.assert_allclose+.
|
89
|
+
# If input objects (+x1_data+ or/and +x2_data+ in this example) represent
|
90
|
+
# integer variables, their gradients are ignored.
|
91
|
+
#
|
92
|
+
# You can simplify a test when +MyFunc+ gets only one argument:
|
93
|
+
#
|
94
|
+
# check_backward(func, x1_data, gy_data)
|
95
|
+
#
|
96
|
+
# If +MyFunc+ is a loss function which returns a zero-dimensional
|
97
|
+
# array, pass +nil+ to +gy_data+. In this case, it sets +1+ to
|
98
|
+
# +grad+ attribute of the result:
|
99
|
+
#
|
100
|
+
# check_backward(my_loss_func, [x1_data, x2_data], nil)
|
101
|
+
#
|
102
|
+
# If +MyFunc+ returns multiple outputs, pass all gradients for outputs as a Array:
|
103
|
+
#
|
104
|
+
# gy1_data = Numo::NArray[...]
|
105
|
+
# gy2_data = Numo::NArray[...]
|
106
|
+
# check_backward(func, x1_data, [gy1_data, gy2_data])
|
107
|
+
#
|
108
|
+
# You can also test a +Chainer::Link+.
|
109
|
+
# To check gradients of parameters of the link, set a Array of the parameters
|
110
|
+
# to +params+ arguments:
|
111
|
+
#
|
112
|
+
# check_backward(my_link, [x1_data, x2_data], gy_data, [my_link.W, my_link.b])
|
113
|
+
#
|
114
|
+
# Note that +params+ are not +Numo::NArray+ s,
|
115
|
+
# but +Chainer::Variables+ s.
|
116
|
+
#
|
117
|
+
# Function objects are acceptable as +func+ argument:
|
118
|
+
#
|
119
|
+
# check_backward(lambda{|x1, x1| f(x1, x2)}, [x1_data, x2_data], gy_data)
|
120
|
+
#
|
121
|
+
# @note
|
122
|
+
# +func+ is called many times to get numerical gradients for all inputs.
|
123
|
+
# This function doesn't work correctly when +func+ behaves randomly as
|
124
|
+
# it gets different gradients.
|
125
|
+
# @param [Method, Proc] func A function which gets +Chainer::Variable+ s
|
126
|
+
# and returns +Chainer::Variable+ s. +func+ must returns
|
127
|
+
# a Array of +Chainer::Variable+ s or one
|
128
|
+
# +Chainer::Variable+. You can use +Chainer::Function+
|
129
|
+
# object, +Chainer::Link+ object or a function satisfying the
|
130
|
+
# condition.
|
131
|
+
# @param [Numo::NArray or Array<Numo::NArray>] x_data A set of +Numo::NArray+ s to be
|
132
|
+
# passed to +func+. If +x_data+ is one +Numo::NArray+ object, it is
|
133
|
+
# treated as +(x_data,)+.
|
134
|
+
# @param [Numo::NArray or Array<Numo::NArray> or nil] y_grad A set of +Numo::NArray+ s representing gradients of return-values of
|
135
|
+
# +func+. If +y_grad+ is one +Numo::NArray+ object, it is
|
136
|
+
# treated as +(y_grad,)+. If +func+ is a loss-function,
|
137
|
+
# +y_grad+ should be set to +nil+.
|
138
|
+
# @param [Chainer::Variable or Array<Chainder::Variable>] params A set of +Chainer::Variable+ s whose gradients are checked.
|
139
|
+
# When +func+ is a +Chainer::Link+ object,
|
140
|
+
# set its parameters as +params+.
|
141
|
+
# If +params+ is one +Chainer::Variable+ object,
|
142
|
+
# it is treated as +(params,)+.
|
143
|
+
# @param [Float] eps Epsilon value to be passed to +numerical_grad+.
|
144
|
+
# @param [Float] atol Absolute tolerance to be passed to +Chainer::Testing.assert_allclose+.
|
145
|
+
# @param [Float] rtol Relative tolerance to be passed to +Chainer::Testing.assert_allclose+.
|
146
|
+
# @param [Array<Boolean>] no_grads Flag to skip variable for gradient assertion.
|
147
|
+
# It should be same length as +x_data+.
|
148
|
+
# @param [Numo::NArray.class] dtype +x_data+ and +y_grad+ are casted to this
|
149
|
+
# dtype when calculating numerical gradients. Only float types and
|
150
|
+
# +nil+ are allowed.
|
151
|
+
# @see
|
152
|
+
# .numerical_grad
|
153
|
+
#
|
154
|
+
def check_backward(func, x_data, y_grad, params=[], eps: 0.001, atol: 1e-5, rtol: 1e-4, no_grads: nil, dtype: nil)
|
155
|
+
x_data = _as_tuple(x_data)
|
156
|
+
if !y_grad.nil?
|
157
|
+
y_grad = _as_tuple(y_grad)
|
158
|
+
end
|
159
|
+
|
160
|
+
params = _as_tuple(params)
|
161
|
+
xs = x_data.map{|x| Chainer::Variable.new(x)}
|
162
|
+
y = func.(*xs)
|
163
|
+
y = _as_tuple(y)
|
164
|
+
y = Chainer::Functions::Math::Identity.identity(*y)
|
165
|
+
y = _as_tuple(y)
|
166
|
+
|
167
|
+
if !y_grad.nil?
|
168
|
+
if (y).size != (y_grad).size
|
169
|
+
raise TypeError, "`y_grad` must have the same length of output values"
|
170
|
+
end
|
171
|
+
|
172
|
+
y.zip(y_grad).each do |iy, igy|
|
173
|
+
iy.grad = igy
|
174
|
+
end
|
175
|
+
else
|
176
|
+
if (y).size != 1
|
177
|
+
raise TypeError, "When `y_grad` is `nil`, the function must return azero-dimentional array"
|
178
|
+
end
|
179
|
+
y_grad = [1]
|
180
|
+
end
|
181
|
+
# We only need to call `backward` for one result `Chainer::Variable`.
|
182
|
+
# `Chainer::Variable.backward` method calls `Chainer::Function.backward` of its creator.
|
183
|
+
y[0].backward()
|
184
|
+
|
185
|
+
if dtype.nil?
|
186
|
+
casted_xs = x_data.map{|x| Chainer::Variable.new(x)}
|
187
|
+
else
|
188
|
+
if (dtype != Numo::DFloat) and (dtype != Numo::SFloat)
|
189
|
+
raise TypeError, "`dtype` is allowed only float type"
|
190
|
+
end
|
191
|
+
if (params).size > 0
|
192
|
+
raise TypeError, "`dtype` is available only if `params` is empty"
|
193
|
+
end
|
194
|
+
casted_xs = x_data.map{|x|
|
195
|
+
if x.class == Numo::DFloat or x.class == Numo::SFloat
|
196
|
+
Chainer::Variable.new(dtype.cast(x))
|
197
|
+
else
|
198
|
+
Chainer::Variable.new(x)
|
199
|
+
end
|
200
|
+
}
|
201
|
+
end
|
202
|
+
|
203
|
+
f = lambda do |_|
|
204
|
+
ys = func.(*casted_xs)
|
205
|
+
ys = _as_tuple(ys)
|
206
|
+
return ys.map{|y| y.data}.to_a
|
207
|
+
end
|
208
|
+
|
209
|
+
if no_grads.nil?
|
210
|
+
no_grads = xs.map{|x| (x.dtype != Numo::DFloat) and (x.dtype != Numo::SFloat)}
|
211
|
+
else
|
212
|
+
if no_grads.size != xs.size
|
213
|
+
raise TypeError, "Length of no_grads param and xs should be same."
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
no_grads.zip(xs, casted_xs).each do |skip, x, cx|
|
218
|
+
if skip
|
219
|
+
raise unless x.grad.nil?
|
220
|
+
next
|
221
|
+
end
|
222
|
+
gx, = numerical_grad(f, [cx.data], y_grad, eps)
|
223
|
+
Chainer::Testing.assert_allclose(x.grad, gx, atol: atol, rtol: rtol)
|
224
|
+
if dtype.nil?
|
225
|
+
raise unless gx.class == x.grad.class
|
226
|
+
else
|
227
|
+
if ((gx.class != Numo::DFloat) and (gx.class != Numo::SFloat)) and (gx.class != dtype)
|
228
|
+
raise
|
229
|
+
end
|
230
|
+
end
|
231
|
+
end
|
232
|
+
|
233
|
+
params.each do |p|
|
234
|
+
gp, = numerical_grad(f, [p.data], y_grad, eps)
|
235
|
+
Chainer::Testing.assert_allclose(p.grad, gp, atol: atol, rtol: rtol)
|
236
|
+
raise unless gp.dtype === p.grad.dtype
|
237
|
+
end
|
238
|
+
end
|
239
|
+
module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward
|
240
|
+
end
|