red-chainer 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -1,66 +1,71 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Math
4
-
5
- class Neg < ::Chainer::Function
4
+ class Neg < ::Chainer::FunctionNode
5
+ def label
6
+ '__neg__'
7
+ end
8
+
6
9
  def forward(x)
7
- retain_inputs([])
8
10
  [Utils::Array.force_array(-x[0])]
9
11
  end
10
12
 
11
- def backward(x, gy)
12
- [Utils::Array.force_array(-gy[0])]
13
+ def backward(indexes, gy)
14
+ [-gy[0]]
13
15
  end
14
16
  end
15
17
 
16
- class Add < ::Chainer::Function
18
+ class Add < ::Chainer::FunctionNode
17
19
  def forward(x)
18
- retain_inputs([])
19
20
  [Utils::Array.force_array(x[0] + x[1])]
20
21
  end
21
22
 
22
- def backward(x, gy)
23
+ def backward(indexes, gy)
23
24
  [gy[0], gy[0]]
24
25
  end
25
26
  end
26
27
 
27
- class AddConstant < ::Chainer::Function
28
+ class AddConstant < ::Chainer::FunctionNode
28
29
  def initialize(value)
29
30
  @value = value
30
31
  end
31
32
 
32
33
  def forward(x)
33
- retain_inputs([])
34
34
  [Utils::Array.force_array(x[0] + @value)]
35
35
  end
36
36
 
37
- def backward(x, gy)
37
+ def backward(indexes, gy)
38
38
  [gy[0]]
39
39
  end
40
40
  end
41
-
42
- class Sub < ::Chainer::Function
41
+
42
+ class Sub < ::Chainer::FunctionNode
43
+ def label
44
+ '_ - _'
45
+ end
46
+
43
47
  def forward(x)
44
- retain_inputs([])
45
48
  [Utils::Array.force_array(x[0] - x[1])]
46
49
  end
47
50
 
48
- def backward(x, gy)
49
- [gy[0], Utils::Array.force_array(-gy[0])]
51
+ def backward(indexes, gy)
52
+ [gy[0], -gy[0]]
50
53
  end
51
54
  end
52
55
 
53
- class Mul < ::Chainer::Function
56
+ class Mul < ::Chainer::FunctionNode
54
57
  def forward(x)
58
+ retain_inputs([0, 1])
55
59
  [Utils::Array.force_array(x[0] * x[1])]
56
60
  end
57
61
 
58
- def backward(x, gy)
59
- [Utils::Array.force_array(gy[0] * x[1]), Utils::Array.force_array(gy[0] * x[0])]
62
+ def backward(indexes, gy)
63
+ xs = get_retained_inputs
64
+ indexes.map { |i| gy[0] * xs[1 - i] }
60
65
  end
61
66
  end
62
67
 
63
- class MulConstant < ::Chainer::Function
68
+ class MulConstant < ::Chainer::FunctionNode
64
69
  def initialize(value)
65
70
  @value = value
66
71
  end
@@ -69,23 +74,23 @@ module Chainer
69
74
  [Utils::Array.force_array(@value * x[0])]
70
75
  end
71
76
 
72
- def backward(x, gy)
73
- [Utils::Array.force_array(@value * gy[0])]
77
+ def backward(indexes, gy)
78
+ [gy[0] * @value]
74
79
  end
75
80
  end
76
-
77
- class Div < ::Chainer::Function
81
+
82
+ class Div < ::Chainer::FunctionNode
78
83
  def forward(x)
79
84
  [Utils::Array.force_array(x[0] / x[1])]
80
85
  end
81
86
 
82
- def backward(x, gy)
87
+ def backward(indexes, gy)
83
88
  gx0 = Utils::Array.force_array(gy[0] / x[1])
84
89
  [gx0, Utils::Array.force_array(-1 * gx0 * x[0] / x[1])]
85
90
  end
86
91
  end
87
-
88
- class PowVarVar < ::Chainer::Function
92
+
93
+ class PowVarVar < ::Chainer::FunctionNode
89
94
  def forward(x)
