red-chainer 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -1,66 +1,71 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Math
4
-
5
- class Neg < ::Chainer::Function
4
+ class Neg < ::Chainer::FunctionNode
5
+ def label
6
+ '__neg__'
7
+ end
8
+
6
9
  def forward(x)
7
- retain_inputs([])
8
10
  [Utils::Array.force_array(-x[0])]
9
11
  end
10
12
 
11
- def backward(x, gy)
12
- [Utils::Array.force_array(-gy[0])]
13
+ def backward(indexes, gy)
14
+ [-gy[0]]
13
15
  end
14
16
  end
15
17
 
16
- class Add < ::Chainer::Function
18
+ class Add < ::Chainer::FunctionNode
17
19
  def forward(x)
18
- retain_inputs([])
19
20
  [Utils::Array.force_array(x[0] + x[1])]
20
21
  end
21
22
 
22
- def backward(x, gy)
23
+ def backward(indexes, gy)
23
24
  [gy[0], gy[0]]
24
25
  end
25
26
  end
26
27
 
27
- class AddConstant < ::Chainer::Function
28
+ class AddConstant < ::Chainer::FunctionNode
28
29
  def initialize(value)
29
30
  @value = value
30
31
  end
31
32
 
32
33
  def forward(x)
33
- retain_inputs([])
34
34
  [Utils::Array.force_array(x[0] + @value)]
35
35
  end
36
36
 
37
- def backward(x, gy)
37
+ def backward(indexes, gy)
38
38
  [gy[0]]
39
39
  end
40
40
  end
41
-
42
- class Sub < ::Chainer::Function
41
+
42
+ class Sub < ::Chainer::FunctionNode
43
+ def label
44
+ '_ - _'
45
+ end
46
+
43
47
  def forward(x)
44
- retain_inputs([])
45
48
  [Utils::Array.force_array(x[0] - x[1])]
46
49
  end
47
50
 
48
- def backward(x, gy)
49
- [gy[0], Utils::Array.force_array(-gy[0])]
51
+ def backward(indexes, gy)
52
+ [gy[0], -gy[0]]
50
53
  end
51
54
  end
52
55
 
53
- class Mul < ::Chainer::Function
56
+ class Mul < ::Chainer::FunctionNode
54
57
  def forward(x)
58
+ retain_inputs([0, 1])
55
59
  [Utils::Array.force_array(x[0] * x[1])]
56
60
  end
57
61
 
58
- def backward(x, gy)
59
- [Utils::Array.force_array(gy[0] * x[1]), Utils::Array.force_array(gy[0] * x[0])]
62
+ def backward(indexes, gy)
63
+ xs = get_retained_inputs
64
+ indexes.map { |i| gy[0] * xs[1 - i] }
60
65
  end
61
66
  end
62
67
 
63
- class MulConstant < ::Chainer::Function
68
+ class MulConstant < ::Chainer::FunctionNode
64
69
  def initialize(value)
65
70
  @value = value
66
71
  end
@@ -69,23 +74,23 @@ module Chainer
69
74
  [Utils::Array.force_array(@value * x[0])]
70
75
  end
71
76
 
72
- def backward(x, gy)
73
- [Utils::Array.force_array(@value * gy[0])]
77
+ def backward(indexes, gy)
78
+ [gy[0] * @value]
74
79
  end
75
80
  end
76
-
77
- class Div < ::Chainer::Function
81
+
82
+ class Div < ::Chainer::FunctionNode
78
83
  def forward(x)
79
84
  [Utils::Array.force_array(x[0] / x[1])]
80
85
  end
81
86
 
82
- def backward(x, gy)
87
+ def backward(indexes, gy)
83
88
  gx0 = Utils::Array.force_array(gy[0] / x[1])
84
89
  [gx0, Utils::Array.force_array(-1 * gx0 * x[0] / x[1])]
85
90
  end
86
91
  end
87
-
88
- class PowVarVar < ::Chainer::Function
92
+
93
+ class PowVarVar < ::Chainer::FunctionNode
89
94
  def forward(x)
