red-chainer 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -22,7 +22,7 @@ module Chainer
22
22
  end
23
23
 
24
24
  actual.each_with_index{|actual_val, *i|
25
- if (expect[*i].to_f - actual_val.to_f).abs > atol + rtol * expect[*i].abs
25
+ if (expect[*i].to_f - actual_val.to_f).abs > atol + rtol * expect[*i].to_f.abs
26
26
  raise "assert_allclose Error\n expect: #{expect.inspect}\n actual : #{actual.inspect}\n (#{i})=> #{(expect - actual).abs.max()} > #{atol + rtol * expect[*i].abs}"
27
27
  end
28
28
  }
@@ -34,7 +34,7 @@ module Chainer
34
34
  # If this is just a link object, the link is registered by the name 'main'.
35
35
  # @param [Dataset::Convert] converter Converter function to build input arrays.
36
36
  # `Chainer::Dataset.concat_examples` is used by default.
37
- # @param [integer] device Device to which the training data is sent. Negative value indicates the host memory (CPU).
37
+ # @param [Chainer::Device] device Device to which the training data is sent.
38
38
  # @param [Function] eval_hook Function to prepare for each evaluation process.
39
39
  # It is called at the beginning of the evaluation.
40
40
  # The evaluator extension object is passed at each call.
@@ -97,8 +97,7 @@ module Chainer
97
97
  # @return dict Result dictionary. This dictionary is further reported via `Chainer.save_report` without specifying any observer.
98
98
  def evaluate
99
99
  iterator = @iterators[:main]
100
- target = @targets[:main]
101
- eval_func = @eval_func || target
100
+ eval_func = @eval_func || @targets[:main]
102
101
 
103
102
  @eval_hook.(self) if @eval_hook
104
103
 
@@ -57,7 +57,7 @@ module Chainer
57
57
  def serialize(serializer)
58
58
  @t = serializer.('t', @t)
59
59
  @last_value = serializer.('last_value', @last_value)
60
- if @last_value.is_a?(Numo::NArray)
60
+ if Chainer.array?(@last_value)
61
61
  @last_value = @last_value[0]
62
62
  end
63
63
  end
@@ -1,4 +1,5 @@
1
1
  require 'erb'
2
+ require 'time'
2
3
 
3
4
  module Chainer
4
5
  module Training
@@ -1,12 +1,11 @@
1
1
  module Chainer
2
2
  module Training
3
3
  class ExtensionEntry
4
- attr_accessor :extension, :trigger, :invoke_before_training, :priority
4
+ attr_accessor :extension, :trigger, :priority
5
5
 
6
- def initialize(extension, priority, trigger, invoke_before_training)
6
+ def initialize(extension, priority, trigger)
7
7
  @extension = extension
8
8
  @trigger = trigger
9
- @invoke_before_training = invoke_before_training
10
9
  @priority = priority
11
10
  end
12
11
  end
@@ -47,7 +46,7 @@ module Chainer
47
46
  Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f
48
47
  end
49
48
 
50
- def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
49
+ def extend(extension, name: nil, trigger: nil, priority: nil)
51
50
  if name.nil?
52
51
  name = if extension.name
53
52
  extension.name
@@ -69,10 +68,6 @@ module Chainer
69
68
  priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER
70
69
  end
71
70
 
72
- if invoke_before_training.nil?
73
- invoke_before_training = extension.methods.include?(:invoke_before_training) ? extension.invoke_before_training : false
74
- end
75
-
76
71
  modified_name = name
77
72
  ordinal = 0
78
73
 
@@ -82,7 +77,7 @@ module Chainer
82
77
  end
83
78
 
84
79
  extension.name = modified_name
85
- @extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger, invoke_before_training)
80
+ @extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger)
86
81
  end
87
82
 
88
83
  def get_extension(name)
@@ -13,6 +13,11 @@ module Chainer
13
13
  @previous_epoch_detail = 0.0
