red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -22,7 +22,7 @@ module Chainer
|
|
22
22
|
end
|
23
23
|
|
24
24
|
actual.each_with_index{|actual_val, *i|
|
25
|
-
if (expect[*i].to_f - actual_val.to_f).abs > atol + rtol * expect[*i].abs
|
25
|
+
if (expect[*i].to_f - actual_val.to_f).abs > atol + rtol * expect[*i].to_f.abs
|
26
26
|
raise "assert_allclose Error\n expect: #{expect.inspect}\n actual : #{actual.inspect}\n (#{i})=> #{(expect - actual).abs.max()} > #{atol + rtol * expect[*i].abs}"
|
27
27
|
end
|
28
28
|
}
|
@@ -34,7 +34,7 @@ module Chainer
|
|
34
34
|
# If this is just a link object, the link is registered by the name 'main'.
|
35
35
|
# @param [Dataset::Convert] converter Converter function to build input arrays.
|
36
36
|
# `Chainer::Dataset.concat_examples` is used by default.
|
37
|
-
# @param [
|
37
|
+
# @param [Chainer::Device] device Device to which the training data is sent.
|
38
38
|
# @param [Function] eval_hook Function to prepare for each evaluation process.
|
39
39
|
# It is called at the beginning of the evaluation.
|
40
40
|
# The evaluator extension object is passed at each call.
|
@@ -97,8 +97,7 @@ module Chainer
|
|
97
97
|
# @return dict Result dictionary. This dictionary is further reported via `Chainer.save_report` without specifying any observer.
|
98
98
|
def evaluate
|
99
99
|
iterator = @iterators[:main]
|
100
|
-
|
101
|
-
eval_func = @eval_func || target
|
100
|
+
eval_func = @eval_func || @targets[:main]
|
102
101
|
|
103
102
|
@eval_hook.(self) if @eval_hook
|
104
103
|
|
@@ -1,12 +1,11 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Training
|
3
3
|
class ExtensionEntry
|
4
|
-
attr_accessor :extension, :trigger, :
|
4
|
+
attr_accessor :extension, :trigger, :priority
|
5
5
|
|
6
|
-
def initialize(extension, priority, trigger
|
6
|
+
def initialize(extension, priority, trigger)
|
7
7
|
@extension = extension
|
8
8
|
@trigger = trigger
|
9
|
-
@invoke_before_training = invoke_before_training
|
10
9
|
@priority = priority
|
11
10
|
end
|
12
11
|
end
|
@@ -47,7 +46,7 @@ module Chainer
|
|
47
46
|
Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f
|
48
47
|
end
|
49
48
|
|
50
|
-
def extend(extension, name: nil, trigger: nil, priority: nil
|
49
|
+
def extend(extension, name: nil, trigger: nil, priority: nil)
|
51
50
|
if name.nil?
|
52
51
|
name = if extension.name
|
53
52
|
extension.name
|
@@ -69,10 +68,6 @@ module Chainer
|
|
69
68
|
priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER
|
70
69
|
end
|
71
70
|
|
72
|
-
if invoke_before_training.nil?
|
73
|
-
invoke_before_training = extension.methods.include?(:invoke_before_training) ? extension.invoke_before_training : false
|
74
|
-
end
|
75
|
-
|
76
71
|
modified_name = name
|
77
72
|
ordinal = 0
|
78
73
|
|
@@ -82,7 +77,7 @@ module Chainer
|
|
82
77
|
end
|
83
78
|
|
84
79
|
extension.name = modified_name
|
85
|
-
@extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger
|
80
|
+
@extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger)
|
86
81
|
end
|
87
82
|
|
88
83
|
def get_extension(name)
|
@@ -13,6 +13,11 @@ module Chainer
|
|
13
13
|
@previous_epoch_detail = 0.0
|
14
14
|
end
|
15
15
|
|
16
|
+
# Decides whether the extension should be called on this iteration.
|
17
|
+
#
|
18
|
+
# @param [Chainer::Trainer] trainer Trainer object that this trigger is associated with.
|
19
|
+
# The updater associated with this trainer is used to determine if the trigger should fire.
|
20
|
+
# @return [boolean] True if the corresponding extension should be invoked in this iteration.
|
16
21
|
def call(trainer)
|
17
22
|
updater = trainer.updater
|
18
23
|
if @unit == 'epoch'
|
@@ -30,7 +35,7 @@ module Chainer
|
|
30
35
|
iteration = updater.iteration
|
31
36
|
previous_iteration = @previous_iteration
|
32
37
|
if previous_iteration < 0
|
33
|
-
|
38
|
+
previous_iteration = iteration - 1
|
34
39
|
end
|
35
40
|
fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
|
36
41
|
end
|
@@ -38,7 +43,7 @@ module Chainer
|
|
38
43
|
# save current values
|
39
44
|
@previous_iteration = updater.iteration
|
40
45
|
@previous_epoch_detail = updater.epoch_detail
|
41
|
-
|
46
|
+
|
42
47
|
fire
|
43
48
|
end
|
44
49
|
|
data/lib/chainer/utils/array.rb
CHANGED
@@ -4,7 +4,8 @@ module Chainer
|
|
4
4
|
def self.force_array(x, dtype=nil)
|
5
5
|
if x.is_a? Integer or x.is_a? Float
|
6
6
|
if dtype.nil?
|
7
|
-
|
7
|
+
xm = Chainer::Device.default.xm
|
8
|
+
xm::NArray.cast(x)
|
8
9
|
else
|
9
10
|
dtype.cast(x.dup)
|
10
11
|
end
|
@@ -16,6 +17,84 @@ module Chainer
|
|
16
17
|
end
|
17
18
|
end
|
18
19
|
end
|
20
|
+
|
21
|
+
def self.take(x, indices, axis: nil)
|
22
|
+
if axis
|
23
|
+
indices = make_indecies_with_axis(x.shape, indices, axis)
|
24
|
+
end
|
25
|
+
x[indices]
|
26
|
+
end
|
27
|
+
|
28
|
+
def self.make_indecies_with_axis(shape, indices, axis, values = [])
|
29
|
+
target_axis = values.size
|
30
|
+
if shape.size == values.size
|
31
|
+
values.zip(shape.drop(1) + [1]).reduce(0) do |sum, (x, ndim)|
|
32
|
+
(sum + x) * ndim
|
33
|
+
end
|
34
|
+
else
|
35
|
+
enum = (axis == target_axis) ? indices : (0...shape[target_axis])
|
36
|
+
if enum.is_a?(Integer)
|
37
|
+
make_indecies_with_axis(shape, indices, axis, values + [indices])
|
38
|
+
else
|
39
|
+
enum.map do |x|
|
40
|
+
make_indecies_with_axis(shape, indices, axis, values + [x])
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
def self.rollaxis(y, axis, start: 0)
|
47
|
+
n = y.ndim
|
48
|
+
# normalize axis
|
49
|
+
axis = axis < 0 ? n + axis : axis
|
50
|
+
if axis >= n
|
51
|
+
raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{n}"
|
52
|
+
end
|
53
|
+
|
54
|
+
if start < 0
|
55
|
+
start += n
|
56
|
+
end
|
57
|
+
|
58
|
+
unless 0 <= start && start < n + 1
|
59
|
+
raise ArgumentError, "start arg requires #{-n} <= start < #{n}, but #{start} was passed in"
|
60
|
+
end
|
61
|
+
|
62
|
+
if axis < start
|
63
|
+
start -= 1
|
64
|
+
end
|
65
|
+
|
66
|
+
if axis == start
|
67
|
+
return y
|
68
|
+
end
|
69
|
+
|
70
|
+
axes = (0...n).to_a
|
71
|
+
axes.delete_at(axis)
|
72
|
+
axes.insert(start <= axes.size ? start : -1, axis)
|
73
|
+
y.transpose(*axes)
|
74
|
+
end
|
75
|
+
|
76
|
+
def self.broadcast_to(x, shape)
|
77
|
+
if x.shape.size > shape.size
|
78
|
+
raise TypeError, "Shape of data mismatch\n x.shape.size(#{x.shape.size}) > shape.size(#{shape.size})"
|
79
|
+
end
|
80
|
+
|
81
|
+
tile_shape = []
|
82
|
+
if x.shape.size > 0
|
83
|
+
shape[-x.shape.size..-1].each_with_index do |s, i|
|
84
|
+
if x.shape[i] == 1
|
85
|
+
tile_shape << s
|
86
|
+
elsif x.shape[i] == s
|
87
|
+
tile_shape << 1
|
88
|
+
else
|
89
|
+
raise TypeError, "Shape of data mismatch\n#{x.shape} != #{shape}"
|
90
|
+
end
|
91
|
+
end
|
92
|
+
else
|
93
|
+
tile_shape = shape
|
94
|
+
end
|
95
|
+
|
96
|
+
x.tile(*shape[0...-x.shape.size], *tile_shape)
|
97
|
+
end
|
19
98
|
end
|
20
99
|
end
|
21
100
|
end
|
data/lib/chainer/utils/conv.rb
CHANGED
@@ -10,7 +10,15 @@ module Chainer
|
|
10
10
|
end
|
11
11
|
end
|
12
12
|
|
13
|
-
def self.
|
13
|
+
def self.get_deconv_outsize(size, k, s, p, cover_all: false)
|
14
|
+
if cover_all
|
15
|
+
s * (size - 1) + k -s + 1 - 2 * p
|
16
|
+
else
|
17
|
+
s * (size - 1) + k - 2 * p
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def self.im2col(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
|
14
22
|
n, c, h, w = img.shape
|
15
23
|
|
16
24
|
out_h = self.get_conv_outsize(h, kh, sy, ph, cover_all: cover_all, d: dy)
|
@@ -40,7 +48,7 @@ module Chainer
|
|
40
48
|
col
|
41
49
|
end
|
42
50
|
|
43
|
-
def self.
|
51
|
+
def self.col2im(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
|
44
52
|
n, c, kh, kw, out_h, out_w = col.shape
|
45
53
|
img = col.class.zeros(n, c, h + 2 * ph + sy - 1, w + 2 * pw + sx - 1)
|
46
54
|
kh.times do |j|
|
@@ -1,10 +1,10 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Utils
|
3
3
|
module Initializer
|
4
|
-
def self.get_fans(shape)
|
4
|
+
def self.get_fans(shape, device: Chainer::Device.default)
|
5
5
|
raise 'shape must be of length >= 2: shape={}' if shape.size < 2
|
6
6
|
slice_arr = shape.slice(2, shape.size)
|
7
|
-
receptive_field_size = slice_arr.empty? ? 1 :
|
7
|
+
receptive_field_size = slice_arr.empty? ? 1 : device.xm::Int32[slice_arr].prod
|
8
8
|
fan_in = shape[1] * receptive_field_size
|
9
9
|
fan_out = shape[0] * receptive_field_size
|
10
10
|
[fan_in, fan_out]
|
data/lib/chainer/variable.rb
CHANGED
@@ -2,25 +2,47 @@ module Chainer
|
|
2
2
|
class Variable
|
3
3
|
attr_accessor :data, :grad, :requires_grad, :node
|
4
4
|
|
5
|
+
# Converts an array or a variable into +Chainer::Variable+.
|
6
|
+
# This is a convenient function to get a +Chainer::Variable+ object
|
7
|
+
# transparently from a raw array or a variable.
|
8
|
+
# Note: that this function should only be used for type consistency
|
9
|
+
# (i.e. to enforce the return value of an API having type +Chainer::Variable+).
|
10
|
+
# The +Chianer::Variable.requires_grad+ flag is kept as is; if +obj+ is a raw array,
|
11
|
+
# the newly created variable has +requires_grad = false+.
|
12
|
+
# In order to make a variable w.r.t. which you want to compute the gradient,
|
13
|
+
# you should use $Chainer::Variable$ directly.
|
14
|
+
#
|
15
|
+
# @param [Numo::NArray or Chainer::Variable] obj An array or a variable that you want to convert to $Chainer::Variable$.
|
16
|
+
# @return [Chainer::Variable] A variable converted from +obj+. If +obj+ is a raw array,
|
17
|
+
# this is a new +Chianer::Variable+ object that wraps the array. If +obj+ is already a +Chainer::Variable+ object, this function returns +obj+ as is.
|
18
|
+
def self.as_variable(obj)
|
19
|
+
return obj if obj.kind_of?(Chainer::Variable)
|
20
|
+
# TODO if obj is_backprop_required is true, set requires_grad = true
|
21
|
+
self.new(obj, requires_grad: false)
|
22
|
+
end
|
23
|
+
|
5
24
|
def initialize(data=nil, name: nil, grad: nil, requires_grad: true)
|
6
|
-
unless data.nil? ||
|
7
|
-
raise TypeError, "Numo::NArray are expected."
|
25
|
+
unless data.nil? || Chainer.array?(data)
|
26
|
+
raise TypeError, "Numo::NArray or Cumo::NArray are expected."
|
8
27
|
end
|
9
28
|
|
10
29
|
@data = [data]
|
11
30
|
@grad = grad
|
12
31
|
@requires_grad = requires_grad
|
13
|
-
@node = VariableNode.new(variable: self, name: name
|
32
|
+
@node = VariableNode.new(variable: self, name: name)
|
33
|
+
@grad_var = grad.nil? ? nil : Chainer::Variable.new(grad)
|
14
34
|
end
|
15
35
|
|
16
36
|
def data
|
17
37
|
return @data[0]
|
18
38
|
end
|
39
|
+
alias_method :array, :data
|
19
40
|
|
20
41
|
def data=(d)
|
21
42
|
@data[0] = d
|
22
43
|
@node.set_data_type(d)
|
23
44
|
end
|
45
|
+
alias_method :array=, :data=
|
24
46
|
|
25
47
|
def name
|
26
48
|
return @node.name
|
@@ -34,6 +56,7 @@ module Chainer
|
|
34
56
|
@node.label
|
35
57
|
end
|
36
58
|
|
59
|
+
# deprecated FunctionNode
|
37
60
|
def creator
|
38
61
|
@node.creator
|
39
62
|
end
|
@@ -42,12 +65,30 @@ module Chainer
|
|
42
65
|
@node.creator = func
|
43
66
|
end
|
44
67
|
|
68
|
+
def creator_node
|
69
|
+
@node.creator_node
|
70
|
+
end
|
71
|
+
|
72
|
+
def creator_node=(func)
|
73
|
+
@node.creator_node = func
|
74
|
+
end
|
75
|
+
|
45
76
|
def grad
|
46
|
-
@
|
77
|
+
gv = @grad_var
|
78
|
+
gv.nil? ? nil : gv.data
|
47
79
|
end
|
48
80
|
|
49
81
|
def grad=(g)
|
50
|
-
|
82
|
+
self.grad_var = g.nil? ? nil : Chainer::Variable.new(g)
|
83
|
+
end
|
84
|
+
|
85
|
+
def grad_var
|
86
|
+
@grad_var
|
87
|
+
end
|
88
|
+
|
89
|
+
def grad_var=(g)
|
90
|
+
Utils::Variable.check_grad_type(nil, self, g.data) unless g.nil?
|
91
|
+
@grad_var = g
|
51
92
|
end
|
52
93
|
|
53
94
|
def shape
|
@@ -70,123 +111,172 @@ module Chainer
|
|
70
111
|
@node.rank
|
71
112
|
end
|
72
113
|
|
114
|
+
def transpose
|
115
|
+
Chainer::Functions::Array::Transpose.transpose(self)
|
116
|
+
end
|
117
|
+
|
118
|
+
def reshape(*shape)
|
119
|
+
if shape.size == 1 && shape[0].kind_of?(::Aray)
|
120
|
+
shape = shape[0]
|
121
|
+
end
|
122
|
+
Chainer::Functions::Array::Reshape.reshape(self, shape)
|
123
|
+
end
|
124
|
+
|
125
|
+
# Clears the gradient array.
|
73
126
|
def cleargrad
|
74
|
-
@
|
127
|
+
@grad_var = nil
|
128
|
+
end
|
129
|
+
|
130
|
+
# Notifies the variable that the given node is its creator.
|
131
|
+
#
|
132
|
+
# @param [Chainer::FunctionNode] fnode node that has this variable as an output.
|
133
|
+
def set_creator_node(fnode)
|
134
|
+
@node.set_creator_node(fnode)
|
75
135
|
end
|
76
136
|
|
77
|
-
def backward(retain_grad: false)
|
78
|
-
|
137
|
+
def backward(retain_grad: false, enable_double_backprop: true)
|
138
|
+
old_enable_backprop = Chainer.configuration.enable_backprop
|
139
|
+
Chainer.configuration.enable_backprop = enable_double_backprop
|
140
|
+
_backward_main(retain_grad)
|
141
|
+
Chainer.configuration.enable_backprop = old_enable_backprop
|
142
|
+
end
|
143
|
+
|
144
|
+
def _backward_main(retain_grad)
|
145
|
+
node.check_old_style_gradient
|
146
|
+
return if self.creator_node.nil?
|
79
147
|
|
80
|
-
|
148
|
+
seen_set = Set.new
|
149
|
+
grads = {}
|
150
|
+
if self.data.size == 1 && self.grad_var.nil?
|
81
151
|
self.grad = self.data.new_ones
|
82
152
|
end
|
153
|
+
grads[self.node] = self.grad_var
|
83
154
|
|
84
|
-
funcs = [self.
|
155
|
+
funcs = [self.creator_node]
|
156
|
+
seen_set.add(self.creator_node)
|
85
157
|
|
86
|
-
while func = funcs.
|
158
|
+
while func = funcs.shift
|
159
|
+
inputs = func.inputs
|
160
|
+
target_input_indexes = inputs.each_with_index.map { |x, i| i if x.requires_grad }.compact
|
161
|
+
next if target_input_indexes.empty?
|
87
162
|
outputs = func.outputs.map(&:__getobj__)
|
88
|
-
in_data = func.inputs.map(&:data)
|
89
|
-
out_grad = outputs.map { |y| y.nil? ? nil : y.grad }
|
90
|
-
|
91
|
-
func.output_data = outputs.map { |y| y.nil? ? nil : y.data }
|
92
|
-
gxs = func.backward(in_data, out_grad)
|
93
163
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
164
|
+
in_data = inputs.map(&:data)
|
165
|
+
out_grad = outputs.map do |y|
|
166
|
+
next nil if y.nil?
|
167
|
+
next grads[y] unless grads[y].nil?
|
168
|
+
y.grad_var
|
98
169
|
end
|
170
|
+
out_grad_data = out_grad.map { |g| g.nil? ? g : g.data }
|
171
|
+
|
172
|
+
# Collect the current input gradients.
|
173
|
+
#
|
174
|
+
# When the same variable is passed to multiple input slots (e.g. an expression like +f(x, x)+),
|
175
|
+
# it makes the gradient accumulation complicated since the back-propagated gradients w.r.t.
|
176
|
+
# the first and second argument should be accumulated to the current gradient w.r.t. the same variable.
|
177
|
+
# In this case, the current implementation passes the current gradient only to the first occurrence of the variable
|
178
|
+
# in the input tuple and passes +nil+ to the rest of the occurrences.
|
179
|
+
# For example, when the input variables are +(x, x)+,
|
180
|
+
# the input gradient passed to the +backward_accumulate+ method is +(gx, nil)+ where +gx+ is the current gradient of ++x++.
|
181
|
+
# See also the docstring of +FunctionNode.backward_accumulate+.
|
182
|
+
target_inputs = target_input_indexes.map { |i| inputs[i] }
|
183
|
+
in_grad = []
|
184
|
+
target_input_indexes.each_with_index do |index_i, i|
|
185
|
+
x = inputs[index_i]
|
186
|
+
if target_inputs[0...i].include?(x)
|
187
|
+
gx = nil
|
188
|
+
elsif grads[x]
|
189
|
+
gx = grads[x]
|
190
|
+
elsif x.creator_node.nil?
|
191
|
+
gx = x.grad_var
|
192
|
+
else
|
193
|
+
gx = nil
|
194
|
+
end
|
195
|
+
in_grad << gx
|
196
|
+
end
|
197
|
+
|
198
|
+
gxs = func.backward_accumulate(target_input_indexes, out_grad, in_grad)
|
199
|
+
raise "Unmatched matries size: gxs.size(#{gxs.size}) != in_grad.size(#{in_grad.size})" unless gxs.size == in_grad.size
|
99
200
|
|
100
201
|
unless retain_grad
|
101
202
|
outputs.each do |y|
|
102
203
|
unless y.nil? || y == @node
|
103
|
-
y
|
204
|
+
grads[y] = nil
|
205
|
+
y_var = y.get_variable
|
206
|
+
y_var.grad_var = nil unless y_var.nil?
|
104
207
|
end
|
105
208
|
end
|
106
209
|
end
|
107
210
|
|
108
|
-
|
109
|
-
need_copy = []
|
110
|
-
|
111
|
-
func.inputs.zip(gxs).each do |x, gx|
|
211
|
+
gxs.each_with_index do |gx, i|
|
112
212
|
next if gx.nil?
|
213
|
+
x = target_inputs[i]
|
113
214
|
next unless x.requires_grad
|
114
215
|
|
115
|
-
Utils::Variable.check_grad_type(func, x, gx)
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
else # not leaf
|
131
|
-
funcs << x.creator
|
132
|
-
if seen_vars.include?(id_x)
|
133
|
-
if need_copy.include?(id_x)
|
134
|
-
x.grad = Utils::Array.force_array(gx + x.grad)
|
135
|
-
need_copy.delete(id_x)
|
136
|
-
else
|
137
|
-
x.grad += gx
|
138
|
-
end
|
139
|
-
else
|
140
|
-
x.grad = gx
|
141
|
-
seen_vars << id_x
|
142
|
-
need_copy << id_x
|
143
|
-
end
|
216
|
+
Utils::Variable.check_grad_type(func, x, gx.data)
|
217
|
+
|
218
|
+
if target_inputs[0...i].include?(x)
|
219
|
+
cur_gx = grads[x]
|
220
|
+
grads[x] = cur_gx.nil? ? gx : gx + cur_gx
|
221
|
+
else
|
222
|
+
grads[x] = gx
|
223
|
+
end
|
224
|
+
|
225
|
+
x_var = x.get_variable
|
226
|
+
x_var.grad_var = grads[x] if x_var
|
227
|
+
|
228
|
+
if x.creator_node && !seen_set.include?(x.creator_node)
|
229
|
+
funcs << x.creator_node
|
230
|
+
seen_set.add(x.creator_node)
|
144
231
|
end
|
145
232
|
end
|
146
|
-
|
233
|
+
|
234
|
+
funcs.sort_by! { |f| -f.rank }
|
235
|
+
|
236
|
+
end
|
147
237
|
end
|
148
238
|
|
149
239
|
def -@
|
150
|
-
Functions::Math::Neg.new.(self)
|
240
|
+
Functions::Math::Neg.new.apply([self]).first
|
151
241
|
end
|
152
242
|
|
153
243
|
def +(other)
|
154
244
|
if other.instance_of?(Chainer::Variable)
|
155
|
-
Functions::Math::Add.new.(
|
245
|
+
Functions::Math::Add.new.apply([self, other])[0]
|
156
246
|
else
|
157
|
-
Functions::Math::AddConstant.new(other).(self)
|
247
|
+
Functions::Math::AddConstant.new(other).apply([self])[0]
|
158
248
|
end
|
159
249
|
end
|
160
250
|
|
161
|
-
def -(other)
|
251
|
+
def -(other)
|
162
252
|
if other.instance_of?(Chainer::Variable)
|
163
|
-
Functions::Math::Sub.new.(
|
253
|
+
Functions::Math::Sub.new.apply([self, other])[0]
|
164
254
|
else
|
165
|
-
Functions::Math::AddConstant.new(-other).(self)
|
255
|
+
Functions::Math::AddConstant.new(-other).apply([self])[0]
|
166
256
|
end
|
167
257
|
end
|
168
258
|
|
169
259
|
def *(other)
|
170
260
|
if other.instance_of?(Chainer::Variable)
|
171
|
-
Functions::Math::Mul.new.(
|
261
|
+
Functions::Math::Mul.new.apply([self, other])[0]
|
172
262
|
else
|
173
|
-
Functions::Math::MulConstant.new(other).(self)
|
263
|
+
Functions::Math::MulConstant.new(other).apply([self])[0]
|
174
264
|
end
|
175
265
|
end
|
176
266
|
|
177
267
|
def /(other)
|
178
268
|
if other.instance_of?(Chainer::Variable)
|
179
|
-
Functions::Math::Div.new.(
|
269
|
+
Functions::Math::Div.new.apply([self, other])[0]
|
180
270
|
else
|
181
|
-
Functions::Math::MulConstant.new(1 / other).(self)
|
271
|
+
Functions::Math::MulConstant.new(1 / other).apply([self])[0]
|
182
272
|
end
|
183
273
|
end
|
184
274
|
|
185
|
-
def **(other)
|
275
|
+
def **(other)
|
186
276
|
if other.instance_of?(Chainer::Variable)
|
187
|
-
Functions::Math::PowVarVar.new.(
|
277
|
+
Functions::Math::PowVarVar.new.apply([self, other])[0]
|
188
278
|
else
|
189
|
-
Functions::Math::PowVarConst.new(other).(self)
|
279
|
+
Functions::Math::PowVarConst.new(other).apply([self])[0]
|
190
280
|
end
|
191
281
|
end
|
192
282
|
|
@@ -196,7 +286,7 @@ module Chainer
|
|
196
286
|
|
197
287
|
# when left side is Numeric value and right side is Chainer::Value, call this method.
|
198
288
|
def coerce(other)
|
199
|
-
other = self.data.class
|
289
|
+
other = self.data.class.new.fill(other) if other.kind_of?(Numeric)
|
200
290
|
[Chainer::Variable.new(other, requires_grad: false), self]
|
201
291
|
end
|
202
292
|
end
|