90
95
  @y = Utils::Array.force_array(x[0] ** x[1])
91
96
  [@y]
@@ -94,12 +99,13 @@ module Chainer
94
99
  def backward(x, gy)
95
100
  one = x[1].class.ones[0]
96
101
  gx0 = Utils::Array.force_array(x[1] * (x[0] ** (x[1] - one)) * gy[0])
97
- gx1 = Utils::Array.force_array(Numo::NMath.log(x[0]) * @y * gy[0])
102
+ xm = Chainer.get_array_module(x[0])
103
+ gx1 = Utils::Array.force_array(xm::NMath.log(x[0]) * @y * gy[0])
98
104
  [gx0, gx1]
99
105
  end
100
106
  end
101
107
 
102
- class PowVarConst < ::Chainer::Function
108
+ class PowVarConst < ::Chainer::FunctionNode
103
109
  def initialize(value)
104
110
  @value = value
105
111
  end
@@ -113,7 +119,7 @@ module Chainer
113
119
  gx = @value * (x[0] ** val_1) * gy[0]
114
120
  [Utils::Array.force_array(gx)]
115
121
  end
116
- end
122
+ end
117
123
  end
118
124
  end
119
125
  end
@@ -0,0 +1,28 @@
1
+ module Chainer
2
+ module Functions
3
+ module Math
4
+ class Exp < Chainer::FunctionNode
5
+ # Elementwise exponential function.
6
+ def self.exp(x)
7
+ self.new.apply([x]).first
8
+ end
9
+
10
+ def label
11
+ 'exp'
12
+ end
13
+
14
+ def forward(x)
15
+ retain_inputs([])
16
+ retain_outputs([0])
17
+ xm = Chainer.get_array_module(x.first)
18
+ [Utils::Array.force_array(xm::NMath.exp(x.first))]
19
+ end
20
+
21
+ def backward(indexes, gy)
22
+ y = get_retained_outputs.first
23
+ [y * gy.first]
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
@@ -2,7 +2,7 @@ module Chainer
2
2
  module Functions
3
3
  module Math
4
4
  # Identity function.
5
- class Identity < Chainer::Function
5
+ class Identity < Chainer::FunctionNode
6
6
  def check_type_forward(in_types)
7
7
  # pass
8
8
  end
@@ -12,13 +12,14 @@ module Chainer
12
12
  return xs
13
13
  end
14
14
 
15
- def backward(xs, gys)
15
+ def backward(indexes, gys)
16
16
  return gys
17
17
  end
18
18
 
19
19
  # Just returns input variables.
20
20
  def self.identity(*inputs)
21
- self.new.(*inputs)
21
+ ret = self.new.apply(inputs)
22
+ ret.size == 1 ? ret[0] : ret
22
23
  end
23
24
  end
24
25
  end