90
95
  @y = Utils::Array.force_array(x[0] ** x[1])
91
96
  [@y]
@@ -94,12 +99,13 @@ module Chainer
94
99
  def backward(x, gy)
95
100
  one = x[1].class.ones[0]
96
101
  gx0 = Utils::Array.force_array(x[1] * (x[0] ** (x[1] - one)) * gy[0])
97
- gx1 = Utils::Array.force_array(Numo::NMath.log(x[0]) * @y * gy[0])
102
+ xm = Chainer.get_array_module(x[0])
103
+ gx1 = Utils::Array.force_array(xm::NMath.log(x[0]) * @y * gy[0])
98
104
  [gx0, gx1]
99
105
  end
100
106
  end
101
107
 
102
- class PowVarConst < ::Chainer::Function
108
+ class PowVarConst < ::Chainer::FunctionNode
103
109
  def initialize(value)
104
110
  @value = value
105
111
  end
@@ -113,7 +119,7 @@ module Chainer
113
119
  gx = @value * (x[0] ** val_1) * gy[0]
114
120
  [Utils::Array.force_array(gx)]
115
121
  end
116
- end
122
+ end
117
123
  end
118
124
  end
119
125
  end
@@ -0,0 +1,28 @@
1
+ module Chainer
2
+ module Functions
3
+ module Math
4
+ class Exp < Chainer::FunctionNode
5
+ # Elementwise exponential function.
6
+ def self.exp(x)
7
+ self.new.apply([x]).first
8
+ end
9
+
10
+ def label
11
+ 'exp'
12
+ end
13
+
14
+ def forward(x)
15
+ retain_inputs([])
16
+ retain_outputs([0])
17
+ xm = Chainer.get_array_module(x.first)
18
+ [Utils::Array.force_array(xm::NMath.exp(x.first))]
19
+ end
20
+
21
+ def backward(indexes, gy)
22
+ y = get_retained_outputs.first
23
+ [y * gy.first]
24
+ end
25
+ end
26
+ end
27
+ end
28
+ end
@@ -2,7 +2,7 @@ module Chainer
2
2
  module Functions
3
3
  module Math
4
4
  # Identity function.
5
- class Identity < Chainer::Function
5
+ class Identity < Chainer::FunctionNode
6
6
  def check_type_forward(in_types)
7
7
  # pass
8
8
  end
@@ -12,13 +12,14 @@ module Chainer
12
12
  return xs
13
13
  end
14
14
 
15
- def backward(xs, gys)
15
+ def backward(indexes, gys)
16
16
  return gys
17
17
  end
18
18
 
19
19
  # Just returns input variables.
20
20
  def self.identity(*inputs)
21
- self.new.(*inputs)
21
+ ret = self.new.apply(inputs)
22
+ ret.size == 1 ? ret[0] : ret
22
23
  end
23
24
  end
24
25
  end
