red-chainer 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -11,7 +11,7 @@ module Chainer
11
11
  get_cifar(100, with_label, ndim, scale)
12
12
  end
13
13
 
14
- def self.get_cifar(n_classes, with_label, ndim, scale)
14
+ def self.get_cifar(n_classes, with_label, ndim, scale, device: Chainer::Device.default)
15
15
  train_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).to_table
16
16
  test_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).to_table
17
17
 
@@ -25,13 +25,14 @@ module Chainer
25
25
  test_labels = test_table[:fine_label]
26
26
  end
27
27
 
28
+ xm = device.xm
28
29
  [
29
- preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale),
30
- preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale)
30
+ preprocess_cifar(xm::UInt8[*train_data], xm::UInt8[*train_labels], with_label, ndim, scale),
31
+ preprocess_cifar(xm::UInt8[*test_data], xm::UInt8[*test_labels], with_label, ndim, scale)
31
32
  ]
32
33
  end
33
34
 
34
- def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
35
+ def self.preprocess_cifar(images, labels, withlabel, ndim, scale, device: Chainer::Device.default)
35
36
  if ndim == 1
36
37
  images = images.reshape(images.shape[0], 3072)
37
38
  elsif ndim == 3
@@ -39,11 +40,12 @@ module Chainer
39
40
  else
40
41
  raise 'invalid ndim for CIFAR dataset'
41
42
  end
42
- images = images.cast_to(Numo::SFloat)
43
+ xm = device.xm
44
+ images = images.cast_to(xm::SFloat)
43
45
  images *= scale / 255.0
44
46
 
45
47
  if withlabel
46
- labels = labels.cast_to(Numo::Int32)
48
+ labels = labels.cast_to(xm::Int32)
47
49
  TupleDataset.new(images, labels)
48
50
  else
49
51
  images
@@ -1,13 +1,17 @@
1
- require 'zlib'
1
+ require 'datasets'
2
2
 
3
3
  module Chainer
4
4
  module Datasets
5
- module Mnist
6
- def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::SFloat, label_dtype: Numo::Int32)
7
- train_raw = retrieve_mnist_training
5
+ module MNIST
6
+ def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil)
7
+ xm = Chainer::Device.default.xm
8
+ dtype ||= xm::SFloat
9
+ label_dtype ||= xm::Int32
10
+
11
+ train_raw = retrieve_mnist(type: :train)
8
12
  train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)
9
13
 
10
- test_raw = retrieve_mnist_test
14
+ test_raw = retrieve_mnist(type: :test)
11
15
  test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, label_dtype)
12
16
  [train, test]
13
17
  end
@@ -24,7 +28,7 @@ module Chainer
24
28
 
25
29
  images = images.cast_to(image_dtype)
26
30
  images *= scale / 255.0
27
-
31
+
28
32
  if withlabel
29
33
  labels = raw[:y].cast_to(label_dtype)
30
34
  TupleDataset.new(images, labels)
@@ -33,56 +37,11 @@ module Chainer
33
37
  end
34
38
  end
35
39
 
