red-chainer 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -22,7 +22,7 @@ module Chainer
22
22
  end
23
23
 
24
24
  actual.each_with_index{|actual_val, *i|
25
- if (expect[*i].to_f - actual_val.to_f).abs > atol + rtol * expect[*i].abs
25
+ if (expect[*i].to_f - actual_val.to_f).abs > atol + rtol * expect[*i].to_f.abs
26
26
  raise "assert_allclose Error\n expect: #{expect.inspect}\n actual : #{actual.inspect}\n (#{i})=> #{(expect - actual).abs.max()} > #{atol + rtol * expect[*i].abs}"
27
27
  end
28
28
  }
@@ -34,7 +34,7 @@ module Chainer
34
34
  # If this is just a link object, the link is registered by the name 'main'.
35
35
  # @param [Dataset::Convert] converter Converter function to build input arrays.
36
36
  # `Chainer::Dataset.concat_examples` is used by default.
37
- # @param [integer] device Device to which the training data is sent. Negative value indicates the host memory (CPU).
37
+ # @param [Chainer::Device] device Device to which the training data is sent.
38
38
  # @param [Function] eval_hook Function to prepare for each evaluation process.
39
39
  # It is called at the beginning of the evaluation.
40
40
  # The evaluator extension object is passed at each call.
@@ -97,8 +97,7 @@ module Chainer
97
97
  # @return dict Result dictionary. This dictionary is further reported via `Chainer.save_report` without specifying any observer.
98
98
  def evaluate
99
99
  iterator = @iterators[:main]
100
- target = @targets[:main]
101
- eval_func = @eval_func || target
100
+ eval_func = @eval_func || @targets[:main]
102
101
 
103
102
  @eval_hook.(self) if @eval_hook
104
103
 
@@ -57,7 +57,7 @@ module Chainer
57
57
  def serialize(serializer)
58
58
  @t = serializer.('t', @t)
59
59
  @last_value = serializer.('last_value', @last_value)
60
- if @last_value.is_a?(Numo::NArray)
60
+ if Chainer.array?(@last_value)
61
61
  @last_value = @last_value[0]
62
62
  end
63
63
  end
@@ -1,4 +1,5 @@
1
1
  require 'erb'
2
+ require 'time'
2
3
 
3
4
  module Chainer
4
5
  module Training
@@ -1,12 +1,11 @@
1
1
  module Chainer
2
2
  module Training
3
3
  class ExtensionEntry
4
- attr_accessor :extension, :trigger, :invoke_before_training, :priority
4
+ attr_accessor :extension, :trigger, :priority
5
5
 
6
- def initialize(extension, priority, trigger, invoke_before_training)
6
+ def initialize(extension, priority, trigger)
7
7
  @extension = extension
8
8
  @trigger = trigger
9
- @invoke_before_training = invoke_before_training
10
9
  @priority = priority
11
10
  end
12
11
  end
@@ -47,7 +46,7 @@ module Chainer
47
46
  Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f
48
47
  end
49
48
 
50
- def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
49
+ def extend(extension, name: nil, trigger: nil, priority: nil)
51
50
  if name.nil?
52
51
  name = if extension.name
53
52
  extension.name
@@ -69,10 +68,6 @@ module Chainer
69
68
  priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER
70
69
  end
71
70
 
72
- if invoke_before_training.nil?
73
- invoke_before_training = extension.methods.include?(:invoke_before_training) ? extension.invoke_before_training : false
74
- end
75
-
76
71
  modified_name = name
77
72
  ordinal = 0
78
73
 
@@ -82,7 +77,7 @@ module Chainer
82
77
  end
83
78
 
84
79
  extension.name = modified_name
85
- @extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger, invoke_before_training)
80
+ @extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger)
86
81
  end
87
82
 
88
83
  def get_extension(name)
@@ -13,6 +13,11 @@ module Chainer
13
13
  @previous_epoch_detail = 0.0