@@ -0,0 +1,52 @@
1
+ module Chainer
2
+ module Functions
3
+ module Math
4
+ # Sum of array elements over a given axis.
5
+ class Sum < Chainer::FunctionNode
6
+ # Sum of array elements over a given axis
7
+ #
8
+ # @param [Chainer::Variable] x Elements to sum
9
+ # @param [nil, Integer, Array<Integer>] axis Axis which a sum is performed
10
+ # @param[boolean] keepdims If `true`, the specified axes are remained as axes of length one
11
+ # @return [Chainer::Variable] Output variable
12
+ def self.sum(x, axis: nil, keepdims: false)
13
+ Sum.new(axis: axis, keepdims: keepdims).apply([x]).first
14
+ end
15
+
16
+ def initialize(axis: nil, keepdims: false)
17
+ if axis.nil?
18
+ @axis = nil
19
+ elsif axis.is_a?(Integer)
20
+ @axis = [axis]
21
+ elsif axis.is_a?(::Array) && axis.all? { |e| e.is_a?(Integer) }
22
+ raise ArgumentError, "duplicate value in axis: #{axis}" unless axis.uniq.size == axis.size
23
+ @axis = axis
24
+ else
25
+ raise TypeError, 'nil, Integer or Array of int are required'
26
+ end
27
+
28
+ @keepdims = keepdims
29
+ end
30
+
31
+ def forward(inputs)
32
+ x = inputs.first
33
+ ret = x.sum(axis: @axis, keepdims: @keepdims)
34
+ ret = x.class.cast(ret)
35
+ [ret]
36
+ end
37
+
38
+ def backward(indexes, grad_outputs)
39
+ gy = grad_outputs.first
40
+ ndim = @inputs.first.shape.size
41
+ unless ndim == 0 || @axis.nil? || @keepdims
42
+ actual_axis = @axis.map { |axis| axis >= 0 ? axis : axis + ndim }
43
+ shape = gy.shape
44
+ actual_axis.sort.each { |axis| shape.insert(axis, 1) }
45
+ gy = Chainer::Functions::Array::Reshape.reshape(gy, shape)
46
+ end
47
+ [Chainer::Functions::Array::BroadcastTo.broadcast_to(gy, @inputs.first.shape)]
48
+ end
49
+ end
50
+ end
51
+ end
52
+ end
@@ -1,7 +1,8 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Noise
4
- class Dropout < Chainer::Function
4
+ class Dropout < Chainer::FunctionNode
5
+ attr_reader :mask
5
6
  # Drops elements of input variable randomly.
6
7
  #
7
8
  # This function drops input elements randomly with probability `ratio` and
@@ -12,7 +13,7 @@ module Chainer
12
13
  # @param [float] ratio Dropout ratio. The ``ratio`` must be `0.0 <= ratio < 1.0`.
13
14
  # @return [Chainer::Variable] Output variable.
14
15
  def self.dropout(x, ratio: 0.5)
15
- Chainer.configuration.train ? self.new(ratio).(x) : x
16
+ Chainer.configuration.train ? self.new(ratio).apply([x])[0] : Chainer::Variable.as_variable(x)
16
17
  end
17
18
 
18
19
  def initialize(dropout_ratio)
@@ -23,7 +24,6 @@ module Chainer
23
24
  end
24
25
 
25
26
  def forward(x)
26
- retain_inputs([])
27
27
  unless self.instance_variable_defined?(:@mask)
28
28
  scale = x[0].class[*[1.0 / (1 - @dropout_ratio)]][0]
29
29
  flag = x[0].class.new(*x[0].shape).rand >= @dropout_ratio
@@ -36,7 +36,23 @@ module Chainer
36
36
  end
37
37
 
38
38
  def backward(x, gy)
39
- [gy[0] * @mask]
39
+ DropoutGrad.new(@mask).apply(gy)
40
+ end
41
+ end
42
+
43
+ # Computes the gradient of the Dropout function.
44
+ class DropoutGrad < Chainer::FunctionNode
45
+ def initialize(mask)
46
+ @mask = mask
47
+ end
48
+
49
+ def forward(inputs)
50
+ y = inputs.first * @mask
51
+ [y]
52
+ end
53
+
54
+ def backward(indexes, gy)
55
+ DropoutGrad.new(@mask).apply(gy)
40
56
  end
41
57
  end
42
58
  end
@@ -1,134 +1,287 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Normalization
4
- class BatchNormalizationFunction < Chainer::Function
4
+ module Calculation
5
+ def apply_bn_fwd(xp, x, mean, inv_std, gamma, beta)
6
+ # NOTE: all arguments should be broadcasted to x.shape
7
+ # (mean, inv_std, gamma, and beta have to already be expanded)
8
+ x_hat = x_hat(x, mean, inv_std)
9
+ y = gamma * x_hat
10
+ y += beta
11
+ y
12
+ end
13
+
14
+ def x_hat(x, mean, inv_std)
15
+ x_mu = x - mean
16
+ x_mu *= inv_std
17
+ x_mu
18
+ end
19
+
20
+ def zero_if_none(xp, x, shape, dtype)
21
+ # TODO: Return broadcasted 0 instead of a zeroed array.
22
+ x.nil? ? dtype.zeros(*shape) : x
23
+ end
24
+ end
25
+
26
+ class BatchNormalization < Chainer::FunctionNode
27
+ include Calculation
5
28
  attr_reader :running_mean, :running_var