36
- def self.retrieve_mnist_training
37
- urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
38
- 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz']
39
- retrieve_mnist('train.npz', urls)
40
- end
41
-
42
- def self.retrieve_mnist_test
43
- urls = ['http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
44
- 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz']
45
- retrieve_mnist('test.npz', urls)
46
- end
47
-
48
- def self.retrieve_mnist(name, urls)
49
- root = Chainer::Dataset::Download.get_dataset_directory('pfnet/chainer/mnist')
50
- path = File.expand_path(name, root)
51
- Chainer::Dataset::Download.cache_or_load_file(path) do
52
- make_npz(path, urls)
53
- end
54
- end
55
-
56
- def self.make_npz(path, urls)
57
- x_url, y_url = urls
58
- x_path = Chainer::Dataset::Download.cached_download(x_url)
59
- y_path = Chainer::Dataset::Download.cached_download(y_url)
60
-
61
- x = nil
62
- y = nil
63
-
64
- Zlib::GzipReader.open(x_path) do |fx|
65
- Zlib::GzipReader.open(y_path) do |fy|
66
- fx.read(4)
67
- fy.read(4)
68
-
69
- n = fx.read(4).unpack('i>')[0]
70
- fy.read(4)
71
- fx.read(8)
72
-
73
- x = Numo::UInt8.new(n, 784).rand(n)
74
- y = Numo::UInt8.new(n).rand(n)
75
-
76
- n.times do |i|
77
- y[i] = fy.read(1).ord
78
- 784.times do |j|
79
- x[i, j] = fx.read(1).ord
80
- end
81
- end
82
- end
83
- end
40
+ def self.retrieve_mnist(type:)
41
+ train_table = ::Datasets::MNIST.new(type: type).to_table
84
42
 
85
- { x: x, y: y}
43
+ xm = Chainer::Device.default.xm
44
+ { x: xm::UInt8[*train_table[:pixels]], y: xm::UInt8[*train_table[:label]] }
86
45
  end
87
46
  end
88
47
  end
@@ -0,0 +1,88 @@
1
+ module Chainer
2
+ module Device
3
+ # Creates device
4
+ #
5
+ # @param [Integer or Chainer::AbstractDevice] device_spec Device specifier.
6
+ # Negative integer indicates CPU. 0 or positive integer indicates GPU.
7
+ # If a device object is given, itself is returned.
8
+ # @return [Chainer::AbstractDevice] device object
9
+ def create(device_spec)
10
+ return device_spec if device_spec.kind_of?(AbstractDevice)
11
+ if device_spec.kind_of?(Integer)
12
+ return CpuDevice.new if device_spec < 0
13
+ return GpuDevice.new(device_spec)
14
+ end
15
+ raise "Invalid device_spec: #{device_spec}"
16
+ end
17
+ module_function :create
18
+
19
+ # Changes default device
20
+ #
21
+ # @param [Object] device_spec
22
+ # @see Chainer::Device.create
23
+ def change_default(device_spec)
24
+ @default = create(device_spec)
25
+ @default.use
26
+ end
27
+ module_function :change_default
28
+
29
+ # Gets default device
30
+ #
31
+ # @return [Chainer::AbstractDevice] the default device.
32
+ def default
33
+ @default ||= CpuDevice.new
34
+ end
35
+ module_function :default
36
+
37
+ # TODO(sonots): Add Device.from_array after Cumo provides an API
38
+ # to return GPU device ID from Cumo::NArray.
39
+ end
40
+
41
+ class AbstractDevice
42
+ def xm
43
+ raise NotImplementedError
44
+ end
45
+
46
+ def use
47
+ end
48
+ end
49
+
50
+ class CpuDevice < AbstractDevice
51
+ def xm
52
+ Numo
53
+ end
54
+
55
+ def ==(other)
56
+ return false unless other.is_a?(CpuDevice)
57
+ true
58
+ end
59
+ end
60
+
61
+ class GpuDevice < AbstractDevice
62
+ attr_reader :id
63
+
64
+ # @param [Integer] id GPU Device ID. If not given, CUDA current device id is used.
65
+ def initialize(id = nil)
66
+ Chainer::CUDA.check_available
67
+ id ||= Cumo::CUDA::Runtime.cudaGetDevice
68
+ if id < 0
69
+ raise 'GPU Device ID must not be negative'
70
+ end
71
+ @id = id
72
+ end
73
+
74
+ def xm
75
+ Cumo
76
+ end
77
+
78
+ def ==(other)
79
+ return false unless other.is_a?(GpuDevice)
80
+ id == other.id
81
+ end
82
+
83
+ # Sets CUDA current device with owned GPU Device ID
84
+ def use
85
+ Cumo::CUDA::Runtime.cudaSetDevice(@id)
86
+ end
87
+ end
88
+ end
@@ -1,70 +1,89 @@
1
+ require 'chainer/function_node'
1
2
  module Chainer
2
3
  class Function
3
4
 
4
5
  attr_reader :rank, :inputs, :outputs, :retain_after_backward
5
- attr_accessor :output_data
6
+ attr_accessor :output_data, :owned_node
6
7
 
7
8
  def initialize
8
9
  @rank = 0
9
10
  end
10
11
 
11
12
  def call(*inputs)
12
- inputs = inputs.map do |x|
13
- if x.kind_of?(Chainer::Variable)
14
- x
15
- else
16
- Variable.new(x, requires_grad: false)
17
- end
18
- end
13
+ node = self.node
19
14
 
20
- in_data = inputs.map(&:data)
21
- requires_grad = inputs.any?(&:requires_grad)
15
+ node.function = self
16
+ node.weak_function = nil
17
+ @node = WeakRef.new(node)
18
+ @owned_node = nil
22
19
 
23
- @input_indexes_to_retain = nil
24
- @output_indexes_to_retain = nil
25
- outputs = forward(in_data)
26
- raise if !outputs.is_a? Array
20
+ ret = node.apply(inputs)
27
21
 
28
- ret = outputs.map do |y|
29
- Variable.new(y, requires_grad: requires_grad)
30
- end
22
+ ret.size == 1 ? ret[0] : ret
23
+ end
31
24
 
32
- if Chainer.configuration.enable_backprop
33
- @rank = inputs.map(&:rank).max || 0
25
+ def inputs
26
+ @node.inputs
27
+ end
34
28
 
35
- ret.each { |y| y.creator = self }
29
+ def outputs
30
+ @node.outputs
31
+ end
36
32
 
37
- @inputs = inputs.map(&:node)
38
- @outputs = ret.map { |y| WeakRef.new(y.node) }
33
+ def node
34
+ noderef = @node
35
+ nd = noderef ? noderef.__getobj__ : @owned_node
36
+ return nd if nd
39
37
 
40
- @input_indexes_to_retain = 0...inputs.size if @input_indexes_to_retain.nil?
41
- @input_indexes_to_retain.each do |index|
42
- inputs[index].retain_data()
43
- end
44
- remove_instance_variable(:@input_indexes_to_retain)
38
+ nd = FunctionAdapter.new(self)
39
+ @owned_node = nd
40
+ nd
41
+ end
45
42
 
46
- unless @output_indexes_to_retain.nil?
47
- @output_indexes_to_retain.each do |index|
48
- ret[index].retain_data()
49
- end
50
- remove_instance_variable(:@output_indexes_to_retain)
51
- end
52
- end
43
+ def output_data
44
+ node.output_data
45
+ end
53
46
 
54
- ret.size == 1 ? ret[0] : ret
47
+ def rank
48
+ @node.rank
49
+ end
50
+
51
+ def label
52
+ self.class.to_s
55
53
  end
56
54
 
57
55
  def forward(inputs)
58
- # TODO: GPU branch processing
59
- forward_cpu(inputs)
56
+ xm = Chainer.get_array_module(*inputs)
57
+ if xm == Cumo
58
+ forward_gpu(inputs)
59
+ else
60
+ forward_cpu(inputs)
61
+ end
60
62
  end
61
63
 
62
64
  def forward_cpu(inputs)
63
65
  raise NotImplementedError
64
66
  end
65
67
 
68
+ def forward_gpu(inputs)
69
+ raise NotImplementedError
70
+ end
71
+
66
72
  def backward(inputs, grad_outputs)
67
- backward_cpu(inputs, grad_outputs)
73
+ xm = Chainer.get_array_module(*(inputs + grad_outputs))
74
+ if xm == Cumo
75
+ backward_gpu(inputs, grad_outputs)
76
+ else
77
+ backward_cpu(inputs, grad_outputs)
78
+ end
79
+ end
80
+
81
+ def backward_cpu(inputs, grad_outputs)
82
+ return [nil] * inputs.size
83
+ end
84
+
85
+ def backward_gpu(inputs, grad_outputs)
86
+ return [nil] * inputs.size
68
87
  end
69
88
 
70
89
  def retain_inputs(indexes)
@@ -72,10 +91,53 @@ module Chainer
72
91
  end
73
92
 
74
93
  def retain_outputs(indexes, retain_after_backward: false)
75
- @output_indexes_to_retain = indexes
76
- if retain_after_backward
77
- @retain_after_backward = retain_after_backward
94
+ node.retain_outputs(indexes)
95
+ end
96
+ end
97
+
98
+ class FunctionAdapter < ::Chainer::FunctionNode
99
+ attr_accessor :function, :weak_function
100
+
101
+ def initialize(function)
102
+ super()
103
+ @weak_function = WeakRef.new(function)
104
+ function.owned_node = self
105
+ end
106
+
107
+ def function
108
+ func = @function
109
+ return func if func
110
+
111
+ weak_func = @weak_function
112
+ weak_func.__getobj__
113
+ end
114
+
115
+ def label
116
+ @function.label
117
+ end
118
+
119
+ def forward(inputs)
120
+ retain_inputs(inputs.size.times.to_a)
121
+ @function.forward(inputs)
122
+ end
123
+
124
+ def backward(target_input_indexes, grad_outputs)
125
+ in_data = @inputs.map { |input| input.data }
126
+ grad_out_data = grad_outputs.map { |grad| grad.nil? ? nil : grad.data }
127
+
128
+ gxs = @function.backward(in_data, grad_out_data)
129
+ ret = []
130
+ target_input_indexes.each do |i|
131
+ if gxs[i].nil?
132
+ g = nil
133
+ else
134
+ g = Chainer::Variable.new(gxs[i])
135
+ g.node.old_style_grad_generator = @function.label
136
+ end
137
+ ret << g
78
138
  end
139
+
140
+ ret
79
141
  end
80
142
  end
81
143
  end
@@ -0,0 +1,454 @@
1
+ # Function node of the computational graph.
2
+ # FunctionNode is a class representing a node in a computational graph.
3
+ # The node corresponds to an application of a differentiable function to input variables.
4
+ # When a differentiable function is applied to `Chainer::Variable` objects,
5
+ # it creates an instance of FunctionNode implementation and calls its `apply` method.
6
+ # The `apply` method basically does the following three things.
7
+ # 1. Adding an edge from the function node to the variable node corresponding to each input.
8
+ # The node of each input is extracted by `Chainer::`Variable.node`.
9
+ # 2. Computing the output arrays of the function.
10
+ # 3. Creating a :class:`Variable` object for each output array and
11
+ # adding an edge from the node of the variable to the function node.
12
+ # The output variables are then returned.
13
+ module Chainer
14
+ class FunctionNode
15
+ attr_accessor :rank, :inputs, :outputs
16
+
17
+ def initialize
18
+ @rank = 0
19
+ @inputs = nil
20
+ @outputs = nil
21
+
22
+ @retained_output_data = nil
23
+ @input_indexes_to_retain = nil
24
+ @output_indexes_to_retain = nil
25
+ end
26
+
27
+ # Short text that represents the function.
28
+ #
29
+ # The default implementation returns its type name.
30
+ # Each function should override it to give more information.
31
+ def label
32
+ self.class.name
33
+ end
34
+
35
+ # A tuple of the retained output arrays.
36
+ # This property is mainly used by $Function$. Users basically do
37
+ # not have to use this property; use $get_retained_outputs$ instead.
38
+ def output_data
39
+ raise RuntimeError, 'retained output data is gone' if @retained_output_data.nil?
40
+ out_data = [nil] * @outputs.size
41
+ @output_indexes_to_retain.zip(@retained_output_data).each do |index, data|
42
+ out_data[index] = data
43
+ end
44
+ out_data
45
+ end
46
+
47
+ # Computes output variables and grows the computational graph.
48
+ #
49
+ # Basic behavior is expressed in the documentation of `FunctionNode`.
50
+ # @param [Chainer::Variable, Numo::NArray] inputs If the element is an Numo::NArray,
51
+ # it is automatically wrapped with `Chainer::Variable`.
52
+ # @return [Array<Chainer::Variable>] A tuple of output `Chainer::Variable` objectts.
53
+ def apply(inputs)
54
+ input_vars = inputs.map { |x| Chainer::Variable.as_variable(x) }
55
+ in_data = input_vars.map(&:data)
56
+ requires_grad = input_vars.map(&:requires_grad).any?
57
+
58
+ # Forward propagation
59
+ @input_indexes_to_retain = nil
60
+ @output_indexes_to_retain = nil
61
+ outputs = forward(in_data)
62
+ raise TypeError, "#{outputs.class} not Array" unless outputs.is_a?(Array)
63
+
64
+ ret = outputs.map { |y| Chainer::Variable.new(y, requires_grad: requires_grad) }
65
+
66
+ if Chainer.configuration.enable_backprop
67
+ # Topological ordering
68
+ @rank = input_vars.size > 0 ? input_vars.map(&:rank).max : 0
69
+
70
+ # Add backward edges
71
+ ret.each { |y| y.creator_node = self }
72
+ @inputs = input_vars.map(&:node)
73
+ # Add forward edges (must be weak references)
74
+ @outputs = ret.map { |y| WeakRef.new(y.node) }
75
+
76
+ unless @input_indexes_to_retain.nil?
77
+ @input_indexes_to_retain.each do |index|
78
+ input_vars[index].retain_data
79
+ end
80
+ end
81
+
82
+ unless @output_indexes_to_retain.nil?
83
+ retained_data = []
84
+ @output_indexes_to_retain.each do |index|
85
+ ret[index].retain_data
86
+ retained_data << outputs[index]
87
+ end
88
+ @retained_output_data = Array(retained_data)
89
+ end
90
+ end
91
+
92
+ ret
93
+ end
94
+
95
+ # Computes the output arrays from the input arrays.
96
+ #
97
+ # @param [Array] inputs input array(s)
98
+ # @return [Array] output array(s)
99
+ def forward(inputs)
100
+ raise TypeError, "mustt inputs > 0, inputs size is #{inputs.size}" if inputs.size.zero?
101
+ # TODO GPU
102
+ forward_cpu(inputs)
103
+ end
104
+
105
+ # Computes the output arrays from the input Numo::NArray.
106
+ #
107
+ # @param [Array<Numo::NArray>] inputs Numo::NArray objects.
108
+ # @return [Array<Numo::NArray>] Array of output arrays.
109
+ def forward_cpu(inputs)
110
+ raise NotImplementedError
111
+ end
112
+
113
+ # Lets specified input variable nodes keep data arrays.
114
+ #
115
+ # By calling this method from `forward`, the function node can specify which inputs are required for backprop.
116
+ # The input variables with retained arrays can be obtained by `get_retained_inputs` from `backward`.
117
+ #
118
+ # Note that **this method must not be called from the outside of forward method.**
119
+ # @param [Integer, Array] indexes Indexes of input variables that the function does not require for backprop.
120
+ def retain_inputs(indexes)
121
+ @input_indexes_to_retain = indexes
122
+ end
123
+
124
+ # Lets specified output variable nodes keep data arrays.
125
+ #
126
+ # By calling this method from `forward`, the function node can specify which outputs are required for backprop.
127
+ # If this method is not called, any output variables are not marked to keep the data array at the point of returning from `apply`.
128
+ # The output variables with retained arrays can be obtained by `get_retained_outputs` from `backward`.
129
+ # Note that **this method must not be called from the outside of forward method.**
130
+ # @param [Integer, Array] indexes Indexes of input variables that the function does not require for backprop.
131
+ def retain_outputs(indexes)
132
+ @output_indexes_to_retain = indexes
133
+ end
134
+
135
+ # Computes gradients w.r.t. specified inputs given output gradients.
136
+ #
137
+ # This method is used to compute one step of the backpropagation corresponding to the forward computation of this function node.
138
+ # Given the gradients w.r.t. output variables, this method computes the gradients w.r.t. specified input variables.
139
+ # Note that this method does not need to compute any input gradients not specified by `target_input_indexes`
140
+ # It enables the function node to return the input gradients with the full computational history,
141
+ # in which case it supports *differentiable backpropagation* or *higher-order differentiation*.
142
+ #
143
+ # @param [Array<Integer>] target_indexes Indices of the input variables w.r.t. which the gradients are required.
144
+ # It is guaranteed that this tuple contains at least one element.
145
+ # @param [Array<Chainer::Variable>] grad_outputs Gradients w.r.t. the output variables.
146
+ # If the gradient w.r.t. an output variable is not given, the corresponding element is `None`.
147
+ # @return [Array<Chainer::Variable>] Array of Chainer::Variable that represent the gradients.
148
+ def backward(target_indexes, grad_outputs)
149
+ [nil] * target_indexes.size
150
+ end
151
+
152
+ # Computes gradients w.r.t. specified inputs and accumulates them.
153
+ #
154
+ # This method provides a way to fuse the backward computation and the gradient accumulations
155
+ # in the case that the multiple functions are applied to the same variable.
156
+ # Users have to override either of this method or `backward`.
157
+ # It is often simpler to implement `backward` and is recommended if you do not need to provide efficient gradient accumulation.
158
+ #
159
+ # @param [Array<Integer>] target_input_indexes Indices of the input variables w.r.t. which the gradients are required.
160
+ # It is guaranteed that this tuple contains at least one element.
161
+ # @param [Array<Chainer::Variable>] grad_outputs Gradients w.r.t. the output variables.
162
+ # If the gradient w.r.t. an output variable is not given, the corresponding element is `None`.
163
+ # @param [Array<Chainer::Variable>] grad_inputs Gradients w.r.t. the input variables specified by `target_input_indexes`.
164
+ # These values are computed by other computation paths.
165
+ # If there is no gradient value existing for the variable, the corresponding element is ``None``.
166
+ # @return [Array<Chainer::Variable>] Array of variables that represent the gradients w.r.t. specified input variables.
167
+ def backward_accumulate(target_input_indexes, grad_outputs, grad_inputs)
168
+ gxs = backward(target_input_indexes, grad_outputs)
169
+
170
+ len_gxs = gxs.size
171
+ if len_gxs == @inputs.size
172
+ gxs = target_input_indexes.map { |i| gxs[i] }
173
+ elsif len_gxs != target_input_indexes.size
174
+ raise ArgumentError, "number of gradients returned by #{impl_name} (#{label}) is incorrect."
175
+ end
176
+
177
+ gxs.zip(grad_inputs).map do |gx, g_input|
178
+ if g_input.nil?
179
+ gx
180
+ elsif gx.nil?
181
+ g_input
182
+ else
183
+ gx + g_input
184
+ end
185
+ end
186
+ end
187
+
188
+ # Returns a Array of retained input variables.
189
+ #
190
+ # This method is used to retrieve the input variables retained in `forward`.
191
+ #
192
+ # @return [Array] a Array of retained input variables.
193
+ def get_retained_inputs
194
+ @input_indexes_to_retain.map { |index| @inputs[index].get_variable }
195
+ end
196
+
197
+ # Returns a Array of retained output variables.
198
+ #
199
+ # This method is used to retrieve the input variables retained in `forward`.
200
+ #
201
+ # @return [Array] a Array of retained input variables.
202
+ def get_retained_outputs
203
+ ret = []
204
+ outputs = @outputs
205
+
206
+ new_outputs = outputs.dup
207
+ outputs_modified = false
208
+
209
+ @output_indexes_to_retain.zip(@retained_output_data) do |index, data|
210
+ output = outputs[index].__getobj__
211
+ if output.nil?
212
+ output_var = Chainer::Variable.new(data)
213
+ output_var.creator_node = self
214
+ new_outputs[index] = WeakRef.new(output_var)
215
+ outputs_modified = true
216
+ else
217
+ output_var = output.get_variable
218
+ end
219
+
220
+ ret << output_var
221
+ end
222
+
223
+ if outputs_modified
224
+ @outputs = Array(new_outputs)
225
+ end
226
+
227
+ ret
228
+ end
229
+
230
+ # Purges in/out nodes and this function node itself from the graph.
231
+ def unchain
232
+ @outputs.each do |y|
233
+ y_ref = y.()
234
+ unless y_ref.nil?
235
+ y_ref.unchain
236
+ end
237
+ end
238
+ @inputs = nil
239
+ end
240
+
241
+ private
242
+
243
+ def impl_name
244
+ self.class.name
245
+ end
246
+ end
247
+
248
+ def self.grad(outputs, inputs, grad_outputs: nil, grad_inputs: nil, set_grad: false, retain_grad: false, enable_double_backprop: false)
249
+ # The implementation consists of three steps.
250
+
251
+ if !outputs.is_a?(Array)
252
+ raise TypeError, "outputs must be Array, not #{outputs.class}"
253
+ end
254
+ if !inputs.is_a?(Array)
255
+ raise TypeError, "inputs must be Array, not #{inputs.class}"
256
+ end
257
+ if !grad_outputs.nil? && !grad_outputs.is_a?(Array)
258
+ raise TypeError, "grad_outputs must be Array, not #{grad_outputs.class}"
259
+ end
260
+ if !grad_inputs.nil? && !grad_inputs.is_a?(Array)
261
+ raise TypeError, "grad_inputs must be Array, not #{grad_inputs.class}"
262
+ end
263
+
264
+ # 1. Backward enumeration: all the nodes reachable backward from the output
265
+ # nodes are enumerated. The forward direction links are collected in
266
+ # this step. Note that the variable nodes whose requires_grad is false
267
+ # are ignored and their creators are not searched.
268
+ candidate_funcs = outputs.map(&:creator_node).compact
269
+ visited_funcs = Set.new
270
+ forward_graph = {}
271
+
272
+ while func = candidate_funcs.pop
273
+ next if visited_funcs.include?(func)
274
+ visited_funcs.add(func)
275
+
276
+ func.inputs.each do |x|
277
+ next unless x.requires_grad
278
+ forward_graph[x] = [] if forward_graph[x].nil?
279
+ forward_graph[x] << func
280
+ creator = x.creator_node
281
+ if creator && !visited_funcs.include?(creator)
282
+ candidate_funcs << creator
283
+ end
284
+ end
285
+ end
286
+
287
+ # 2. Forward enumeration: all the nodes in the subgraph reachable from the
288
+ # input nodes are enumerated. The extracted (sub-)subgraph is the union
289
+ # of all paths that backpropagation will visit.
290
+ candidate_vars = inputs.map(&:node)
291
+ visited_funcs = Set.new
292
+ grad_required = Set.new
293
+ while x = candidate_vars.pop
294
+ grad_required.add(x)
295
+ forward_graph[x].each do |func|
296
+ next if visited_funcs.include?(func)
297
+ visited_funcs.add(func)
298
+ func.outputs.each do |y_ref|
299
+ y = y_ref.__getobj__
300
+ if y && forward_graph[y]
301
+ candidate_vars << y
302
+ end
303
+ end
304
+ end
305
+ end
306
+
307
+ # 3. Backpropagation: the backpropagation is executed along the
308
+ # (sub-)subgraph. It uses the topological order of the subgraph which is
309
+ # induced by the reversed order of function applications ("rank").
310
+ grads = {} # mapping from variable nodes to their gradients
311
+
312
+ # Initialize the gradient mapping.
313
+ grad_outputs = [nil] * outputs.size if grad_outputs.nil?
314
+ outputs.zip(grad_outputs).each do |y, gy|
315
+ if gy.nil?
316
+ gy_data = y.data.new_ones
317
+ gy = Chainer::Variable.new(gy_data, requires_grad: false)
318
+ end
319
+
320
+ grads[y.node] = gy
321
+ end
322
+
323
+ unless grad_inputs.nil?
324
+ inputs.zip(grad_inputs).each do |x, gx|
325
+ grads[x.node] = gx unless gx.nil?
326
+ end
327
+ end
328
+
329
+ # Backprop implementation. It edits grads which will only contain the
330
+ # gradients w.r.t. the inputs.
331
+ old_enable_backprop = Chainer.configuration.enable_backprop
332
+ Chainer.configuration.enable_backprop = enable_double_backprop
333
+ backprop(outputs, inputs, grad_required, retain_grad, grads)
334
+ Chainer.configuration.enable_backprop = old_enable_backprop
335
+
336
+ # Extract the gradients w.r.t. the inputs and return them.
337
+ ret = inputs.map { |x| grads[x.node] }
338
+ if set_grad
339
+ inputs.zip(ret).each do |x, gx|
340
+ x.grad_var = gx
341
+ end
342
+ end
343
+
344
+ ret
345
+ end
346
+
347
+ def self.backprop(outputs, inputs, grad_required, retain_grad, grads)
348
+ candidate_funcs = []
349
+ visited_funcs = Set.new
350
+
351
+ push_candidate = -> (func) do
352
+ return if visited_funcs.include?(func)
353
+
354
+ # Negate since heapq is min-heap
355
+ # The second element is used to make each item unique
356
+ visited_funcs.add(func)
357
+ candidate_funcs.unshift(func)
358
+ candidate_funcs.sort_by! { |f| f.rank }
359
+ end
360
+
361
+ pop_candidate = -> () do
362
+ candidate_funcs.pop
363
+ end
364
+
365
+ outputs.each do |y|
366
+ creator = y.creator_node
367
+ next if creator.nil?
368
+ push_candidate.(creator)
369
+ end
370
+
371
+ input_nodes = Set.new(inputs.map(&:node))
372
+
373
+ while func = pop_candidate.()
374
+ # Collect the gradients w.r.t. the outputs
375
+ gys = []
376
+
377
+ func.outputs.each do |y_ref|
378
+ y = y_ref.__getobj__
379
+ if y.nil?
380
+ gys << nil
381
+ next
382
+ end
383
+ gys << grads[y]
384
+ end
385
+
386
+ # Collect the gradients w.r.t. the inputs
387
+ #
388
+ # Note (Tokui): when the same variable is passed multiple times as
389
+ # inputs in the same function (e.g. an expression like f(x, x)), the
390
+ # current implementation passes None as the current gradient w.r.t.
391
+ # such an input except for the first one (i.e., it builds gxs like
392
+ # (gx, None) where gx is the current gradient w.r.t. x).
393
+ gxs = []
394
+ input_indexes = []
395
+ selected_inputs = Set.new
396
+ func.inputs.each_with_index do |x, i|
397
+ next unless grad_required.include?(x)
398
+
399
+ input_indexes << i
400
+ if selected_inputs.include?(x)
401
+ gxs << nil
402
+ else
403
+ gxs << grads[x]
404
+ selected_inputs.add(x)
405
+ end
406
+ end
407
+
408
+ next if input_indexes.empty?
409
+
410
+ # Do backward
411
+ new_gxs = func.backward_accumulate(input_indexes, gys, gxs)
412
+
413
+ # Delete output gradients that are not required to return
414
+ func.outputs.each do |y_ref|
415
+ y = y_ref.__getobj__
416
+ if y && grads[y] && !input_nodes.include?(y)
417
+ grads.delete(y)
418
+ end
419
+ end
420
+
421
+ # Update grads
422
+ selected_inputs = Set.new
423
+ input_indexes.zip(new_gxs).each do |i, g|
424
+ next if g.nil?
425
+
426
+ node = func.inputs[i]
427
+ if selected_inputs.include?(node)
428
+ # Accumulate the duplicated gradients here
429
+ cur_gx = grads[node]
430
+ if cur_gx
431
+ g = g + cur_gx
432
+ end
433
+ else
434
+ selected_inputs.add(node)
435
+ end
436
+
437
+ grads[node] = g
438
+
439
+ if retain_grad
440
+ v = node.get_variable
441
+ if v
442
+ v.grad_var = g
443
+ end
444
+ end
445
+
446
+ creator = node.creator_node
447
+ if creator
448
+ push_candidate.(creator)
449
+ end
450
+ end
451
+ end
452
+ end
453
+ private_class_method :backprop
454
+ end