@@ -0,0 +1,52 @@
1
+ module Chainer
2
+ module Functions
3
+ module Math
4
+ # Sum of array elements over a given axis.
5
+ class Sum < Chainer::FunctionNode
6
+ # Sum of array elements over a given axis
7
+ #
8
+ # @param [Chainer::Variable] x Elements to sum
9
+ # @param [nil, Integer, Array<Integer>] axis Axis which a sum is performed
10
+ # @param[boolean] keepdims If `true`, the specified axes are remained as axes of length one
11
+ # @return [Chainer::Variable] Output variable
12
+ def self.sum(x, axis: nil, keepdims: false)
13
+ Sum.new(axis: axis, keepdims: keepdims).apply([x]).first
14
+ end
15
+
16
+ def initialize(axis: nil, keepdims: false)
17
+ if axis.nil?
18
+ @axis = nil
19
+ elsif axis.is_a?(Integer)
20
+ @axis = [axis]
21
+ elsif axis.is_a?(::Array) && axis.all? { |e| e.is_a?(Integer) }
22
+ raise ArgumentError, "duplicate value in axis: #{axis}" unless axis.uniq.size == axis.size
23
+ @axis = axis
24
+ else
25
+ raise TypeError, 'nil, Integer or Array of int are required'
26
+ end
27
+
28
+ @keepdims = keepdims
29
+ end
30
+
31
+ def forward(inputs)
32
+ x = inputs.first
33
+ ret = x.sum(axis: @axis, keepdims: @keepdims)
34
+ ret = x.class.cast(ret)
35
+ [ret]
36
+ end
37
+
38
+ def backward(indexes, grad_outputs)
39
+ gy = grad_outputs.first
40
+ ndim = @inputs.first.shape.size
41
+ unless ndim == 0 || @axis.nil? || @keepdims
42
+ actual_axis = @axis.map { |axis| axis >= 0 ? axis : axis + ndim }
43
+ shape = gy.shape
44
+ actual_axis.sort.each { |axis| shape.insert(axis, 1) }
45
+ gy = Chainer::Functions::Array::Reshape.reshape(gy, shape)
46
+ end
47
+ [Chainer::Functions::Array::BroadcastTo.broadcast_to(gy, @inputs.first.shape)]
48
+ end
49
+ end
50
+ end
51
+ end
52
+ end
@@ -1,7 +1,8 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Noise
4
- class Dropout < Chainer::Function
4
+ class Dropout < Chainer::FunctionNode
5
+ attr_reader :mask
5
6
  # Drops elements of input variable randomly.
6
7
  #
7
8
  # This function drops input elements randomly with probability `ratio` and
@@ -12,7 +13,7 @@ module Chainer
12
13
  # @param [float] ratio Dropout ratio. The ``ratio`` must be `0.0 <= ratio < 1.0`.
13
14
  # @return [Chainer::Variable] Output variable.
14
15
  def self.dropout(x, ratio: 0.5)
15
- Chainer.configuration.train ? self.new(ratio).(x) : x
16
+ Chainer.configuration.train ? self.new(ratio).apply([x])[0] : Chainer::Variable.as_variable(x)
16
17
  end
17
18
 
18
19
  def initialize(dropout_ratio)
@@ -23,7 +24,6 @@ module Chainer
23
24
  end
24
25
 
25
26
  def forward(x)
26
- retain_inputs([])
27
27
  unless self.instance_variable_defined?(:@mask)
28
28
  scale = x[0].class[*[1.0 / (1 - @dropout_ratio)]][0]
29
29
  flag = x[0].class.new(*x[0].shape).rand >= @dropout_ratio
@@ -36,7 +36,23 @@ module Chainer
36
36
  end
37
37
 
38
38
  def backward(x, gy)
39
- [gy[0] * @mask]
39
+ DropoutGrad.new(@mask).apply(gy)
40
+ end
41
+ end
42
+
43
+ # Computes the gradient of the Dropout function.
44
+ class DropoutGrad < Chainer::FunctionNode
45
+ def initialize(mask)
46
+ @mask = mask
47
+ end
48
+
49
+ def forward(inputs)
50
+ y = inputs.first * @mask
51
+ [y]
52
+ end
53
+
54
+ def backward(indexes, gy)
55
+ DropoutGrad.new(@mask).apply(gy)
40
56
  end
41
57
  end
42
58
  end
@@ -1,134 +1,287 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Normalization
4
- class BatchNormalizationFunction < Chainer::Function
4
+ module Calculation
5
+ def apply_bn_fwd(xp, x, mean, inv_std, gamma, beta)
6
+ # NOTE: all arguments should be broadcasted to x.shape
7
+ # (mean, inv_std, gamma, and beta have to already be expanded)
8
+ x_hat = x_hat(x, mean, inv_std)
9
+ y = gamma * x_hat
10
+ y += beta
11
+ y
12
+ end
13
+
14
+ def x_hat(x, mean, inv_std)
15
+ x_mu = x - mean
16
+ x_mu *= inv_std
17
+ x_mu
18
+ end
19
+
20
+ def zero_if_none(xp, x, shape, dtype)
21
+ # TODO: Return broadcasted 0 instead of a zeroed array.
22
+ x.nil? ? dtype.zeros(*shape) : x
23
+ end
24
+ end
25
+
26
+ class BatchNormalization < Chainer::FunctionNode
27
+ include Calculation
5
28
  attr_reader :running_mean, :running_var