6
- # Batch normalization function with fixed statistics.
7
- # This is a variant of batch normalization, where the mean and variance
8
- # statistics are given by the caller as fixed variables. This is
9
- # used on testing mode of the batch normalization layer, where batch
10
- # statistics cannot be used for prediction consistency.
11
- #
12
- # @param [Chainer::Variable] x Input variable.
13
- # @param [Chainer::Variable] gamma Scaling parameter of normalized data.
14
- # @param [Chainer::Variable] beta Shifting parameter of scaled normalized data.
15
- # @param [Chainer::Variable] mean Shifting parameter of input.
16
- # @param [Chainer::Variable] var Square of scaling parameter of input.
17
- # @param [float] eps Epsilon value for numerical stability.
18
- def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
19
- old_train = Chainer.configuration.train
20
- Chainer.configuration.train = false
21
- norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
22
- Chainer.configuration.train = old_train
23
- norm
24
- end
25
-
26
- def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
29
+
30
+ def self.batch_normalization(x, gamma, beta, eps: 2e-5, running_mean: nil, running_var: nil, decay: 0.9)
31
+ BatchNormalization.new(eps: eps, mean: running_mean, var: running_var, decay: decay).apply([x, gamma, beta])[0]
32
+ end
33
+
34
+ def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
35
+ @mean = nil
36
+ @inv_std = nil
37
+
27
38
  @running_mean = mean
28
39
  @running_var = var
29
40
  @eps = eps
30
- @mean_cache = nil
31
41
  @decay = decay
32
42
  end
33
43
 
34
44
  def forward(inputs)
35
- x, gamma, beta = inputs[0], inputs[1], inputs[2]
36
- if Chainer.configuration.train
37
- if @running_mean.nil?
38
- @running_mean = Numo::NArray[*gamma].new_zeros
39
- @running_var = Numo::NArray[*gamma].new_zeros
40
- else
41
- @running_mean = Numo::NArray[*@running_mean]
42
- @running_var = Numo::NArray[*@running_var]
43
- end
44
- elsif inputs.size == 5
45
- @fixed_mean = inputs[3]
46
- @fixed_var = inputs[4]
45
+ retain_inputs([0, 1])
46
+ x, gamma, beta = inputs
47
+ xp = Chainer.get_array_module(x)
48
+
49
+ if @running_mean.nil?
50
+ @running_mean = xp::NArray[*gamma].new_zeros
51
+ @running_var = xp::NArray[*gamma].new_zeros
47
52
  end
48
53
 
54
+ # expander inserts singleton dimensions to gamma and beta so that they
55
+ # can be broadcasted with x.
49
56
  head_ndim = gamma.ndim + 1
50
- gamma_expander = [1] + gamma.shape + [1] * (x.ndim - head_ndim)
51
- gamma = gamma.reshape(*gamma_expander)
52
- beta_expander = [1] + beta.shape + [1] * (x.ndim - head_ndim)
53
- beta = beta.reshape(*beta_expander)
54
-
55
- if Chainer.configuration.train
56
- axis = [0] + (head_ndim...(x.ndim)).to_a
57
- mean = x.mean(axis: axis)
58
- # FIXME: numpy.var
59
- var = x.var(axis: axis)
60
- var += @eps
61
- else
62
- mean = @fixed_mean
63
- var = @fixed_var + @eps
57
+ # TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
58
+ suffix = [1] * (x.ndim - head_ndim)
59
+ expander = -> (arr) do
60
+ shape = [1] + arr.shape + suffix
61
+ arr.reshape(*shape)
64
62
  end
63
+ @expander = expander
64
+ @axis = [0] + (head_ndim...(x.ndim)).to_a
65
65
 
