red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -22,7 +22,7 @@ module Chainer
|
|
22
22
|
end
|
23
23
|
|
24
24
|
actual.each_with_index{|actual_val, *i|
|
25
|
-
if (expect[*i].to_f - actual_val.to_f).abs > atol + rtol * expect[*i].abs
|
25
|
+
if (expect[*i].to_f - actual_val.to_f).abs > atol + rtol * expect[*i].to_f.abs
|
26
26
|
raise "assert_allclose Error\n expect: #{expect.inspect}\n actual : #{actual.inspect}\n (#{i})=> #{(expect - actual).abs.max()} > #{atol + rtol * expect[*i].abs}"
|
27
27
|
end
|
28
28
|
}
|
@@ -34,7 +34,7 @@ module Chainer
|
|
34
34
|
# If this is just a link object, the link is registered by the name 'main'.
|
35
35
|
# @param [Dataset::Convert] converter Converter function to build input arrays.
|
36
36
|
# `Chainer::Dataset.concat_examples` is used by default.
|
37
|
-
# @param [
|
37
|
+
# @param [Chainer::Device] device Device to which the training data is sent.
|
38
38
|
# @param [Function] eval_hook Function to prepare for each evaluation process.
|
39
39
|
# It is called at the beginning of the evaluation.
|
40
40
|
# The evaluator extension object is passed at each call.
|
@@ -97,8 +97,7 @@ module Chainer
|
|
97
97
|
# @return dict Result dictionary. This dictionary is further reported via `Chainer.save_report` without specifying any observer.
|
98
98
|
def evaluate
|
99
99
|
iterator = @iterators[:main]
|
100
|
-
|
101
|
-
eval_func = @eval_func || target
|
100
|
+
eval_func = @eval_func || @targets[:main]
|
102
101
|
|
103
102
|
@eval_hook.(self) if @eval_hook
|
104
103
|
|
@@ -1,12 +1,11 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Training
|
3
3
|
class ExtensionEntry
|
4
|
-
attr_accessor :extension, :trigger, :
|
4
|
+
attr_accessor :extension, :trigger, :priority
|
5
5
|
|
6
|
-
def initialize(extension, priority, trigger
|
6
|
+
def initialize(extension, priority, trigger)
|
7
7
|
@extension = extension
|
8
8
|
@trigger = trigger
|
9
|
-
@invoke_before_training = invoke_before_training
|
10
9
|
@priority = priority
|
11
10
|
end
|
12
11
|
end
|
@@ -47,7 +46,7 @@ module Chainer
|
|
47
46
|
Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f
|
48
47
|
end
|
49
48
|
|
50
|
-
def extend(extension, name: nil, trigger: nil, priority: nil
|
49
|
+
def extend(extension, name: nil, trigger: nil, priority: nil)
|
51
50
|
if name.nil?
|
52
51
|
name = if extension.name
|
53
52
|
extension.name
|
@@ -69,10 +68,6 @@ module Chainer
|
|
69
68
|
priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER
|
70
69
|
end
|
71
70
|
|
72
|
-
if invoke_before_training.nil?
|
73
|
-
invoke_before_training = extension.methods.include?(:invoke_before_training) ? extension.invoke_before_training : false
|
74
|
-
end
|
75
|
-
|
76
71
|
modified_name = name
|
77
72
|
ordinal = 0
|
78
73
|
|
@@ -82,7 +77,7 @@ module Chainer
|
|
82
77
|
end
|
83
78
|
|
84
79
|
extension.name = modified_name
|
85
|
-
@extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger
|
80
|
+
@extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger)
|
86
81
|
end
|
87
82
|
|
88
83
|
def get_extension(name)
|
@@ -13,6 +13,11 @@ module Chainer
|
|
13
13
|
@previous_epoch_detail = 0.0
|
14
14
|
end
|
15
15
|
|
16
|
+
# Decides whether the extension should be called on this iteration.
|
17
|
+
#
|
18
|
+
# @param [Chainer::Trainer] trainer Trainer object that this trigger is associated with.
|
19
|
+
# The updater associated with this trainer is used to determine if the trigger should fire.
|
20
|
+
# @return [boolean] True if the corresponding extension should be invoked in this iteration.
|
16
21
|
def call(trainer)
|
17
22
|
updater = trainer.updater
|
18
23
|
if @unit == 'epoch'
|
@@ -30,7 +35,7 @@ module Chainer
|
|
30
35
|
iteration = updater.iteration
|
31
36
|
previous_iteration = @previous_iteration
|
32
37
|
if previous_iteration < 0
|
33
|
-
|
38
|
+
previous_iteration = iteration - 1
|
34
39
|
end
|
35
40
|
fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
|
36
41
|
end
|
@@ -38,7 +43,7 @@ module Chainer
|
|
38
43
|
# save current values
|
39
44
|
@previous_iteration = updater.iteration
|
40
45
|
@previous_epoch_detail = updater.epoch_detail
|
41
|
-
|
46
|
+
|
42
47
|
fire
|
43
48
|
end
|
44
49
|
|
data/lib/chainer/utils/array.rb
CHANGED
@@ -4,7 +4,8 @@ module Chainer
|
|
4
4
|
def self.force_array(x, dtype=nil)
|
5
5
|
if x.is_a? Integer or x.is_a? Float
|
6
6
|
if dtype.nil?
|
7
|
-
|
7
|
+
xm = Chainer::Device.default.xm
|
8
|
+
xm::NArray.cast(x)
|
8
9
|
else
|
9
10
|
dtype.cast(x.dup)
|
10
11
|
end
|
@@ -16,6 +17,84 @@ module Chainer
|
|
16
17
|
end
|
17
18
|
end
|
18
19
|
end
|
20
|
+
|
21
|
+
def self.take(x, indices, axis: nil)
|
22
|
+
if axis
|
23
|
+
indices = make_indecies_with_axis(x.shape, indices, axis)
|
24
|
+
end
|
25
|
+
x[indices]
|
26
|
+
end
|
27
|
+
|
28
|
+
def self.make_indecies_with_axis(shape, indices, axis, values = [])
|
29
|
+
target_axis = values.size
|
30
|
+
if shape.size == values.size
|
31
|
+
values.zip(shape.drop(1) + [1]).reduce(0) do |sum, (x, ndim)|
|
32
|
+
(sum + x) * ndim
|
33
|
+
end
|
34
|
+
else
|
35
|
+
enum = (axis == target_axis) ? indices : (0...shape[target_axis])
|
36
|
+
if enum.is_a?(Integer)
|
37
|
+
make_indecies_with_axis(shape, indices, axis, values + [indices])
|
38
|
+
else
|
39
|
+
enum.map do |x|
|
40
|
+
make_indecies_with_axis(shape, indices, axis, values + [x])
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def self.rollaxis(y, axis, start: 0)
|
47
|
+
n = y.ndim
|
48
|
+
# normalize axis
|
49
|
+
axis = axis < 0 ? n + axis : axis
|
50
|
+
if axis >= n
|
51
|
+
raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{n}"
|
52
|
+
end
|
53
|
+
|
54
|
+
if start < 0
|
55
|
+
start += n
|
56
|
+
end
|
57
|
+
|
58
|
+
unless 0 <= start && start < n + 1
|
59
|
+
raise ArgumentError, "start arg requires #{-n} <= start < #{n}, but #{start} was passed in"
|
60
|
+
end
|
61
|
+
|
62
|
+
if axis < start
|
63
|
+
start -= 1
|
64
|
+
end
|
65
|
+
|
66
|
+
if axis == start
|
67
|
+
return y
|
68
|
+
end
|
69
|
+
|
70
|
+
axes = (0...n).to_a
|
71
|
+
axes.delete_at(axis)
|
72
|
+
axes.insert(start <= axes.size ? start : -1, axis)
|
73
|
+
y.transpose(*axes)
|
74
|
+
end
|
75
|
+
|
76
|
+
def self.broadcast_to(x, shape)
|
77
|
+
if x.shape.size > shape.size
|
78
|
+
raise TypeError, "Shape of data mismatch\n x.shape.size(#{x.shape.size}) > shape.size(#{shape.size})"
|
79
|
+
end
|
80
|
+
|
81
|
+
tile_shape = []
|
82
|
+
if x.shape.size > 0
|
83
|
+
shape[-x.shape.size..-1].each_with_index do |s, i|
|
84
|
+
if x.shape[i] == 1
|
85
|
+
tile_shape << s
|
86
|
+
elsif x.shape[i] == s
|
87
|
+
tile_shape << 1
|
88
|
+
else
|
89
|
+
raise TypeError, "Shape of data mismatch\n#{x.shape} != #{shape}"
|
90
|
+
end
|
91
|
+
end
|
92
|
+
else
|
93
|
+
tile_shape = shape
|
94
|
+
end
|
95
|
+
|
96
|
+
x.tile(*shape[0...-x.shape.size], *tile_shape)
|
97
|
+
end
|
19
98
|
end
|
20
99
|
end
|
21
100
|
end
|
data/lib/chainer/utils/conv.rb
CHANGED
@@ -10,7 +10,15 @@ module Chainer
|
|
10
10
|
end
|
11
11
|
end
|
12
12
|
|
13
|
-
def self.
|
13
|
+
def self.get_deconv_outsize(size, k, s, p, cover_all: false)
|
14
|
+
if cover_all
|
15
|
+
s * (size - 1) + k -s + 1 - 2 * p
|
16
|
+
else
|
17
|
+
s * (size - 1) + k - 2 * p
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def self.im2col(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
|
14
22
|
n, c, h, w = img.shape
|
15
23
|
|
16
24
|
out_h = self.get_conv_outsize(h, kh, sy, ph, cover_all: cover_all, d: dy)
|
@@ -40,7 +48,7 @@ module Chainer
|
|
40
48
|
col
|
41
49
|
end
|
42
50
|
|
43
|
-
def self.
|
51
|
+
def self.col2im(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
|
44
52
|
n, c, kh, kw, out_h, out_w = col.shape
|
45
53
|
img = col.class.zeros(n, c, h + 2 * ph + sy - 1, w + 2 * pw + sx - 1)
|
46
54
|
kh.times do |j|
|
@@ -1,10 +1,10 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Utils
|
3
3
|
module Initializer
|
4
|
-
def self.get_fans(shape)
|
4
|
+
def self.get_fans(shape, device: Chainer::Device.default)
|
5
5
|
raise 'shape must be of length >= 2: shape={}' if shape.size < 2
|
6
6
|
slice_arr = shape.slice(2, shape.size)
|
7
|
-
receptive_field_size = slice_arr.empty? ? 1 :
|
7
|
+
receptive_field_size = slice_arr.empty? ? 1 : device.xm::Int32[slice_arr].prod
|
8
8
|
fan_in = shape[1] * receptive_field_size
|
9
9
|
fan_out = shape[0] * receptive_field_size
|
10
10
|
[fan_in, fan_out]
|
data/lib/chainer/variable.rb
CHANGED
@@ -2,25 +2,47 @@ module Chainer
|
|
2
2
|
class Variable
|
3
3
|
attr_accessor :data, :grad, :requires_grad, :node
|
4
4
|
|
5
|
+
# Converts an array or a variable into +Chainer::Variable+.
|
6
|
+
# This is a convenient function to get a +Chainer::Variable+ object
|
7
|
+
# transparently from a raw array or a variable.
|
8
|
+
# Note: that this function should only be used for type consistency
|
9
|
+
# (i.e. to enforce the return value of an API having type +Chainer::Variable+).
|
10
|
+
# The +Chianer::Variable.requires_grad+ flag is kept as is; if +obj+ is a raw array,
|
11
|
+
# the newly created variable has +requires_grad = false+.
|
12
|
+
# In order to make a variable w.r.t. which you want to compute the gradient,
|
13
|
+
# you should use $Chainer::Variable$ directly.
|
14
|
+
#
|
15
|
+
# @param [Numo::NArray or Chainer::Variable] obj An array or a variable that you want to convert to $Chainer::Variable$.
|
16
|
+
# @return [Chainer::Variable] A variable converted from +obj+. If +obj+ is a raw array,
|
17
|
+
# this is a new +Chianer::Variable+ object that wraps the array. If +obj+ is already a +Chainer::Variable+ object, this function returns +obj+ as is.
|
18
|
+
def self.as_variable(obj)
|
19
|
+
return obj if obj.kind_of?(Chainer::Variable)
|
20
|
+
# TODO if obj is_backprop_required is true, set requires_grad = true
|
21
|
+
self.new(obj, requires_grad: false)
|
22
|
+
end
|
23
|
+
|
5
24
|
def initialize(data=nil, name: nil, grad: nil, requires_grad: true)
|
6
|
-
unless data.nil? ||
|
7
|
-
raise TypeError, "Numo::NArray are expected."
|
25
|
+
unless data.nil? || Chainer.array?(data)
|
26
|
+
raise TypeError, "Numo::NArray or Cumo::NArray are expected."
|
8
27
|
end
|
9
28
|
|
10
29
|
@data = [data]
|
11
30
|
@grad = grad
|
12
31
|
@requires_grad = requires_grad
|
13
|
-
@node = VariableNode.new(variable: self, name: name
|
32
|
+
@node = VariableNode.new(variable: self, name: name)
|
33
|
+
@grad_var = grad.nil? ? nil : Chainer::Variable.new(grad)
|
14
34
|
end
|
15
35
|
|
16
36
|
def data
|
17
37
|
return @data[0]
|
18
38
|
end
|
39
|
+
alias_method :array, :data
|
19
40
|
|
20
41
|
def data=(d)
|
21
42
|
@data[0] = d
|
22
43
|
@node.set_data_type(d)
|
23
44
|
end
|
45
|
+
alias_method :array=, :data=
|
24
46
|
|
25
47
|
def name
|
26
48
|
return @node.name
|
@@ -34,6 +56,7 @@ module Chainer
|
|
34
56
|
@node.label
|
35
57
|
end
|
36
58
|
|
59
|
+
# deprecated FunctionNode
|
37
60
|
def creator
|
38
61
|
@node.creator
|
39
62
|
end
|
@@ -42,12 +65,30 @@ module Chainer
|
|
42
65
|
@node.creator = func
|
43
66
|
end
|
44
67
|
|
68
|
+
def creator_node
|
69
|
+
@node.creator_node
|
70
|
+
end
|
71
|
+
|
72
|
+
def creator_node=(func)
|
73
|
+
@node.creator_node = func
|
74
|
+
end
|
75
|
+
|
45
76
|
def grad
|
46
|
-
@
|
77
|
+
gv = @grad_var
|
78
|
+
gv.nil? ? nil : gv.data
|
47
79
|
end
|
48
80
|
|
49
81
|
def grad=(g)
|
50
|
-
|
82
|
+
self.grad_var = g.nil? ? nil : Chainer::Variable.new(g)
|
83
|
+
end
|
84
|
+
|
85
|
+
def grad_var
|
86
|
+
@grad_var
|
87
|
+
end
|
88
|
+
|
89
|
+
def grad_var=(g)
|
90
|
+
Utils::Variable.check_grad_type(nil, self, g.data) unless g.nil?
|
91
|
+
@grad_var = g
|
51
92
|
end
|
52
93
|
|
53
94
|
def shape
|
@@ -70,123 +111,172 @@ module Chainer
|
|
70
111
|
@node.rank
|
71
112
|
end
|
72
113
|
|
114
|
+
def transpose
|
115
|
+
Chainer::Functions::Array::Transpose.transpose(self)
|
116
|
+
end
|
117
|
+
|
118
|
+
def reshape(*shape)
|
119
|
+
if shape.size == 1 && shape[0].kind_of?(::Aray)
|
120
|
+
shape = shape[0]
|
121
|
+
end
|
122
|
+
Chainer::Functions::Array::Reshape.reshape(self, shape)
|
123
|
+
end
|
124
|
+
|
125
|
+
# Clears the gradient array.
|
73
126
|
def cleargrad
|
74
|
-
@
|
127
|
+
@grad_var = nil
|
128
|
+
end
|
129
|
+
|
130
|
+
# Notifies the variable that the given node is its creator.
|
131
|
+
#
|
132
|
+
# @param [Chainer::FunctionNode] fnode node that has this variable as an output.
|
133
|
+
def set_creator_node(fnode)
|
134
|
+
@node.set_creator_node(fnode)
|
75
135
|
end
|
76
136
|
|
77
|
-
def backward(retain_grad: false)
|
78
|
-
|
137
|
+
def backward(retain_grad: false, enable_double_backprop: true)
|
138
|
+
old_enable_backprop = Chainer.configuration.enable_backprop
|
139
|
+
Chainer.configuration.enable_backprop = enable_double_backprop
|
140
|
+
_backward_main(retain_grad)
|
141
|
+
Chainer.configuration.enable_backprop = old_enable_backprop
|
142
|
+
end
|
143
|
+
|
144
|
+
def _backward_main(retain_grad)
|
145
|
+
node.check_old_style_gradient
|
146
|
+
return if self.creator_node.nil?
|
79
147
|
|
80
|
-
|
148
|
+
seen_set = Set.new
|
149
|
+
grads = {}
|
150
|
+
if self.data.size == 1 && self.grad_var.nil?
|
81
151
|
self.grad = self.data.new_ones
|
82
152
|
end
|
153
|
+
grads[self.node] = self.grad_var
|
83
154
|
|
84
|
-
funcs = [self.
|
155
|
+
funcs = [self.creator_node]
|
156
|
+
seen_set.add(self.creator_node)
|
85
157
|
|
86
|
-
while func = funcs.
|
158
|
+
while func = funcs.shift
|
159
|
+
inputs = func.inputs
|
160
|
+
target_input_indexes = inputs.each_with_index.map { |x, i| i if x.requires_grad }.compact
|
161
|
+
next if target_input_indexes.empty?
|
87
162
|
outputs = func.outputs.map(&:__getobj__)
|
88
|
-
in_data = func.inputs.map(&:data)
|
89
|
-
out_grad = outputs.map { |y| y.nil? ? nil : y.grad }
|
90
|
-
|
91
|
-
func.output_data = outputs.map { |y| y.nil? ? nil : y.data }
|
92
|
-
gxs = func.backward(in_data, out_grad)
|
93
163
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
164
|
+
in_data = inputs.map(&:data)
|
165
|
+
out_grad = outputs.map do |y|
|
166
|
+
next nil if y.nil?
|
167
|
+
next grads[y] unless grads[y].nil?
|
168
|
+
y.grad_var
|
98
169
|
end
|
170
|
+
out_grad_data = out_grad.map { |g| g.nil? ? g : g.data }
|
171
|
+
|
172
|
+
# Collect the current input gradients.
|
173
|
+
#
|
174
|
+
# When the same variable is passed to multiple input slots (e.g. an expression like +f(x, x)+),
|
175
|
+
# it makes the gradient accumulation complicated since the back-propagated gradients w.r.t.
|
176
|
+
# the first and second argument should be accumulated to the current gradient w.r.t. the same variable.
|
177
|
+
# In this case, the current implementation passes the current gradient only to the first occurrence of the variable
|
178
|
+
# in the input tuple and passes +nil+ to the rest of the occurrences.
|
179
|
+
# For example, when the input variables are +(x, x)+,
|
180
|
+
# the input gradient passed to the +backward_accumulate+ method is +(gx, nil)+ where +gx+ is the current gradient of ++x++.
|
181
|
+
# See also the docstring of +FunctionNode.backward_accumulate+.
|
182
|
+
target_inputs = target_input_indexes.map { |i| inputs[i] }
|
183
|
+
in_grad = []
|
184
|
+
target_input_indexes.each_with_index do |index_i, i|
|
185
|
+
x = inputs[index_i]
|
186
|
+
if target_inputs[0...i].include?(x)
|
187
|
+
gx = nil
|
188
|
+
elsif grads[x]
|
189
|
+
gx = grads[x]
|
190
|
+
elsif x.creator_node.nil?
|
191
|
+
gx = x.grad_var
|
192
|
+
else
|
193
|
+
gx = nil
|
194
|
+
end
|
195
|
+
in_grad << gx
|
196
|
+
end
|
197
|
+
|
198
|
+
gxs = func.backward_accumulate(target_input_indexes, out_grad, in_grad)
|
199
|
+
raise "Unmatched matries size: gxs.size(#{gxs.size}) != in_grad.size(#{in_grad.size})" unless gxs.size == in_grad.size
|
99
200
|
|
100
201
|
unless retain_grad
|
101
202
|
outputs.each do |y|
|
102
203
|
unless y.nil? || y == @node
|
103
|
-
y
|
204
|
+
grads[y] = nil
|
205
|
+
y_var = y.get_variable
|
206
|
+
y_var.grad_var = nil unless y_var.nil?
|
104
207
|
end
|
105
208
|
end
|
106
209
|
end
|
107
210
|
|
108
|
-
|
109
|
-
need_copy = []
|
110
|
-
|
111
|
-
func.inputs.zip(gxs).each do |x, gx|
|
211
|
+
gxs.each_with_index do |gx, i|
|
112
212
|
next if gx.nil?
|
213
|
+
x = target_inputs[i]
|
113
214
|
next unless x.requires_grad
|
114
215
|
|
115
|
-
Utils::Variable.check_grad_type(func, x, gx)
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
else # not leaf
|
131
|
-
funcs << x.creator
|
132
|
-
if seen_vars.include?(id_x)
|
133
|
-
if need_copy.include?(id_x)
|
134
|
-
x.grad = Utils::Array.force_array(gx + x.grad)
|
135
|
-
need_copy.delete(id_x)
|
136
|
-
else
|
137
|
-
x.grad += gx
|
138
|
-
end
|
139
|
-
else
|
140
|
-
x.grad = gx
|
141
|
-
seen_vars << id_x
|
142
|
-
need_copy << id_x
|
143
|
-
end
|
216
|
+
Utils::Variable.check_grad_type(func, x, gx.data)
|
217
|
+
|
218
|
+
if target_inputs[0...i].include?(x)
|
219
|
+
cur_gx = grads[x]
|
220
|
+
grads[x] = cur_gx.nil? ? gx : gx + cur_gx
|
221
|
+
else
|
222
|
+
grads[x] = gx
|
223
|
+
end
|
224
|
+
|
225
|
+
x_var = x.get_variable
|
226
|
+
x_var.grad_var = grads[x] if x_var
|
227
|
+
|
228
|
+
if x.creator_node && !seen_set.include?(x.creator_node)
|
229
|
+
funcs << x.creator_node
|
230
|
+
seen_set.add(x.creator_node)
|
144
231
|
end
|
145
232
|
end
|
146
|
-
|
233
|
+
|
234
|
+
funcs.sort_by! { |f| -f.rank }
|
235
|
+
|
236
|
+
end
|
147
237
|
end
|
148
238
|
|
149
239
|
def -@
|
150
|
-
Functions::Math::Neg.new.(self)
|
240
|
+
Functions::Math::Neg.new.apply([self]).first
|
151
241
|
end
|
152
242
|
|
153
243
|
def +(other)
|
154
244
|
if other.instance_of?(Chainer::Variable)
|
155
|
-
Functions::Math::Add.new.(
|
245
|
+
Functions::Math::Add.new.apply([self, other])[0]
|
156
246
|
else
|
157
|
-
Functions::Math::AddConstant.new(other).(self)
|
247
|
+
Functions::Math::AddConstant.new(other).apply([self])[0]
|
158
248
|
end
|
159
249
|
end
|
160
250
|
|
161
|
-
def -(other)
|
251
|
+
def -(other)
|
162
252
|
if other.instance_of?(Chainer::Variable)
|
163
|
-
Functions::Math::Sub.new.(
|
253
|
+
Functions::Math::Sub.new.apply([self, other])[0]
|
164
254
|
else
|
165
|
-
Functions::Math::AddConstant.new(-other).(self)
|
255
|
+
Functions::Math::AddConstant.new(-other).apply([self])[0]
|
166
256
|
end
|
167
257
|
end
|
168
258
|
|
169
259
|
def *(other)
|
170
260
|
if other.instance_of?(Chainer::Variable)
|
171
|
-
Functions::Math::Mul.new.(
|
261
|
+
Functions::Math::Mul.new.apply([self, other])[0]
|
172
262
|
else
|
173
|
-
Functions::Math::MulConstant.new(other).(self)
|
263
|
+
Functions::Math::MulConstant.new(other).apply([self])[0]
|
174
264
|
end
|
175
265
|
end
|
176
266
|
|
177
267
|
def /(other)
|
178
268
|
if other.instance_of?(Chainer::Variable)
|
179
|
-
Functions::Math::Div.new.(
|
269
|
+
Functions::Math::Div.new.apply([self, other])[0]
|
180
270
|
else
|
181
|
-
Functions::Math::MulConstant.new(1 / other).(self)
|
271
|
+
Functions::Math::MulConstant.new(1 / other).apply([self])[0]
|
182
272
|
end
|
183
273
|
end
|
184
274
|
|
185
|
-
def **(other)
|
275
|
+
def **(other)
|
186
276
|
if other.instance_of?(Chainer::Variable)
|
187
|
-
Functions::Math::PowVarVar.new.(
|
277
|
+
Functions::Math::PowVarVar.new.apply([self, other])[0]
|
188
278
|
else
|
189
|
-
Functions::Math::PowVarConst.new(other).(self)
|
279
|
+
Functions::Math::PowVarConst.new(other).apply([self])[0]
|
190
280
|
end
|
191
281
|
end
|
192
282
|
|
@@ -196,7 +286,7 @@ module Chainer
|
|
196
286
|
|
197
287
|
# when left side is Numeric value and right side is Chainer::Value, call this method.
|
198
288
|
def coerce(other)
|
199
|
-
other = self.data.class
|
289
|
+
other = self.data.class.new.fill(other) if other.kind_of?(Numeric)
|
200
290
|
[Chainer::Variable.new(other, requires_grad: false), self]
|
201
291
|
end
|
202
292
|
end
|