6
- # Batch normalization function with fixed statistics.
7
- # This is a variant of batch normalization, where the mean and variance
8
- # statistics are given by the caller as fixed variables. This is
9
- # used on testing mode of the batch normalization layer, where batch
10
- # statistics cannot be used for prediction consistency.
11
- #
12
- # @param [Chainer::Variable] x Input variable.
13
- # @param [Chainer::Variable] gamma Scaling parameter of normalized data.
14
- # @param [Chainer::Variable] beta Shifting parameter of scaled normalized data.
15
- # @param [Chainer::Variable] mean Shifting parameter of input.
16
- # @param [Chainer::Variable] var Square of scaling parameter of input.
17
- # @param [float] eps Epsilon value for numerical stability.
18
- def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
19
- old_train = Chainer.configuration.train
20
- Chainer.configuration.train = false
21
- norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
22
- Chainer.configuration.train = old_train
23
- norm
24
- end
25
-
26
- def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
29
+
30
+ def self.batch_normalization(x, gamma, beta, eps: 2e-5, running_mean: nil, running_var: nil, decay: 0.9)
31
+ BatchNormalization.new(eps: eps, mean: running_mean, var: running_var, decay: decay).apply([x, gamma, beta])[0]
32
+ end
33
+
34
+ def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
35
+ @mean = nil
36
+ @inv_std = nil
37
+
27
38
  @running_mean = mean
28
39
  @running_var = var
29
40
  @eps = eps
30
- @mean_cache = nil
31
41
  @decay = decay
32
42
  end
33
43
 
34
44
  def forward(inputs)
35
- x, gamma, beta = inputs[0], inputs[1], inputs[2]
36
- if Chainer.configuration.train
37
- if @running_mean.nil?
38
- @running_mean = Numo::NArray[*gamma].new_zeros
39
- @running_var = Numo::NArray[*gamma].new_zeros
40
- else
41
- @running_mean = Numo::NArray[*@running_mean]
42
- @running_var = Numo::NArray[*@running_var]
43
- end
44
- elsif inputs.size == 5
45
- @fixed_mean = inputs[3]
46
- @fixed_var = inputs[4]
45
+ retain_inputs([0, 1])
46
+ x, gamma, beta = inputs
47
+ xp = Chainer.get_array_module(x)
48
+
49
+ if @running_mean.nil?
50
+ @running_mean = xp::NArray[*gamma].new_zeros
51
+ @running_var = xp::NArray[*gamma].new_zeros
47
52
  end
48
53
 
54
+ # expander inserts singleton dimensions to gamma and beta so that they
55
+ # can be broadcasted with x.
49
56
  head_ndim = gamma.ndim + 1
50
- gamma_expander = [1] + gamma.shape + [1] * (x.ndim - head_ndim)
51
- gamma = gamma.reshape(*gamma_expander)
52
- beta_expander = [1] + beta.shape + [1] * (x.ndim - head_ndim)
53
- beta = beta.reshape(*beta_expander)
54
-
55
- if Chainer.configuration.train
56
- axis = [0] + (head_ndim...(x.ndim)).to_a
57
- mean = x.mean(axis: axis)
58
- # FIXME: numpy.var
59
- var = x.var(axis: axis)
60
- var += @eps
61
- else
62
- mean = @fixed_mean
63
- var = @fixed_var + @eps
57
+ # TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
58
+ suffix = [1] * (x.ndim - head_ndim)
59
+ expander = -> (arr) do
60
+ shape = [1] + arr.shape + suffix
61
+ arr.reshape(*shape)
64
62
  end
63
+ @expander = expander
64
+ @axis = [0] + (head_ndim...(x.ndim)).to_a
65
65
 