66
- @std = Numo::NMath.sqrt(var)
66
+ gamma = expander.(gamma)
67
+ beta = expander.(beta)
68
+ @mean = x.mean(axis: @axis)
67
69
 
68
- mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
69
- x_mu = x - mean.reshape(*mean_expander)
70
- std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
71
- x_mu /= @std.reshape(*std_expander)
72
- @x_hat = x_mu
73
- y = gamma * @x_hat
74
- y += beta
70
+ # TODO: Numo::Array can not be specified standard deviation
71
+ var = ((x - x.mean(axis: @axis, keepdims: true)) ** 2).mean(axis: @axis)
75
72
 
76
- if Chainer.configuration.train
77
- m = x.size.div(gamma.size)
78
- adjust = m / [m - 1.0, 1.0].max
79
- @running_mean *= @decay
80
- temp_ar = Numo::NArray[*mean]
81
- temp_ar *= (1 - @decay)
82
- @running_mean += temp_ar
83
-
84
- @running_var *= @decay
85
- temp_ar = Numo::NArray[*var]
86
- temp_ar *= ((1 - @decay) * adjust)
87
- @running_var += temp_ar
88
- end
73
+ var += @eps
74
+ @inv_std = var ** (-0.5)
75
+
76
+ y = apply_bn_fwd(xp, x, expander.(@mean), expander.(@inv_std), gamma, beta)
77
+ # Update running statistics
78
+ m = x.size.div(gamma.size)
79
+ adjust = m / [m - 1.0, 1.0].max
80
+ @running_mean *= @decay
81
+ @running_mean += (1 - @decay) * @mean
82
+ @running_var *= @decay
83
+ @running_var += (1 - @decay) * adjust * var
89
84
 
90
- [y,]
85
+ [y]
86
+ end
87
+
88
+ def backward(indexes, grad_outputs)
89
+ x, gamma = get_retained_inputs
90
+ gy, = grad_outputs
91
+
92
+ # hatappi debug
93
+ #@mean = @mean.class.new(@mean.shape).seq
94
+ #@inv_std = @inv_std.class.new(@inv_std.shape).seq
95
+ #x.data = x.data.class.new(x.shape).seq
96
+ #gamma.data = gamma.data.class.new(gamma.shape).seq
97
+ #gy.data = gy.data.class.new(gy.shape).seq
98
+
99
+ f = BatchNormalizationGrad.new(@eps, @expander, @axis, @mean, @inv_std)
100
+ f.(x, gamma, gy)
101
+ end
102
+ end
103
+
104
+ class BatchNormalizationGrad < Function
105
+ include Calculation
106
+
107
+ def initialize(eps, expander, axis, mean, inv_std)
108
+ @eps = eps
109
+ @expander = expander
110
+ @axis = axis
111
+ @mean = mean
112
+ @inv_std = inv_std
113
+ end
114
+
115
+ def forward(inputs)
116
+ retain_inputs([0, 1, 2])
117
+ x, gamma, gy = inputs
118
+ expander = @expander
119
+
120
+ inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
121
+ xp = Chainer.get_array_module(x)
122
+
123
+ gbeta = gy.sum(axis: @axis)
124
+ x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
125
+ ggamma = (gy * x_hat).sum(axis: @axis)
126
+ gx = expander.(gamma * @inv_std) * (gy - (x_hat * expander.(ggamma) + expander.(gbeta)) * inv_m)
127
+
128
+ retain_outputs([0, 1])
129
+ [gx, ggamma, gbeta]
91
130
  end
92
131
 
93
132
  def backward(inputs, grad_outputs)