14
14
  end
15
15
 
16
+ # Decides whether the extension should be called on this iteration.
17
+ #
18
+ # @param [Chainer::Trainer] trainer Trainer object that this trigger is associated with.
19
+ # The updater associated with this trainer is used to determine if the trigger should fire.
20
+ # @return [boolean] True if the corresponding extension should be invoked in this iteration.
16
21
  def call(trainer)
17
22
  updater = trainer.updater
18
23
  if @unit == 'epoch'
@@ -30,7 +35,7 @@ module Chainer
30
35
  iteration = updater.iteration
31
36
  previous_iteration = @previous_iteration
32
37
  if previous_iteration < 0
33
- previous_iteration = iteration - 1
38
+ previous_iteration = iteration - 1
34
39
  end
35
40
  fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
36
41
  end
@@ -38,7 +43,7 @@ module Chainer
38
43
  # save current values
39
44
  @previous_iteration = updater.iteration
40
45
  @previous_epoch_detail = updater.epoch_detail
41
-
46
+
42
47
  fire
43
48
  end
44
49
 
@@ -4,7 +4,8 @@ module Chainer
4
4
  def self.force_array(x, dtype=nil)
5
5
  if x.is_a? Integer or x.is_a? Float
6
6
  if dtype.nil?
7
- Numo::NArray.cast(x)
7
+ xm = Chainer::Device.default.xm
8
+ xm::NArray.cast(x)
8
9
  else
9
10
  dtype.cast(x.dup)
10
11
  end
@@ -16,6 +17,84 @@ module Chainer
16
17
  end
17
18
  end
18
19
  end
20
+
21
+ def self.take(x, indices, axis: nil)
22
+ if axis
23
+ indices = make_indecies_with_axis(x.shape, indices, axis)
24
+ end
25
+ x[indices]
26
+ end
27
+
28
+ def self.make_indecies_with_axis(shape, indices, axis, values = [])
29
+ target_axis = values.size
30
+ if shape.size == values.size
31
+ values.zip(shape.drop(1) + [1]).reduce(0) do |sum, (x, ndim)|
32
+ (sum + x) * ndim
33
+ end
34
+ else
35
+ enum = (axis == target_axis) ? indices : (0...shape[target_axis])
36
+ if enum.is_a?(Integer)
37
+ make_indecies_with_axis(shape, indices, axis, values + [indices])
38
+ else
39
+ enum.map do |x|
40
+ make_indecies_with_axis(shape, indices, axis, values + [x])
41
+ end
42
+ end
43
+ end
44
+ end
45
+
46
+ def self.rollaxis(y, axis, start: 0)
47
+ n = y.ndim
48
+ # normalize axis
49
+ axis = axis < 0 ? n + axis : axis
50
+ if axis >= n
51
+ raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{n}"
52
+ end
53
+
54
+ if start < 0
55
+ start += n
56
+ end
57
+
58
+ unless 0 <= start && start < n + 1
59
+ raise ArgumentError, "start arg requires #{-n} <= start < #{n}, but #{start} was passed in"
60
+ end
61
+
62
+ if axis < start
63
+ start -= 1
64
+ end
65
+
66
+ if axis == start
67
+ return y
68
+ end
69
+
70
+ axes = (0...n).to_a
71
+ axes.delete_at(axis)
72
+ axes.insert(start <= axes.size ? start : -1, axis)
73
+ y.transpose(*axes)
74
+ end
75
+
76
+ def self.broadcast_to(x, shape)
77
+ if x.shape.size > shape.size
78
+ raise TypeError, "Shape of data mismatch\n x.shape.size(#{x.shape.size}) > shape.size(#{shape.size})"
79
+ end
80
+
81
+ tile_shape = []
82
+ if x.shape.size > 0
83
+ shape[-x.shape.size..-1].each_with_index do |s, i|
84
+ if x.shape[i] == 1
85
+ tile_shape << s
86
+ elsif x.shape[i] == s
87
+ tile_shape << 1
88
+ else
89
+ raise TypeError, "Shape of data mismatch\n#{x.shape} != #{shape}"
90
+ end
91
+ end
92
+ else
93
+ tile_shape = shape
94
+ end
95
+
96
+ x.tile(*shape[0...-x.shape.size], *tile_shape)
97
+ end
19
98
  end