66
- @std = Numo::NMath.sqrt(var)
66
+ gamma = expander.(gamma)
67
+ beta = expander.(beta)
68
+ @mean = x.mean(axis: @axis)
67
69
 
68
- mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
69
- x_mu = x - mean.reshape(*mean_expander)
70
- std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
71
- x_mu /= @std.reshape(*std_expander)
72
- @x_hat = x_mu
73
- y = gamma * @x_hat
74
- y += beta
70
+ # TODO: Numo::Array can not be specified standard deviation
71
+ var = ((x - x.mean(axis: @axis, keepdims: true)) ** 2).mean(axis: @axis)
75
72
 
76
- if Chainer.configuration.train
77
- m = x.size.div(gamma.size)
78
- adjust = m / [m - 1.0, 1.0].max
79
- @running_mean *= @decay
80
- temp_ar = Numo::NArray[*mean]
81
- temp_ar *= (1 - @decay)
82
- @running_mean += temp_ar
83
-
84
- @running_var *= @decay
85
- temp_ar = Numo::NArray[*var]
86
- temp_ar *= ((1 - @decay) * adjust)
87
- @running_var += temp_ar
88
- end
73
+ var += @eps
74
+ @inv_std = var ** (-0.5)
75
+
76
+ y = apply_bn_fwd(xp, x, expander.(@mean), expander.(@inv_std), gamma, beta)
77
+ # Update running statistics
78
+ m = x.size.div(gamma.size)
79
+ adjust = m / [m - 1.0, 1.0].max
80
+ @running_mean *= @decay
81
+ @running_mean += (1 - @decay) * @mean
82
+ @running_var *= @decay
83
+ @running_var += (1 - @decay) * adjust * var
89
84
 
90
- [y,]
85
+ [y]
86
+ end
87
+
88
+ def backward(indexes, grad_outputs)
89
+ x, gamma = get_retained_inputs
90
+ gy, = grad_outputs
91
+
92
+ # hatappi debug
93
+ #@mean = @mean.class.new(@mean.shape).seq
94
+ #@inv_std = @inv_std.class.new(@inv_std.shape).seq
95
+ #x.data = x.data.class.new(x.shape).seq
96
+ #gamma.data = gamma.data.class.new(gamma.shape).seq
97
+ #gy.data = gy.data.class.new(gy.shape).seq
98
+
99
+ f = BatchNormalizationGrad.new(@eps, @expander, @axis, @mean, @inv_std)
100
+ f.(x, gamma, gy)
101
+ end
102
+ end
103
+
104
+ class BatchNormalizationGrad < Function
105
+ include Calculation
106
+
107
+ def initialize(eps, expander, axis, mean, inv_std)
108
+ @eps = eps
109
+ @expander = expander
110
+ @axis = axis
111
+ @mean = mean
112
+ @inv_std = inv_std
113
+ end
114
+
115
+ def forward(inputs)
116
+ retain_inputs([0, 1, 2])
117
+ x, gamma, gy = inputs
118
+ expander = @expander
119
+
120
+ inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
121
+ xp = Chainer.get_array_module(x)
122
+
123
+ gbeta = gy.sum(axis: @axis)
124
+ x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
125
+ ggamma = (gy * x_hat).sum(axis: @axis)
126
+ gx = expander.(gamma * @inv_std) * (gy - (x_hat * expander.(ggamma) + expander.(gbeta)) * inv_m)
127
+
128
+ retain_outputs([0, 1])
129
+ [gx, ggamma, gbeta]
91
130
  end
92
131
 
93
132
  def backward(inputs, grad_outputs)
