red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -0,0 +1,41 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
class Cast < FunctionNode
|
5
|
+
# Cast an input variable to a given type.
|
6
|
+
#
|
7
|
+
# @param x [Chainer::Variable or Numo::Narray] x : Input variable to be casted.
|
8
|
+
# @param type [Numo::Narray class] type : data class to cast
|
9
|
+
# @return [Chainer::Variable] Variable holding a casted array.
|
10
|
+
#
|
11
|
+
# example
|
12
|
+
# > x = Numo::UInt8.new(3, 5).seq
|
13
|
+
# > x.class
|
14
|
+
# # => Numo::UInt8
|
15
|
+
# > y = Chainer::Functions::Array::Cast.cast(x, Numo::DFloat)
|
16
|
+
# > y.dtype
|
17
|
+
# # => Numo::DFloat
|
18
|
+
def self.cast(x, type)
|
19
|
+
if (Chainer.array?(x) && x.class == type) || (x.is_a?(Chainer::Variable) && x.dtype == type)
|
20
|
+
return Chainer::Variable.as_variable(x)
|
21
|
+
end
|
22
|
+
self.new(type).apply([x]).first
|
23
|
+
end
|
24
|
+
|
25
|
+
def initialize(type)
|
26
|
+
@type = type
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(x)
|
30
|
+
@in_type = x.first.class
|
31
|
+
[x.first.cast_to(@type)]
|
32
|
+
end
|
33
|
+
|
34
|
+
def backward(indexes, g)
|
35
|
+
[Cast.cast(g.first, @in_type)]
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Reshapes an input array without copy.
|
5
|
+
class Reshape < FunctionNode
|
6
|
+
def initialize(shape)
|
7
|
+
@shape = shape
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.reshape(x, shape)
|
11
|
+
return Chainer::Variable.as_variable(x) if x.shape == shape
|
12
|
+
return self.new(shape).apply([x]).first
|
13
|
+
end
|
14
|
+
|
15
|
+
def forward(inputs)
|
16
|
+
x = inputs.first
|
17
|
+
new_shape = @shape.map { |s| s == -1 ? nil : s }
|
18
|
+
[x.reshape(*new_shape)]
|
19
|
+
end
|
20
|
+
|
21
|
+
def backward(indexes, grad_outputs)
|
22
|
+
gx = grad_outputs.first
|
23
|
+
[Reshape.reshape(gx, @inputs.first.shape)]
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,57 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Roll axis of an array.
|
5
|
+
class Rollaxis < FunctionNode
|
6
|
+
# Roll the axis backwards to the given position.
|
7
|
+
#
|
8
|
+
# @param [Chainer::Variable] x Input variable
|
9
|
+
# @param [Integer] axis The axis to roll backwards.
|
10
|
+
# @param [Integer] start The place to which the axis is moved.
|
11
|
+
# @return [Chainer::Variable] Variable whose axis is rolled.
|
12
|
+
def self.rollaxis(x, axis, start: 0)
|
13
|
+
Rollaxis.new(axis, start).apply([x]).first
|
14
|
+
end
|
15
|
+
|
16
|
+
def initialize(axis, start)
|
17
|
+
unless axis.is_a?(Integer)
|
18
|
+
raise ArgumentError, 'axis must be int'
|
19
|
+
end
|
20
|
+
|
21
|
+
unless start.is_a?(Integer)
|
22
|
+
raise ArgumentError, 'start must be int'
|
23
|
+
end
|
24
|
+
|
25
|
+
@axis = axis
|
26
|
+
@start = start
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(inputs)
|
30
|
+
retain_inputs([])
|
31
|
+
@in_ndim = inputs.first.ndim
|
32
|
+
|
33
|
+
[Chainer::Utils::Array.rollaxis(inputs.first, @axis, start: @start)]
|
34
|
+
end
|
35
|
+
|
36
|
+
def backward(indexes, gy)
|
37
|
+
axis = @axis
|
38
|
+
if axis < 0
|
39
|
+
axis += @in_ndim
|
40
|
+
end
|
41
|
+
start = @start
|
42
|
+
if start < 0
|
43
|
+
start += @in_ndim
|
44
|
+
end
|
45
|
+
|
46
|
+
if axis > start
|
47
|
+
axis += 1
|
48
|
+
else
|
49
|
+
start -= 1
|
50
|
+
end
|
51
|
+
|
52
|
+
Rollaxis.new(start, axis).apply(gy)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
@@ -0,0 +1,72 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Select elements stored in given indices.
|
5
|
+
class SelectItem < FunctionNode
|
6
|
+
# Select elements stored in given indices.
|
7
|
+
# This function returns $t.choose(x.T)$, that means
|
8
|
+
# $y[i] == x[i, t[i]]$ for all $i$.
|
9
|
+
#
|
10
|
+
# @param [Chainer::Variable] x Variable storing arrays.
|
11
|
+
# @param [Chainer::Variable] t Variable storing index numbers.
|
12
|
+
# @return [Chainer::Variable] Variable that holds $t$-th element of $x$.
|
13
|
+
def self.select_item(x, t)
|
14
|
+
SelectItem.new.apply([x, t]).first
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(inputs)
|
18
|
+
retain_inputs([1])
|
19
|
+
x, t = inputs
|
20
|
+
@in_shape = x.shape
|
21
|
+
@in_dtype = x.class
|
22
|
+
|
23
|
+
# TODO: x[six.moves.range(t.size), t]
|
24
|
+
new_x = x.class.zeros(t.size)
|
25
|
+
t.size.times.each do |i|
|
26
|
+
new_x[i] = x[i, t[i]]
|
27
|
+
end
|
28
|
+
x = new_x
|
29
|
+
|
30
|
+
[x]
|
31
|
+
end
|
32
|
+
|
33
|
+
def backward(indexes, gy)
|
34
|
+
t = get_retained_inputs.first
|
35
|
+
ret = []
|
36
|
+
if indexes.include?(0)
|
37
|
+
ggx = Assign.new(@in_shape, @in_dtype, t).apply(gy).first
|
38
|
+
ret << ggx
|
39
|
+
end
|
40
|
+
if indexes.include?(1)
|
41
|
+
ret << nil
|
42
|
+
end
|
43
|
+
ret
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
class Assign < FunctionNode
|
48
|
+
def initialize(shape, dtype, t)
|
49
|
+
@shape = shape
|
50
|
+
@dtype = dtype
|
51
|
+
@t = t.data
|
52
|
+
end
|
53
|
+
|
54
|
+
def forward(inputs)
|
55
|
+
gx = @dtype.zeros(*@shape)
|
56
|
+
|
57
|
+
# TODO: gx[six.moves.range(self.t.size), self.t] = inputs[0]
|
58
|
+
# binding.pry
|
59
|
+
@t.size.times.each do |i|
|
60
|
+
gx[i, @t[i]] = inputs[0][i]
|
61
|
+
end
|
62
|
+
|
63
|
+
[gx]
|
64
|
+
end
|
65
|
+
|
66
|
+
def backward(indexes, gy)
|
67
|
+
SelectItem.new.apply([gy[0], @t])
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
@@ -0,0 +1,78 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
class Squeeze < FunctionNode
|
5
|
+
# Remove demensions of size one from the shape of a Numo::NArray.
|
6
|
+
# @param [Chainer::Variable or Numo::NArray] x Input data.
|
7
|
+
# @param [nil or integer or array of integer] axis A subset of the single-dimensional entries in the shape to remove.
|
8
|
+
# If `nil` is supplied, all of them are removed. The dimension index starts at zero.
|
9
|
+
# If an axis with dimension greater than one is selected, an error is raised.
|
10
|
+
# @return [Chainer::Variable] Variable whose dimensions of size 1 are removed.
|
11
|
+
def self.squeeze(x, axis: nil)
|
12
|
+
self.new(axis: axis).apply([x]).first
|
13
|
+
end
|
14
|
+
|
15
|
+
def initialize(axis: nil)
|
16
|
+
if axis.nil?
|
17
|
+
@axis = nil
|
18
|
+
elsif axis.kind_of?(Integer)
|
19
|
+
@axis = [axis]
|
20
|
+
elsif axis.kind_of?(::Array) && Array(axis).all? { |i| i.kind_of?(Integer) }
|
21
|
+
@axis = axis
|
22
|
+
else
|
23
|
+
raise TypeError, 'axis must be None, int or tuple of ints'
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
def forward(inputs)
|
28
|
+
x = inputs.first
|
29
|
+
shape = x.shape
|
30
|
+
|
31
|
+
# TODO: numpy.squeeze
|
32
|
+
if @axis.nil?
|
33
|
+
new_shape = shape.reject { |axis| axis == 1 }
|
34
|
+
else
|
35
|
+
new_shape = shape
|
36
|
+
@axis.map do |a|
|
37
|
+
raise StandardError, "cannot select an axis to squeeze out which has size not equal to one" unless shape[a] == 1
|
38
|
+
new_shape[a] = nil
|
39
|
+
end
|
40
|
+
new_shape.compact!
|
41
|
+
end
|
42
|
+
ret = new_shape.size.zero? ? x.class.new.fill(x[0]) : x.reshape(*new_shape)
|
43
|
+
|
44
|
+
[ret]
|
45
|
+
end
|
46
|
+
|
47
|
+
def backward(indexes, grad_outputs)
|
48
|
+
if @axis.nil?
|
49
|
+
axis = argone(@inputs[0].shape)
|
50
|
+
else
|
51
|
+
axis = @axis
|
52
|
+
ndim = @inputs[0].shape.size
|
53
|
+
axis = axis.map { |x| x < 0 ? x + ndim : x }
|
54
|
+
axis.sort!
|
55
|
+
end
|
56
|
+
gx = grad_outputs.first
|
57
|
+
|
58
|
+
shape = gx.shape
|
59
|
+
axis.each do |x|
|
60
|
+
shape.insert(x, 1)
|
61
|
+
end
|
62
|
+
[gx.reshape(*shape)]
|
63
|
+
end
|
64
|
+
|
65
|
+
private
|
66
|
+
|
67
|
+
def argone(iterable)
|
68
|
+
result = []
|
69
|
+
Array(iterable).each_with_index do |x, i|
|
70
|
+
raise StandardError, "elements in iterable must be int" unless x.kind_of?(Integer)
|
71
|
+
result << i if x == 1
|
72
|
+
end
|
73
|
+
result
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -0,0 +1,44 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Array
|
4
|
+
# Permute the dimensions of an array.
|
5
|
+
class Transpose < FunctionNode
|
6
|
+
# Permute the dimensions of an input variable without copy.
|
7
|
+
#
|
8
|
+
# @param [Chainer::Variable] x Input Variable.
|
9
|
+
# @param [::Array<Integer>] axes By default, reverse the dimensions,
|
10
|
+
# otherwise permute the axes according to the values given.
|
11
|
+
# @return [Chainer::Variable] Variable whose axes are permuted.
|
12
|
+
def self.transpose(x, axes: nil)
|
13
|
+
Transpose.new(axes: axes).apply([x]).first
|
14
|
+
end
|
15
|
+
|
16
|
+
def initialize(axes: nil)
|
17
|
+
@axes = axes
|
18
|
+
end
|
19
|
+
|
20
|
+
def label
|
21
|
+
'Transpose'
|
22
|
+
end
|
23
|
+
|
24
|
+
def forward(inputs)
|
25
|
+
x = inputs.first
|
26
|
+
[x.transpose(*@axes)]
|
27
|
+
end
|
28
|
+
|
29
|
+
def backward(indexes, grad_outputs)
|
30
|
+
inv_axes = @axes
|
31
|
+
if inv_axes
|
32
|
+
axes_len = inv_axes.size
|
33
|
+
|
34
|
+
axes = inv_axes.map { |ax| ax % axes_len }
|
35
|
+
inv_axes = Numo::NArray[*axes].sort_index.to_a
|
36
|
+
end
|
37
|
+
|
38
|
+
Transpose.new(axes: inv_axes).apply(grad_outputs)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
|
@@ -1,7 +1,8 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Connection
|
4
|
-
class Convolution2DFunction < Chainer::
|
4
|
+
class Convolution2DFunction < Chainer::FunctionNode
|
5
|
+
attr_reader :sy, :sx, :ph, :pw, :cover_all
|
5
6
|
# Two-dimensional convolution function.
|
6
7
|
# This is an implementation of two-dimensional convolution in ConvNets.
|
7
8
|
# It takes three variables: the input image `x`, the filter weight `w`, and the bias vector `b`.
|
@@ -33,9 +34,9 @@ module Chainer
|
|
33
34
|
# w_O &= (w_I + 2w_P - w_K + s_X - 1) / s_X + 1.
|
34
35
|
# If the bias vector is given, then it is added to all spatial locations of the output of convolution.
|
35
36
|
#
|
36
|
-
# @param [Chainer::Variable or Numo::NArray] x Input variable of shape :math:`(n, c_I, h_I, w_I)`.
|
37
|
-
# @param [Chainer::Variable or Numo::NArray] w Weight variable of shape :math:`(c_O, c_I, h_K, w_K)`.
|
38
|
-
# @param [Chainer::Variable or Numo::NArray] b Bias variable of length :math:`c_O`
|
37
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x Input variable of shape :math:`(n, c_I, h_I, w_I)`.
|
38
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] w Weight variable of shape :math:`(c_O, c_I, h_K, w_K)`.
|
39
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] b Bias variable of length :math:`c_O`
|
39
40
|
# @param [Int or 2-D Array] stride Stride of filter applications. `stride=s` and `stride=(s, s)` are equivalent.
|
40
41
|
# @param [Int or 2-D Array] pad Spatial padding width for input arrays.
|
41
42
|
# @param [Boolean] cover_all If `true`, all spatial locations are convoluted into some output pixels.
|
@@ -43,48 +44,64 @@ module Chainer
|
|
43
44
|
def self.convolution_2d(x, w, b: nil, stride: 1, pad: 0, cover_all: false)
|
44
45
|
func = self.new(stride: stride, pad: pad, cover_all: cover_all)
|
45
46
|
if b.nil?
|
46
|
-
|
47
|
+
args = [x, w]
|
47
48
|
else
|
48
|
-
|
49
|
+
args = [x, w, b]
|
49
50
|
end
|
51
|
+
|
52
|
+
func.apply(args).first
|
50
53
|
end
|
51
54
|
|
52
55
|
def initialize(stride: 1, pad: 0, cover_all: false)
|
53
|
-
@sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
|
54
|
-
@ph, @pw = pad.is_a?(Array) ? pad : [pad, pad]
|
56
|
+
@sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
|
57
|
+
@ph, @pw = pad.is_a?(::Array) ? pad : [pad, pad]
|
55
58
|
@cover_all = cover_all
|
56
59
|
end
|
57
60
|
|
58
|
-
def
|
61
|
+
def forward(inputs)
|
62
|
+
retain_inputs([0, 1])
|
59
63
|
x = inputs[0]
|
60
64
|
w = inputs[1]
|
61
65
|
b = inputs.size == 3 ? inputs[2] : nil
|
62
66
|
|
63
|
-
|
67
|
+
unless inputs.all? { |i| i.is_a?(Numo::NArray) }
|
68
|
+
if b.nil?
|
69
|
+
raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}"
|
70
|
+
else
|
71
|
+
raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}, b: #{b.class}"
|
72
|
+
end
|
73
|
+
end
|
64
74
|
|
65
|
-
|
66
|
-
|
75
|
+
kh, kw = w.shape[2..-1]
|
76
|
+
col = Chainer::Utils::Conv.im2col(x, kh, kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)
|
77
|
+
y = Chainer::Utils::Math.tensordot(col, w, [[1, 2, 3], [1, 2, 3]]).cast_to(x.class)
|
67
78
|
y += b if b
|
68
|
-
|
79
|
+
|
69
80
|
[y.transpose(0, 3, 1, 2)]
|
70
81
|
end
|
71
82
|
|
72
|
-
def
|
73
|
-
x, w
|
74
|
-
gy = grad_outputs
|
75
|
-
height, width = x.shape[2..-1]
|
83
|
+
def backward(indexes, grad_outputs)
|
84
|
+
x, w = get_retained_inputs
|
85
|
+
gy = grad_outputs.first
|
76
86
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
87
|
+
ret = []
|
88
|
+
if indexes.include?(0)
|
89
|
+
xh, xw = x.shape[2..-1]
|
90
|
+
gx = Deconvolution2DFunction.deconvolution_2d(gy, w, stride: [@sy, @sx], pad: [@ph, @pw], outsize: [xh, xw])
|
91
|
+
ret << gx
|
92
|
+
end
|
81
93
|
|
82
|
-
if
|
83
|
-
[
|
84
|
-
|
85
|
-
|
86
|
-
|
94
|
+
if indexes.include?(1)
|
95
|
+
gw = Chainer::Functions::Connection::Convolution2DGradW.new(self).apply([x, gy]).first
|
96
|
+
ret << gw
|
97
|
+
end
|
98
|
+
|
99
|
+
if indexes.include?(2)
|
100
|
+
gb = Chainer::Functions::Math::Sum.sum(gy, axis: [0, 2, 3])
|
101
|
+
ret << gb
|
87
102
|
end
|
103
|
+
|
104
|
+
ret
|
88
105
|
end
|
89
106
|
end
|
90
107
|
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Connection
|
4
|
+
class Convolution2DGradW < Chainer::FunctionNode
|
5
|
+
def initialize(conv2d)
|
6
|
+
w_node = conv2d.inputs[1]
|
7
|
+
|
8
|
+
@kh, @kw = w_node.shape[2..-1]
|
9
|
+
@sy = conv2d.sy
|
10
|
+
@sx = conv2d.sx
|
11
|
+
@ph = conv2d.ph
|
12
|
+
@pw = conv2d.pw
|
13
|
+
@cover_all = conv2d.cover_all
|
14
|
+
@w_dtype = w_node.dtype
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(inputs)
|
18
|
+
retain_inputs([0, 1])
|
19
|
+
x, gy = inputs
|
20
|
+
col = Chainer::Utils::Conv.im2col(x, @kh, @kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)
|
21
|
+
|
22
|
+
gw = Chainer::Utils::Math.tensordot(gy, col, [[0, 2, 3], [0, 4, 5]]).cast_to(@w_dtype)
|
23
|
+
[gw]
|
24
|
+
end
|
25
|
+
|
26
|
+
def backward(indexes, grad_outputs)
|
27
|
+
x, gy = get_retained_inputs
|
28
|
+
ggw = grad_outputs.first
|
29
|
+
|
30
|
+
ret = []
|
31
|
+
if indexes.include?(0)
|
32
|
+
xh, xw = x.shape[2..-1]
|
33
|
+
gx = Deconvolution2DFunction.deconvolution_2d(gy, ggw, stride: [@sy, @sx], pad: [@ph, @pw], outsize: [xh, xw])
|
34
|
+
ret << gx
|
35
|
+
end
|
36
|
+
|
37
|
+
if indexes.include?(1)
|
38
|
+
ggy = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(x, ggw, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all)
|
39
|
+
ret << ggy
|
40
|
+
end
|
41
|
+
|
42
|
+
ret
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
|