94
- x, gamma = inputs[0], inputs[1]
95
- gy = grad_outputs[0]
133
+ expander = @expander
134
+
135
+ x, gamma, gy = inputs
136
+ gx1, ggamma1, = output_data
137
+ ggx1, gggamma1, ggbeta1 = grad_outputs
138
+ xp = Chainer.get_array_module(x)
139
+
140
+ # auxiliary values
141
+ inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
142
+ r = ggx1.nil? ? 0 : (gx1 * ggx1).sum(axis: @axis)
143
+ coeff = gamma * @inv_std
144
+ coeff_m = coeff * inv_m
145
+ x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
146
+
147
+ # handle None in output gradients
148
+ ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
149
+ gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
150
+ ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
151
+
152
+ gggamma2 = gggamma1 - coeff_m * (x_hat * ggx1).sum(axis: @axis)
153
+ ggbeta2 = ggbeta1 - coeff_m * ggx1.sum(axis: @axis)
154
+
155
+ ggamma2 = r / gamma
156
+
157
+ gx_hat2 = (expander.(gggamma2) * gy - expander.(coeff_m * ggamma1) * ggx1)
158
+ gstd2 = -@inv_std * (r + (x_hat * gx_hat2).sum(axis: @axis))
159
+ gmean2 = -@inv_std * gx_hat2.sum(axis: @axis)
160
+ gx2 = expander.(@inv_std) * gx_hat2 + inv_m * (expander.(gmean2) + x_hat * expander.(gstd2))
161
+ ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(coeff) * ggx1)
162
+
163
+ [gx2, ggamma2, ggy2]
164
+ end
165
+ end
166
+
167
+ class FixedBatchNormalization < FunctionNode
168
+ include Calculation
169
+
170
+ attr_reader :inv_var
171
+
172
+ def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
173
+ FixedBatchNormalization.new(eps: eps).apply([x, gamma, beta, mean, var]).first
174
+ end
175
+
176
+ def initialize(eps: 2e-5)
177
+ @inv_std = nil
178
+ @inv_var = nil
179
+ @eps = eps
180
+ end
181
+
182
+ def forward(inputs)
183
+ retain_inputs([0, 1, 3, 4])
184
+ x, gamma, beta, mean, var = inputs
185
+ xp = Chainer.get_array_module(x)
186
+
187
+ # expander inserts singleton dimensions to gamma and beta so that they
188
+ # can be broadcasted with x.
96
189
  head_ndim = gamma.ndim + 1
97
- m = gamma.class[x.size.div(gamma.size)][0]
98
- axis = [0] + (head_ndim...(x.ndim)).to_a
99
-
100
- if inputs.size == 5
101
- mean = inputs[3]
102
- var = inputs[4]
103
- std = Numo::NMath.sqrt(var)
104
- gs = gamma / std
105
- gbeta = gy.sum(axis: axis)
106
-
107
- mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
108
- x_mu = x - mean.reshape(*mean_expander)
109
- std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
110
- x_mu /= std.reshape(*std_expander)
111
- x_hat = x_mu
112
- ggamma = (gy * x_hat).sum(axis: axis)
113
- gmean = -gs * gbeta
114
- gvar = -0.5 * gamma / var * ggamma
115
- gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
116
- gx = gs.reshape(*gs_expander)
117
- return [gx, ggamma, gbeta, gmean, gvar]
190
+ # TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
191
+ suffix = [1] * (x.ndim - head_ndim)
192
+ expander = -> (arr) do
193
+ shape = [1] + arr.shape + suffix
194
+ arr.reshape(*shape)
118
195
  end
196
+ @expander = expander
197
+ @axis = [0] + (head_ndim...(x.ndim)).to_a
198
+
199
+ gamma = expander.(gamma)
200
+ beta = expander.(beta)
201
+ var += @eps
202
+ @inv_var = var.reciprocal
203
+ @inv_std = xp::NMath.sqrt(@inv_var)
119
204
 
120
- gbeta = gy.sum(axis: axis)
121
- ggamma = (gy * @x_hat).sum(axis: axis)
122
- tmp = (gamma / @std)
123
- tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
124
- tmp = tmp.reshape(*tmp_expander)
205
+ y = apply_bn_fwd(xp, x, expander.(mean), expander.(@inv_std), gamma, beta)
206
+ [y]
207
+ end
125
208
 
