red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -1,66 +1,71 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Math
|
4
|
-
|
5
|
-
|
4
|
+
class Neg < ::Chainer::FunctionNode
|
5
|
+
def label
|
6
|
+
'__neg__'
|
7
|
+
end
|
8
|
+
|
6
9
|
def forward(x)
|
7
|
-
retain_inputs([])
|
8
10
|
[Utils::Array.force_array(-x[0])]
|
9
11
|
end
|
10
12
|
|
11
|
-
def backward(
|
12
|
-
[
|
13
|
+
def backward(indexes, gy)
|
14
|
+
[-gy[0]]
|
13
15
|
end
|
14
16
|
end
|
15
17
|
|
16
|
-
class Add < ::Chainer::
|
18
|
+
class Add < ::Chainer::FunctionNode
|
17
19
|
def forward(x)
|
18
|
-
retain_inputs([])
|
19
20
|
[Utils::Array.force_array(x[0] + x[1])]
|
20
21
|
end
|
21
22
|
|
22
|
-
def backward(
|
23
|
+
def backward(indexes, gy)
|
23
24
|
[gy[0], gy[0]]
|
24
25
|
end
|
25
26
|
end
|
26
27
|
|
27
|
-
class AddConstant < ::Chainer::
|
28
|
+
class AddConstant < ::Chainer::FunctionNode
|
28
29
|
def initialize(value)
|
29
30
|
@value = value
|
30
31
|
end
|
31
32
|
|
32
33
|
def forward(x)
|
33
|
-
retain_inputs([])
|
34
34
|
[Utils::Array.force_array(x[0] + @value)]
|
35
35
|
end
|
36
36
|
|
37
|
-
def backward(
|
37
|
+
def backward(indexes, gy)
|
38
38
|
[gy[0]]
|
39
39
|
end
|
40
40
|
end
|
41
|
-
|
42
|
-
class Sub < ::Chainer::
|
41
|
+
|
42
|
+
class Sub < ::Chainer::FunctionNode
|
43
|
+
def label
|
44
|
+
'_ - _'
|
45
|
+
end
|
46
|
+
|
43
47
|
def forward(x)
|
44
|
-
retain_inputs([])
|
45
48
|
[Utils::Array.force_array(x[0] - x[1])]
|
46
49
|
end
|
47
50
|
|
48
|
-
def backward(
|
49
|
-
[gy[0],
|
51
|
+
def backward(indexes, gy)
|
52
|
+
[gy[0], -gy[0]]
|
50
53
|
end
|
51
54
|
end
|
52
55
|
|
53
|
-
class Mul < ::Chainer::
|
56
|
+
class Mul < ::Chainer::FunctionNode
|
54
57
|
def forward(x)
|
58
|
+
retain_inputs([0, 1])
|
55
59
|
[Utils::Array.force_array(x[0] * x[1])]
|
56
60
|
end
|
57
61
|
|
58
|
-
def backward(
|
59
|
-
|
62
|
+
def backward(indexes, gy)
|
63
|
+
xs = get_retained_inputs
|
64
|
+
indexes.map { |i| gy[0] * xs[1 - i] }
|
60
65
|
end
|
61
66
|
end
|
62
67
|
|
63
|
-
class MulConstant < ::Chainer::
|
68
|
+
class MulConstant < ::Chainer::FunctionNode
|
64
69
|
def initialize(value)
|
65
70
|
@value = value
|
66
71
|
end
|
@@ -69,23 +74,23 @@ module Chainer
|
|
69
74
|
[Utils::Array.force_array(@value * x[0])]
|
70
75
|
end
|
71
76
|
|
72
|
-
def backward(
|
73
|
-
[
|
77
|
+
def backward(indexes, gy)
|
78
|
+
[gy[0] * @value]
|
74
79
|
end
|
75
80
|
end
|
76
|
-
|
77
|
-
class Div < ::Chainer::
|
81
|
+
|
82
|
+
class Div < ::Chainer::FunctionNode
|
78
83
|
def forward(x)
|
79
84
|
[Utils::Array.force_array(x[0] / x[1])]
|
80
85
|
end
|
81
86
|
|
82
|
-
def backward(
|
87
|
+
def backward(indexes, gy)
|
83
88
|
gx0 = Utils::Array.force_array(gy[0] / x[1])
|
84
89
|
[gx0, Utils::Array.force_array(-1 * gx0 * x[0] / x[1])]
|
85
90
|
end
|
86
91
|
end
|
87
|
-
|
88
|
-
class PowVarVar < ::Chainer::
|
92
|
+
|
93
|
+
class PowVarVar < ::Chainer::FunctionNode
|
89
94
|
def forward(x)
|
90
95
|
@y = Utils::Array.force_array(x[0] ** x[1])
|
91
96
|
[@y]
|
@@ -94,12 +99,13 @@ module Chainer
|
|
94
99
|
def backward(x, gy)
|
95
100
|
one = x[1].class.ones[0]
|
96
101
|
gx0 = Utils::Array.force_array(x[1] * (x[0] ** (x[1] - one)) * gy[0])
|
97
|
-
|
102
|
+
xm = Chainer.get_array_module(x[0])
|
103
|
+
gx1 = Utils::Array.force_array(xm::NMath.log(x[0]) * @y * gy[0])
|
98
104
|
[gx0, gx1]
|
99
105
|
end
|
100
106
|
end
|
101
107
|
|
102
|
-
class PowVarConst < ::Chainer::
|
108
|
+
class PowVarConst < ::Chainer::FunctionNode
|
103
109
|
def initialize(value)
|
104
110
|
@value = value
|
105
111
|
end
|
@@ -113,7 +119,7 @@ module Chainer
|
|
113
119
|
gx = @value * (x[0] ** val_1) * gy[0]
|
114
120
|
[Utils::Array.force_array(gx)]
|
115
121
|
end
|
116
|
-
end
|
122
|
+
end
|
117
123
|
end
|
118
124
|
end
|
119
125
|
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Math
|
4
|
+
class Exp < Chainer::FunctionNode
|
5
|
+
# Elementwise exponential function.
|
6
|
+
def self.exp(x)
|
7
|
+
self.new.apply([x]).first
|
8
|
+
end
|
9
|
+
|
10
|
+
def label
|
11
|
+
'exp'
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(x)
|
15
|
+
retain_inputs([])
|
16
|
+
retain_outputs([0])
|
17
|
+
xm = Chainer.get_array_module(x.first)
|
18
|
+
[Utils::Array.force_array(xm::NMath.exp(x.first))]
|
19
|
+
end
|
20
|
+
|
21
|
+
def backward(indexes, gy)
|
22
|
+
y = get_retained_outputs.first
|
23
|
+
[y * gy.first]
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -2,7 +2,7 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Math
|
4
4
|
# Identity function.
|
5
|
-
class Identity < Chainer::
|
5
|
+
class Identity < Chainer::FunctionNode
|
6
6
|
def check_type_forward(in_types)
|
7
7
|
# pass
|
8
8
|
end
|
@@ -12,13 +12,14 @@ module Chainer
|
|
12
12
|
return xs
|
13
13
|
end
|
14
14
|
|
15
|
-
def backward(
|
15
|
+
def backward(indexes, gys)
|
16
16
|
return gys
|
17
17
|
end
|
18
18
|
|
19
19
|
# Just returns input variables.
|
20
20
|
def self.identity(*inputs)
|
21
|
-
self.new.(
|
21
|
+
ret = self.new.apply(inputs)
|
22
|
+
ret.size == 1 ? ret[0] : ret
|
22
23
|
end
|
23
24
|
end
|
24
25
|
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Math
|
4
|
+
# Sum of array elements over a given axis.
|
5
|
+
class Sum < Chainer::FunctionNode
|
6
|
+
# Sum of array elements over a given axis
|
7
|
+
#
|
8
|
+
# @param [Chainer::Variable] x Elements to sum
|
9
|
+
# @param [nil, Integer, Array<Integer>] axis Axis which a sum is performed
|
10
|
+
# @param[boolean] keepdims If `true`, the specified axes are remained as axes of length one
|
11
|
+
# @return [Chainer::Variable] Output variable
|
12
|
+
def self.sum(x, axis: nil, keepdims: false)
|
13
|
+
Sum.new(axis: axis, keepdims: keepdims).apply([x]).first
|
14
|
+
end
|
15
|
+
|
16
|
+
def initialize(axis: nil, keepdims: false)
|
17
|
+
if axis.nil?
|
18
|
+
@axis = nil
|
19
|
+
elsif axis.is_a?(Integer)
|
20
|
+
@axis = [axis]
|
21
|
+
elsif axis.is_a?(::Array) && axis.all? { |e| e.is_a?(Integer) }
|
22
|
+
raise ArgumentError, "duplicate value in axis: #{axis}" unless axis.uniq.size == axis.size
|
23
|
+
@axis = axis
|
24
|
+
else
|
25
|
+
raise TypeError, 'nil, Integer or Array of int are required'
|
26
|
+
end
|
27
|
+
|
28
|
+
@keepdims = keepdims
|
29
|
+
end
|
30
|
+
|
31
|
+
def forward(inputs)
|
32
|
+
x = inputs.first
|
33
|
+
ret = x.sum(axis: @axis, keepdims: @keepdims)
|
34
|
+
ret = x.class.cast(ret)
|
35
|
+
[ret]
|
36
|
+
end
|
37
|
+
|
38
|
+
def backward(indexes, grad_outputs)
|
39
|
+
gy = grad_outputs.first
|
40
|
+
ndim = @inputs.first.shape.size
|
41
|
+
unless ndim == 0 || @axis.nil? || @keepdims
|
42
|
+
actual_axis = @axis.map { |axis| axis >= 0 ? axis : axis + ndim }
|
43
|
+
shape = gy.shape
|
44
|
+
actual_axis.sort.each { |axis| shape.insert(axis, 1) }
|
45
|
+
gy = Chainer::Functions::Array::Reshape.reshape(gy, shape)
|
46
|
+
end
|
47
|
+
[Chainer::Functions::Array::BroadcastTo.broadcast_to(gy, @inputs.first.shape)]
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
@@ -1,7 +1,8 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Noise
|
4
|
-
class Dropout < Chainer::
|
4
|
+
class Dropout < Chainer::FunctionNode
|
5
|
+
attr_reader :mask
|
5
6
|
# Drops elements of input variable randomly.
|
6
7
|
#
|
7
8
|
# This function drops input elements randomly with probability `ratio` and
|
@@ -12,7 +13,7 @@ module Chainer
|
|
12
13
|
# @param [float] ratio Dropout ratio. The ``ratio`` must be `0.0 <= ratio < 1.0`.
|
13
14
|
# @return [Chainer::Variable] Output variable.
|
14
15
|
def self.dropout(x, ratio: 0.5)
|
15
|
-
Chainer.configuration.train ? self.new(ratio).(x) : x
|
16
|
+
Chainer.configuration.train ? self.new(ratio).apply([x])[0] : Chainer::Variable.as_variable(x)
|
16
17
|
end
|
17
18
|
|
18
19
|
def initialize(dropout_ratio)
|
@@ -23,7 +24,6 @@ module Chainer
|
|
23
24
|
end
|
24
25
|
|
25
26
|
def forward(x)
|
26
|
-
retain_inputs([])
|
27
27
|
unless self.instance_variable_defined?(:@mask)
|
28
28
|
scale = x[0].class[*[1.0 / (1 - @dropout_ratio)]][0]
|
29
29
|
flag = x[0].class.new(*x[0].shape).rand >= @dropout_ratio
|
@@ -36,7 +36,23 @@ module Chainer
|
|
36
36
|
end
|
37
37
|
|
38
38
|
def backward(x, gy)
|
39
|
-
|
39
|
+
DropoutGrad.new(@mask).apply(gy)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
# Computes the gradient of the Dropout function.
|
44
|
+
class DropoutGrad < Chainer::FunctionNode
|
45
|
+
def initialize(mask)
|
46
|
+
@mask = mask
|
47
|
+
end
|
48
|
+
|
49
|
+
def forward(inputs)
|
50
|
+
y = inputs.first * @mask
|
51
|
+
[y]
|
52
|
+
end
|
53
|
+
|
54
|
+
def backward(indexes, gy)
|
55
|
+
DropoutGrad.new(@mask).apply(gy)
|
40
56
|
end
|
41
57
|
end
|
42
58
|
end
|
@@ -1,134 +1,287 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Normalization
|
4
|
-
|
4
|
+
module Calculation
|
5
|
+
def apply_bn_fwd(xp, x, mean, inv_std, gamma, beta)
|
6
|
+
# NOTE: all arguments should be broadcasted to x.shape
|
7
|
+
# (mean, inv_std, gamma, and beta have to already be expanded)
|
8
|
+
x_hat = x_hat(x, mean, inv_std)
|
9
|
+
y = gamma * x_hat
|
10
|
+
y += beta
|
11
|
+
y
|
12
|
+
end
|
13
|
+
|
14
|
+
def x_hat(x, mean, inv_std)
|
15
|
+
x_mu = x - mean
|
16
|
+
x_mu *= inv_std
|
17
|
+
x_mu
|
18
|
+
end
|
19
|
+
|
20
|
+
def zero_if_none(xp, x, shape, dtype)
|
21
|
+
# TODO: Return broadcasted 0 instead of a zeroed array.
|
22
|
+
x.nil? ? dtype.zeros(*shape) : x
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
class BatchNormalization < Chainer::FunctionNode
|
27
|
+
include Calculation
|
5
28
|
attr_reader :running_mean, :running_var
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
# @param [Chainer::Variable] mean Shifting parameter of input.
|
16
|
-
# @param [Chainer::Variable] var Square of scaling parameter of input.
|
17
|
-
# @param [float] eps Epsilon value for numerical stability.
|
18
|
-
def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
|
19
|
-
old_train = Chainer.configuration.train
|
20
|
-
Chainer.configuration.train = false
|
21
|
-
norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
|
22
|
-
Chainer.configuration.train = old_train
|
23
|
-
norm
|
24
|
-
end
|
25
|
-
|
26
|
-
def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
|
29
|
+
|
30
|
+
def self.batch_normalization(x, gamma, beta, eps: 2e-5, running_mean: nil, running_var: nil, decay: 0.9)
|
31
|
+
BatchNormalization.new(eps: eps, mean: running_mean, var: running_var, decay: decay).apply([x, gamma, beta])[0]
|
32
|
+
end
|
33
|
+
|
34
|
+
def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
|
35
|
+
@mean = nil
|
36
|
+
@inv_std = nil
|
37
|
+
|
27
38
|
@running_mean = mean
|
28
39
|
@running_var = var
|
29
40
|
@eps = eps
|
30
|
-
@mean_cache = nil
|
31
41
|
@decay = decay
|
32
42
|
end
|
33
43
|
|
34
44
|
def forward(inputs)
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
@running_var = Numo::NArray[*@running_var]
|
43
|
-
end
|
44
|
-
elsif inputs.size == 5
|
45
|
-
@fixed_mean = inputs[3]
|
46
|
-
@fixed_var = inputs[4]
|
45
|
+
retain_inputs([0, 1])
|
46
|
+
x, gamma, beta = inputs
|
47
|
+
xp = Chainer.get_array_module(x)
|
48
|
+
|
49
|
+
if @running_mean.nil?
|
50
|
+
@running_mean = xp::NArray[*gamma].new_zeros
|
51
|
+
@running_var = xp::NArray[*gamma].new_zeros
|
47
52
|
end
|
48
53
|
|
54
|
+
# expander inserts singleton dimensions to gamma and beta so that they
|
55
|
+
# can be broadcasted with x.
|
49
56
|
head_ndim = gamma.ndim + 1
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
if Chainer.configuration.train
|
56
|
-
axis = [0] + (head_ndim...(x.ndim)).to_a
|
57
|
-
mean = x.mean(axis: axis)
|
58
|
-
# FIXME: numpy.var
|
59
|
-
var = x.var(axis: axis)
|
60
|
-
var += @eps
|
61
|
-
else
|
62
|
-
mean = @fixed_mean
|
63
|
-
var = @fixed_var + @eps
|
57
|
+
# TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
|
58
|
+
suffix = [1] * (x.ndim - head_ndim)
|
59
|
+
expander = -> (arr) do
|
60
|
+
shape = [1] + arr.shape + suffix
|
61
|
+
arr.reshape(*shape)
|
64
62
|
end
|
63
|
+
@expander = expander
|
64
|
+
@axis = [0] + (head_ndim...(x.ndim)).to_a
|
65
65
|
|
66
|
-
|
66
|
+
gamma = expander.(gamma)
|
67
|
+
beta = expander.(beta)
|
68
|
+
@mean = x.mean(axis: @axis)
|
67
69
|
|
68
|
-
|
69
|
-
|
70
|
-
std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
|
71
|
-
x_mu /= @std.reshape(*std_expander)
|
72
|
-
@x_hat = x_mu
|
73
|
-
y = gamma * @x_hat
|
74
|
-
y += beta
|
70
|
+
# TODO: Numo::Array can not be specified standard deviation
|
71
|
+
var = ((x - x.mean(axis: @axis, keepdims: true)) ** 2).mean(axis: @axis)
|
75
72
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
@running_var += temp_ar
|
88
|
-
end
|
73
|
+
var += @eps
|
74
|
+
@inv_std = var ** (-0.5)
|
75
|
+
|
76
|
+
y = apply_bn_fwd(xp, x, expander.(@mean), expander.(@inv_std), gamma, beta)
|
77
|
+
# Update running statistics
|
78
|
+
m = x.size.div(gamma.size)
|
79
|
+
adjust = m / [m - 1.0, 1.0].max
|
80
|
+
@running_mean *= @decay
|
81
|
+
@running_mean += (1 - @decay) * @mean
|
82
|
+
@running_var *= @decay
|
83
|
+
@running_var += (1 - @decay) * adjust * var
|
89
84
|
|
90
|
-
[y
|
85
|
+
[y]
|
86
|
+
end
|
87
|
+
|
88
|
+
def backward(indexes, grad_outputs)
|
89
|
+
x, gamma = get_retained_inputs
|
90
|
+
gy, = grad_outputs
|
91
|
+
|
92
|
+
# hatappi debug
|
93
|
+
#@mean = @mean.class.new(@mean.shape).seq
|
94
|
+
#@inv_std = @inv_std.class.new(@inv_std.shape).seq
|
95
|
+
#x.data = x.data.class.new(x.shape).seq
|
96
|
+
#gamma.data = gamma.data.class.new(gamma.shape).seq
|
97
|
+
#gy.data = gy.data.class.new(gy.shape).seq
|
98
|
+
|
99
|
+
f = BatchNormalizationGrad.new(@eps, @expander, @axis, @mean, @inv_std)
|
100
|
+
f.(x, gamma, gy)
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
class BatchNormalizationGrad < Function
|
105
|
+
include Calculation
|
106
|
+
|
107
|
+
def initialize(eps, expander, axis, mean, inv_std)
|
108
|
+
@eps = eps
|
109
|
+
@expander = expander
|
110
|
+
@axis = axis
|
111
|
+
@mean = mean
|
112
|
+
@inv_std = inv_std
|
113
|
+
end
|
114
|
+
|
115
|
+
def forward(inputs)
|
116
|
+
retain_inputs([0, 1, 2])
|
117
|
+
x, gamma, gy = inputs
|
118
|
+
expander = @expander
|
119
|
+
|
120
|
+
inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
|
121
|
+
xp = Chainer.get_array_module(x)
|
122
|
+
|
123
|
+
gbeta = gy.sum(axis: @axis)
|
124
|
+
x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
|
125
|
+
ggamma = (gy * x_hat).sum(axis: @axis)
|
126
|
+
gx = expander.(gamma * @inv_std) * (gy - (x_hat * expander.(ggamma) + expander.(gbeta)) * inv_m)
|
127
|
+
|
128
|
+
retain_outputs([0, 1])
|
129
|
+
[gx, ggamma, gbeta]
|
91
130
|
end
|
92
131
|
|
93
132
|
def backward(inputs, grad_outputs)
|
94
|
-
|
95
|
-
|
133
|
+
expander = @expander
|
134
|
+
|
135
|
+
x, gamma, gy = inputs
|
136
|
+
gx1, ggamma1, = output_data
|
137
|
+
ggx1, gggamma1, ggbeta1 = grad_outputs
|
138
|
+
xp = Chainer.get_array_module(x)
|
139
|
+
|
140
|
+
# auxiliary values
|
141
|
+
inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
|
142
|
+
r = ggx1.nil? ? 0 : (gx1 * ggx1).sum(axis: @axis)
|
143
|
+
coeff = gamma * @inv_std
|
144
|
+
coeff_m = coeff * inv_m
|
145
|
+
x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
|
146
|
+
|
147
|
+
# handle None in output gradients
|
148
|
+
ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
|
149
|
+
gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
|
150
|
+
ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
|
151
|
+
|
152
|
+
gggamma2 = gggamma1 - coeff_m * (x_hat * ggx1).sum(axis: @axis)
|
153
|
+
ggbeta2 = ggbeta1 - coeff_m * ggx1.sum(axis: @axis)
|
154
|
+
|
155
|
+
ggamma2 = r / gamma
|
156
|
+
|
157
|
+
gx_hat2 = (expander.(gggamma2) * gy - expander.(coeff_m * ggamma1) * ggx1)
|
158
|
+
gstd2 = -@inv_std * (r + (x_hat * gx_hat2).sum(axis: @axis))
|
159
|
+
gmean2 = -@inv_std * gx_hat2.sum(axis: @axis)
|
160
|
+
gx2 = expander.(@inv_std) * gx_hat2 + inv_m * (expander.(gmean2) + x_hat * expander.(gstd2))
|
161
|
+
ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(coeff) * ggx1)
|
162
|
+
|
163
|
+
[gx2, ggamma2, ggy2]
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
class FixedBatchNormalization < FunctionNode
|
168
|
+
include Calculation
|
169
|
+
|
170
|
+
attr_reader :inv_var
|
171
|
+
|
172
|
+
def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
|
173
|
+
FixedBatchNormalization.new(eps: eps).apply([x, gamma, beta, mean, var]).first
|
174
|
+
end
|
175
|
+
|
176
|
+
def initialize(eps: 2e-5)
|
177
|
+
@inv_std = nil
|
178
|
+
@inv_var = nil
|
179
|
+
@eps = eps
|
180
|
+
end
|
181
|
+
|
182
|
+
def forward(inputs)
|
183
|
+
retain_inputs([0, 1, 3, 4])
|
184
|
+
x, gamma, beta, mean, var = inputs
|
185
|
+
xp = Chainer.get_array_module(x)
|
186
|
+
|
187
|
+
# expander inserts singleton dimensions to gamma and beta so that they
|
188
|
+
# can be broadcasted with x.
|
96
189
|
head_ndim = gamma.ndim + 1
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
var = inputs[4]
|
103
|
-
std = Numo::NMath.sqrt(var)
|
104
|
-
gs = gamma / std
|
105
|
-
gbeta = gy.sum(axis: axis)
|
106
|
-
|
107
|
-
mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
|
108
|
-
x_mu = x - mean.reshape(*mean_expander)
|
109
|
-
std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
|
110
|
-
x_mu /= std.reshape(*std_expander)
|
111
|
-
x_hat = x_mu
|
112
|
-
ggamma = (gy * x_hat).sum(axis: axis)
|
113
|
-
gmean = -gs * gbeta
|
114
|
-
gvar = -0.5 * gamma / var * ggamma
|
115
|
-
gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
|
116
|
-
gx = gs.reshape(*gs_expander)
|
117
|
-
return [gx, ggamma, gbeta, gmean, gvar]
|
190
|
+
# TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
|
191
|
+
suffix = [1] * (x.ndim - head_ndim)
|
192
|
+
expander = -> (arr) do
|
193
|
+
shape = [1] + arr.shape + suffix
|
194
|
+
arr.reshape(*shape)
|
118
195
|
end
|
196
|
+
@expander = expander
|
197
|
+
@axis = [0] + (head_ndim...(x.ndim)).to_a
|
198
|
+
|
199
|
+
gamma = expander.(gamma)
|
200
|
+
beta = expander.(beta)
|
201
|
+
var += @eps
|
202
|
+
@inv_var = var.reciprocal
|
203
|
+
@inv_std = xp::NMath.sqrt(@inv_var)
|
119
204
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
|
124
|
-
tmp = tmp.reshape(*tmp_expander)
|
205
|
+
y = apply_bn_fwd(xp, x, expander.(mean), expander.(@inv_std), gamma, beta)
|
206
|
+
[y]
|
207
|
+
end
|
125
208
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
209
|
+
def backward(indexes, grad_outputs)
|
210
|
+
x, gamma, mean, var = get_retained_inputs
|
211
|
+
gy, = grad_outputs
|
212
|
+
f = FixedBatchNormalizationGrad.new(@eps, @expander, @axis, @inv_std, @inv_var)
|
213
|
+
f.(x, gamma, mean, var, gy)
|
214
|
+
end
|
215
|
+
end
|
130
216
|
|
131
|
-
|
217
|
+
class FixedBatchNormalizationGrad < Function
|
218
|
+
include Calculation
|
219
|
+
|
220
|
+
def initialize(eps, expander, axis, inv_std, inv_var)
|
221
|
+
@eps = eps
|
222
|
+
@expander = expander
|
223
|
+
@axis = axis
|
224
|
+
@inv_std = inv_std
|
225
|
+
@inv_var = inv_var
|
226
|
+
end
|
227
|
+
|
228
|
+
def forward(inputs)
|
229
|
+
retain_inputs([0, 1, 2, 4])
|
230
|
+
x, gamma, mean, var, gy = inputs
|
231
|
+
expander = @expander
|
232
|
+
xp = Chainer.get_array_module(x)
|
233
|
+
|
234
|
+
if @inv_std.nil? || @inv_var.nil?
|
235
|
+
@inv_var = (var + @eps).reciprocal
|
236
|
+
@inv_std = xp::NMath.sqrt(@inv_var)
|
237
|
+
end
|
238
|
+
|
239
|
+
@gamma_over_std = gamma * @inv_std
|
240
|
+
x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
|
241
|
+
|
242
|
+
gx = expander.(@gamma_over_std) * gy
|
243
|
+
gbeta = gy.sum(axis: @axis)
|
244
|
+
ggamma = (x_hat * gy).sum(axis: @axis)
|
245
|
+
gmean = -@gamma_over_std * gbeta
|
246
|
+
gvar = -0.5 * gamma * @inv_var * ggamma
|
247
|
+
|
248
|
+
retain_outputs([0, 1, 2, 3, 4])
|
249
|
+
[gx, ggamma, gbeta, gmean, gvar]
|
250
|
+
end
|
251
|
+
|
252
|
+
def backward(inputs, grad_outputs)
|
253
|
+
x, gamma, mean, _, gy = inputs
|
254
|
+
ggx1, gggamma1, ggbeta1, ggmean1, ggvar1 = grad_outputs
|
255
|
+
gx1, ggamma1, gbeta1, gmean1, gvar1 = output_data
|
256
|
+
|
257
|
+
# Handle None in output gradients.
|
258
|
+
xp = Chainer.get_array_module(x)
|
259
|
+
ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
|
260
|
+
gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
|
261
|
+
ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
|
262
|
+
ggmean1 = zero_if_none(xp, ggmean1, mean.shape, mean.class)
|
263
|
+
ggvar1 = zero_if_none(xp, ggvar1, mean.shape, mean.class)
|
264
|
+
|
265
|
+
expander = @expander
|
266
|
+
x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
|
267
|
+
tmp = -0.5 * ggvar1
|
268
|
+
|
269
|
+
gamma_over_var = gamma * @inv_var
|
270
|
+
g_gamma_over_var = tmp * ggamma1
|
271
|
+
|
272
|
+
gggamma2 = gggamma1 + tmp * gamma_over_var
|
273
|
+
gx_hat = gy * expander.(gggamma2)
|
274
|
+
gx2 = expander.(@inv_std) * gx_hat
|
275
|
+
gmean2 = -@inv_std * gx_hat.sum(axis: @axis)
|
276
|
+
|
277
|
+
g_gamma_over_std = (ggx1 * gy).sum(axis: @axis) - ggmean1 * gbeta1
|
278
|
+
ggbeta2 = ggbeta1 - ggmean1 * @gamma_over_std
|
279
|
+
ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(@gamma_over_std) * ggx1)
|
280
|
+
|
281
|
+
ggamma2 = (@inv_var * g_gamma_over_var + @inv_std * g_gamma_over_std)
|
282
|
+
gvar2 = -(ggamma2 * gamma_over_var + 0.5 * @inv_var * ((x_hat * gx_hat).sum(axis: @axis) - @gamma_over_std * g_gamma_over_std))
|
283
|
+
|
284
|
+
[gx2, ggamma2, gmean2, gvar2, ggy2]
|
132
285
|
end
|
133
286
|
end
|
134
287
|
end
|