red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -2,7 +2,7 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
# Leaky rectifier unit.
|
5
|
-
class LeakyReLU <
|
5
|
+
class LeakyReLU < FunctionNode
|
6
6
|
# Leaky Rectified Linear Unit function.
|
7
7
|
#
|
8
8
|
# This function is expressed as
|
@@ -13,7 +13,7 @@ module Chainer
|
|
13
13
|
#
|
14
14
|
# where $a$ is a configurable slope value.
|
15
15
|
#
|
16
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
16
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
17
17
|
# @param [float] slope Slope value $a$.
|
18
18
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
19
19
|
# @example
|
@@ -31,32 +31,57 @@ module Chainer
|
|
31
31
|
# [-0.4, 1]]
|
32
32
|
#
|
33
33
|
def self.leaky_relu(x, slope: 0.2)
|
34
|
-
self.new(slope: slope).(x)
|
34
|
+
self.new(slope: slope).apply([x])[0]
|
35
35
|
end
|
36
36
|
|
37
37
|
def initialize(slope:0.2)
|
38
38
|
@slope = slope
|
39
39
|
end
|
40
40
|
|
41
|
-
def
|
42
|
-
|
43
|
-
y
|
41
|
+
def forward(inputs)
|
42
|
+
x, = inputs
|
43
|
+
y = x.dup
|
44
|
+
y[x < 0] *= @slope
|
44
45
|
if @slope >= 0
|
45
|
-
retain_inputs([])
|
46
46
|
retain_outputs([0])
|
47
|
+
else
|
48
|
+
retain_inputs([0])
|
47
49
|
end
|
48
50
|
[y]
|
49
51
|
end
|
50
52
|
|
51
|
-
def
|
52
|
-
|
53
|
+
def backward(indexes, grad_outputs)
|
54
|
+
if @slope >= 0
|
55
|
+
x = nil
|
56
|
+
y = get_retained_outputs.first.data
|
57
|
+
else
|
58
|
+
x = get_retained_inputs.first.data
|
59
|
+
y = nil
|
60
|
+
end
|
61
|
+
LeakyReLUGrad.new(x, y, @slope).apply(grad_outputs)
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
class LeakyReLUGrad < FunctionNode
|
66
|
+
def initialize(x, y, slope)
|
67
|
+
@x = x
|
68
|
+
@y = y
|
69
|
+
@slope = slope
|
70
|
+
end
|
71
|
+
|
72
|
+
def forward(inputs)
|
73
|
+
gy, = inputs
|
74
|
+
gy = gy.dup
|
53
75
|
if @slope >= 0
|
54
|
-
y
|
55
|
-
gx[y[0] < 0] *= @slope
|
76
|
+
gy[@y < 0] *= @slope
|
56
77
|
else
|
57
|
-
|
78
|
+
gy[@x < 0] *= @slope
|
58
79
|
end
|
59
|
-
[
|
80
|
+
[gy]
|
81
|
+
end
|
82
|
+
|
83
|
+
def backward(indexes, grad_outputs)
|
84
|
+
LeakyReLUGrad.new(@x, @y, @slope).apply(grad_outputs)
|
60
85
|
end
|
61
86
|
end
|
62
87
|
end
|
@@ -2,11 +2,12 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
def self.logsumexp(x)
|
5
|
+
xm = Chainer.get_array_module(x)
|
5
6
|
m = x.max(axis: 1, keepdims: true)
|
6
7
|
y = x - m
|
7
|
-
y =
|
8
|
+
y = xm::NMath.exp(y)
|
8
9
|
s = y.sum(axis: 1, keepdims: true)
|
9
|
-
s =
|
10
|
+
s = xm::NMath.log(s)
|
10
11
|
m + s
|
11
12
|
end
|
12
13
|
|
@@ -16,7 +17,7 @@ module Chainer
|
|
16
17
|
end
|
17
18
|
|
18
19
|
# Log-softmax activation function.
|
19
|
-
class LogSoftmax <
|
20
|
+
class LogSoftmax < FunctionNode
|
20
21
|
# Channel-wise log-softmax function.
|
21
22
|
#
|
22
23
|
# This function computes its logarithm of softmax along the second axis.
|
@@ -36,7 +37,7 @@ module Chainer
|
|
36
37
|
# because +softmax(x)+ may returns +0+.
|
37
38
|
# +log_softmax+ method is more stable.
|
38
39
|
#
|
39
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
|
40
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
|
40
41
|
# @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
|
41
42
|
#
|
42
43
|
# @see Chainer::Functions::Softmax
|
@@ -56,23 +57,59 @@ module Chainer
|
|
56
57
|
# => true
|
57
58
|
#
|
58
59
|
def self.log_softmax(x)
|
59
|
-
self.new.(x)
|
60
|
+
self.new.apply([x]).first
|
60
61
|
end
|
61
62
|
|
62
63
|
def forward(xs)
|
63
64
|
y = Chainer::Functions::Activation._log_softmax(xs[0])
|
64
65
|
@x_shape = xs[0].shape
|
65
66
|
@x_dtype = xs[0].class
|
66
|
-
retain_inputs([])
|
67
67
|
retain_outputs([0])
|
68
68
|
[y]
|
69
69
|
end
|
70
70
|
|
71
|
-
def backward(
|
72
|
-
y =
|
73
|
-
|
71
|
+
def backward(indexes, gy)
|
72
|
+
y = get_retained_outputs.first
|
73
|
+
LogSoftmaxGrad.new(@x_shape, @x_dtype).apply([y, gy[0]])
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
class LogSoftmaxGrad < FunctionNode
|
78
|
+
def initialize(x_shape, x_dtype)
|
79
|
+
@x_shape = x_shape
|
80
|
+
@x_dtype = x_dtype
|
81
|
+
end
|
82
|
+
|
83
|
+
def forward(inputs)
|
84
|
+
retain_inputs([0, 1])
|
85
|
+
y, gy = inputs
|
86
|
+
|
87
|
+
xm = Chainer.get_array_module(y)
|
88
|
+
gx = gy - xm::NMath.exp(y) * gy.sum(axis: 1, keepdims: true)
|
74
89
|
[gx]
|
75
90
|
end
|
91
|
+
|
92
|
+
def backward(indexes, ggx)
|
93
|
+
y, gy = get_retained_inputs
|
94
|
+
ret = []
|
95
|
+
exp_y = Chainer::Functions::Math::Exp.exp(y)
|
96
|
+
|
97
|
+
if indexes.include?(0)
|
98
|
+
gy_sum = Chainer::Functions::Math::Sum.sum(gy, axis: 1, keepdims: true)
|
99
|
+
gy_sum = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy_sum, gy.shape)
|
100
|
+
|
101
|
+
g0 = -ggx.first * exp_y * gy_sum
|
102
|
+
ret << g0
|
103
|
+
end
|
104
|
+
if indexes.include?(1)
|
105
|
+
a = Chainer::Functions::Math::Sum.sum(ggx.first * exp_y, axis: 1, keepdims: true)
|
106
|
+
a = Chainer::Functions::Array::BroadcastTo.broadcast_to(a, gy.shape)
|
107
|
+
g1 = ggx.first - a
|
108
|
+
ret << g1
|
109
|
+
end
|
110
|
+
|
111
|
+
ret
|
112
|
+
end
|
76
113
|
end
|
77
114
|
end
|
78
115
|
end
|
@@ -2,14 +2,14 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
# Rectified Linear Unit.
|
5
|
-
class Relu <
|
5
|
+
class Relu < FunctionNode
|
6
6
|
# Rectified Linear Unit function.
|
7
7
|
#
|
8
8
|
# $$
|
9
9
|
# f(x)=\\max(0, x).
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example
|
15
15
|
# > x = Numo::SFloat[[-1, 0], [2, -3], [-2, 1]]
|
@@ -23,18 +23,18 @@ module Chainer
|
|
23
23
|
# => [3, 2]
|
24
24
|
#
|
25
25
|
def self.relu(x)
|
26
|
-
self.new.(x)
|
26
|
+
y, = self.new.apply([x])
|
27
|
+
y
|
27
28
|
end
|
28
29
|
|
29
|
-
def
|
30
|
-
retain_inputs([])
|
30
|
+
def forward(x)
|
31
31
|
retain_outputs([0])
|
32
32
|
[Utils::Array.force_array(x[0].class.maximum(x[0], 0))]
|
33
33
|
end
|
34
34
|
|
35
|
-
def
|
36
|
-
y =
|
37
|
-
|
35
|
+
def backward(indexes, gy)
|
36
|
+
y = get_retained_outputs.first
|
37
|
+
ReLUGrad2.new(y).apply([gy[0]])
|
38
38
|
end
|
39
39
|
end
|
40
40
|
end
|
@@ -0,0 +1,34 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Computes the gradient of the ReLU function.
|
5
|
+
#
|
6
|
+
# This function takes 2 variables b and c, and
|
7
|
+
# computes f(b, c) = sign(b) * c with backpropagation
|
8
|
+
# where operations are dones in elementwise manner
|
9
|
+
# and sign(x) = 1 when x > 0 is positive and 0 otherwise.
|
10
|
+
# As the gradient of f with respect to b is 0,
|
11
|
+
# we do not backpropagate errors toward b for computational efficiency.<Paste>
|
12
|
+
class ReLUGrad2 < FunctionNode
|
13
|
+
def initialize(b)
|
14
|
+
@b = b.data
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(inputs)
|
18
|
+
y = inputs[0] * (@b > 0)
|
19
|
+
[Utils::Array.force_array(y, y.class)]
|
20
|
+
end
|
21
|
+
|
22
|
+
def backward(indexes, gy)
|
23
|
+
[gy[0] * heaviside(@b)]
|
24
|
+
end
|
25
|
+
|
26
|
+
private
|
27
|
+
|
28
|
+
def heaviside(x)
|
29
|
+
(x > 0).cast_to(x.class)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -2,14 +2,14 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
# Logistic sigmoid function.
|
5
|
-
class Sigmoid <
|
5
|
+
class Sigmoid < FunctionNode
|
6
6
|
# Element-wise sigmoid logistic function.
|
7
7
|
#
|
8
8
|
# $$
|
9
9
|
# f(x)=(1 + \\exp(-x))^ { -1 }.
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example It maps the input values into the range of $`[0, 1]`$.
|
15
15
|
# > x = Numo::SFloat.new(3).seq(-2, 2)
|
@@ -21,21 +21,23 @@ module Chainer
|
|
21
21
|
# [0.119203, 0.5, 0.880797]
|
22
22
|
#
|
23
23
|
def self.sigmoid(x)
|
24
|
-
self.new.(x)
|
24
|
+
self.new.apply([x]).first
|
25
25
|
end
|
26
26
|
|
27
|
-
def
|
27
|
+
def forward(inputs)
|
28
|
+
x, = inputs
|
28
29
|
half = 0.5
|
29
|
-
|
30
|
-
|
30
|
+
xm = Chainer.get_array_module(x)
|
31
|
+
y = Utils::Array.force_array((xm::NMath.tanh(x * half) * half)+ half)
|
31
32
|
retain_outputs([0])
|
32
|
-
|
33
|
+
[y]
|
33
34
|
end
|
34
35
|
|
35
|
-
def
|
36
|
-
|
37
|
-
y =
|
38
|
-
|
36
|
+
def backward(indexes, grad_outputs)
|
37
|
+
x = nil
|
38
|
+
y = get_retained_outputs.first
|
39
|
+
gy, = grad_outputs
|
40
|
+
Chainer::Functions::Activation::SigmoidGrad.new([x]).apply([y, gy])
|
39
41
|
end
|
40
42
|
end
|
41
43
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Logistic sigmoid gradient function.
|
5
|
+
class SigmoidGrad < FunctionNode
|
6
|
+
def initialize(inputs)
|
7
|
+
@x, = inputs
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(inputs)
|
11
|
+
retain_inputs([0, 1])
|
12
|
+
y, gy = inputs
|
13
|
+
one = 1
|
14
|
+
[Utils::Array.force_array(gy * y * (one - y))]
|
15
|
+
end
|
16
|
+
|
17
|
+
def backward(indexes, grad_outputs)
|
18
|
+
y, gy = get_retained_inputs
|
19
|
+
g, = grad_outputs
|
20
|
+
[g * gy * ( 1 -2 * y), g * y * (1 - y)]
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -2,14 +2,14 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
# Hyperbolic tangent function.
|
5
|
-
class Tanh <
|
5
|
+
class Tanh < FunctionNode
|
6
6
|
# Elementwise hyperbolic tangent function.
|
7
7
|
#
|
8
8
|
# $$
|
9
9
|
# f(x)=\\tanh(x).
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example
|
15
15
|
# > x = Numo::SFloat.new(3).seq(-1, 2)
|
@@ -21,20 +21,57 @@ module Chainer
|
|
21
21
|
# [-0.761594, 0.761594, 0.995055]
|
22
22
|
#
|
23
23
|
def self.tanh(x)
|
24
|
-
self.new.(x)
|
24
|
+
self.new.apply([x]).first
|
25
25
|
end
|
26
26
|
|
27
|
-
def
|
28
|
-
|
29
|
-
|
27
|
+
def forward(x)
|
28
|
+
xm = Chainer.get_array_module(x[0])
|
29
|
+
y = Utils::Array.force_array(xm::NMath.tanh(x[0]))
|
30
30
|
retain_outputs([0])
|
31
|
-
|
31
|
+
@use_cudnn = false
|
32
|
+
[y]
|
32
33
|
end
|
33
34
|
|
34
|
-
def
|
35
|
-
|
36
|
-
|
37
|
-
|
35
|
+
def backward(indexes, grad_outputs)
|
36
|
+
if @use_cudnn
|
37
|
+
x = get_retained_inputs.first.data
|
38
|
+
else
|
39
|
+
x = nil
|
40
|
+
end
|
41
|
+
|
42
|
+
y = get_retained_outputs.first
|
43
|
+
gy = grad_outputs.first
|
44
|
+
TanhGrad.new(x).apply([y, gy])
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
class TanhGrad < FunctionNode
|
49
|
+
def initialize(x)
|
50
|
+
super()
|
51
|
+
|
52
|
+
# The original input `x` is only required for cuDNN.
|
53
|
+
# If it is None, this class does not use cuDNN.
|
54
|
+
# Note that x must be c-contiguous and it is checked
|
55
|
+
# in Tanh.forward_gpu.
|
56
|
+
@x = x
|
57
|
+
end
|
58
|
+
|
59
|
+
def forward(inputs)
|
60
|
+
retain_inputs([0, 1])
|
61
|
+
y, gy = inputs
|
62
|
+
|
63
|
+
one = y.class.new.fill(1)
|
64
|
+
[Utils::Array.force_array(gy * (one - y * y))]
|
65
|
+
end
|
66
|
+
|
67
|
+
def backward(indexes, grad_outputs)
|
68
|
+
y, gy = get_retained_inputs
|
69
|
+
g = grad_outputs[0]
|
70
|
+
|
71
|
+
y_mul_g = y * g
|
72
|
+
grad_y = -2 * gy * y_mul_g
|
73
|
+
ggy = g - y * y_mul_g
|
74
|
+
[grad_y, ggy]
|
38
75
|
end
|
39
76
|
end
|
40
77
|
end
|
@@ -0,0 +1,56 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Function that broadcasts an array to a new shape.
|
5
|
+
class BroadcastTo < FunctionNode
|
6
|
+
def initialize(shape)
|
7
|
+
@shape = shape
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.broadcast_to(x, shape)
|
11
|
+
return Chainer::Variable.as_variable(x) if x.shape == shape
|
12
|
+
self.new(shape).apply([x]).first
|
13
|
+
end
|
14
|
+
|
15
|
+
def forward(inputs)
|
16
|
+
x = inputs.first
|
17
|
+
[Chainer::Utils::Array.broadcast_to(x, @shape)]
|
18
|
+
end
|
19
|
+
|
20
|
+
def backward(indexes, grad_outputs)
|
21
|
+
gx = grad_outputs.first
|
22
|
+
shape = @inputs.first.shape
|
23
|
+
ndim = shape.size
|
24
|
+
lead = gx.ndim - ndim
|
25
|
+
lead_axis = lead.times.to_a
|
26
|
+
axis = shape.each_with_object([]).with_index do |(sx, res), i|
|
27
|
+
next unless sx == 1
|
28
|
+
res << i + lead
|
29
|
+
end
|
30
|
+
gx = Chainer::Functions::Math::Sum.sum(gx, axis: lead_axis + axis, keepdims: true)
|
31
|
+
return [Chainer::Functions::Array::Squeeze.squeeze(gx, axis: lead_axis)] if lead > 0
|
32
|
+
[gx]
|
33
|
+
end
|
34
|
+
|
35
|
+
private
|
36
|
+
|
37
|
+
def backward_one(shape, dtype, g)
|
38
|
+
return dtype.zeros(shape) unless g
|
39
|
+
|
40
|
+
ndim = shape.size
|
41
|
+
if g.ndim != ndim
|
42
|
+
g = g.sum(axis: 0...(g.ndim - ndim))
|
43
|
+
end
|
44
|
+
|
45
|
+
axis = shape.each_with_index.select{|sx, i| sx == 1 }.map{|sx, i| i }
|
46
|
+
if axis.size > 0
|
47
|
+
g.sum(keepdims: true, axis: axis)
|
48
|
+
else
|
49
|
+
g
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|