126
- ggamma_expander = [1] + ggamma.shape + [1] * (x.ndim - head_ndim)
127
- gbeta_expander = [1] + gbeta.shape + [1] * (x.ndim - head_ndim)
128
-
129
- gx = tmp * (gy - (@x_hat * ggamma.reshape(*ggamma_expander) + gbeta.reshape(*gbeta_expander)) / m )
209
+ def backward(indexes, grad_outputs)
210
+ x, gamma, mean, var = get_retained_inputs
211
+ gy, = grad_outputs
212
+ f = FixedBatchNormalizationGrad.new(@eps, @expander, @axis, @inv_std, @inv_var)
213
+ f.(x, gamma, mean, var, gy)
214
+ end
215
+ end
130
216
 
131
- [gx, ggamma, gbeta]
217
+ class FixedBatchNormalizationGrad < Function
218
+ include Calculation
219
+
220
+ def initialize(eps, expander, axis, inv_std, inv_var)
221
+ @eps = eps
222
+ @expander = expander
223
+ @axis = axis
224
+ @inv_std = inv_std
225
+ @inv_var = inv_var
226
+ end
227
+
228
+ def forward(inputs)
229
+ retain_inputs([0, 1, 2, 4])
230
+ x, gamma, mean, var, gy = inputs
231
+ expander = @expander
232
+ xp = Chainer.get_array_module(x)
233
+
234
+ if @inv_std.nil? || @inv_var.nil?
235
+ @inv_var = (var + @eps).reciprocal
236
+ @inv_std = xp::NMath.sqrt(@inv_var)
237
+ end
238
+
239
+ @gamma_over_std = gamma * @inv_std
240
+ x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
241
+
242
+ gx = expander.(@gamma_over_std) * gy
243
+ gbeta = gy.sum(axis: @axis)
244
+ ggamma = (x_hat * gy).sum(axis: @axis)
245
+ gmean = -@gamma_over_std * gbeta
246
+ gvar = -0.5 * gamma * @inv_var * ggamma
247
+
248
+ retain_outputs([0, 1, 2, 3, 4])
249
+ [gx, ggamma, gbeta, gmean, gvar]
250
+ end
251
+
252
+ def backward(inputs, grad_outputs)
253
+ x, gamma, mean, _, gy = inputs
254
+ ggx1, gggamma1, ggbeta1, ggmean1, ggvar1 = grad_outputs
255
+ gx1, ggamma1, gbeta1, gmean1, gvar1 = output_data
256
+
257
+ # Handle None in output gradients.
258
+ xp = Chainer.get_array_module(x)
259
+ ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
260
+ gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
261
+ ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
262
+ ggmean1 = zero_if_none(xp, ggmean1, mean.shape, mean.class)
263
+ ggvar1 = zero_if_none(xp, ggvar1, mean.shape, mean.class)
264
+
265
+ expander = @expander
266
+ x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
267
+ tmp = -0.5 * ggvar1
268
+
269
+ gamma_over_var = gamma * @inv_var
270
+ g_gamma_over_var = tmp * ggamma1
271
+
272
+ gggamma2 = gggamma1 + tmp * gamma_over_var
273
+ gx_hat = gy * expander.(gggamma2)
274
+ gx2 = expander.(@inv_std) * gx_hat
275
+ gmean2 = -@inv_std * gx_hat.sum(axis: @axis)
276
+
277
+ g_gamma_over_std = (ggx1 * gy).sum(axis: @axis) - ggmean1 * gbeta1
278
+ ggbeta2 = ggbeta1 - ggmean1 * @gamma_over_std
279
+ ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(@gamma_over_std) * ggx1)
280
+
281
+ ggamma2 = (@inv_var * g_gamma_over_var + @inv_std * g_gamma_over_std)
282
+ gvar2 = -(ggamma2 * gamma_over_var + 0.5 * @inv_var * ((x_hat * gx_hat).sum(axis: @axis) - @gamma_over_std * g_gamma_over_std))
283
+
284
+ [gx2, ggamma2, gmean2, gvar2, ggy2]
132
285
  end
133
286
  end
134
287
  end