red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -1,66 +1,71 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Math
|
4
|
-
|
5
|
-
|
4
|
+
class Neg < ::Chainer::FunctionNode
|
5
|
+
def label
|
6
|
+
'__neg__'
|
7
|
+
end
|
8
|
+
|
6
9
|
def forward(x)
|
7
|
-
retain_inputs([])
|
8
10
|
[Utils::Array.force_array(-x[0])]
|
9
11
|
end
|
10
12
|
|
11
|
-
def backward(
|
12
|
-
[
|
13
|
+
def backward(indexes, gy)
|
14
|
+
[-gy[0]]
|
13
15
|
end
|
14
16
|
end
|
15
17
|
|
16
|
-
class Add < ::Chainer::
|
18
|
+
class Add < ::Chainer::FunctionNode
|
17
19
|
def forward(x)
|
18
|
-
retain_inputs([])
|
19
20
|
[Utils::Array.force_array(x[0] + x[1])]
|
20
21
|
end
|
21
22
|
|
22
|
-
def backward(
|
23
|
+
def backward(indexes, gy)
|
23
24
|
[gy[0], gy[0]]
|
24
25
|
end
|
25
26
|
end
|
26
27
|
|
27
|
-
class AddConstant < ::Chainer::
|
28
|
+
class AddConstant < ::Chainer::FunctionNode
|
28
29
|
def initialize(value)
|
29
30
|
@value = value
|
30
31
|
end
|
31
32
|
|
32
33
|
def forward(x)
|
33
|
-
retain_inputs([])
|
34
34
|
[Utils::Array.force_array(x[0] + @value)]
|
35
35
|
end
|
36
36
|
|
37
|
-
def backward(
|
37
|
+
def backward(indexes, gy)
|
38
38
|
[gy[0]]
|
39
39
|
end
|
40
40
|
end
|
41
|
-
|
42
|
-
class Sub < ::Chainer::
|
41
|
+
|
42
|
+
class Sub < ::Chainer::FunctionNode
|
43
|
+
def label
|
44
|
+
'_ - _'
|
45
|
+
end
|
46
|
+
|
43
47
|
def forward(x)
|
44
|
-
retain_inputs([])
|
45
48
|
[Utils::Array.force_array(x[0] - x[1])]
|
46
49
|
end
|
47
50
|
|
48
|
-
def backward(
|
49
|
-
[gy[0],
|
51
|
+
def backward(indexes, gy)
|
52
|
+
[gy[0], -gy[0]]
|
50
53
|
end
|
51
54
|
end
|
52
55
|
|
53
|
-
class Mul < ::Chainer::
|
56
|
+
class Mul < ::Chainer::FunctionNode
|
54
57
|
def forward(x)
|
58
|
+
retain_inputs([0, 1])
|
55
59
|
[Utils::Array.force_array(x[0] * x[1])]
|
56
60
|
end
|
57
61
|
|
58
|
-
def backward(
|
59
|
-
|
62
|
+
def backward(indexes, gy)
|
63
|
+
xs = get_retained_inputs
|
64
|
+
indexes.map { |i| gy[0] * xs[1 - i] }
|
60
65
|
end
|
61
66
|
end
|
62
67
|
|
63
|
-
class MulConstant < ::Chainer::
|
68
|
+
class MulConstant < ::Chainer::FunctionNode
|
64
69
|
def initialize(value)
|
65
70
|
@value = value
|
66
71
|
end
|
@@ -69,23 +74,23 @@ module Chainer
|
|
69
74
|
[Utils::Array.force_array(@value * x[0])]
|
70
75
|
end
|
71
76
|
|
72
|
-
def backward(
|
73
|
-
[
|
77
|
+
def backward(indexes, gy)
|
78
|
+
[gy[0] * @value]
|
74
79
|
end
|
75
80
|
end
|
76
|
-
|
77
|
-
class Div < ::Chainer::
|
81
|
+
|
82
|
+
class Div < ::Chainer::FunctionNode
|
78
83
|
def forward(x)
|
79
84
|
[Utils::Array.force_array(x[0] / x[1])]
|
80
85
|
end
|
81
86
|
|
82
|
-
def backward(
|
87
|
+
def backward(indexes, gy)
|
83
88
|
gx0 = Utils::Array.force_array(gy[0] / x[1])
|
84
89
|
[gx0, Utils::Array.force_array(-1 * gx0 * x[0] / x[1])]
|
85
90
|
end
|
86
91
|
end
|
87
|
-
|
88
|
-
class PowVarVar < ::Chainer::
|
92
|
+
|
93
|
+
class PowVarVar < ::Chainer::FunctionNode
|
89
94
|
def forward(x)
|
90
95
|
@y = Utils::Array.force_array(x[0] ** x[1])
|
91
96
|
[@y]
|
@@ -94,12 +99,13 @@ module Chainer
|
|
94
99
|
def backward(x, gy)
|
95
100
|
one = x[1].class.ones[0]
|
96
101
|
gx0 = Utils::Array.force_array(x[1] * (x[0] ** (x[1] - one)) * gy[0])
|
97
|
-
|
102
|
+
xm = Chainer.get_array_module(x[0])
|
103
|
+
gx1 = Utils::Array.force_array(xm::NMath.log(x[0]) * @y * gy[0])
|
98
104
|
[gx0, gx1]
|
99
105
|
end
|
100
106
|
end
|
101
107
|
|
102
|
-
class PowVarConst < ::Chainer::
|
108
|
+
class PowVarConst < ::Chainer::FunctionNode
|
103
109
|
def initialize(value)
|
104
110
|
@value = value
|
105
111
|
end
|
@@ -113,7 +119,7 @@ module Chainer
|
|
113
119
|
gx = @value * (x[0] ** val_1) * gy[0]
|
114
120
|
[Utils::Array.force_array(gx)]
|
115
121
|
end
|
116
|
-
end
|
122
|
+
end
|
117
123
|
end
|
118
124
|
end
|
119
125
|
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Math
|
4
|
+
class Exp < Chainer::FunctionNode
|
5
|
+
# Elementwise exponential function.
|
6
|
+
def self.exp(x)
|
7
|
+
self.new.apply([x]).first
|
8
|
+
end
|
9
|
+
|
10
|
+
def label
|
11
|
+
'exp'
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(x)
|
15
|
+
retain_inputs([])
|
16
|
+
retain_outputs([0])
|
17
|
+
xm = Chainer.get_array_module(x.first)
|
18
|
+
[Utils::Array.force_array(xm::NMath.exp(x.first))]
|
19
|
+
end
|
20
|
+
|
21
|
+
def backward(indexes, gy)
|
22
|
+
y = get_retained_outputs.first
|
23
|
+
[y * gy.first]
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -2,7 +2,7 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Math
|
4
4
|
# Identity function.
|
5
|
-
class Identity < Chainer::
|
5
|
+
class Identity < Chainer::FunctionNode
|
6
6
|
def check_type_forward(in_types)
|
7
7
|
# pass
|
8
8
|
end
|
@@ -12,13 +12,14 @@ module Chainer
|
|
12
12
|
return xs
|
13
13
|
end
|
14
14
|
|
15
|
-
def backward(
|
15
|
+
def backward(indexes, gys)
|
16
16
|
return gys
|
17
17
|
end
|
18
18
|
|
19
19
|
# Just returns input variables.
|
20
20
|
def self.identity(*inputs)
|
21
|
-
self.new.(
|
21
|
+
ret = self.new.apply(inputs)
|
22
|
+
ret.size == 1 ? ret[0] : ret
|
22
23
|
end
|
23
24
|
end
|
24
25
|
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Math
|
4
|
+
# Sum of array elements over a given axis.
|
5
|
+
class Sum < Chainer::FunctionNode
|
6
|
+
# Sum of array elements over a given axis
|
7
|
+
#
|
8
|
+
# @param [Chainer::Variable] x Elements to sum
|
9
|
+
# @param [nil, Integer, Array<Integer>] axis Axis which a sum is performed
|
10
|
+
# @param[boolean] keepdims If `true`, the specified axes are remained as axes of length one
|
11
|
+
# @return [Chainer::Variable] Output variable
|
12
|
+
def self.sum(x, axis: nil, keepdims: false)
|
13
|
+
Sum.new(axis: axis, keepdims: keepdims).apply([x]).first
|
14
|
+
end
|
15
|
+
|
16
|
+
def initialize(axis: nil, keepdims: false)
|
17
|
+
if axis.nil?
|
18
|
+
@axis = nil
|
19
|
+
elsif axis.is_a?(Integer)
|
20
|
+
@axis = [axis]
|
21
|
+
elsif axis.is_a?(::Array) && axis.all? { |e| e.is_a?(Integer) }
|
22
|
+
raise ArgumentError, "duplicate value in axis: #{axis}" unless axis.uniq.size == axis.size
|
23
|
+
@axis = axis
|
24
|
+
else
|
25
|
+
raise TypeError, 'nil, Integer or Array of int are required'
|
26
|
+
end
|
27
|
+
|
28
|
+
@keepdims = keepdims
|
29
|
+
end
|
30
|
+
|
31
|
+
def forward(inputs)
|
32
|
+
x = inputs.first
|
33
|
+
ret = x.sum(axis: @axis, keepdims: @keepdims)
|
34
|
+
ret = x.class.cast(ret)
|
35
|
+
[ret]
|
36
|
+
end
|
37
|
+
|
38
|
+
def backward(indexes, grad_outputs)
|
39
|
+
gy = grad_outputs.first
|
40
|
+
ndim = @inputs.first.shape.size
|
41
|
+
unless ndim == 0 || @axis.nil? || @keepdims
|
42
|
+
actual_axis = @axis.map { |axis| axis >= 0 ? axis : axis + ndim }
|
43
|
+
shape = gy.shape
|
44
|
+
actual_axis.sort.each { |axis| shape.insert(axis, 1) }
|
45
|
+
gy = Chainer::Functions::Array::Reshape.reshape(gy, shape)
|
46
|
+
end
|
47
|
+
[Chainer::Functions::Array::BroadcastTo.broadcast_to(gy, @inputs.first.shape)]
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
@@ -1,7 +1,8 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Noise
|
4
|
-
class Dropout < Chainer::
|
4
|
+
class Dropout < Chainer::FunctionNode
|
5
|
+
attr_reader :mask
|
5
6
|
# Drops elements of input variable randomly.
|
6
7
|
#
|
7
8
|
# This function drops input elements randomly with probability `ratio` and
|
@@ -12,7 +13,7 @@ module Chainer
|
|
12
13
|
# @param [float] ratio Dropout ratio. The ``ratio`` must be `0.0 <= ratio < 1.0`.
|
13
14
|
# @return [Chainer::Variable] Output variable.
|
14
15
|
def self.dropout(x, ratio: 0.5)
|
15
|
-
Chainer.configuration.train ? self.new(ratio).(x) : x
|
16
|
+
Chainer.configuration.train ? self.new(ratio).apply([x])[0] : Chainer::Variable.as_variable(x)
|
16
17
|
end
|
17
18
|
|
18
19
|
def initialize(dropout_ratio)
|
@@ -23,7 +24,6 @@ module Chainer
|
|
23
24
|
end
|
24
25
|
|
25
26
|
def forward(x)
|
26
|
-
retain_inputs([])
|
27
27
|
unless self.instance_variable_defined?(:@mask)
|
28
28
|
scale = x[0].class[*[1.0 / (1 - @dropout_ratio)]][0]
|
29
29
|
flag = x[0].class.new(*x[0].shape).rand >= @dropout_ratio
|
@@ -36,7 +36,23 @@ module Chainer
|
|
36
36
|
end
|
37
37
|
|
38
38
|
def backward(x, gy)
|
39
|
-
|
39
|
+
DropoutGrad.new(@mask).apply(gy)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
# Computes the gradient of the Dropout function.
|
44
|
+
class DropoutGrad < Chainer::FunctionNode
|
45
|
+
def initialize(mask)
|
46
|
+
@mask = mask
|
47
|
+
end
|
48
|
+
|
49
|
+
def forward(inputs)
|
50
|
+
y = inputs.first * @mask
|
51
|
+
[y]
|
52
|
+
end
|
53
|
+
|
54
|
+
def backward(indexes, gy)
|
55
|
+
DropoutGrad.new(@mask).apply(gy)
|
40
56
|
end
|
41
57
|
end
|
42
58
|
end
|
@@ -1,134 +1,287 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Normalization
|
4
|
-
|
4
|
+
module Calculation
|
5
|
+
def apply_bn_fwd(xp, x, mean, inv_std, gamma, beta)
|
6
|
+
# NOTE: all arguments should be broadcasted to x.shape
|
7
|
+
# (mean, inv_std, gamma, and beta have to already be expanded)
|
8
|
+
x_hat = x_hat(x, mean, inv_std)
|
9
|
+
y = gamma * x_hat
|
10
|
+
y += beta
|
11
|
+
y
|
12
|
+
end
|
13
|
+
|
14
|
+
def x_hat(x, mean, inv_std)
|
15
|
+
x_mu = x - mean
|
16
|
+
x_mu *= inv_std
|
17
|
+
x_mu
|
18
|
+
end
|
19
|
+
|
20
|
+
def zero_if_none(xp, x, shape, dtype)
|
21
|
+
# TODO: Return broadcasted 0 instead of a zeroed array.
|
22
|
+
x.nil? ? dtype.zeros(*shape) : x
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
class BatchNormalization < Chainer::FunctionNode
|
27
|
+
include Calculation
|
5
28
|
attr_reader :running_mean, :running_var
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
# @param [Chainer::Variable] mean Shifting parameter of input.
|
16
|
-
# @param [Chainer::Variable] var Square of scaling parameter of input.
|
17
|
-
# @param [float] eps Epsilon value for numerical stability.
|
18
|
-
def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
|
19
|
-
old_train = Chainer.configuration.train
|
20
|
-
Chainer.configuration.train = false
|
21
|
-
norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
|
22
|
-
Chainer.configuration.train = old_train
|
23
|
-
norm
|
24
|
-
end
|
25
|
-
|
26
|
-
def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
|
29
|
+
|
30
|
+
def self.batch_normalization(x, gamma, beta, eps: 2e-5, running_mean: nil, running_var: nil, decay: 0.9)
|
31
|
+
BatchNormalization.new(eps: eps, mean: running_mean, var: running_var, decay: decay).apply([x, gamma, beta])[0]
|
32
|
+
end
|
33
|
+
|
34
|
+
def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
|
35
|
+
@mean = nil
|
36
|
+
@inv_std = nil
|
37
|
+
|
27
38
|
@running_mean = mean
|
28
39
|
@running_var = var
|
29
40
|
@eps = eps
|
30
|
-
@mean_cache = nil
|
31
41
|
@decay = decay
|
32
42
|
end
|
33
43
|
|
34
44
|
def forward(inputs)
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
@running_var = Numo::NArray[*@running_var]
|
43
|
-
end
|
44
|
-
elsif inputs.size == 5
|
45
|
-
@fixed_mean = inputs[3]
|
46
|
-
@fixed_var = inputs[4]
|
45
|
+
retain_inputs([0, 1])
|
46
|
+
x, gamma, beta = inputs
|
47
|
+
xp = Chainer.get_array_module(x)
|
48
|
+
|
49
|
+
if @running_mean.nil?
|
50
|
+
@running_mean = xp::NArray[*gamma].new_zeros
|
51
|
+
@running_var = xp::NArray[*gamma].new_zeros
|
47
52
|
end
|
48
53
|
|
54
|
+
# expander inserts singleton dimensions to gamma and beta so that they
|
55
|
+
# can be broadcasted with x.
|
49
56
|
head_ndim = gamma.ndim + 1
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
if Chainer.configuration.train
|
56
|
-
axis = [0] + (head_ndim...(x.ndim)).to_a
|
57
|
-
mean = x.mean(axis: axis)
|
58
|
-
# FIXME: numpy.var
|
59
|
-
var = x.var(axis: axis)
|
60
|
-
var += @eps
|
61
|
-
else
|
62
|
-
mean = @fixed_mean
|
63
|
-
var = @fixed_var + @eps
|
57
|
+
# TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
|
58
|
+
suffix = [1] * (x.ndim - head_ndim)
|
59
|
+
expander = -> (arr) do
|
60
|
+
shape = [1] + arr.shape + suffix
|
61
|
+
arr.reshape(*shape)
|
64
62
|
end
|
63
|
+
@expander = expander
|
64
|
+
@axis = [0] + (head_ndim...(x.ndim)).to_a
|
65
65
|
|
66
|
-
|
66
|
+
gamma = expander.(gamma)
|
67
|
+
beta = expander.(beta)
|
68
|
+
@mean = x.mean(axis: @axis)
|
67
69
|
|
68
|
-
|
69
|
-
|
70
|
-
std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
|
71
|
-
x_mu /= @std.reshape(*std_expander)
|
72
|
-
@x_hat = x_mu
|
73
|
-
y = gamma * @x_hat
|
74
|
-
y += beta
|
70
|
+
# TODO: Numo::Array can not be specified standard deviation
|
71
|
+
var = ((x - x.mean(axis: @axis, keepdims: true)) ** 2).mean(axis: @axis)
|
75
72
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
@running_var += temp_ar
|
88
|
-
end
|
73
|
+
var += @eps
|
74
|
+
@inv_std = var ** (-0.5)
|
75
|
+
|
76
|
+
y = apply_bn_fwd(xp, x, expander.(@mean), expander.(@inv_std), gamma, beta)
|
77
|
+
# Update running statistics
|
78
|
+
m = x.size.div(gamma.size)
|
79
|
+
adjust = m / [m - 1.0, 1.0].max
|
80
|
+
@running_mean *= @decay
|
81
|
+
@running_mean += (1 - @decay) * @mean
|
82
|
+
@running_var *= @decay
|
83
|
+
@running_var += (1 - @decay) * adjust * var
|
89
84
|
|
90
|
-
[y
|
85
|
+
[y]
|
86
|
+
end
|
87
|
+
|
88
|
+
def backward(indexes, grad_outputs)
|
89
|
+
x, gamma = get_retained_inputs
|
90
|
+
gy, = grad_outputs
|
91
|
+
|
92
|
+
# hatappi debug
|
93
|
+
#@mean = @mean.class.new(@mean.shape).seq
|
94
|
+
#@inv_std = @inv_std.class.new(@inv_std.shape).seq
|
95
|
+
#x.data = x.data.class.new(x.shape).seq
|
96
|
+
#gamma.data = gamma.data.class.new(gamma.shape).seq
|
97
|
+
#gy.data = gy.data.class.new(gy.shape).seq
|
98
|
+
|
99
|
+
f = BatchNormalizationGrad.new(@eps, @expander, @axis, @mean, @inv_std)
|
100
|
+
f.(x, gamma, gy)
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
class BatchNormalizationGrad < Function
|
105
|
+
include Calculation
|
106
|
+
|
107
|
+
def initialize(eps, expander, axis, mean, inv_std)
|
108
|
+
@eps = eps
|
109
|
+
@expander = expander
|
110
|
+
@axis = axis
|
111
|
+
@mean = mean
|
112
|
+
@inv_std = inv_std
|
113
|
+
end
|
114
|
+
|
115
|
+
def forward(inputs)
|
116
|
+
retain_inputs([0, 1, 2])
|
117
|
+
x, gamma, gy = inputs
|
118
|
+
expander = @expander
|
119
|
+
|
120
|
+
inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
|
121
|
+
xp = Chainer.get_array_module(x)
|
122
|
+
|
123
|
+
gbeta = gy.sum(axis: @axis)
|
124
|
+
x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
|
125
|
+
ggamma = (gy * x_hat).sum(axis: @axis)
|
126
|
+
gx = expander.(gamma * @inv_std) * (gy - (x_hat * expander.(ggamma) + expander.(gbeta)) * inv_m)
|
127
|
+
|
128
|
+
retain_outputs([0, 1])
|
129
|
+
[gx, ggamma, gbeta]
|
91
130
|
end
|
92
131
|
|
93
132
|
def backward(inputs, grad_outputs)
|
94
|
-
|
95
|
-
|
133
|
+
expander = @expander
|
134
|
+
|
135
|
+
x, gamma, gy = inputs
|
136
|
+
gx1, ggamma1, = output_data
|
137
|
+
ggx1, gggamma1, ggbeta1 = grad_outputs
|
138
|
+
xp = Chainer.get_array_module(x)
|
139
|
+
|
140
|
+
# auxiliary values
|
141
|
+
inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
|
142
|
+
r = ggx1.nil? ? 0 : (gx1 * ggx1).sum(axis: @axis)
|
143
|
+
coeff = gamma * @inv_std
|
144
|
+
coeff_m = coeff * inv_m
|
145
|
+
x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
|
146
|
+
|
147
|
+
# handle None in output gradients
|
148
|
+
ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
|
149
|
+
gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
|
150
|
+
ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
|
151
|
+
|
152
|
+
gggamma2 = gggamma1 - coeff_m * (x_hat * ggx1).sum(axis: @axis)
|
153
|
+
ggbeta2 = ggbeta1 - coeff_m * ggx1.sum(axis: @axis)
|
154
|
+
|
155
|
+
ggamma2 = r / gamma
|
156
|
+
|
157
|
+
gx_hat2 = (expander.(gggamma2) * gy - expander.(coeff_m * ggamma1) * ggx1)
|
158
|
+
gstd2 = -@inv_std * (r + (x_hat * gx_hat2).sum(axis: @axis))
|
159
|
+
gmean2 = -@inv_std * gx_hat2.sum(axis: @axis)
|
160
|
+
gx2 = expander.(@inv_std) * gx_hat2 + inv_m * (expander.(gmean2) + x_hat * expander.(gstd2))
|
161
|
+
ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(coeff) * ggx1)
|
162
|
+
|
163
|
+
[gx2, ggamma2, ggy2]
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
class FixedBatchNormalization < FunctionNode
|
168
|
+
include Calculation
|
169
|
+
|
170
|
+
attr_reader :inv_var
|
171
|
+
|
172
|
+
def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
|
173
|
+
FixedBatchNormalization.new(eps: eps).apply([x, gamma, beta, mean, var]).first
|
174
|
+
end
|
175
|
+
|
176
|
+
def initialize(eps: 2e-5)
|
177
|
+
@inv_std = nil
|
178
|
+
@inv_var = nil
|
179
|
+
@eps = eps
|
180
|
+
end
|
181
|
+
|
182
|
+
def forward(inputs)
|
183
|
+
retain_inputs([0, 1, 3, 4])
|
184
|
+
x, gamma, beta, mean, var = inputs
|
185
|
+
xp = Chainer.get_array_module(x)
|
186
|
+
|
187
|
+
# expander inserts singleton dimensions to gamma and beta so that they
|
188
|
+
# can be broadcasted with x.
|
96
189
|
head_ndim = gamma.ndim + 1
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
var = inputs[4]
|
103
|
-
std = Numo::NMath.sqrt(var)
|
104
|
-
gs = gamma / std
|
105
|
-
gbeta = gy.sum(axis: axis)
|
106
|
-
|
107
|
-
mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
|
108
|
-
x_mu = x - mean.reshape(*mean_expander)
|
109
|
-
std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
|
110
|
-
x_mu /= std.reshape(*std_expander)
|
111
|
-
x_hat = x_mu
|
112
|
-
ggamma = (gy * x_hat).sum(axis: axis)
|
113
|
-
gmean = -gs * gbeta
|
114
|
-
gvar = -0.5 * gamma / var * ggamma
|
115
|
-
gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
|
116
|
-
gx = gs.reshape(*gs_expander)
|
117
|
-
return [gx, ggamma, gbeta, gmean, gvar]
|
190
|
+
# TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
|
191
|
+
suffix = [1] * (x.ndim - head_ndim)
|
192
|
+
expander = -> (arr) do
|
193
|
+
shape = [1] + arr.shape + suffix
|
194
|
+
arr.reshape(*shape)
|
118
195
|
end
|
196
|
+
@expander = expander
|
197
|
+
@axis = [0] + (head_ndim...(x.ndim)).to_a
|
198
|
+
|
199
|
+
gamma = expander.(gamma)
|
200
|
+
beta = expander.(beta)
|
201
|
+
var += @eps
|
202
|
+
@inv_var = var.reciprocal
|
203
|
+
@inv_std = xp::NMath.sqrt(@inv_var)
|
119
204
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
|
124
|
-
tmp = tmp.reshape(*tmp_expander)
|
205
|
+
y = apply_bn_fwd(xp, x, expander.(mean), expander.(@inv_std), gamma, beta)
|
206
|
+
[y]
|
207
|
+
end
|
125
208
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
209
|
+
def backward(indexes, grad_outputs)
|
210
|
+
x, gamma, mean, var = get_retained_inputs
|
211
|
+
gy, = grad_outputs
|
212
|
+
f = FixedBatchNormalizationGrad.new(@eps, @expander, @axis, @inv_std, @inv_var)
|
213
|
+
f.(x, gamma, mean, var, gy)
|
214
|
+
end
|
215
|
+
end
|
130
216
|
|
131
|
-
|
217
|
+
class FixedBatchNormalizationGrad < Function
|
218
|
+
include Calculation
|
219
|
+
|
220
|
+
def initialize(eps, expander, axis, inv_std, inv_var)
|
221
|
+
@eps = eps
|
222
|
+
@expander = expander
|
223
|
+
@axis = axis
|
224
|
+
@inv_std = inv_std
|
225
|
+
@inv_var = inv_var
|
226
|
+
end
|
227
|
+
|
228
|
+
def forward(inputs)
|
229
|
+
retain_inputs([0, 1, 2, 4])
|
230
|
+
x, gamma, mean, var, gy = inputs
|
231
|
+
expander = @expander
|
232
|
+
xp = Chainer.get_array_module(x)
|
233
|
+
|
234
|
+
if @inv_std.nil? || @inv_var.nil?
|
235
|
+
@inv_var = (var + @eps).reciprocal
|
236
|
+
@inv_std = xp::NMath.sqrt(@inv_var)
|
237
|
+
end
|
238
|
+
|
239
|
+
@gamma_over_std = gamma * @inv_std
|
240
|
+
x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
|
241
|
+
|
242
|
+
gx = expander.(@gamma_over_std) * gy
|
243
|
+
gbeta = gy.sum(axis: @axis)
|
244
|
+
ggamma = (x_hat * gy).sum(axis: @axis)
|
245
|
+
gmean = -@gamma_over_std * gbeta
|
246
|
+
gvar = -0.5 * gamma * @inv_var * ggamma
|
247
|
+
|
248
|
+
retain_outputs([0, 1, 2, 3, 4])
|
249
|
+
[gx, ggamma, gbeta, gmean, gvar]
|
250
|
+
end
|
251
|
+
|
252
|
+
def backward(inputs, grad_outputs)
|
253
|
+
x, gamma, mean, _, gy = inputs
|
254
|
+
ggx1, gggamma1, ggbeta1, ggmean1, ggvar1 = grad_outputs
|
255
|
+
gx1, ggamma1, gbeta1, gmean1, gvar1 = output_data
|
256
|
+
|
257
|
+
# Handle None in output gradients.
|
258
|
+
xp = Chainer.get_array_module(x)
|
259
|
+
ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
|
260
|
+
gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
|
261
|
+
ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
|
262
|
+
ggmean1 = zero_if_none(xp, ggmean1, mean.shape, mean.class)
|
263
|
+
ggvar1 = zero_if_none(xp, ggvar1, mean.shape, mean.class)
|
264
|
+
|
265
|
+
expander = @expander
|
266
|
+
x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
|
267
|
+
tmp = -0.5 * ggvar1
|
268
|
+
|
269
|
+
gamma_over_var = gamma * @inv_var
|
270
|
+
g_gamma_over_var = tmp * ggamma1
|
271
|
+
|
272
|
+
gggamma2 = gggamma1 + tmp * gamma_over_var
|
273
|
+
gx_hat = gy * expander.(gggamma2)
|
274
|
+
gx2 = expander.(@inv_std) * gx_hat
|
275
|
+
gmean2 = -@inv_std * gx_hat.sum(axis: @axis)
|
276
|
+
|
277
|
+
g_gamma_over_std = (ggx1 * gy).sum(axis: @axis) - ggmean1 * gbeta1
|
278
|
+
ggbeta2 = ggbeta1 - ggmean1 * @gamma_over_std
|
279
|
+
ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(@gamma_over_std) * ggx1)
|
280
|
+
|
281
|
+
ggamma2 = (@inv_var * g_gamma_over_var + @inv_std * g_gamma_over_std)
|
282
|
+
gvar2 = -(ggamma2 * gamma_over_var + 0.5 * @inv_var * ((x_hat * gx_hat).sum(axis: @axis) - @gamma_over_std * g_gamma_over_std))
|
283
|
+
|
284
|
+
[gx2, ggamma2, gmean2, gvar2, ggy2]
|
132
285
|
end
|
133
286
|
end
|
134
287
|
end
|