red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -2,7 +2,7 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
# Leaky rectifier unit.
|
5
|
-
class LeakyReLU <
|
5
|
+
class LeakyReLU < FunctionNode
|
6
6
|
# Leaky Rectified Linear Unit function.
|
7
7
|
#
|
8
8
|
# This function is expressed as
|
@@ -13,7 +13,7 @@ module Chainer
|
|
13
13
|
#
|
14
14
|
# where $a$ is a configurable slope value.
|
15
15
|
#
|
16
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
16
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
17
17
|
# @param [float] slope Slope value $a$.
|
18
18
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
19
19
|
# @example
|
@@ -31,32 +31,57 @@ module Chainer
|
|
31
31
|
# [-0.4, 1]]
|
32
32
|
#
|
33
33
|
def self.leaky_relu(x, slope: 0.2)
|
34
|
-
self.new(slope: slope).(x)
|
34
|
+
self.new(slope: slope).apply([x])[0]
|
35
35
|
end
|
36
36
|
|
37
37
|
def initialize(slope:0.2)
|
38
38
|
@slope = slope
|
39
39
|
end
|
40
40
|
|
41
|
-
def
|
42
|
-
|
43
|
-
y
|
41
|
+
def forward(inputs)
|
42
|
+
x, = inputs
|
43
|
+
y = x.dup
|
44
|
+
y[x < 0] *= @slope
|
44
45
|
if @slope >= 0
|
45
|
-
retain_inputs([])
|
46
46
|
retain_outputs([0])
|
47
|
+
else
|
48
|
+
retain_inputs([0])
|
47
49
|
end
|
48
50
|
[y]
|
49
51
|
end
|
50
52
|
|
51
|
-
def
|
52
|
-
|
53
|
+
def backward(indexes, grad_outputs)
|
54
|
+
if @slope >= 0
|
55
|
+
x = nil
|
56
|
+
y = get_retained_outputs.first.data
|
57
|
+
else
|
58
|
+
x = get_retained_inputs.first.data
|
59
|
+
y = nil
|
60
|
+
end
|
61
|
+
LeakyReLUGrad.new(x, y, @slope).apply(grad_outputs)
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
class LeakyReLUGrad < FunctionNode
|
66
|
+
def initialize(x, y, slope)
|
67
|
+
@x = x
|
68
|
+
@y = y
|
69
|
+
@slope = slope
|
70
|
+
end
|
71
|
+
|
72
|
+
def forward(inputs)
|
73
|
+
gy, = inputs
|
74
|
+
gy = gy.dup
|
53
75
|
if @slope >= 0
|
54
|
-
y
|
55
|
-
gx[y[0] < 0] *= @slope
|
76
|
+
gy[@y < 0] *= @slope
|
56
77
|
else
|
57
|
-
|
78
|
+
gy[@x < 0] *= @slope
|
58
79
|
end
|
59
|
-
[
|
80
|
+
[gy]
|
81
|
+
end
|
82
|
+
|
83
|
+
def backward(indexes, grad_outputs)
|
84
|
+
LeakyReLUGrad.new(@x, @y, @slope).apply(grad_outputs)
|
60
85
|
end
|
61
86
|
end
|
62
87
|
end
|
@@ -2,11 +2,12 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
def self.logsumexp(x)
|
5
|
+
xm = Chainer.get_array_module(x)
|
5
6
|
m = x.max(axis: 1, keepdims: true)
|
6
7
|
y = x - m
|
7
|
-
y =
|
8
|
+
y = xm::NMath.exp(y)
|
8
9
|
s = y.sum(axis: 1, keepdims: true)
|
9
|
-
s =
|
10
|
+
s = xm::NMath.log(s)
|
10
11
|
m + s
|
11
12
|
end
|
12
13
|
|
@@ -16,7 +17,7 @@ module Chainer
|
|
16
17
|
end
|
17
18
|
|
18
19
|
# Log-softmax activation function.
|
19
|
-
class LogSoftmax <
|
20
|
+
class LogSoftmax < FunctionNode
|
20
21
|
# Channel-wise log-softmax function.
|
21
22
|
#
|
22
23
|
# This function computes its logarithm of softmax along the second axis.
|
@@ -36,7 +37,7 @@ module Chainer
|
|
36
37
|
# because +softmax(x)+ may returns +0+.
|
37
38
|
# +log_softmax+ method is more stable.
|
38
39
|
#
|
39
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
|
40
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
|
40
41
|
# @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
|
41
42
|
#
|
42
43
|
# @see Chainer::Functions::Softmax
|
@@ -56,23 +57,59 @@ module Chainer
|
|
56
57
|
# => true
|
57
58
|
#
|
58
59
|
def self.log_softmax(x)
|
59
|
-
self.new.(x)
|
60
|
+
self.new.apply([x]).first
|
60
61
|
end
|
61
62
|
|
62
63
|
def forward(xs)
|
63
64
|
y = Chainer::Functions::Activation._log_softmax(xs[0])
|
64
65
|
@x_shape = xs[0].shape
|
65
66
|
@x_dtype = xs[0].class
|
66
|
-
retain_inputs([])
|
67
67
|
retain_outputs([0])
|
68
68
|
[y]
|
69
69
|
end
|
70
70
|
|
71
|
-
def backward(
|
72
|
-
y =
|
73
|
-
|
71
|
+
def backward(indexes, gy)
|
72
|
+
y = get_retained_outputs.first
|
73
|
+
LogSoftmaxGrad.new(@x_shape, @x_dtype).apply([y, gy[0]])
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
class LogSoftmaxGrad < FunctionNode
|
78
|
+
def initialize(x_shape, x_dtype)
|
79
|
+
@x_shape = x_shape
|
80
|
+
@x_dtype = x_dtype
|
81
|
+
end
|
82
|
+
|
83
|
+
def forward(inputs)
|
84
|
+
retain_inputs([0, 1])
|
85
|
+
y, gy = inputs
|
86
|
+
|
87
|
+
xm = Chainer.get_array_module(y)
|
88
|
+
gx = gy - xm::NMath.exp(y) * gy.sum(axis: 1, keepdims: true)
|
74
89
|
[gx]
|
75
90
|
end
|
91
|
+
|
92
|
+
def backward(indexes, ggx)
|
93
|
+
y, gy = get_retained_inputs
|
94
|
+
ret = []
|
95
|
+
exp_y = Chainer::Functions::Math::Exp.exp(y)
|
96
|
+
|
97
|
+
if indexes.include?(0)
|
98
|
+
gy_sum = Chainer::Functions::Math::Sum.sum(gy, axis: 1, keepdims: true)
|
99
|
+
gy_sum = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy_sum, gy.shape)
|
100
|
+
|
101
|
+
g0 = -ggx.first * exp_y * gy_sum
|
102
|
+
ret << g0
|
103
|
+
end
|
104
|
+
if indexes.include?(1)
|
105
|
+
a = Chainer::Functions::Math::Sum.sum(ggx.first * exp_y, axis: 1, keepdims: true)
|
106
|
+
a = Chainer::Functions::Array::BroadcastTo.broadcast_to(a, gy.shape)
|
107
|
+
g1 = ggx.first - a
|
108
|
+
ret << g1
|
109
|
+
end
|
110
|
+
|
111
|
+
ret
|
112
|
+
end
|
76
113
|
end
|
77
114
|
end
|
78
115
|
end
|
@@ -2,14 +2,14 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
# Rectified Linear Unit.
|
5
|
-
class Relu <
|
5
|
+
class Relu < FunctionNode
|
6
6
|
# Rectified Linear Unit function.
|
7
7
|
#
|
8
8
|
# $$
|
9
9
|
# f(x)=\\max(0, x).
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example
|
15
15
|
# > x = Numo::SFloat[[-1, 0], [2, -3], [-2, 1]]
|
@@ -23,18 +23,18 @@ module Chainer
|
|
23
23
|
# => [3, 2]
|
24
24
|
#
|
25
25
|
def self.relu(x)
|
26
|
-
self.new.(x)
|
26
|
+
y, = self.new.apply([x])
|
27
|
+
y
|
27
28
|
end
|
28
29
|
|
29
|
-
def
|
30
|
-
retain_inputs([])
|
30
|
+
def forward(x)
|
31
31
|
retain_outputs([0])
|
32
32
|
[Utils::Array.force_array(x[0].class.maximum(x[0], 0))]
|
33
33
|
end
|
34
34
|
|
35
|
-
def
|
36
|
-
y =
|
37
|
-
|
35
|
+
def backward(indexes, gy)
|
36
|
+
y = get_retained_outputs.first
|
37
|
+
ReLUGrad2.new(y).apply([gy[0]])
|
38
38
|
end
|
39
39
|
end
|
40
40
|
end
|
@@ -0,0 +1,34 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Computes the gradient of the ReLU function.
|
5
|
+
#
|
6
|
+
# This function takes 2 variables b and c, and
|
7
|
+
# computes f(b, c) = sign(b) * c with backpropagation
|
8
|
+
# where operations are dones in elementwise manner
|
9
|
+
# and sign(x) = 1 when x > 0 is positive and 0 otherwise.
|
10
|
+
# As the gradient of f with respect to b is 0,
|
11
|
+
# we do not backpropagate errors toward b for computational efficiency.<Paste>
|
12
|
+
class ReLUGrad2 < FunctionNode
|
13
|
+
def initialize(b)
|
14
|
+
@b = b.data
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(inputs)
|
18
|
+
y = inputs[0] * (@b > 0)
|
19
|
+
[Utils::Array.force_array(y, y.class)]
|
20
|
+
end
|
21
|
+
|
22
|
+
def backward(indexes, gy)
|
23
|
+
[gy[0] * heaviside(@b)]
|
24
|
+
end
|
25
|
+
|
26
|
+
private
|
27
|
+
|
28
|
+
def heaviside(x)
|
29
|
+
(x > 0).cast_to(x.class)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -2,14 +2,14 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
# Logistic sigmoid function.
|
5
|
-
class Sigmoid <
|
5
|
+
class Sigmoid < FunctionNode
|
6
6
|
# Element-wise sigmoid logistic function.
|
7
7
|
#
|
8
8
|
# $$
|
9
9
|
# f(x)=(1 + \\exp(-x))^ { -1 }.
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example It maps the input values into the range of $`[0, 1]`$.
|
15
15
|
# > x = Numo::SFloat.new(3).seq(-2, 2)
|
@@ -21,21 +21,23 @@ module Chainer
|
|
21
21
|
# [0.119203, 0.5, 0.880797]
|
22
22
|
#
|
23
23
|
def self.sigmoid(x)
|
24
|
-
self.new.(x)
|
24
|
+
self.new.apply([x]).first
|
25
25
|
end
|
26
26
|
|
27
|
-
def
|
27
|
+
def forward(inputs)
|
28
|
+
x, = inputs
|
28
29
|
half = 0.5
|
29
|
-
|
30
|
-
|
30
|
+
xm = Chainer.get_array_module(x)
|
31
|
+
y = Utils::Array.force_array((xm::NMath.tanh(x * half) * half)+ half)
|
31
32
|
retain_outputs([0])
|
32
|
-
|
33
|
+
[y]
|
33
34
|
end
|
34
35
|
|
35
|
-
def
|
36
|
-
|
37
|
-
y =
|
38
|
-
|
36
|
+
def backward(indexes, grad_outputs)
|
37
|
+
x = nil
|
38
|
+
y = get_retained_outputs.first
|
39
|
+
gy, = grad_outputs
|
40
|
+
Chainer::Functions::Activation::SigmoidGrad.new([x]).apply([y, gy])
|
39
41
|
end
|
40
42
|
end
|
41
43
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Logistic sigmoid gradient function.
|
5
|
+
class SigmoidGrad < FunctionNode
|
6
|
+
def initialize(inputs)
|
7
|
+
@x, = inputs
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(inputs)
|
11
|
+
retain_inputs([0, 1])
|
12
|
+
y, gy = inputs
|
13
|
+
one = 1
|
14
|
+
[Utils::Array.force_array(gy * y * (one - y))]
|
15
|
+
end
|
16
|
+
|
17
|
+
def backward(indexes, grad_outputs)
|
18
|
+
y, gy = get_retained_inputs
|
19
|
+
g, = grad_outputs
|
20
|
+
[g * gy * ( 1 -2 * y), g * y * (1 - y)]
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -2,14 +2,14 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
4
|
# Hyperbolic tangent function.
|
5
|
-
class Tanh <
|
5
|
+
class Tanh < FunctionNode
|
6
6
|
# Elementwise hyperbolic tangent function.
|
7
7
|
#
|
8
8
|
# $$
|
9
9
|
# f(x)=\\tanh(x).
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example
|
15
15
|
# > x = Numo::SFloat.new(3).seq(-1, 2)
|
@@ -21,20 +21,57 @@ module Chainer
|
|
21
21
|
# [-0.761594, 0.761594, 0.995055]
|
22
22
|
#
|
23
23
|
def self.tanh(x)
|
24
|
-
self.new.(x)
|
24
|
+
self.new.apply([x]).first
|
25
25
|
end
|
26
26
|
|
27
|
-
def
|
28
|
-
|
29
|
-
|
27
|
+
def forward(x)
|
28
|
+
xm = Chainer.get_array_module(x[0])
|
29
|
+
y = Utils::Array.force_array(xm::NMath.tanh(x[0]))
|
30
30
|
retain_outputs([0])
|
31
|
-
|
31
|
+
@use_cudnn = false
|
32
|
+
[y]
|
32
33
|
end
|
33
34
|
|
34
|
-
def
|
35
|
-
|
36
|
-
|
37
|
-
|
35
|
+
def backward(indexes, grad_outputs)
|
36
|
+
if @use_cudnn
|
37
|
+
x = get_retained_inputs.first.data
|
38
|
+
else
|
39
|
+
x = nil
|
40
|
+
end
|
41
|
+
|
42
|
+
y = get_retained_outputs.first
|
43
|
+
gy = grad_outputs.first
|
44
|
+
TanhGrad.new(x).apply([y, gy])
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
class TanhGrad < FunctionNode
|
49
|
+
def initialize(x)
|
50
|
+
super()
|
51
|
+
|
52
|
+
# The original input `x` is only required for cuDNN.
|
53
|
+
# If it is None, this class does not use cuDNN.
|
54
|
+
# Note that x must be c-contiguous and it is checked
|
55
|
+
# in Tanh.forward_gpu.
|
56
|
+
@x = x
|
57
|
+
end
|
58
|
+
|
59
|
+
def forward(inputs)
|
60
|
+
retain_inputs([0, 1])
|
61
|
+
y, gy = inputs
|
62
|
+
|
63
|
+
one = y.class.new.fill(1)
|
64
|
+
[Utils::Array.force_array(gy * (one - y * y))]
|
65
|
+
end
|
66
|
+
|
67
|
+
def backward(indexes, grad_outputs)
|
68
|
+
y, gy = get_retained_inputs
|
69
|
+
g = grad_outputs[0]
|
70
|
+
|
71
|
+
y_mul_g = y * g
|
72
|
+
grad_y = -2 * gy * y_mul_g
|
73
|
+
ggy = g - y * y_mul_g
|
74
|
+
[grad_y, ggy]
|
38
75
|
end
|
39
76
|
end
|
40
77
|
end
|
@@ -0,0 +1,56 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Function that broadcasts an array to a new shape.
|
5
|
+
class BroadcastTo < FunctionNode
|
6
|
+
def initialize(shape)
|
7
|
+
@shape = shape
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.broadcast_to(x, shape)
|
11
|
+
return Chainer::Variable.as_variable(x) if x.shape == shape
|
12
|
+
self.new(shape).apply([x]).first
|
13
|
+
end
|
14
|
+
|
15
|
+
def forward(inputs)
|
16
|
+
x = inputs.first
|
17
|
+
[Chainer::Utils::Array.broadcast_to(x, @shape)]
|
18
|
+
end
|
19
|
+
|
20
|
+
def backward(indexes, grad_outputs)
|
21
|
+
gx = grad_outputs.first
|
22
|
+
shape = @inputs.first.shape
|
23
|
+
ndim = shape.size
|
24
|
+
lead = gx.ndim - ndim
|
25
|
+
lead_axis = lead.times.to_a
|
26
|
+
axis = shape.each_with_object([]).with_index do |(sx, res), i|
|
27
|
+
next unless sx == 1
|
28
|
+
res << i + lead
|
29
|
+
end
|
30
|
+
gx = Chainer::Functions::Math::Sum.sum(gx, axis: lead_axis + axis, keepdims: true)
|
31
|
+
return [Chainer::Functions::Array::Squeeze.squeeze(gx, axis: lead_axis)] if lead > 0
|
32
|
+
[gx]
|
33
|
+
end
|
34
|
+
|
35
|
+
private
|
36
|
+
|
37
|
+
def backward_one(shape, dtype, g)
|
38
|
+
return dtype.zeros(shape) unless g
|
39
|
+
|
40
|
+
ndim = shape.size
|
41
|
+
if g.ndim != ndim
|
42
|
+
g = g.sum(axis: 0...(g.ndim - ndim))
|
43
|
+
end
|
44
|
+
|
45
|
+
axis = shape.each_with_index.select{|sx, i| sx == 1 }.map{|sx, i| i }
|
46
|
+
if axis.size > 0
|
47
|
+
g.sum(keepdims: true, axis: axis)
|
48
|
+
else
|
49
|
+
g
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|