94
- x, gamma = inputs[0], inputs[1]
95
- gy = grad_outputs[0]
133
+ expander = @expander
134
+
135
+ x, gamma, gy = inputs
136
+ gx1, ggamma1, = output_data
137
+ ggx1, gggamma1, ggbeta1 = grad_outputs
138
+ xp = Chainer.get_array_module(x)
139
+
140
+ # auxiliary values
141
+ inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
142
+ r = ggx1.nil? ? 0 : (gx1 * ggx1).sum(axis: @axis)
143
+ coeff = gamma * @inv_std
144
+ coeff_m = coeff * inv_m
145
+ x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
146
+
147
+ # handle None in output gradients
148
+ ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
149
+ gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
150
+ ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
151
+
152
+ gggamma2 = gggamma1 - coeff_m * (x_hat * ggx1).sum(axis: @axis)
153
+ ggbeta2 = ggbeta1 - coeff_m * ggx1.sum(axis: @axis)
154
+
155
+ ggamma2 = r / gamma
156
+
157
+ gx_hat2 = (expander.(gggamma2) * gy - expander.(coeff_m * ggamma1) * ggx1)
158
+ gstd2 = -@inv_std * (r + (x_hat * gx_hat2).sum(axis: @axis))
159
+ gmean2 = -@inv_std * gx_hat2.sum(axis: @axis)
160
+ gx2 = expander.(@inv_std) * gx_hat2 + inv_m * (expander.(gmean2) + x_hat * expander.(gstd2))
161
+ ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(coeff) * ggx1)
162
+
163
+ [gx2, ggamma2, ggy2]
164
+ end
165
+ end
166
+
167
+ class FixedBatchNormalization < FunctionNode
168
+ include Calculation
169
+
170
+ attr_reader :inv_var
171
+
172
+ def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
173
+ FixedBatchNormalization.new(eps: eps).apply([x, gamma, beta, mean, var]).first
174
+ end
175
+
176
+ def initialize(eps: 2e-5)
177
+ @inv_std = nil
178
+ @inv_var = nil
179
+ @eps = eps
180
+ end
181
+
182
+ def forward(inputs)
183
+ retain_inputs([0, 1, 3, 4])
184
+ x, gamma, beta, mean, var = inputs
185
+ xp = Chainer.get_array_module(x)
186
+
187
+ # expander inserts singleton dimensions to gamma and beta so that they
188
+ # can be broadcasted with x.
96
189
  head_ndim = gamma.ndim + 1
97
- m = gamma.class[x.size.div(gamma.size)][0]
98
- axis = [0] + (head_ndim...(x.ndim)).to_a
99
-
100
- if inputs.size == 5
101
- mean = inputs[3]
102
- var = inputs[4]
103
- std = Numo::NMath.sqrt(var)
104
- gs = gamma / std
105
- gbeta = gy.sum(axis: axis)
106
-
107
- mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
108
- x_mu = x - mean.reshape(*mean_expander)
109
- std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
110
- x_mu /= std.reshape(*std_expander)
111
- x_hat = x_mu
112
- ggamma = (gy * x_hat).sum(axis: axis)
113
- gmean = -gs * gbeta
114
- gvar = -0.5 * gamma / var * ggamma
115
- gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
116
- gx = gs.reshape(*gs_expander)
117
- return [gx, ggamma, gbeta, gmean, gvar]
190
+ # TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
191
+ suffix = [1] * (x.ndim - head_ndim)
192
+ expander = -> (arr) do
193
+ shape = [1] + arr.shape + suffix
194
+ arr.reshape(*shape)
118
195
  end
196
+ @expander = expander
197
+ @axis = [0] + (head_ndim...(x.ndim)).to_a
198
+
199
+ gamma = expander.(gamma)
200
+ beta = expander.(beta)
201
+ var += @eps
202
+ @inv_var = var.reciprocal
203
+ @inv_std = xp::NMath.sqrt(@inv_var)
119
204
 
120
- gbeta = gy.sum(axis: axis)
121
- ggamma = (gy * @x_hat).sum(axis: axis)
122
- tmp = (gamma / @std)
123
- tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
124
- tmp = tmp.reshape(*tmp_expander)
205
+ y = apply_bn_fwd(xp, x, expander.(mean), expander.(@inv_std), gamma, beta)
206
+ [y]
207
+ end
125
208
 