14
14
  end
15
15
 
16
+ # Decides whether the extension should be called on this iteration.
17
+ #
18
+ # @param [Chainer::Trainer] trainer Trainer object that this trigger is associated with.
19
+ # The updater associated with this trainer is used to determine if the trigger should fire.
20
+ # @return [boolean] True if the corresponding extension should be invoked in this iteration.
16
21
  def call(trainer)
17
22
  updater = trainer.updater
18
23
  if @unit == 'epoch'
@@ -30,7 +35,7 @@ module Chainer
30
35
  iteration = updater.iteration
31
36
  previous_iteration = @previous_iteration
32
37
  if previous_iteration < 0
33
- previous_iteration = iteration - 1
38
+ previous_iteration = iteration - 1
34
39
  end
35
40
  fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
36
41
  end
@@ -38,7 +43,7 @@ module Chainer
38
43
  # save current values
39
44
  @previous_iteration = updater.iteration
40
45
  @previous_epoch_detail = updater.epoch_detail
41
-
46
+
42
47
  fire
43
48
  end
44
49
 
@@ -4,7 +4,8 @@ module Chainer
4
4
  def self.force_array(x, dtype=nil)
5
5
  if x.is_a? Integer or x.is_a? Float
6
6
  if dtype.nil?
7
- Numo::NArray.cast(x)
7
+ xm = Chainer::Device.default.xm
8
+ xm::NArray.cast(x)
8
9
  else
9
10
  dtype.cast(x.dup)
10
11
  end
@@ -16,6 +17,84 @@ module Chainer
16
17
  end
17
18
  end
18
19
  end
20
+
21
+ def self.take(x, indices, axis: nil)
22
+ if axis
23
+ indices = make_indecies_with_axis(x.shape, indices, axis)
24
+ end
25
+ x[indices]
26
+ end
27
+
28
+ def self.make_indecies_with_axis(shape, indices, axis, values = [])
29
+ target_axis = values.size
30
+ if shape.size == values.size
31
+ values.zip(shape.drop(1) + [1]).reduce(0) do |sum, (x, ndim)|
32
+ (sum + x) * ndim
33
+ end
34
+ else
35
+ enum = (axis == target_axis) ? indices : (0...shape[target_axis])
36
+ if enum.is_a?(Integer)
37
+ make_indecies_with_axis(shape, indices, axis, values + [indices])
38
+ else
39
+ enum.map do |x|
40
+ make_indecies_with_axis(shape, indices, axis, values + [x])
41
+ end
42
+ end
43
+ end
44
+ end
45
+
46
+ def self.rollaxis(y, axis, start: 0)
47
+ n = y.ndim
48
+ # normalize axis
49
+ axis = axis < 0 ? n + axis : axis
50
+ if axis >= n
51
+ raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{n}"
52
+ end
53
+
54
+ if start < 0
55
+ start += n
56
+ end
57
+
58
+ unless 0 <= start && start < n + 1
59
+ raise ArgumentError, "start arg requires #{-n} <= start < #{n}, but #{start} was passed in"
60
+ end
61
+
62
+ if axis < start
63
+ start -= 1
64
+ end
65
+
66
+ if axis == start
67
+ return y
68
+ end
69
+
70
+ axes = (0...n).to_a
71
+ axes.delete_at(axis)
72
+ axes.insert(start <= axes.size ? start : -1, axis)
73
+ y.transpose(*axes)
74
+ end
75
+
76
+ def self.broadcast_to(x, shape)
77
+ if x.shape.size > shape.size
78
+ raise TypeError, "Shape of data mismatch\n x.shape.size(#{x.shape.size}) > shape.size(#{shape.size})"
79
+ end
80
+
81
+ tile_shape = []
82
+ if x.shape.size > 0
83
+ shape[-x.shape.size..-1].each_with_index do |s, i|
84
+ if x.shape[i] == 1
85
+ tile_shape << s
86
+ elsif x.shape[i] == s
87
+ tile_shape << 1
88
+ else
89
+ raise TypeError, "Shape of data mismatch\n#{x.shape} != #{shape}"
90
+ end
91
+ end
92
+ else
93
+ tile_shape = shape
94
+ end
95
+
96
+ x.tile(*shape[0...-x.shape.size], *tile_shape)
97
+ end
19
98
  end