20
99
  end
21
100
  end
@@ -10,7 +10,15 @@ module Chainer
10
10
  end
11
11
  end
12
12
 
13
- def self.im2col_cpu(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
13
+ def self.get_deconv_outsize(size, k, s, p, cover_all: false)
14
+ if cover_all
15
+ s * (size - 1) + k -s + 1 - 2 * p
16
+ else
17
+ s * (size - 1) + k - 2 * p
18
+ end
19
+ end
20
+
21
+ def self.im2col(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
14
22
  n, c, h, w = img.shape
15
23
 
16
24
  out_h = self.get_conv_outsize(h, kh, sy, ph, cover_all: cover_all, d: dy)
@@ -40,7 +48,7 @@ module Chainer
40
48
  col
41
49
  end
42
50
 
43
- def self.col2im_cpu(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
51
+ def self.col2im(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
44
52
  n, c, kh, kw, out_h, out_w = col.shape
45
53
  img = col.class.zeros(n, c, h + 2 * ph + sy - 1, w + 2 * pw + sx - 1)
46
54
  kh.times do |j|
@@ -1,10 +1,10 @@
1
1
  module Chainer
2
2
  module Utils
3
3
  module Initializer
4
- def self.get_fans(shape)
4
+ def self.get_fans(shape, device: Chainer::Device.default)
5
5
  raise 'shape must be of length >= 2: shape={}' if shape.size < 2
6
6
  slice_arr = shape.slice(2, shape.size)
7
- receptive_field_size = slice_arr.empty? ? 1 : Numo::Int32[slice_arr].prod
7
+ receptive_field_size = slice_arr.empty? ? 1 : device.xm::Int32[slice_arr].prod
8
8
  fan_in = shape[1] * receptive_field_size
9
9
  fan_out = shape[0] * receptive_field_size
10
10
  [fan_in, fan_out]
@@ -2,25 +2,47 @@ module Chainer
2
2
  class Variable
3
3
  attr_accessor :data, :grad, :requires_grad, :node
4
4
 
5
+ # Converts an array or a variable into +Chainer::Variable+.
6
+ # This is a convenient function to get a +Chainer::Variable+ object
7
+ # transparently from a raw array or a variable.
8
+ # Note: that this function should only be used for type consistency
9
+ # (i.e. to enforce the return value of an API having type +Chainer::Variable+).
10
+ # The +Chianer::Variable.requires_grad+ flag is kept as is; if +obj+ is a raw array,
11
+ # the newly created variable has +requires_grad = false+.
12
+ # In order to make a variable w.r.t. which you want to compute the gradient,
13
+ # you should use $Chainer::Variable$ directly.
14
+ #
15
+ # @param [Numo::NArray or Chainer::Variable] obj An array or a variable that you want to convert to $Chainer::Variable$.
16
+ # @return [Chainer::Variable] A variable converted from +obj+. If +obj+ is a raw array,
17
+ # this is a new +Chianer::Variable+ object that wraps the array. If +obj+ is already a +Chainer::Variable+ object, this function returns +obj+ as is.
18
+ def self.as_variable(obj)
19
+ return obj if obj.kind_of?(Chainer::Variable)
20
+ # TODO if obj is_backprop_required is true, set requires_grad = true
21
+ self.new(obj, requires_grad: false)
22
+ end
23
+
5
24
  def initialize(data=nil, name: nil, grad: nil, requires_grad: true)
6
- unless data.nil? || data.is_a?(Numo::NArray)
7
- raise TypeError, "Numo::NArray are expected."
25
+ unless data.nil? || Chainer.array?(data)
26
+ raise TypeError, "Numo::NArray or Cumo::NArray are expected."
8
27
  end
9
28
 
10
29
  @data = [data]
11
30
  @grad = grad
12
31
  @requires_grad = requires_grad
13
- @node = VariableNode.new(variable: self, name: name, grad: grad)
32
+ @node = VariableNode.new(variable: self, name: name)
33
+ @grad_var = grad.nil? ? nil : Chainer::Variable.new(grad)
14
34
  end
15
35
 
16
36
  def data
17
37
  return @data[0]
18
38
  end
39
+ alias_method :array, :data
19
40
 
20
41
  def data=(d)
21
42
  @data[0] = d
22
43
  @node.set_data_type(d)
23
44
  end
45
+ alias_method :array=, :data=
24
46
 
25
47
  def name
26
48
  return @node.name
@@ -34,6 +56,7 @@ module Chainer
34
56
  @node.label
35
57
  end
36
58
 
59
+ # deprecated FunctionNode
37
60
  def creator
38
61
  @node.creator
39
62
  end
@@ -42,12 +65,30 @@ module Chainer
42
65
  @node.creator = func
43
66
  end
44
67
 
68
+ def creator_node
69
+ @node.creator_node
70
+ end
71
+
72
+ def creator_node=(func)
73
+ @node.creator_node = func
74
+ end
75
+
45
76
  def grad
46
- @node.grad
77
+ gv = @grad_var
78
+ gv.nil? ? nil : gv.data
47
79
  end
48
80
 
49
81
  def grad=(g)
50
- @node.set_grad_with_check(g, nil, self)
82
+ self.grad_var = g.nil? ? nil : Chainer::Variable.new(g)
83
+ end
84
+
85
+ def grad_var
86
+ @grad_var
87
+ end
88
+
89
+ def grad_var=(g)
90
+ Utils::Variable.check_grad_type(nil, self, g.data) unless g.nil?
91
+ @grad_var = g
51
92
  end
52
93
 
53
94
  def shape
@@ -70,123 +111,172 @@ module Chainer
70
111
  @node.rank
71
112
  end
72
113
 
114
+ def transpose
115
+ Chainer::Functions::Array::Transpose.transpose(self)
116
+ end
117
+
118
+ def reshape(*shape)
119
+ if shape.size == 1 && shape[0].kind_of?(::Aray)
120
+ shape = shape[0]
121
+ end
122
+ Chainer::Functions::Array::Reshape.reshape(self, shape)
123
+ end
124
+
125
+ # Clears the gradient array.
73
126
  def cleargrad
74
- @node.grad = nil
127
+ @grad_var = nil
128
+ end
129
+
130
+ # Notifies the variable that the given node is its creator.
131
+ #
132
+ # @param [Chainer::FunctionNode] fnode node that has this variable as an output.
133
+ def set_creator_node(fnode)
134
+ @node.set_creator_node(fnode)
75
135
  end
76
136
 
77
- def backward(retain_grad: false)
78
- return if self.creator.nil?
137
+ def backward(retain_grad: false, enable_double_backprop: true)
138
+ old_enable_backprop = Chainer.configuration.enable_backprop
139
+ Chainer.configuration.enable_backprop = enable_double_backprop
140
+ _backward_main(retain_grad)
141
+ Chainer.configuration.enable_backprop = old_enable_backprop
142
+ end
143
+
144
+ def _backward_main(retain_grad)
145
+ node.check_old_style_gradient
146
+ return if self.creator_node.nil?
79
147
 
80
- if self.data.size == 1 && self.grad.nil?
148
+ seen_set = Set.new
149
+ grads = {}
150
+ if self.data.size == 1 && self.grad_var.nil?
81
151
  self.grad = self.data.new_ones
82
152
  end
153
+ grads[self.node] = self.grad_var
83
154
 
84
- funcs = [self.creator]
155
+ funcs = [self.creator_node]
156
+ seen_set.add(self.creator_node)
85
157
 
86
- while func = funcs.pop
158
+ while func = funcs.shift
159
+ inputs = func.inputs
160
+ target_input_indexes = inputs.each_with_index.map { |x, i| i if x.requires_grad }.compact
161
+ next if target_input_indexes.empty?
87
162
  outputs = func.outputs.map(&:__getobj__)
88
- in_data = func.inputs.map(&:data)
89
- out_grad = outputs.map { |y| y.nil? ? nil : y.grad }
90
-
91
- func.output_data = outputs.map { |y| y.nil? ? nil : y.data }
92
- gxs = func.backward(in_data, out_grad)
93
163
 
94
- raise unless gxs.size == in_data.size
95
-
96
- unless func.retain_after_backward
97
- func.output_data = nil
164
+ in_data = inputs.map(&:data)
165
+ out_grad = outputs.map do |y|
166
+ next nil if y.nil?
167
+ next grads[y] unless grads[y].nil?
168
+ y.grad_var
98
169
  end
170
+ out_grad_data = out_grad.map { |g| g.nil? ? g : g.data }
171
+
172
+ # Collect the current input gradients.
173
+ #
174
+ # When the same variable is passed to multiple input slots (e.g. an expression like +f(x, x)+),
175
+ # it makes the gradient accumulation complicated since the back-propagated gradients w.r.t.
176
+ # the first and second argument should be accumulated to the current gradient w.r.t. the same variable.
177
+ # In this case, the current implementation passes the current gradient only to the first occurrence of the variable
178
+ # in the input tuple and passes +nil+ to the rest of the occurrences.
179
+ # For example, when the input variables are +(x, x)+,
180
+ # the input gradient passed to the +backward_accumulate+ method is +(gx, nil)+ where +gx+ is the current gradient of ++x++.
181
+ # See also the docstring of +FunctionNode.backward_accumulate+.
182
+ target_inputs = target_input_indexes.map { |i| inputs[i] }
183
+ in_grad = []
184
+ target_input_indexes.each_with_index do |index_i, i|
185
+ x = inputs[index_i]
186
+ if target_inputs[0...i].include?(x)
187
+ gx = nil
188
+ elsif grads[x]
189
+ gx = grads[x]
190
+ elsif x.creator_node.nil?
191
+ gx = x.grad_var
192
+ else
193
+ gx = nil
194
+ end
195
+ in_grad << gx
196
+ end
197
+
198
+ gxs = func.backward_accumulate(target_input_indexes, out_grad, in_grad)
199
+ raise "Unmatched matries size: gxs.size(#{gxs.size}) != in_grad.size(#{in_grad.size})" unless gxs.size == in_grad.size
99
200
 
100
201
  unless retain_grad
101
202
  outputs.each do |y|
102
203
  unless y.nil? || y == @node
103
- y.grad = nil
204
+ grads[y] = nil
205
+ y_var = y.get_variable
206
+ y_var.grad_var = nil unless y_var.nil?
104
207
  end
105
208
  end
106
209
  end
107
210
 
108
- seen_vars = []
109
- need_copy = []
110
-
111
- func.inputs.zip(gxs).each do |x, gx|
211
+ gxs.each_with_index do |gx, i|
112
212
  next if gx.nil?
213
+ x = target_inputs[i]
113
214
  next unless x.requires_grad
114
215
 
115
- Utils::Variable.check_grad_type(func, x, gx)
116
-
117
- id_x = x.object_id
118
- if x.creator.nil? # leaf
119
- if x.grad.nil?
120
- x.grad = gx
121
- need_copy << id_x
122
- else
123
- if need_copy.include?(id_x)
124
- x.grad = Utils::Array.force_array(x.grad + gx)
125
- need_copy.delete(id_x)
126
- else
127
- x.grad += gx
128
- end
129
- end
130
- else # not leaf
131
- funcs << x.creator
132
- if seen_vars.include?(id_x)
133
- if need_copy.include?(id_x)
134
- x.grad = Utils::Array.force_array(gx + x.grad)
135
- need_copy.delete(id_x)
136
- else
137
- x.grad += gx
138
- end
139
- else
140
- x.grad = gx
141
- seen_vars << id_x
142
- need_copy << id_x
143
- end
216
+ Utils::Variable.check_grad_type(func, x, gx.data)
217
+
218
+ if target_inputs[0...i].include?(x)
219
+ cur_gx = grads[x]
220
+ grads[x] = cur_gx.nil? ? gx : gx + cur_gx
221
+ else
222
+ grads[x] = gx
223
+ end
224
+
225
+ x_var = x.get_variable
226
+ x_var.grad_var = grads[x] if x_var
227
+
228
+ if x.creator_node && !seen_set.include?(x.creator_node)
229
+ funcs << x.creator_node
230
+ seen_set.add(x.creator_node)
144
231
  end
145
232
  end
146
- end
233
+
234
+ funcs.sort_by! { |f| -f.rank }
235
+
236
+ end
147
237
  end
148
238
 
149
239
  def -@
150
- Functions::Math::Neg.new.(self)
240
+ Functions::Math::Neg.new.apply([self]).first
151
241
  end
152
242
 
153
243
  def +(other)
154
244
  if other.instance_of?(Chainer::Variable)
155
- Functions::Math::Add.new.(*[self, other])
245
+ Functions::Math::Add.new.apply([self, other])[0]
156
246
  else
157
- Functions::Math::AddConstant.new(other).(self)
247
+ Functions::Math::AddConstant.new(other).apply([self])[0]
158
248
  end
159
249
  end
160
250
 
161
- def -(other)
251
+ def -(other)
162
252
  if other.instance_of?(Chainer::Variable)
163
- Functions::Math::Sub.new.(*[self, other])
253
+ Functions::Math::Sub.new.apply([self, other])[0]
164
254
  else
165
- Functions::Math::AddConstant.new(-other).(self)
255
+ Functions::Math::AddConstant.new(-other).apply([self])[0]
166
256
  end
167
257
  end
168
258
 
169
259
  def *(other)
170
260
  if other.instance_of?(Chainer::Variable)
171
- Functions::Math::Mul.new.(*[self, other])
261
+ Functions::Math::Mul.new.apply([self, other])[0]
172
262
  else
173
- Functions::Math::MulConstant.new(other).(self)
263
+ Functions::Math::MulConstant.new(other).apply([self])[0]
174
264
  end
175
265
  end
176
266
 
177
267
  def /(other)
178
268
  if other.instance_of?(Chainer::Variable)
179
- Functions::Math::Div.new.(*[self, other])
269
+ Functions::Math::Div.new.apply([self, other])[0]
180
270
  else
181
- Functions::Math::MulConstant.new(1 / other).(self)
271
+ Functions::Math::MulConstant.new(1 / other).apply([self])[0]
182
272
  end
183
273
  end
184
274
 
185
- def **(other)
275
+ def **(other)
186
276
  if other.instance_of?(Chainer::Variable)
187
- Functions::Math::PowVarVar.new.(*[self, other])
277
+ Functions::Math::PowVarVar.new.apply([self, other])[0]
188
278
  else
189
- Functions::Math::PowVarConst.new(other).(self)
279
+ Functions::Math::PowVarConst.new(other).apply([self])[0]
190
280
  end
191
281
  end
192
282
 
@@ -196,7 +286,7 @@ module Chainer
196
286
 
197
287
  # when left side is Numeric value and right side is Chainer::Value, call this method.
198
288
  def coerce(other)
199
- other = self.data.class[*other] if other.kind_of?(Numeric)
289
+ other = self.data.class.new.fill(other) if other.kind_of?(Numeric)
200
290
  [Chainer::Variable.new(other, requires_grad: false), self]
201
291
  end
202
292
  end