red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -0,0 +1,41 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
class Cast < FunctionNode
|
5
|
+
# Cast an input variable to a given type.
|
6
|
+
#
|
7
|
+
# @param x [Chainer::Variable or Numo::Narray] x : Input variable to be casted.
|
8
|
+
# @param type [Numo::Narray class] type : data class to cast
|
9
|
+
# @return [Chainer::Variable] Variable holding a casted array.
|
10
|
+
#
|
11
|
+
# example
|
12
|
+
# > x = Numo::UInt8.new(3, 5).seq
|
13
|
+
# > x.class
|
14
|
+
# # => Numo::UInt8
|
15
|
+
# > y = Chainer::Functions::Array::Cast.cast(x, Numo::DFloat)
|
16
|
+
# > y.dtype
|
17
|
+
# # => Numo::DFloat
|
18
|
+
def self.cast(x, type)
|
19
|
+
if (Chainer.array?(x) && x.class == type) || (x.is_a?(Chainer::Variable) && x.dtype == type)
|
20
|
+
return Chainer::Variable.as_variable(x)
|
21
|
+
end
|
22
|
+
self.new(type).apply([x]).first
|
23
|
+
end
|
24
|
+
|
25
|
+
def initialize(type)
|
26
|
+
@type = type
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(x)
|
30
|
+
@in_type = x.first.class
|
31
|
+
[x.first.cast_to(@type)]
|
32
|
+
end
|
33
|
+
|
34
|
+
def backward(indexes, g)
|
35
|
+
[Cast.cast(g.first, @in_type)]
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Reshapes an input array without copy.
|
5
|
+
class Reshape < FunctionNode
|
6
|
+
def initialize(shape)
|
7
|
+
@shape = shape
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.reshape(x, shape)
|
11
|
+
return Chainer::Variable.as_variable(x) if x.shape == shape
|
12
|
+
return self.new(shape).apply([x]).first
|
13
|
+
end
|
14
|
+
|
15
|
+
def forward(inputs)
|
16
|
+
x = inputs.first
|
17
|
+
new_shape = @shape.map { |s| s == -1 ? nil : s }
|
18
|
+
[x.reshape(*new_shape)]
|
19
|
+
end
|
20
|
+
|
21
|
+
def backward(indexes, grad_outputs)
|
22
|
+
gx = grad_outputs.first
|
23
|
+
[Reshape.reshape(gx, @inputs.first.shape)]
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,57 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Roll axis of an array.
|
5
|
+
class Rollaxis < FunctionNode
|
6
|
+
# Roll the axis backwards to the given position.
|
7
|
+
#
|
8
|
+
# @param [Chainer::Variable] x Input variable
|
9
|
+
# @param [Integer] axis The axis to roll backwards.
|
10
|
+
# @param [Integer] start The place to which the axis is moved.
|
11
|
+
# @return [Chainer::Variable] Variable whose axis is rolled.
|
12
|
+
def self.rollaxis(x, axis, start: 0)
|
13
|
+
Rollaxis.new(axis, start).apply([x]).first
|
14
|
+
end
|
15
|
+
|
16
|
+
def initialize(axis, start)
|
17
|
+
unless axis.is_a?(Integer)
|
18
|
+
raise ArgumentError, 'axis must be int'
|
19
|
+
end
|
20
|
+
|
21
|
+
unless start.is_a?(Integer)
|
22
|
+
raise ArgumentError, 'start must be int'
|
23
|
+
end
|
24
|
+
|
25
|
+
@axis = axis
|
26
|
+
@start = start
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(inputs)
|
30
|
+
retain_inputs([])
|
31
|
+
@in_ndim = inputs.first.ndim
|
32
|
+
|
33
|
+
[Chainer::Utils::Array.rollaxis(inputs.first, @axis, start: @start)]
|
34
|
+
end
|
35
|
+
|
36
|
+
def backward(indexes, gy)
|
37
|
+
axis = @axis
|
38
|
+
if axis < 0
|
39
|
+
axis += @in_ndim
|
40
|
+
end
|
41
|
+
start = @start
|
42
|
+
if start < 0
|
43
|
+
start += @in_ndim
|
44
|
+
end
|
45
|
+
|
46
|
+
if axis > start
|
47
|
+
axis += 1
|
48
|
+
else
|
49
|
+
start -= 1
|
50
|
+
end
|
51
|
+
|
52
|
+
Rollaxis.new(start, axis).apply(gy)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
@@ -0,0 +1,72 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Select elements stored in given indices.
|
5
|
+
class SelectItem < FunctionNode
|
6
|
+
# Select elements stored in given indices.
|
7
|
+
# This function returns $t.choose(x.T)$, that means
|
8
|
+
# $y[i] == x[i, t[i]]$ for all $i$.
|
9
|
+
#
|
10
|
+
# @param [Chainer::Variable] x Variable storing arrays.
|
11
|
+
# @param [Chainer::Variable] t Variable storing index numbers.
|
12
|
+
# @return [Chainer::Variable] Variable that holds $t$-th element of $x$.
|
13
|
+
def self.select_item(x, t)
|
14
|
+
SelectItem.new.apply([x, t]).first
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(inputs)
|
18
|
+
retain_inputs([1])
|
19
|
+
x, t = inputs
|
20
|
+
@in_shape = x.shape
|
21
|
+
@in_dtype = x.class
|
22
|
+
|
23
|
+
# TODO: x[six.moves.range(t.size), t]
|
24
|
+
new_x = x.class.zeros(t.size)
|
25
|
+
t.size.times.each do |i|
|
26
|
+
new_x[i] = x[i, t[i]]
|
27
|
+
end
|
28
|
+
x = new_x
|
29
|
+
|
30
|
+
[x]
|
31
|
+
end
|
32
|
+
|
33
|
+
def backward(indexes, gy)
|
34
|
+
t = get_retained_inputs.first
|
35
|
+
ret = []
|
36
|
+
if indexes.include?(0)
|
37
|
+
ggx = Assign.new(@in_shape, @in_dtype, t).apply(gy).first
|
38
|
+
ret << ggx
|
39
|
+
end
|
40
|
+
if indexes.include?(1)
|
41
|
+
ret << nil
|
42
|
+
end
|
43
|
+
ret
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
class Assign < FunctionNode
|
48
|
+
def initialize(shape, dtype, t)
|
49
|
+
@shape = shape
|
50
|
+
@dtype = dtype
|
51
|
+
@t = t.data
|
52
|
+
end
|
53
|
+
|
54
|
+
def forward(inputs)
|
55
|
+
gx = @dtype.zeros(*@shape)
|
56
|
+
|
57
|
+
# TODO: gx[six.moves.range(self.t.size), self.t] = inputs[0]
|
58
|
+
# binding.pry
|
59
|
+
@t.size.times.each do |i|
|
60
|
+
gx[i, @t[i]] = inputs[0][i]
|
61
|
+
end
|
62
|
+
|
63
|
+
[gx]
|
64
|
+
end
|
65
|
+
|
66
|
+
def backward(indexes, gy)
|
67
|
+
SelectItem.new.apply([gy[0], @t])
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
@@ -0,0 +1,78 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
class Squeeze < FunctionNode
|
5
|
+
# Remove demensions of size one from the shape of a Numo::NArray.
|
6
|
+
# @param [Chainer::Variable or Numo::NArray] x Input data.
|
7
|
+
# @param [nil or integer or array of integer] axis A subset of the single-dimensional entries in the shape to remove.
|
8
|
+
# If `nil` is supplied, all of them are removed. The dimension index starts at zero.
|
9
|
+
# If an axis with dimension greater than one is selected, an error is raised.
|
10
|
+
# @return [Chainer::Variable] Variable whose dimensions of size 1 are removed.
|
11
|
+
def self.squeeze(x, axis: nil)
|
12
|
+
self.new(axis: axis).apply([x]).first
|
13
|
+
end
|
14
|
+
|
15
|
+
def initialize(axis: nil)
|
16
|
+
if axis.nil?
|
17
|
+
@axis = nil
|
18
|
+
elsif axis.kind_of?(Integer)
|
19
|
+
@axis = [axis]
|
20
|
+
elsif axis.kind_of?(::Array) && Array(axis).all? { |i| i.kind_of?(Integer) }
|
21
|
+
@axis = axis
|
22
|
+
else
|
23
|
+
raise TypeError, 'axis must be None, int or tuple of ints'
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def forward(inputs)
|
28
|
+
x = inputs.first
|
29
|
+
shape = x.shape
|
30
|
+
|
31
|
+
# TODO: numpy.squeeze
|
32
|
+
if @axis.nil?
|
33
|
+
new_shape = shape.reject { |axis| axis == 1 }
|
34
|
+
else
|
35
|
+
new_shape = shape
|
36
|
+
@axis.map do |a|
|
37
|
+
raise StandardError, "cannot select an axis to squeeze out which has size not equal to one" unless shape[a] == 1
|
38
|
+
new_shape[a] = nil
|
39
|
+
end
|
40
|
+
new_shape.compact!
|
41
|
+
end
|
42
|
+
ret = new_shape.size.zero? ? x.class.new.fill(x[0]) : x.reshape(*new_shape)
|
43
|
+
|
44
|
+
[ret]
|
45
|
+
end
|
46
|
+
|
47
|
+
def backward(indexes, grad_outputs)
|
48
|
+
if @axis.nil?
|
49
|
+
axis = argone(@inputs[0].shape)
|
50
|
+
else
|
51
|
+
axis = @axis
|
52
|
+
ndim = @inputs[0].shape.size
|
53
|
+
axis = axis.map { |x| x < 0 ? x + ndim : x }
|
54
|
+
axis.sort!
|
55
|
+
end
|
56
|
+
gx = grad_outputs.first
|
57
|
+
|
58
|
+
shape = gx.shape
|
59
|
+
axis.each do |x|
|
60
|
+
shape.insert(x, 1)
|
61
|
+
end
|
62
|
+
[gx.reshape(*shape)]
|
63
|
+
end
|
64
|
+
|
65
|
+
private
|
66
|
+
|
67
|
+
def argone(iterable)
|
68
|
+
result = []
|
69
|
+
Array(iterable).each_with_index do |x, i|
|
70
|
+
raise StandardError, "elements in iterable must be int" unless x.kind_of?(Integer)
|
71
|
+
result << i if x == 1
|
72
|
+
end
|
73
|
+
result
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -0,0 +1,44 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Permute the dimensions of an array.
|
5
|
+
class Transpose < FunctionNode
|
6
|
+
# Permute the dimensions of an input variable without copy.
|
7
|
+
#
|
8
|
+
# @param [Chainer::Variable] x Input Variable.
|
9
|
+
# @param [::Array<Integer>] axes By default, reverse the dimensions,
|
10
|
+
# otherwise permute the axes according to the values given.
|
11
|
+
# @return [Chainer::Variable] Variable whose axes are permuted.
|
12
|
+
def self.transpose(x, axes: nil)
|
13
|
+
Transpose.new(axes: axes).apply([x]).first
|
14
|
+
end
|
15
|
+
|
16
|
+
def initialize(axes: nil)
|
17
|
+
@axes = axes
|
18
|
+
end
|
19
|
+
|
20
|
+
def label
|
21
|
+
'Transpose'
|
22
|
+
end
|
23
|
+
|
24
|
+
def forward(inputs)
|
25
|
+
x = inputs.first
|
26
|
+
[x.transpose(*@axes)]
|
27
|
+
end
|
28
|
+
|
29
|
+
def backward(indexes, grad_outputs)
|
30
|
+
inv_axes = @axes
|
31
|
+
if inv_axes
|
32
|
+
axes_len = inv_axes.size
|
33
|
+
|
34
|
+
axes = inv_axes.map { |ax| ax % axes_len }
|
35
|
+
inv_axes = Numo::NArray[*axes].sort_index.to_a
|
36
|
+
end
|
37
|
+
|
38
|
+
Transpose.new(axes: inv_axes).apply(grad_outputs)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
|
@@ -1,7 +1,8 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Connection
|
4
|
-
class Convolution2DFunction < Chainer::
|
4
|
+
class Convolution2DFunction < Chainer::FunctionNode
|
5
|
+
attr_reader :sy, :sx, :ph, :pw, :cover_all
|
5
6
|
# Two-dimensional convolution function.
|
6
7
|
# This is an implementation of two-dimensional convolution in ConvNets.
|
7
8
|
# It takes three variables: the input image `x`, the filter weight `w`, and the bias vector `b`.
|
@@ -33,9 +34,9 @@ module Chainer
|
|
33
34
|
# w_O &= (w_I + 2w_P - w_K + s_X - 1) / s_X + 1.
|
34
35
|
# If the bias vector is given, then it is added to all spatial locations of the output of convolution.
|
35
36
|
#
|
36
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable of shape :math:`(n, c_I, h_I, w_I)`.
|
37
|
-
# @param [Chainer::Variable or Numo::NArray] w Weight variable of shape :math:`(c_O, c_I, h_K, w_K)`.
|
38
|
-
# @param [Chainer::Variable or Numo::NArray] b Bias variable of length :math:`c_O`
|
37
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable of shape :math:`(n, c_I, h_I, w_I)`.
|
38
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] w Weight variable of shape :math:`(c_O, c_I, h_K, w_K)`.
|
39
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] b Bias variable of length :math:`c_O`
|
39
40
|
# @param [Int or 2-D Array] stride Stride of filter applications. `stride=s` and `stride=(s, s)` are equivalent.
|
40
41
|
# @param [Int or 2-D Array] pad Spatial padding width for input arrays.
|
41
42
|
# @param [Boolean] cover_all If `true`, all spatial locations are convoluted into some output pixels.
|
@@ -43,48 +44,64 @@ module Chainer
|
|
43
44
|
def self.convolution_2d(x, w, b: nil, stride: 1, pad: 0, cover_all: false)
|
44
45
|
func = self.new(stride: stride, pad: pad, cover_all: cover_all)
|
45
46
|
if b.nil?
|
46
|
-
|
47
|
+
args = [x, w]
|
47
48
|
else
|
48
|
-
|
49
|
+
args = [x, w, b]
|
49
50
|
end
|
51
|
+
|
52
|
+
func.apply(args).first
|
50
53
|
end
|
51
54
|
|
52
55
|
def initialize(stride: 1, pad: 0, cover_all: false)
|
53
|
-
@sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
|
54
|
-
@ph, @pw = pad.is_a?(Array) ? pad : [pad, pad]
|
56
|
+
@sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
|
57
|
+
@ph, @pw = pad.is_a?(::Array) ? pad : [pad, pad]
|
55
58
|
@cover_all = cover_all
|
56
59
|
end
|
57
60
|
|
58
|
-
def
|
61
|
+
def forward(inputs)
|
62
|
+
retain_inputs([0, 1])
|
59
63
|
x = inputs[0]
|
60
64
|
w = inputs[1]
|
61
65
|
b = inputs.size == 3 ? inputs[2] : nil
|
62
66
|
|
63
|
-
|
67
|
+
unless inputs.all? { |i| i.is_a?(Numo::NArray) }
|
68
|
+
if b.nil?
|
69
|
+
raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}"
|
70
|
+
else
|
71
|
+
raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}, b: #{b.class}"
|
72
|
+
end
|
73
|
+
end
|
64
74
|
|
65
|
-
|
66
|
-
|
75
|
+
kh, kw = w.shape[2..-1]
|
76
|
+
col = Chainer::Utils::Conv.im2col(x, kh, kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)
|
77
|
+
y = Chainer::Utils::Math.tensordot(col, w, [[1, 2, 3], [1, 2, 3]]).cast_to(x.class)
|
67
78
|
y += b if b
|
68
|
-
|
79
|
+
|
69
80
|
[y.transpose(0, 3, 1, 2)]
|
70
81
|
end
|
71
82
|
|
72
|
-
def
|
73
|
-
x, w
|
74
|
-
gy = grad_outputs
|
75
|
-
height, width = x.shape[2..-1]
|
83
|
+
def backward(indexes, grad_outputs)
|
84
|
+
x, w = get_retained_inputs
|
85
|
+
gy = grad_outputs.first
|
76
86
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
87
|
+
ret = []
|
88
|
+
if indexes.include?(0)
|
89
|
+
xh, xw = x.shape[2..-1]
|
90
|
+
gx = Deconvolution2DFunction.deconvolution_2d(gy, w, stride: [@sy, @sx], pad: [@ph, @pw], outsize: [xh, xw])
|
91
|
+
ret << gx
|
92
|
+
end
|
81
93
|
|
82
|
-
if
|
83
|
-
[
|
84
|
-
|
85
|
-
|
86
|
-
|
94
|
+
if indexes.include?(1)
|
95
|
+
gw = Chainer::Functions::Connection::Convolution2DGradW.new(self).apply([x, gy]).first
|
96
|
+
ret << gw
|
97
|
+
end
|
98
|
+
|
99
|
+
if indexes.include?(2)
|
100
|
+
gb = Chainer::Functions::Math::Sum.sum(gy, axis: [0, 2, 3])
|
101
|
+
ret << gb
|
87
102
|
end
|
103
|
+
|
104
|
+
ret
|
88
105
|
end
|
89
106
|
end
|
90
107
|
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Connection
|
4
|
+
class Convolution2DGradW < Chainer::FunctionNode
|
5
|
+
def initialize(conv2d)
|
6
|
+
w_node = conv2d.inputs[1]
|
7
|
+
|
8
|
+
@kh, @kw = w_node.shape[2..-1]
|
9
|
+
@sy = conv2d.sy
|
10
|
+
@sx = conv2d.sx
|
11
|
+
@ph = conv2d.ph
|
12
|
+
@pw = conv2d.pw
|
13
|
+
@cover_all = conv2d.cover_all
|
14
|
+
@w_dtype = w_node.dtype
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(inputs)
|
18
|
+
retain_inputs([0, 1])
|
19
|
+
x, gy = inputs
|
20
|
+
col = Chainer::Utils::Conv.im2col(x, @kh, @kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)
|
21
|
+
|
22
|
+
gw = Chainer::Utils::Math.tensordot(gy, col, [[0, 2, 3], [0, 4, 5]]).cast_to(@w_dtype)
|
23
|
+
[gw]
|
24
|
+
end
|
25
|
+
|
26
|
+
def backward(indexes, grad_outputs)
|
27
|
+
x, gy = get_retained_inputs
|
28
|
+
ggw = grad_outputs.first
|
29
|
+
|
30
|
+
ret = []
|
31
|
+
if indexes.include?(0)
|
32
|
+
xh, xw = x.shape[2..-1]
|
33
|
+
gx = Deconvolution2DFunction.deconvolution_2d(gy, ggw, stride: [@sy, @sx], pad: [@ph, @pw], outsize: [xh, xw])
|
34
|
+
ret << gx
|
35
|
+
end
|
36
|
+
|
37
|
+
if indexes.include?(1)
|
38
|
+
ggy = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(x, ggw, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all)
|
39
|
+
ret << ggy
|
40
|
+
end
|
41
|
+
|
42
|
+
ret
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
|