red-chainer 0.2.1 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar/models/vgg.rb +84 -0
- data/examples/cifar/train_cifar.rb +70 -0
- data/examples/iris.rb +103 -0
- data/lib/chainer.rb +17 -0
- data/lib/chainer/configuration.rb +2 -1
- data/lib/chainer/cuda.rb +18 -0
- data/lib/chainer/dataset/convert.rb +30 -9
- data/lib/chainer/datasets/cifar.rb +56 -0
- data/lib/chainer/datasets/mnist.rb +3 -3
- data/lib/chainer/datasets/tuple_dataset.rb +3 -1
- data/lib/chainer/function.rb +1 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
- data/lib/chainer/functions/activation/log_softmax.rb +4 -4
- data/lib/chainer/functions/activation/relu.rb +3 -4
- data/lib/chainer/functions/activation/sigmoid.rb +4 -4
- data/lib/chainer/functions/activation/tanh.rb +5 -5
- data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
- data/lib/chainer/functions/connection/linear.rb +1 -1
- data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
- data/lib/chainer/functions/math/identity.rb +26 -0
- data/lib/chainer/functions/noise/dropout.rb +45 -0
- data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
- data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
- data/lib/chainer/gradient_check.rb +240 -0
- data/lib/chainer/initializer.rb +2 -0
- data/lib/chainer/initializers/constant.rb +1 -1
- data/lib/chainer/initializers/init.rb +5 -1
- data/lib/chainer/initializers/normal.rb +1 -1
- data/lib/chainer/iterators/serial_iterator.rb +1 -1
- data/lib/chainer/link.rb +11 -0
- data/lib/chainer/links/connection/convolution_2d.rb +98 -0
- data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
- data/lib/chainer/optimizer.rb +40 -1
- data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
- data/lib/chainer/parameter.rb +1 -1
- data/lib/chainer/serializers/marshal.rb +7 -3
- data/lib/chainer/testing/array.rb +32 -0
- data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
- data/lib/chainer/training/extensions/snapshot.rb +1 -1
- data/lib/chainer/training/standard_updater.rb +4 -0
- data/lib/chainer/training/trainer.rb +1 -1
- data/lib/chainer/utils/array.rb +13 -2
- data/lib/chainer/utils/conv.rb +59 -0
- data/lib/chainer/utils/math.rb +72 -0
- data/lib/chainer/utils/variable.rb +7 -3
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +1 -0
- metadata +37 -3
@@ -0,0 +1,26 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Math
|
4
|
+
# Identity function.
|
5
|
+
class Identity < Chainer::Function
|
6
|
+
def check_type_forward(in_types)
|
7
|
+
# pass
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(xs)
|
11
|
+
retain_inputs([])
|
12
|
+
return xs
|
13
|
+
end
|
14
|
+
|
15
|
+
def backward(xs, gys)
|
16
|
+
return gys
|
17
|
+
end
|
18
|
+
|
19
|
+
# Just returns input variables.
|
20
|
+
def self.identity(*inputs)
|
21
|
+
self.new.(*inputs)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
@@ -0,0 +1,45 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Noise
|
4
|
+
class Dropout < Chainer::Function
|
5
|
+
# Drops elements of input variable randomly.
|
6
|
+
#
|
7
|
+
# This function drops input elements randomly with probability `ratio` and
|
8
|
+
# scales the remaining elements by factor `1 / (1 - ratio)`.
|
9
|
+
# In testing mode, it does nothing and just returns `x`.
|
10
|
+
#
|
11
|
+
# @param [Chainer::Variable] x Input variable.
|
12
|
+
# @param [float] ratio Dropout ratio. The ``ratio`` must be `0.0 <= ratio < 1.0`.
|
13
|
+
# @return [Chainer::Variable] Output variable.
|
14
|
+
def self.dropout(x, ratio: 0.5)
|
15
|
+
Chainer.configuration.train ? self.new(ratio).(x) : x
|
16
|
+
end
|
17
|
+
|
18
|
+
def initialize(dropout_ratio)
|
19
|
+
if dropout_ratio < 0 || dropout_ratio >= 1.0
|
20
|
+
raise 'dropout_ratio must be in the range [0, 1)'
|
21
|
+
end
|
22
|
+
@dropout_ratio = dropout_ratio
|
23
|
+
end
|
24
|
+
|
25
|
+
def forward(x)
|
26
|
+
retain_inputs([])
|
27
|
+
unless self.instance_variable_defined?(:@mask)
|
28
|
+
scale = x[0].class[*[1.0 / (1 - @dropout_ratio)]][0]
|
29
|
+
flag = x[0].class.new(*x[0].shape).rand >= @dropout_ratio
|
30
|
+
|
31
|
+
@mask = x[0].class.zeros(*x[0].shape)
|
32
|
+
@mask[flag] = 1
|
33
|
+
@mask *= scale
|
34
|
+
end
|
35
|
+
[x[0] * @mask]
|
36
|
+
end
|
37
|
+
|
38
|
+
def backward(x, gy)
|
39
|
+
[gy[0] * @mask]
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
@@ -0,0 +1,136 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Normalization
|
4
|
+
class BatchNormalizationFunction < Chainer::Function
|
5
|
+
attr_reader :running_mean, :running_var
|
6
|
+
# Batch normalization function with fixed statistics.
|
7
|
+
# This is a variant of batch normalization, where the mean and variance
|
8
|
+
# statistics are given by the caller as fixed variables. This is
|
9
|
+
# used on testing mode of the batch normalization layer, where batch
|
10
|
+
# statistics cannot be used for prediction consistency.
|
11
|
+
#
|
12
|
+
# @param [Chainer::Variable] x Input variable.
|
13
|
+
# @param [Chainer::Variable] gamma Scaling parameter of normalized data.
|
14
|
+
# @param [Chainer::Variable] beta Shifting parameter of scaled normalized data.
|
15
|
+
# @param [Chainer::Variable] mean Shifting parameter of input.
|
16
|
+
# @param [Chainer::Variable] var Square of scaling parameter of input.
|
17
|
+
# @param [float] eps Epsilon value for numerical stability.
|
18
|
+
def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
|
19
|
+
old_train = Chainer.configuration.train
|
20
|
+
Chainer.configuration.train = false
|
21
|
+
norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
|
22
|
+
Chainer.configuration.train = old_train
|
23
|
+
norm
|
24
|
+
end
|
25
|
+
|
26
|
+
def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
|
27
|
+
@running_mean = mean
|
28
|
+
@running_var = var
|
29
|
+
@eps = eps
|
30
|
+
@mean_cache = nil
|
31
|
+
@decay = decay
|
32
|
+
end
|
33
|
+
|
34
|
+
def forward(inputs)
|
35
|
+
x, gamma, beta = inputs[0], inputs[1], inputs[2]
|
36
|
+
if Chainer.configuration.train
|
37
|
+
if @running_mean.nil?
|
38
|
+
@running_mean = Numo::NArray[*gamma].new_zeros
|
39
|
+
@running_var = Numo::NArray[*gamma].new_zeros
|
40
|
+
else
|
41
|
+
@running_mean = Numo::NArray[*@running_mean]
|
42
|
+
@running_var = Numo::NArray[*@running_var]
|
43
|
+
end
|
44
|
+
elsif inputs.size == 5
|
45
|
+
@fixed_mean = inputs[3]
|
46
|
+
@fixed_var = inputs[4]
|
47
|
+
end
|
48
|
+
|
49
|
+
head_ndim = gamma.ndim + 1
|
50
|
+
gamma_expander = [1] + gamma.shape + [1] * (x.ndim - head_ndim)
|
51
|
+
gamma = gamma.reshape(*gamma_expander)
|
52
|
+
beta_expander = [1] + beta.shape + [1] * (x.ndim - head_ndim)
|
53
|
+
beta = beta.reshape(*beta_expander)
|
54
|
+
|
55
|
+
if Chainer.configuration.train
|
56
|
+
axis = [0] + (head_ndim...(x.ndim)).to_a
|
57
|
+
mean = x.mean(axis: axis)
|
58
|
+
# FIXME: numpy.var
|
59
|
+
var = x.var(axis: axis)
|
60
|
+
var += @eps
|
61
|
+
else
|
62
|
+
mean = @fixed_mean
|
63
|
+
var = @fixed_var + @eps
|
64
|
+
end
|
65
|
+
|
66
|
+
@std = Numo::NMath.sqrt(var)
|
67
|
+
|
68
|
+
mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
|
69
|
+
x_mu = x - mean.reshape(*mean_expander)
|
70
|
+
std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
|
71
|
+
x_mu /= @std.reshape(*std_expander)
|
72
|
+
@x_hat = x_mu
|
73
|
+
y = gamma * @x_hat
|
74
|
+
y += beta
|
75
|
+
|
76
|
+
if Chainer.configuration.train
|
77
|
+
m = x.size.div(gamma.size)
|
78
|
+
adjust = m / [m - 1.0, 1.0].max
|
79
|
+
@running_mean *= @decay
|
80
|
+
temp_ar = Numo::NArray[*mean]
|
81
|
+
temp_ar *= (1 - @decay)
|
82
|
+
@running_mean += temp_ar
|
83
|
+
|
84
|
+
@running_var *= @decay
|
85
|
+
temp_ar = Numo::NArray[*var]
|
86
|
+
temp_ar *= ((1 - @decay) * adjust)
|
87
|
+
@running_var += temp_ar
|
88
|
+
end
|
89
|
+
|
90
|
+
[y,]
|
91
|
+
end
|
92
|
+
|
93
|
+
def backward(inputs, grad_outputs)
|
94
|
+
x, gamma = inputs[0], inputs[1]
|
95
|
+
gy = grad_outputs[0]
|
96
|
+
head_ndim = gamma.ndim + 1
|
97
|
+
m = gamma.class[x.size.div(gamma.size)][0]
|
98
|
+
axis = [0] + (head_ndim...(x.ndim)).to_a
|
99
|
+
|
100
|
+
if inputs.size == 5
|
101
|
+
mean = inputs[3]
|
102
|
+
var = inputs[4]
|
103
|
+
std = Numo::NMath.sqrt(var)
|
104
|
+
gs = gamma / std
|
105
|
+
gbeta = gy.sum(axis: axis)
|
106
|
+
|
107
|
+
mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
|
108
|
+
x_mu = x - mean.reshape(*mean_expander)
|
109
|
+
std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
|
110
|
+
x_mu /= std.reshape(*std_expander)
|
111
|
+
x_hat = x_mu
|
112
|
+
ggamma = (gy * x_hat).sum(axis: axis)
|
113
|
+
gmean = -gs * gbeta
|
114
|
+
gvar = -0.5 * gamma / var * ggamma
|
115
|
+
gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
|
116
|
+
gx = gs.reshape(*gs_expander)
|
117
|
+
return [gx, ggamma, gbeta, gmean, gvar]
|
118
|
+
end
|
119
|
+
|
120
|
+
gbeta = gy.sum(axis: axis)
|
121
|
+
ggamma = (gy * @x_hat).sum(axis: axis)
|
122
|
+
tmp = (gamma / @std)
|
123
|
+
tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
|
124
|
+
tmp = tmp.reshape(*tmp_expander)
|
125
|
+
|
126
|
+
ggamma_expander = [1] + ggamma.shape + [1] * (x.ndim - head_ndim)
|
127
|
+
gbeta_expander = [1] + gbeta.shape + [1] * (x.ndim - head_ndim)
|
128
|
+
|
129
|
+
gx = tmp * (gy - (@x_hat * ggamma.reshape(*ggamma_expander) + gbeta.reshape(*gbeta_expander)) / m )
|
130
|
+
|
131
|
+
[gx, ggamma, gbeta]
|
132
|
+
end
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
@@ -0,0 +1,57 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Pooling
|
4
|
+
class MaxPooling2D < Pooling2D
|
5
|
+
# Spatial max pooling function
|
6
|
+
#
|
7
|
+
# @param [Chainer::Variable] x Input variable
|
8
|
+
# @param [integer || 2D integer array] Size of pooling window
|
9
|
+
# @param [integer || 2D integer array] Stride of pooling applications
|
10
|
+
# @param [integer || 2D integer array] Spatial padding width for the input array
|
11
|
+
# @param [boolean] If `true`, all spatial locations are pooled int some output pixels
|
12
|
+
# @return [Chainer::Variable] Output variable
|
13
|
+
def self.max_pooling_2d(x, ksize, stride: nil, pad: 0, cover_all: true)
|
14
|
+
self.new(ksize, stride: stride, pad: pad, cover_all: cover_all).(x)
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward_cpu(x)
|
18
|
+
retain_inputs([])
|
19
|
+
@in_shape = x[0].shape
|
20
|
+
@in_dtype = x[0].class
|
21
|
+
|
22
|
+
col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
|
23
|
+
n, c, kh, kw, out_h, out_w = col.shape
|
24
|
+
col = col.reshape(n , c, kh * kw, out_h, out_w)
|
25
|
+
|
26
|
+
# TODO: numpy.argmax(axis=2)
|
27
|
+
d = col.shape[3..-1].reduce(:*) || 1
|
28
|
+
dx = col.shape[2..-1].reduce(:*) || 1
|
29
|
+
max_index = col.max_index(2)
|
30
|
+
@indexes = max_index.flatten.map_with_index { |val, idx| (val - (dx * (idx / d))) / d }.reshape(*max_index.shape)
|
31
|
+
|
32
|
+
y = col.max(axis: 2)
|
33
|
+
[y]
|
34
|
+
end
|
35
|
+
|
36
|
+
def backward_cpu(x, gy)
|
37
|
+
n, c, out_h, out_w = gy[0].shape
|
38
|
+
h, w = @in_shape[2..-1]
|
39
|
+
kh, kw = @kh, @kw
|
40
|
+
|
41
|
+
gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
|
42
|
+
|
43
|
+
indexes = @indexes.flatten
|
44
|
+
indexes += Numo::Int64.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
|
45
|
+
|
46
|
+
gcol[indexes] = gy[0].flatten.dup
|
47
|
+
gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
|
48
|
+
gcol = gcol.swapaxes(2, 4)
|
49
|
+
gcol = gcol.swapaxes(3, 5)
|
50
|
+
|
51
|
+
gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
|
52
|
+
[gx]
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Pooling
|
4
|
+
# Base class of pooling function over a set of 2d planes
|
5
|
+
class Pooling2D < Chainer::Function
|
6
|
+
def initialize(ksize, stride: nil, pad: 0, cover_all: true)
|
7
|
+
if stride.nil?
|
8
|
+
stride = ksize
|
9
|
+
end
|
10
|
+
|
11
|
+
@kh, @kw = ksize.is_a?(Array) ? ksize : [ksize, ksize]
|
12
|
+
@sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
|
13
|
+
@ph, @pw = pad.is_a?(Array) ? pad: [pad, pad]
|
14
|
+
|
15
|
+
@cover_all = cover_all
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,240 @@
|
|
1
|
+
module Chainer
|
2
|
+
def _copy_arrays(xs)
|
3
|
+
xp = Chainer::get_array_module(*xs)
|
4
|
+
xs.map{|x| (x.is_a? Numo::NArray) ? x.dup : x}
|
5
|
+
end
|
6
|
+
|
7
|
+
# Computes numerical gradient by finite differences.
|
8
|
+
#
|
9
|
+
# This function is used to implement gradient check. For usage example, see
|
10
|
+
# unit tests of Chainer::Functions.
|
11
|
+
#
|
12
|
+
# @param [function] f Ruby function with no arguments that runs forward
|
13
|
+
# computation and returns the result.
|
14
|
+
# @param [Array<Arrays>] inputs Array of arrays that should be treated as
|
15
|
+
# inputs. Each element of them is slightly modified to realize numerical
|
16
|
+
# gradient by finite differences.
|
17
|
+
# @param [Array<Arrays>] grad_outputs Array of arrays that are treated as
|
18
|
+
# output gradients.
|
19
|
+
# @param [Float] eps Epsilon value of finite differences.
|
20
|
+
# @return [Array] Numerical gradient arrays corresponding to +inputs+.
|
21
|
+
#
|
22
|
+
def numerical_grad(f, inputs, grad_outputs, eps=0.001)
|
23
|
+
raise unless eps > 0
|
24
|
+
inputs = inputs.to_a
|
25
|
+
grad_outputs = grad_outputs.to_a
|
26
|
+
xp = Numo::NArray
|
27
|
+
grads = inputs.map{|x| x.new_zeros()}
|
28
|
+
|
29
|
+
if inputs[0].ndim < 2
|
30
|
+
tmp = [[inputs[0], grads[0]]]
|
31
|
+
else
|
32
|
+
tmp = (0...inputs[0].shape[0]).map{|i|[inputs[0][i, false], grads[0][i, false]]}
|
33
|
+
end
|
34
|
+
|
35
|
+
tmp.each do |x, gx|
|
36
|
+
x.each_with_index{|xx, *i|
|
37
|
+
orig = x[*i] # hold original value
|
38
|
+
x[*i] = orig + eps
|
39
|
+
ys1 = _copy_arrays(f.call(x))
|
40
|
+
x[*i] = orig - eps
|
41
|
+
ys2 = _copy_arrays(f.call(x))
|
42
|
+
x[*i] = orig
|
43
|
+
|
44
|
+
ys1.zip(ys2, grad_outputs).each do |y1, y2, gy|
|
45
|
+
if !gy.nil?
|
46
|
+
if ((y1 - y2) * gy).is_a? Numo::NArray
|
47
|
+
dot = ((y1 - y2) * gy).sum()
|
48
|
+
else
|
49
|
+
dot = ((y1 - y2) * gy).inject(:+)
|
50
|
+
end
|
51
|
+
gx[*i] += dot / (2*eps).to_f
|
52
|
+
end
|
53
|
+
end
|
54
|
+
}
|
55
|
+
end
|
56
|
+
|
57
|
+
return grads
|
58
|
+
end
|
59
|
+
|
60
|
+
def _as_tuple(x)
|
61
|
+
if x.is_a? Array
|
62
|
+
return x
|
63
|
+
else
|
64
|
+
return [x]
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
# Test backward procedure of a given function.
|
69
|
+
#
|
70
|
+
# This function automatically check backward-process of given function.
|
71
|
+
# For example, when you have a +Chainer::Function+ class +MyFunc+,
|
72
|
+
# that gets two arguments and returns one value, you can make its test like this:
|
73
|
+
#
|
74
|
+
# def test_my_func(self):
|
75
|
+
# func = MyFunc()
|
76
|
+
# x1_data = Numo::NArray[...]
|
77
|
+
# x2_data = Numo::NArray[...]
|
78
|
+
# gy_data = Numo::NArray[...]
|
79
|
+
# check_backward(func, [x1_data, x2_data], gy_data)
|
80
|
+
#
|
81
|
+
# This method creates +Chainer::Variable+ objects with +x_data+
|
82
|
+
# and calls +func+ with the +Chainer::Variable+ s to get its result
|
83
|
+
# as +Chainer::Variable+.
|
84
|
+
# Then, it sets +y_grad+ array to +grad+ attribute of the result and
|
85
|
+
# calls +backward+ method to get gradients of the inputs.
|
86
|
+
# To check correctness of the gradients, the function calls
|
87
|
+
# +numerical_grad+ to calculate numerically the gradients and compares
|
88
|
+
# the types of gradients with +Chainer::Testing.assert_allclose+.
|
89
|
+
# If input objects (+x1_data+ or/and +x2_data+ in this example) represent
|
90
|
+
# integer variables, their gradients are ignored.
|
91
|
+
#
|
92
|
+
# You can simplify a test when +MyFunc+ gets only one argument:
|
93
|
+
#
|
94
|
+
# check_backward(func, x1_data, gy_data)
|
95
|
+
#
|
96
|
+
# If +MyFunc+ is a loss function which returns a zero-dimensional
|
97
|
+
# array, pass +nil+ to +gy_data+. In this case, it sets +1+ to
|
98
|
+
# +grad+ attribute of the result:
|
99
|
+
#
|
100
|
+
# check_backward(my_loss_func, [x1_data, x2_data], nil)
|
101
|
+
#
|
102
|
+
# If +MyFunc+ returns multiple outputs, pass all gradients for outputs as a Array:
|
103
|
+
#
|
104
|
+
# gy1_data = Numo::NArray[...]
|
105
|
+
# gy2_data = Numo::NArray[...]
|
106
|
+
# check_backward(func, x1_data, [gy1_data, gy2_data])
|
107
|
+
#
|
108
|
+
# You can also test a +Chainer::Link+.
|
109
|
+
# To check gradients of parameters of the link, set a Array of the parameters
|
110
|
+
# to +params+ arguments:
|
111
|
+
#
|
112
|
+
# check_backward(my_link, [x1_data, x2_data], gy_data, [my_link.W, my_link.b])
|
113
|
+
#
|
114
|
+
# Note that +params+ are not +Numo::NArray+ s,
|
115
|
+
# but +Chainer::Variables+ s.
|
116
|
+
#
|
117
|
+
# Function objects are acceptable as +func+ argument:
|
118
|
+
#
|
119
|
+
# check_backward(lambda{|x1, x1| f(x1, x2)}, [x1_data, x2_data], gy_data)
|
120
|
+
#
|
121
|
+
# @note
|
122
|
+
# +func+ is called many times to get numerical gradients for all inputs.
|
123
|
+
# This function doesn't work correctly when +func+ behaves randomly as
|
124
|
+
# it gets different gradients.
|
125
|
+
# @param [Method, Proc] func A function which gets +Chainer::Variable+ s
|
126
|
+
# and returns +Chainer::Variable+ s. +func+ must returns
|
127
|
+
# a Array of +Chainer::Variable+ s or one
|
128
|
+
# +Chainer::Variable+. You can use +Chainer::Function+
|
129
|
+
# object, +Chainer::Link+ object or a function satisfying the
|
130
|
+
# condition.
|
131
|
+
# @param [Numo::NArray or Array<Numo::NArray>] x_data A set of +Numo::NArray+ s to be
|
132
|
+
# passed to +func+. If +x_data+ is one +Numo::NArray+ object, it is
|
133
|
+
# treated as +(x_data,)+.
|
134
|
+
# @param [Numo::NArray or Array<Numo::NArray> or nil] y_grad A set of +Numo::NArray+ s representing gradients of return-values of
|
135
|
+
# +func+. If +y_grad+ is one +Numo::NArray+ object, it is
|
136
|
+
# treated as +(y_grad,)+. If +func+ is a loss-function,
|
137
|
+
# +y_grad+ should be set to +nil+.
|
138
|
+
# @param [Chainer::Variable or Array<Chainder::Variable>] params A set of +Chainer::Variable+ s whose gradients are checked.
|
139
|
+
# When +func+ is a +Chainer::Link+ object,
|
140
|
+
# set its parameters as +params+.
|
141
|
+
# If +params+ is one +Chainer::Variable+ object,
|
142
|
+
# it is treated as +(params,)+.
|
143
|
+
# @param [Float] eps Epsilon value to be passed to +numerical_grad+.
|
144
|
+
# @param [Float] atol Absolute tolerance to be passed to +Chainer::Testing.assert_allclose+.
|
145
|
+
# @param [Float] rtol Relative tolerance to be passed to +Chainer::Testing.assert_allclose+.
|
146
|
+
# @param [Array<Boolean>] no_grads Flag to skip variable for gradient assertion.
|
147
|
+
# It should be same length as +x_data+.
|
148
|
+
# @param [Numo::NArray.class] dtype +x_data+ and +y_grad+ are casted to this
|
149
|
+
# dtype when calculating numerical gradients. Only float types and
|
150
|
+
# +nil+ are allowed.
|
151
|
+
# @see
|
152
|
+
# .numerical_grad
|
153
|
+
#
|
154
|
+
def check_backward(func, x_data, y_grad, params=[], eps: 0.001, atol: 1e-5, rtol: 1e-4, no_grads: nil, dtype: nil)
|
155
|
+
x_data = _as_tuple(x_data)
|
156
|
+
if !y_grad.nil?
|
157
|
+
y_grad = _as_tuple(y_grad)
|
158
|
+
end
|
159
|
+
|
160
|
+
params = _as_tuple(params)
|
161
|
+
xs = x_data.map{|x| Chainer::Variable.new(x)}
|
162
|
+
y = func.(*xs)
|
163
|
+
y = _as_tuple(y)
|
164
|
+
y = Chainer::Functions::Math::Identity.identity(*y)
|
165
|
+
y = _as_tuple(y)
|
166
|
+
|
167
|
+
if !y_grad.nil?
|
168
|
+
if (y).size != (y_grad).size
|
169
|
+
raise TypeError, "`y_grad` must have the same length of output values"
|
170
|
+
end
|
171
|
+
|
172
|
+
y.zip(y_grad).each do |iy, igy|
|
173
|
+
iy.grad = igy
|
174
|
+
end
|
175
|
+
else
|
176
|
+
if (y).size != 1
|
177
|
+
raise TypeError, "When `y_grad` is `nil`, the function must return azero-dimentional array"
|
178
|
+
end
|
179
|
+
y_grad = [1]
|
180
|
+
end
|
181
|
+
# We only need to call `backward` for one result `Chainer::Variable`.
|
182
|
+
# `Chainer::Variable.backward` method calls `Chainer::Function.backward` of its creator.
|
183
|
+
y[0].backward()
|
184
|
+
|
185
|
+
if dtype.nil?
|
186
|
+
casted_xs = x_data.map{|x| Chainer::Variable.new(x)}
|
187
|
+
else
|
188
|
+
if (dtype != Numo::DFloat) and (dtype != Numo::SFloat)
|
189
|
+
raise TypeError, "`dtype` is allowed only float type"
|
190
|
+
end
|
191
|
+
if (params).size > 0
|
192
|
+
raise TypeError, "`dtype` is available only if `params` is empty"
|
193
|
+
end
|
194
|
+
casted_xs = x_data.map{|x|
|
195
|
+
if x.class == Numo::DFloat or x.class == Numo::SFloat
|
196
|
+
Chainer::Variable.new(dtype.cast(x))
|
197
|
+
else
|
198
|
+
Chainer::Variable.new(x)
|
199
|
+
end
|
200
|
+
}
|
201
|
+
end
|
202
|
+
|
203
|
+
f = lambda do |_|
|
204
|
+
ys = func.(*casted_xs)
|
205
|
+
ys = _as_tuple(ys)
|
206
|
+
return ys.map{|y| y.data}.to_a
|
207
|
+
end
|
208
|
+
|
209
|
+
if no_grads.nil?
|
210
|
+
no_grads = xs.map{|x| (x.dtype != Numo::DFloat) and (x.dtype != Numo::SFloat)}
|
211
|
+
else
|
212
|
+
if no_grads.size != xs.size
|
213
|
+
raise TypeError, "Length of no_grads param and xs should be same."
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
no_grads.zip(xs, casted_xs).each do |skip, x, cx|
|
218
|
+
if skip
|
219
|
+
raise unless x.grad.nil?
|
220
|
+
next
|
221
|
+
end
|
222
|
+
gx, = numerical_grad(f, [cx.data], y_grad, eps)
|
223
|
+
Chainer::Testing.assert_allclose(x.grad, gx, atol: atol, rtol: rtol)
|
224
|
+
if dtype.nil?
|
225
|
+
raise unless gx.class == x.grad.class
|
226
|
+
else
|
227
|
+
if ((gx.class != Numo::DFloat) and (gx.class != Numo::SFloat)) and (gx.class != dtype)
|
228
|
+
raise
|
229
|
+
end
|
230
|
+
end
|
231
|
+
end
|
232
|
+
|
233
|
+
params.each do |p|
|
234
|
+
gp, = numerical_grad(f, [p.data], y_grad, eps)
|
235
|
+
Chainer::Testing.assert_allclose(p.grad, gp, atol: atol, rtol: rtol)
|
236
|
+
raise unless gp.dtype === p.grad.dtype
|
237
|
+
end
|
238
|
+
end
|
239
|
+
module_function :_copy_arrays, :numerical_grad, :_as_tuple, :check_backward
|
240
|
+
end
|