126
- ggamma_expander = [1] + ggamma.shape + [1] * (x.ndim - head_ndim)
127
- gbeta_expander = [1] + gbeta.shape + [1] * (x.ndim - head_ndim)
128
-
129
- gx = tmp * (gy - (@x_hat * ggamma.reshape(*ggamma_expander) + gbeta.reshape(*gbeta_expander)) / m )
209
+ def backward(indexes, grad_outputs)
210
+ x, gamma, mean, var = get_retained_inputs
211
+ gy, = grad_outputs
212
+ f = FixedBatchNormalizationGrad.new(@eps, @expander, @axis, @inv_std, @inv_var)
213
+ f.(x, gamma, mean, var, gy)
214
+ end
215
+ end
130
216
 
131
- [gx, ggamma, gbeta]
217
+ class FixedBatchNormalizationGrad < Function
218
+ include Calculation
219
+
220
+ def initialize(eps, expander, axis, inv_std, inv_var)
221
+ @eps = eps
222
+ @expander = expander
223
+ @axis = axis
224
+ @inv_std = inv_std
225
+ @inv_var = inv_var
226
+ end
227
+
228
+ def forward(inputs)
229
+ retain_inputs([0, 1, 2, 4])
230
+ x, gamma, mean, var, gy = inputs
231
+ expander = @expander
232
+ xp = Chainer.get_array_module(x)
233
+
234
+ if @inv_std.nil? || @inv_var.nil?
235
+ @inv_var = (var + @eps).reciprocal
236
+ @inv_std = xp::NMath.sqrt(@inv_var)
237
+ end
238
+
239
+ @gamma_over_std = gamma * @inv_std
240
+ x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
241
+
242
+ gx = expander.(@gamma_over_std) * gy
243
+ gbeta = gy.sum(axis: @axis)
244
+ ggamma = (x_hat * gy).sum(axis: @axis)
245
+ gmean = -@gamma_over_std * gbeta
246
+ gvar = -0.5 * gamma * @inv_var * ggamma
247
+
248
+ retain_outputs([0, 1, 2, 3, 4])
249
+ [gx, ggamma, gbeta, gmean, gvar]
250
+ end
251
+
252
+ def backward(inputs, grad_outputs)
253
+ x, gamma, mean, _, gy = inputs
254
+ ggx1, gggamma1, ggbeta1, ggmean1, ggvar1 = grad_outputs
255
+ gx1, ggamma1, gbeta1, gmean1, gvar1 = output_data
256
+
257
+ # Handle None in output gradients.
258
+ xp = Chainer.get_array_module(x)
259
+ ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
260
+ gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
261
+ ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
262
+ ggmean1 = zero_if_none(xp, ggmean1, mean.shape, mean.class)
263
+ ggvar1 = zero_if_none(xp, ggvar1, mean.shape, mean.class)
264
+
265
+ expander = @expander
266
+ x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
267
+ tmp = -0.5 * ggvar1
268
+
269
+ gamma_over_var = gamma * @inv_var
270
+ g_gamma_over_var = tmp * ggamma1
271
+
272
+ gggamma2 = gggamma1 + tmp * gamma_over_var
273
+ gx_hat = gy * expander.(gggamma2)
274
+ gx2 = expander.(@inv_std) * gx_hat
275
+ gmean2 = -@inv_std * gx_hat.sum(axis: @axis)
276
+
277
+ g_gamma_over_std = (ggx1 * gy).sum(axis: @axis) - ggmean1 * gbeta1
278
+ ggbeta2 = ggbeta1 - ggmean1 * @gamma_over_std
279
+ ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(@gamma_over_std) * ggx1)
280
+
281
+ ggamma2 = (@inv_var * g_gamma_over_var + @inv_std * g_gamma_over_std)
282
+ gvar2 = -(ggamma2 * gamma_over_var + 0.5 * @inv_var * ((x_hat * gx_hat).sum(axis: @axis) - @gamma_over_std * g_gamma_over_std))
283
+
284
+ [gx2, ggamma2, gmean2, gvar2, ggy2]
132
285
  end
133
286
  end
134
287
  end