20
99
  end
21
100
  end
@@ -10,7 +10,15 @@ module Chainer
10
10
  end
11
11
  end
12
12
 
13
- def self.im2col_cpu(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
13
+ def self.get_deconv_outsize(size, k, s, p, cover_all: false)
14
+ if cover_all
15
+ s * (size - 1) + k -s + 1 - 2 * p
16
+ else
17
+ s * (size - 1) + k - 2 * p
18
+ end
19
+ end
20
+
21
+ def self.im2col(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
14
22
  n, c, h, w = img.shape
15
23
 
16
24
  out_h = self.get_conv_outsize(h, kh, sy, ph, cover_all: cover_all, d: dy)
@@ -40,7 +48,7 @@ module Chainer
40
48
  col
41
49
  end
42
50
 
43
- def self.col2im_cpu(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
51
+ def self.col2im(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
44
52
  n, c, kh, kw, out_h, out_w = col.shape
45
53
  img = col.class.zeros(n, c, h + 2 * ph + sy - 1, w + 2 * pw + sx - 1)
46
54
  kh.times do |j|
@@ -1,10 +1,10 @@
1
1
  module Chainer
2
2
  module Utils
3
3
  module Initializer
4
- def self.get_fans(shape)
4
+ def self.get_fans(shape, device: Chainer::Device.default)
5
5
  raise 'shape must be of length >= 2: shape={}' if shape.size < 2
6
6
  slice_arr = shape.slice(2, shape.size)
7
- receptive_field_size = slice_arr.empty? ? 1 : Numo::Int32[slice_arr].prod
7
+ receptive_field_size = slice_arr.empty? ? 1 : device.xm::Int32[slice_arr].prod
8
8
  fan_in = shape[1] * receptive_field_size
9
9
  fan_out = shape[0] * receptive_field_size
10
10
  [fan_in, fan_out]
@@ -2,25 +2,47 @@ module Chainer
2
2
  class Variable
3
3
  attr_accessor :data, :grad, :requires_grad, :node
4
4
 
5
+ # Converts an array or a variable into +Chainer::Variable+.
6
+ # This is a convenient function to get a +Chainer::Variable+ object
7
+ # transparently from a raw array or a variable.
8
+ # Note: that this function should only be used for type consistency
9
+ # (i.e. to enforce the return value of an API having type +Chainer::Variable+).
10
+ # The +Chianer::Variable.requires_grad+ flag is kept as is; if +obj+ is a raw array,
11
+ # the newly created variable has +requires_grad = false+.
12
+ # In order to make a variable w.r.t. which you want to compute the gradient,
13
+ # you should use $Chainer::Variable$ directly.
14
+ #
15
+ # @param [Numo::NArray or Chainer::Variable] obj An array or a variable that you want to convert to $Chainer::Variable$.
16
+ # @return [Chainer::Variable] A variable converted from +obj+. If +obj+ is a raw array,
17
+ # this is a new +Chianer::Variable+ object that wraps the array. If +obj+ is already a +Chainer::Variable+ object, this function returns +obj+ as is.
18
+ def self.as_variable(obj)
19
+ return obj if obj.kind_of?(Chainer::Variable)
20
+ # TODO if obj is_backprop_required is true, set requires_grad = true
21
+ self.new(obj, requires_grad: false)
22
+ end
23
+
5
24
  def initialize(data=nil, name: nil, grad: nil, requires_grad: true)
6
- unless data.nil? || data.is_a?(Numo::NArray)
7
- raise TypeError, "Numo::NArray are expected."
25
+ unless data.nil? || Chainer.array?(data)
26
+ raise TypeError, "Numo::NArray or Cumo::NArray are expected."
8
27
  end
9
28
 
10
29
  @data = [data]
11
30
  @grad = grad
12
31
  @requires_grad = requires_grad
13
- @node = VariableNode.new(variable: self, name: name, grad: grad)
32
+ @node = VariableNode.new(variable: self, name: name)
33
+ @grad_var = grad.nil? ? nil : Chainer::Variable.new(grad)
14
34
  end
15
35
 
16
36
  def data
17
37
  return @data[0]
18
38
  end
39
+ alias_method :array, :data
19
40
 
20
41
  def data=(d)
21
42
  @data[0] = d
22
43
  @node.set_data_type(d)
23
44
  end
45
+ alias_method :array=, :data=
24
46
 
25
47
  def name
26
48
  return @node.name
@@ -34,6 +56,7 @@ module Chainer
34
56
  @node.label
35
57
  end
36
58
 
59
+ # deprecated FunctionNode
37
60
  def creator
38
61
  @node.creator
39
62
  end
@@ -42,12 +65,30 @@ module Chainer
42
65
  @node.creator = func
43
66
  end
44
67
 
68
+ def creator_node
69
+ @node.creator_node
70
+ end
71
+
72
+ def creator_node=(func)
73
+ @node.creator_node = func
74
+ end
75
+
45
76
  def grad
46
- @node.grad
77
+ gv = @grad_var
78
+ gv.nil? ? nil : gv.data
47
79
  end
48
80
 
49
81
  def grad=(g)
50
- @node.set_grad_with_check(g, nil, self)
82
+ self.grad_var = g.nil? ? nil : Chainer::Variable.new(g)
83
+ end
84
+
85
+ def grad_var
86
+ @grad_var
87
+ end
88
+
89
+ def grad_var=(g)
90
+ Utils::Variable.check_grad_type(nil, self, g.data) unless g.nil?
91
+ @grad_var = g
51
92
  end
52
93
 
53
94
  def shape
@@ -70,123 +111,172 @@ module Chainer
70
111
  @node.rank
71
112
  end
72
113
 
114
+ def transpose
115
+ Chainer::Functions::Array::Transpose.transpose(self)
116
+ end
117
+
118
+ def reshape(*shape)
119
+ if shape.size == 1 && shape[0].kind_of?(::Aray)
120
+ shape = shape[0]
121
+ end
122
+ Chainer::Functions::Array::Reshape.reshape(self, shape)
123
+ end
124
+
125
+ # Clears the gradient array.
73
126
  def cleargrad
74
- @node.grad = nil
127
+ @grad_var = nil
128
+ end
129
+
130
+ # Notifies the variable that the given node is its creator.
131
+ #
132
+ # @param [Chainer::FunctionNode] fnode node that has this variable as an output.
133
+ def set_creator_node(fnode)
134
+ @node.set_creator_node(fnode)
75
135
  end
76
136
 
77
- def backward(retain_grad: false)
78
- return if self.creator.nil?
137
+ def backward(retain_grad: false, enable_double_backprop: true)
138
+ old_enable_backprop = Chainer.configuration.enable_backprop
139
+ Chainer.configuration.enable_backprop = enable_double_backprop
140
+ _backward_main(retain_grad)
141
+ Chainer.configuration.enable_backprop = old_enable_backprop
142
+ end
143
+
144
+ def _backward_main(retain_grad)
145
+ node.check_old_style_gradient
146
+ return if self.creator_node.nil?
79
147
 
80
- if self.data.size == 1 && self.grad.nil?
148
+ seen_set = Set.new
149
+ grads = {}
150
+ if self.data.size == 1 && self.grad_var.nil?
81
151
  self.grad = self.data.new_ones
82
152
  end
153
+ grads[self.node] = self.grad_var
83
154
 
84
- funcs = [self.creator]
155
+ funcs = [self.creator_node]
156
+ seen_set.add(self.creator_node)
85
157
 
86
- while func = funcs.pop
158
+ while func = funcs.shift
159
+ inputs = func.inputs
160
+ target_input_indexes = inputs.each_with_index.map { |x, i| i if x.requires_grad }.compact
161
+ next if target_input_indexes.empty?
87
162
  outputs = func.outputs.map(&:__getobj__)
88
- in_data = func.inputs.map(&:data)
89
- out_grad = outputs.map { |y| y.nil? ? nil : y.grad }
90
-
91
- func.output_data = outputs.map { |y| y.nil? ? nil : y.data }
92
- gxs = func.backward(in_data, out_grad)
93
163
 
94
- raise unless gxs.size == in_data.size
95
-
96
- unless func.retain_after_backward
97
- func.output_data = nil
164
+ in_data = inputs.map(&:data)
165
+ out_grad = outputs.map do |y|
166
+ next nil if y.nil?
167
+ next grads[y] unless grads[y].nil?
168
+ y.grad_var
98
169
  end
170
+ out_grad_data = out_grad.map { |g| g.nil? ? g : g.data }
171
+
172
+ # Collect the current input gradients.
173
+ #
174
+ # When the same variable is passed to multiple input slots (e.g. an expression like +f(x, x)+),
175
+ # it makes the gradient accumulation complicated since the back-propagated gradients w.r.t.
176
+ # the first and second argument should be accumulated to the current gradient w.r.t. the same variable.
177
+ # In this case, the current implementation passes the current gradient only to the first occurrence of the variable
178
+ # in the input tuple and passes +nil+ to the rest of the occurrences.
179
+ # For example, when the input variables are +(x, x)+,
180
+ # the input gradient passed to the +backward_accumulate+ method is +(gx, nil)+ where +gx+ is the current gradient of ++x++.
181
+ # See also the docstring of +FunctionNode.backward_accumulate+.
182
+ target_inputs = target_input_indexes.map { |i| inputs[i] }
183
+ in_grad = []
184
+ target_input_indexes.each_with_index do |index_i, i|
185
+ x = inputs[index_i]
186
+ if target_inputs[0...i].include?(x)
187
+ gx = nil
188
+ elsif grads[x]
189
+ gx = grads[x]
190
+ elsif x.creator_node.nil?
191
+ gx = x.grad_var
192
+ else
193
+ gx = nil
194
+ end
195
+ in_grad << gx
196
+ end
197
+
198
+ gxs = func.backward_accumulate(target_input_indexes, out_grad, in_grad)
199
+ raise "Unmatched matries size: gxs.size(#{gxs.size}) != in_grad.size(#{in_grad.size})" unless gxs.size == in_grad.size
99
200
 
100
201
  unless retain_grad
101
202
  outputs.each do |y|
102
203
  unless y.nil? || y == @node
103
- y.grad = nil
204
+ grads[y] = nil
205
+ y_var = y.get_variable
206
+ y_var.grad_var = nil unless y_var.nil?
104
207
  end
105
208
  end
106
209
  end
107
210
 
108
- seen_vars = []
109
- need_copy = []
110
-
111
- func.inputs.zip(gxs).each do |x, gx|
211
+ gxs.each_with_index do |gx, i|
112
212
  next if gx.nil?
213
+ x = target_inputs[i]
113
214
  next unless x.requires_grad
114
215
 
115
- Utils::Variable.check_grad_type(func, x, gx)
116
-
117
- id_x = x.object_id
118
- if x.creator.nil? # leaf
119
- if x.grad.nil?
120
- x.grad = gx
121
- need_copy << id_x
122
- else
123
- if need_copy.include?(id_x)
124
- x.grad = Utils::Array.force_array(x.grad + gx)
125
- need_copy.delete(id_x)
126
- else
127
- x.grad += gx
128
- end
129
- end
130
- else # not leaf
131
- funcs << x.creator
132
- if seen_vars.include?(id_x)
133
- if need_copy.include?(id_x)
134
- x.grad = Utils::Array.force_array(gx + x.grad)
135
- need_copy.delete(id_x)
136
- else
137
- x.grad += gx
138
- end
139
- else
140
- x.grad = gx
141
- seen_vars << id_x
142
- need_copy << id_x
143
- end
216
+ Utils::Variable.check_grad_type(func, x, gx.data)
217
+
218
+ if target_inputs[0...i].include?(x)
219
+ cur_gx = grads[x]
220
+ grads[x] = cur_gx.nil? ? gx : gx + cur_gx
221
+ else
222
+ grads[x] = gx
223
+ end
224
+
225
+ x_var = x.get_variable
226
+ x_var.grad_var = grads[x] if x_var
227
+
228
+ if x.creator_node && !seen_set.include?(x.creator_node)
229
+ funcs << x.creator_node
230
+ seen_set.add(x.creator_node)
144
231
  end
145
232
  end
146
- end
233
+
234
+ funcs.sort_by! { |f| -f.rank }
235
+
236
+ end
147
237
  end
148
238
 
149
239
  def -@
150
- Functions::Math::Neg.new.(self)
240
+ Functions::Math::Neg.new.apply([self]).first
151
241
  end
152
242
 
153
243
  def +(other)
154
244
  if other.instance_of?(Chainer::Variable)
155
- Functions::Math::Add.new.(*[self, other])
245
+ Functions::Math::Add.new.apply([self, other])[0]
156
246
  else
157
- Functions::Math::AddConstant.new(other).(self)
247
+ Functions::Math::AddConstant.new(other).apply([self])[0]
158
248
  end
159
249
  end
160
250
 
161
- def -(other)
251
+ def -(other)
162
252
  if other.instance_of?(Chainer::Variable)
163
- Functions::Math::Sub.new.(*[self, other])
253
+ Functions::Math::Sub.new.apply([self, other])[0]
164
254
  else
165
- Functions::Math::AddConstant.new(-other).(self)
255
+ Functions::Math::AddConstant.new(-other).apply([self])[0]
166
256
  end
167
257
  end
168
258
 
169
259
  def *(other)
170
260
  if other.instance_of?(Chainer::Variable)
171
- Functions::Math::Mul.new.(*[self, other])
261
+ Functions::Math::Mul.new.apply([self, other])[0]
172
262
  else
173
- Functions::Math::MulConstant.new(other).(self)
263
+ Functions::Math::MulConstant.new(other).apply([self])[0]
174
264
  end
175
265
  end
176
266
 
177
267
  def /(other)
178
268
  if other.instance_of?(Chainer::Variable)
179
- Functions::Math::Div.new.(*[self, other])
269
+ Functions::Math::Div.new.apply([self, other])[0]
180
270
  else
181
- Functions::Math::MulConstant.new(1 / other).(self)
271
+ Functions::Math::MulConstant.new(1 / other).apply([self])[0]
182
272
  end
183
273
  end
184
274
 
185
- def **(other)
275
+ def **(other)
186
276
  if other.instance_of?(Chainer::Variable)
187
- Functions::Math::PowVarVar.new.(*[self, other])
277
+ Functions::Math::PowVarVar.new.apply([self, other])[0]
188
278
  else
189
- Functions::Math::PowVarConst.new(other).(self)
279
+ Functions::Math::PowVarConst.new(other).apply([self])[0]
190
280
  end
191
281
  end
192
282
 
@@ -196,7 +286,7 @@ module Chainer
196
286
 
197
287
  # when left side is Numeric value and right side is Chainer::Value, call this method.
198
288
  def coerce(other)
199
- other = self.data.class[*other] if other.kind_of?(Numeric)
289
+ other = self.data.class.new.fill(other) if other.kind_of?(Numeric)
200
290
  [Chainer::Variable.new(other, requires_grad: false), self]
201
291
  end
202
292
  end