red-chainer 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -11,7 +11,7 @@ module Chainer
11
11
  get_cifar(100, with_label, ndim, scale)
12
12
  end
13
13
 
14
- def self.get_cifar(n_classes, with_label, ndim, scale)
14
+ def self.get_cifar(n_classes, with_label, ndim, scale, device: Chainer::Device.default)
15
15
  train_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).to_table
16
16
  test_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).to_table
17
17
 
@@ -25,13 +25,14 @@ module Chainer
25
25
  test_labels = test_table[:fine_label]
26
26
  end
27
27
 
28
+ xm = device.xm
28
29
  [
29
- preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale),
30
- preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale)
30
+ preprocess_cifar(xm::UInt8[*train_data], xm::UInt8[*train_labels], with_label, ndim, scale),
31
+ preprocess_cifar(xm::UInt8[*test_data], xm::UInt8[*test_labels], with_label, ndim, scale)
31
32
  ]
32
33
  end
33
34
 
34
- def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
35
+ def self.preprocess_cifar(images, labels, withlabel, ndim, scale, device: Chainer::Device.default)
35
36
  if ndim == 1
36
37
  images = images.reshape(images.shape[0], 3072)
37
38
  elsif ndim == 3
@@ -39,11 +40,12 @@ module Chainer
39
40
  else
40
41
  raise 'invalid ndim for CIFAR dataset'
41
42
  end
42
- images = images.cast_to(Numo::SFloat)
43
+ xm = device.xm
44
+ images = images.cast_to(xm::SFloat)
43
45
  images *= scale / 255.0
44
46
 
45
47
  if withlabel
46
- labels = labels.cast_to(Numo::Int32)
48
+ labels = labels.cast_to(xm::Int32)
47
49
  TupleDataset.new(images, labels)
48
50
  else
49
51
  images
@@ -1,13 +1,17 @@
1
- require 'zlib'
1
+ require 'datasets'
2
2
 
3
3
  module Chainer
4
4
  module Datasets
5
- module Mnist
6
- def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::SFloat, label_dtype: Numo::Int32)
7
- train_raw = retrieve_mnist_training
5
+ module MNIST
6
+ def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil)
7
+ xm = Chainer::Device.default.xm
8
+ dtype ||= xm::SFloat
9
+ label_dtype ||= xm::Int32
10
+
11
+ train_raw = retrieve_mnist(type: :train)
8
12
  train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)
9
13
 
10
- test_raw = retrieve_mnist_test
14
+ test_raw = retrieve_mnist(type: :test)
11
15
  test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, label_dtype)
12
16
  [train, test]
13
17
  end
@@ -24,7 +28,7 @@ module Chainer
24
28
 
25
29
  images = images.cast_to(image_dtype)
26
30
  images *= scale / 255.0
27
-
31
+
28
32
  if withlabel
29
33
  labels = raw[:y].cast_to(label_dtype)
30
34
  TupleDataset.new(images, labels)
@@ -33,56 +37,11 @@ module Chainer
33
37
  end
34
38
  end
35
39
 
36
- def self.retrieve_mnist_training
37
- urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
38
- 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz']
39
- retrieve_mnist('train.npz', urls)
40
- end
41
-
42
- def self.retrieve_mnist_test
43
- urls = ['http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
44
- 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz']
45
- retrieve_mnist('test.npz', urls)
46
- end
47
-
48
- def self.retrieve_mnist(name, urls)
49
- root = Chainer::Dataset::Download.get_dataset_directory('pfnet/chainer/mnist')
50
- path = File.expand_path(name, root)
51
- Chainer::Dataset::Download.cache_or_load_file(path) do
52
- make_npz(path, urls)
53
- end
54
- end
55
-
56
- def self.make_npz(path, urls)
57
- x_url, y_url = urls
58
- x_path = Chainer::Dataset::Download.cached_download(x_url)
59
- y_path = Chainer::Dataset::Download.cached_download(y_url)
60
-
61
- x = nil
62
- y = nil
63
-
64
- Zlib::GzipReader.open(x_path) do |fx|
65
- Zlib::GzipReader.open(y_path) do |fy|
66
- fx.read(4)
67
- fy.read(4)
68
-
69
- n = fx.read(4).unpack('i>')[0]
70
- fy.read(4)
71
- fx.read(8)
72
-
73
- x = Numo::UInt8.new(n, 784).rand(n)
74
- y = Numo::UInt8.new(n).rand(n)
75
-
76
- n.times do |i|
77
- y[i] = fy.read(1).ord
78
- 784.times do |j|
79
- x[i, j] = fx.read(1).ord
80
- end
81
- end
82
- end
83
- end
40
+ def self.retrieve_mnist(type:)
41
+ train_table = ::Datasets::MNIST.new(type: type).to_table
84
42
 
85
- { x: x, y: y}
43
+ xm = Chainer::Device.default.xm
44
+ { x: xm::UInt8[*train_table[:pixels]], y: xm::UInt8[*train_table[:label]] }
86
45
  end
87
46
  end
88
47
  end
@@ -0,0 +1,88 @@
1
+ module Chainer
2
+ module Device
3
+ # Creates device
4
+ #
5
+ # @param [Integer or Chainer::AbstractDevice] device_spec Device specifier.
6
+ # Negative integer indicates CPU. 0 or positive integer indicates GPU.
7
+ # If a device object is given, itself is returned.
8
+ # @return [Chainer::AbstractDevice] device object
9
+ def create(device_spec)
10
+ return device_spec if device_spec.kind_of?(AbstractDevice)
11
+ if device_spec.kind_of?(Integer)
12
+ return CpuDevice.new if device_spec < 0
13
+ return GpuDevice.new(device_spec)
14
+ end
15
+ raise "Invalid device_spec: #{device_spec}"
16
+ end
17
+ module_function :create
18
+
19
+ # Changes default device
20
+ #
21
+ # @param [Object] device_spec
22
+ # @see Chainer::Device.create
23
+ def change_default(device_spec)
24
+ @default = create(device_spec)
25
+ @default.use
26
+ end
27
+ module_function :change_default
28
+
29
+ # Gets default device
30
+ #
31
+ # @return [Chainer::AbstractDevice] the default device.
32
+ def default
33
+ @default ||= CpuDevice.new
34
+ end
35
+ module_function :default
36
+
37
+ # TODO(sonots): Add Device.from_array after Cumo provides an API
38
+ # to return GPU device ID from Cumo::NArray.
39
+ end
40
+
41
+ class AbstractDevice
42
+ def xm
43
+ raise NotImplementedError
44
+ end
45
+
46
+ def use
47
+ end
48
+ end
49
+
50
+ class CpuDevice < AbstractDevice
51
+ def xm
52
+ Numo
53
+ end
54
+
55
+ def ==(other)
56
+ return false unless other.is_a?(CpuDevice)
57
+ true
58
+ end
59
+ end
60
+
61
+ class GpuDevice < AbstractDevice
62
+ attr_reader :id
63
+
64
+ # @param [Integer] id GPU Device ID. If not given, CUDA current device id is used.
65
+ def initialize(id = nil)
66
+ Chainer::CUDA.check_available
67
+ id ||= Cumo::CUDA::Runtime.cudaGetDevice
68
+ if id < 0
69
+ raise 'GPU Device ID must not be negative'
70
+ end
71
+ @id = id
72
+ end
73
+
74
+ def xm
75
+ Cumo
76
+ end
77
+
78
+ def ==(other)
79
+ return false unless other.is_a?(GpuDevice)
80
+ id == other.id
81
+ end
82
+
83
+ # Sets CUDA current device with owned GPU Device ID
84
+ def use
85
+ Cumo::CUDA::Runtime.cudaSetDevice(@id)
86
+ end
87
+ end
88
+ end
@@ -1,70 +1,89 @@
1
+ require 'chainer/function_node'
1
2
  module Chainer
2
3
  class Function
3
4
 
4
5
  attr_reader :rank, :inputs, :outputs, :retain_after_backward
5
- attr_accessor :output_data
6
+ attr_accessor :output_data, :owned_node
6
7
 
7
8
  def initialize
8
9
  @rank = 0
9
10
  end
10
11
 
11
12
  def call(*inputs)
12
- inputs = inputs.map do |x|
13
- if x.kind_of?(Chainer::Variable)
14
- x
15
- else
16
- Variable.new(x, requires_grad: false)
17
- end
18
- end
13
+ node = self.node
19
14
 
20
- in_data = inputs.map(&:data)
21
- requires_grad = inputs.any?(&:requires_grad)
15
+ node.function = self
16
+ node.weak_function = nil
17
+ @node = WeakRef.new(node)
18
+ @owned_node = nil
22
19
 
23
- @input_indexes_to_retain = nil
24
- @output_indexes_to_retain = nil
25
- outputs = forward(in_data)
26
- raise if !outputs.is_a? Array
20
+ ret = node.apply(inputs)
27
21
 
28
- ret = outputs.map do |y|
29
- Variable.new(y, requires_grad: requires_grad)
30
- end
22
+ ret.size == 1 ? ret[0] : ret
23
+ end
31
24
 
32
- if Chainer.configuration.enable_backprop
33
- @rank = inputs.map(&:rank).max || 0
25
+ def inputs
26
+ @node.inputs
27
+ end
34
28
 
35
- ret.each { |y| y.creator = self }
29
+ def outputs
30
+ @node.outputs
31
+ end
36
32
 
37
- @inputs = inputs.map(&:node)
38
- @outputs = ret.map { |y| WeakRef.new(y.node) }
33
+ def node
34
+ noderef = @node
35
+ nd = noderef ? noderef.__getobj__ : @owned_node
36
+ return nd if nd
39
37
 
40
- @input_indexes_to_retain = 0...inputs.size if @input_indexes_to_retain.nil?
41
- @input_indexes_to_retain.each do |index|
42
- inputs[index].retain_data()
43
- end
44
- remove_instance_variable(:@input_indexes_to_retain)
38
+ nd = FunctionAdapter.new(self)
39
+ @owned_node = nd
40
+ nd
41
+ end
45
42
 
46
- unless @output_indexes_to_retain.nil?
47
- @output_indexes_to_retain.each do |index|
48
- ret[index].retain_data()
49
- end
50
- remove_instance_variable(:@output_indexes_to_retain)
51
- end
52
- end
43
+ def output_data
44
+ node.output_data
45
+ end
53
46
 
54
- ret.size == 1 ? ret[0] : ret
47
+ def rank
48
+ @node.rank
49
+ end
50
+
51
+ def label
52
+ self.class.to_s
55
53
  end
56
54
 
57
55
  def forward(inputs)
58
- # TODO: GPU branch processing
59
- forward_cpu(inputs)
56
+ xm = Chainer.get_array_module(*inputs)
57
+ if xm == Cumo
58
+ forward_gpu(inputs)
59
+ else
60
+ forward_cpu(inputs)
61
+ end
60
62
  end
61
63
 
62
64
  def forward_cpu(inputs)
63
65
  raise NotImplementedError
64
66
  end
65
67
 
68
+ def forward_gpu(inputs)
69
+ raise NotImplementedError
70
+ end
71
+
66
72
  def backward(inputs, grad_outputs)
67
- backward_cpu(inputs, grad_outputs)
73
+ xm = Chainer.get_array_module(*(inputs + grad_outputs))
74
+ if xm == Cumo
75
+ backward_gpu(inputs, grad_outputs)
76
+ else
77
+ backward_cpu(inputs, grad_outputs)
78
+ end
79
+ end
80
+
81
+ def backward_cpu(inputs, grad_outputs)
82
+ return [nil] * inputs.size
83
+ end
84
+
85
+ def backward_gpu(inputs, grad_outputs)
86
+ return [nil] * inputs.size
68
87
  end
69
88
 
70
89
  def retain_inputs(indexes)
@@ -72,10 +91,53 @@ module Chainer
72
91
  end
73
92
 
74
93
  def retain_outputs(indexes, retain_after_backward: false)
75
- @output_indexes_to_retain = indexes
76
- if retain_after_backward
77
- @retain_after_backward = retain_after_backward
94
+ node.retain_outputs(indexes)
95
+ end
96
+ end
97
+
98
+ class FunctionAdapter < ::Chainer::FunctionNode
99
+ attr_accessor :function, :weak_function
100
+
101
+ def initialize(function)
102
+ super()
103
+ @weak_function = WeakRef.new(function)
104
+ function.owned_node = self
105
+ end
106
+
107
+ def function
108
+ func = @function
109
+ return func if func
110
+
111
+ weak_func = @weak_function
112
+ weak_func.__getobj__
113
+ end
114
+
115
+ def label
116
+ @function.label
117
+ end
118
+
119
+ def forward(inputs)
120
+ retain_inputs(inputs.size.times.to_a)
121
+ @function.forward(inputs)
122
+ end
123
+
124
+ def backward(target_input_indexes, grad_outputs)
125
+ in_data = @inputs.map { |input| input.data }
126
+ grad_out_data = grad_outputs.map { |grad| grad.nil? ? nil : grad.data }
127
+
128
+ gxs = @function.backward(in_data, grad_out_data)
129
+ ret = []
130
+ target_input_indexes.each do |i|
131
+ if gxs[i].nil?
132
+ g = nil
133
+ else
134
+ g = Chainer::Variable.new(gxs[i])
135
+ g.node.old_style_grad_generator = @function.label
136
+ end
137
+ ret << g
78
138
  end
139
+
140
+ ret
79
141
  end
80
142
  end
81
143
  end
@@ -0,0 +1,454 @@
1
+ # Function node of the computational graph.
2
+ # FunctionNode is a class representing a node in a computational graph.
3
+ # The node corresponds to an application of a differentiable function to input variables.
4
+ # When a differentiable function is applied to `Chainer::Variable` objects,
5
+ # it creates an instance of FunctionNode implementation and calls its `apply` method.
6
+ # The `apply` method basically does the following three things.
7
+ # 1. Adding an edge from the function node to the variable node corresponding to each input.
8
+ # The node of each input is extracted by `Chainer::`Variable.node`.
9
+ # 2. Computing the output arrays of the function.
10
+ # 3. Creating a :class:`Variable` object for each output array and
11
+ # adding an edge from the node of the variable to the function node.
12
+ # The output variables are then returned.
13
+ module Chainer
14
+ class FunctionNode
15
+ attr_accessor :rank, :inputs, :outputs
16
+
17
+ def initialize
18
+ @rank = 0
19
+ @inputs = nil
20
+ @outputs = nil
21
+
22
+ @retained_output_data = nil
23
+ @input_indexes_to_retain = nil
24
+ @output_indexes_to_retain = nil
25
+ end
26
+
27
+ # Short text that represents the function.
28
+ #
29
+ # The default implementation returns its type name.
30
+ # Each function should override it to give more information.
31
+ def label
32
+ self.class.name
33
+ end
34
+
35
+ # A tuple of the retained output arrays.
36
+ # This property is mainly used by $Function$. Users basically do
37
+ # not have to use this property; use $get_retained_outputs$ instead.
38
+ def output_data
39
+ raise RuntimeError, 'retained output data is gone' if @retained_output_data.nil?
40
+ out_data = [nil] * @outputs.size
41
+ @output_indexes_to_retain.zip(@retained_output_data).each do |index, data|
42
+ out_data[index] = data
43
+ end
44
+ out_data
45
+ end
46
+
47
+ # Computes output variables and grows the computational graph.
48
+ #
49
+ # Basic behavior is expressed in the documentation of `FunctionNode`.
50
+ # @param [Chainer::Variable, Numo::NArray] inputs If the element is an Numo::NArray,
51
+ # it is automatically wrapped with `Chainer::Variable`.
52
+ # @return [Array<Chainer::Variable>] A tuple of output `Chainer::Variable` objectts.
53
+ def apply(inputs)
54
+ input_vars = inputs.map { |x| Chainer::Variable.as_variable(x) }
55
+ in_data = input_vars.map(&:data)
56
+ requires_grad = input_vars.map(&:requires_grad).any?
57
+
58
+ # Forward propagation
59
+ @input_indexes_to_retain = nil
60
+ @output_indexes_to_retain = nil
61
+ outputs = forward(in_data)
62
+ raise TypeError, "#{outputs.class} not Array" unless outputs.is_a?(Array)
63
+
64
+ ret = outputs.map { |y| Chainer::Variable.new(y, requires_grad: requires_grad) }
65
+
66
+ if Chainer.configuration.enable_backprop
67
+ # Topological ordering
68
+ @rank = input_vars.size > 0 ? input_vars.map(&:rank).max : 0
69
+
70
+ # Add backward edges
71
+ ret.each { |y| y.creator_node = self }
72
+ @inputs = input_vars.map(&:node)
73
+ # Add forward edges (must be weak references)
74
+ @outputs = ret.map { |y| WeakRef.new(y.node) }
75
+
76
+ unless @input_indexes_to_retain.nil?
77
+ @input_indexes_to_retain.each do |index|
78
+ input_vars[index].retain_data
79
+ end
80
+ end
81
+
82
+ unless @output_indexes_to_retain.nil?
83
+ retained_data = []
84
+ @output_indexes_to_retain.each do |index|
85
+ ret[index].retain_data
86
+ retained_data << outputs[index]
87
+ end
88
+ @retained_output_data = Array(retained_data)
89
+ end
90
+ end
91
+
92
+ ret
93
+ end
94
+
95
+ # Computes the output arrays from the input arrays.
96
+ #
97
+ # @param [Array] inputs input array(s)
98
+ # @return [Array] output array(s)
99
+ def forward(inputs)
100
+ raise TypeError, "mustt inputs > 0, inputs size is #{inputs.size}" if inputs.size.zero?
101
+ # TODO GPU
102
+ forward_cpu(inputs)
103
+ end
104
+
105
+ # Computes the output arrays from the input Numo::NArray.
106
+ #
107
+ # @param [Array<Numo::NArray>] inputs Numo::NArray objects.
108
+ # @return [Array<Numo::NArray>] Array of output arrays.
109
+ def forward_cpu(inputs)
110
+ raise NotImplementedError
111
+ end
112
+
113
+ # Lets specified input variable nodes keep data arrays.
114
+ #
115
+ # By calling this method from `forward`, the function node can specify which inputs are required for backprop.
116
+ # The input variables with retained arrays can be obtained by `get_retained_inputs` from `backward`.
117
+ #
118
+ # Note that **this method must not be called from the outside of forward method.**
119
+ # @param [Integer, Array] indexes Indexes of input variables that the function does not require for backprop.
120
+ def retain_inputs(indexes)
121
+ @input_indexes_to_retain = indexes
122
+ end
123
+
124
+ # Lets specified output variable nodes keep data arrays.
125
+ #
126
+ # By calling this method from `forward`, the function node can specify which outputs are required for backprop.
127
+ # If this method is not called, any output variables are not marked to keep the data array at the point of returning from `apply`.
128
+ # The output variables with retained arrays can be obtained by `get_retained_outputs` from `backward`.
129
+ # Note that **this method must not be called from the outside of forward method.**
130
+ # @param [Integer, Array] indexes Indexes of input variables that the function does not require for backprop.
131
+ def retain_outputs(indexes)
132
+ @output_indexes_to_retain = indexes
133
+ end
134
+
135
+ # Computes gradients w.r.t. specified inputs given output gradients.
136
+ #
137
+ # This method is used to compute one step of the backpropagation corresponding to the forward computation of this function node.
138
+ # Given the gradients w.r.t. output variables, this method computes the gradients w.r.t. specified input variables.
139
+ # Note that this method does not need to compute any input gradients not specified by `target_input_indexes`
140
+ # It enables the function node to return the input gradients with the full computational history,
141
+ # in which case it supports *differentiable backpropagation* or *higher-order differentiation*.
142
+ #
143
+ # @param [Array<Integer>] target_indexes Indices of the input variables w.r.t. which the gradients are required.
144
+ # It is guaranteed that this tuple contains at least one element.
145
+ # @param [Array<Chainer::Variable>] grad_outputs Gradients w.r.t. the output variables.
146
+ # If the gradient w.r.t. an output variable is not given, the corresponding element is `None`.
147
+ # @return [Array<Chainer::Variable>] Array of Chainer::Variable that represent the gradients.
148
+ def backward(target_indexes, grad_outputs)
149
+ [nil] * target_indexes.size
150
+ end
151
+
152
+ # Computes gradients w.r.t. specified inputs and accumulates them.
153
+ #
154
+ # This method provides a way to fuse the backward computation and the gradient accumulations
155
+ # in the case that the multiple functions are applied to the same variable.
156
+ # Users have to override either of this method or `backward`.
157
+ # It is often simpler to implement `backward` and is recommended if you do not need to provide efficient gradient accumulation.
158
+ #
159
+ # @param [Array<Integer>] target_input_indexes Indices of the input variables w.r.t. which the gradients are required.
160
+ # It is guaranteed that this tuple contains at least one element.
161
+ # @param [Array<Chainer::Variable>] grad_outputs Gradients w.r.t. the output variables.
162
+ # If the gradient w.r.t. an output variable is not given, the corresponding element is `None`.
163
+ # @param [Array<Chainer::Variable>] grad_inputs Gradients w.r.t. the input variables specified by `target_input_indexes`.
164
+ # These values are computed by other computation paths.
165
+ # If there is no gradient value existing for the variable, the corresponding element is ``None``.
166
+ # @return [Array<Chainer::Variable>] Array of variables that represent the gradients w.r.t. specified input variables.
167
+ def backward_accumulate(target_input_indexes, grad_outputs, grad_inputs)
168
+ gxs = backward(target_input_indexes, grad_outputs)
169
+
170
+ len_gxs = gxs.size
171
+ if len_gxs == @inputs.size
172
+ gxs = target_input_indexes.map { |i| gxs[i] }
173
+ elsif len_gxs != target_input_indexes.size
174
+ raise ArgumentError, "number of gradients returned by #{impl_name} (#{label}) is incorrect."
175
+ end
176
+
177
+ gxs.zip(grad_inputs).map do |gx, g_input|
178
+ if g_input.nil?
179
+ gx
180
+ elsif gx.nil?
181
+ g_input
182
+ else
183
+ gx + g_input
184
+ end
185
+ end
186
+ end
187
+
188
+ # Returns a Array of retained input variables.
189
+ #
190
+ # This method is used to retrieve the input variables retained in `forward`.
191
+ #
192
+ # @return [Array] a Array of retained input variables.
193
+ def get_retained_inputs
194
+ @input_indexes_to_retain.map { |index| @inputs[index].get_variable }
195
+ end
196
+
197
+ # Returns a Array of retained output variables.
198
+ #
199
+ # This method is used to retrieve the input variables retained in `forward`.
200
+ #
201
+ # @return [Array] a Array of retained input variables.
202
+ def get_retained_outputs
203
+ ret = []
204
+ outputs = @outputs
205
+
206
+ new_outputs = outputs.dup
207
+ outputs_modified = false
208
+
209
+ @output_indexes_to_retain.zip(@retained_output_data) do |index, data|
210
+ output = outputs[index].__getobj__
211
+ if output.nil?
212
+ output_var = Chainer::Variable.new(data)
213
+ output_var.creator_node = self
214
+ new_outputs[index] = WeakRef.new(output_var)
215
+ outputs_modified = true
216
+ else
217
+ output_var = output.get_variable
218
+ end
219
+
220
+ ret << output_var
221
+ end
222
+
223
+ if outputs_modified
224
+ @outputs = Array(new_outputs)
225
+ end
226
+
227
+ ret
228
+ end
229
+
230
+ # Purges in/out nodes and this function node itself from the graph.
231
+ def unchain
232
+ @outputs.each do |y|
233
+ y_ref = y.()
234
+ unless y_ref.nil?
235
+ y_ref.unchain
236
+ end
237
+ end
238
+ @inputs = nil
239
+ end
240
+
241
+ private
242
+
243
+ def impl_name
244
+ self.class.name
245
+ end
246
+ end
247
+
248
+ def self.grad(outputs, inputs, grad_outputs: nil, grad_inputs: nil, set_grad: false, retain_grad: false, enable_double_backprop: false)
249
+ # The implementation consists of three steps.
250
+
251
+ if !outputs.is_a?(Array)
252
+ raise TypeError, "outputs must be Array, not #{outputs.class}"
253
+ end
254
+ if !inputs.is_a?(Array)
255
+ raise TypeError, "inputs must be Array, not #{inputs.class}"
256
+ end
257
+ if !grad_outputs.nil? && !grad_outputs.is_a?(Array)
258
+ raise TypeError, "grad_outputs must be Array, not #{grad_outputs.class}"
259
+ end
260
+ if !grad_inputs.nil? && !grad_inputs.is_a?(Array)
261
+ raise TypeError, "grad_inputs must be Array, not #{grad_inputs.class}"
262
+ end
263
+
264
+ # 1. Backward enumeration: all the nodes reachable backward from the output
265
+ # nodes are enumerated. The forward direction links are collected in
266
+ # this step. Note that the variable nodes whose requires_grad is false
267
+ # are ignored and their creators are not searched.
268
+ candidate_funcs = outputs.map(&:creator_node).compact
269
+ visited_funcs = Set.new
270
+ forward_graph = {}
271
+
272
+ while func = candidate_funcs.pop
273
+ next if visited_funcs.include?(func)
274
+ visited_funcs.add(func)
275
+
276
+ func.inputs.each do |x|
277
+ next unless x.requires_grad
278
+ forward_graph[x] = [] if forward_graph[x].nil?
279
+ forward_graph[x] << func
280
+ creator = x.creator_node
281
+ if creator && !visited_funcs.include?(creator)
282
+ candidate_funcs << creator
283
+ end
284
+ end
285
+ end
286
+
287
+ # 2. Forward enumeration: all the nodes in the subgraph reachable from the
288
+ # input nodes are enumerated. The extracted (sub-)subgraph is the union
289
+ # of all paths that backpropagation will visit.
290
+ candidate_vars = inputs.map(&:node)
291
+ visited_funcs = Set.new
292
+ grad_required = Set.new
293
+ while x = candidate_vars.pop
294
+ grad_required.add(x)
295
+ forward_graph[x].each do |func|
296
+ next if visited_funcs.include?(func)
297
+ visited_funcs.add(func)
298
+ func.outputs.each do |y_ref|
299
+ y = y_ref.__getobj__
300
+ if y && forward_graph[y]
301
+ candidate_vars << y
302
+ end
303
+ end
304
+ end
305
+ end
306
+
307
+ # 3. Backpropagation: the backpropagation is executed along the
308
+ # (sub-)subgraph. It uses the topological order of the subgraph which is
309
+ # induced by the reversed order of function applications ("rank").
310
+ grads = {} # mapping from variable nodes to their gradients
311
+
312
+ # Initialize the gradient mapping.
313
+ grad_outputs = [nil] * outputs.size if grad_outputs.nil?
314
+ outputs.zip(grad_outputs).each do |y, gy|
315
+ if gy.nil?
316
+ gy_data = y.data.new_ones
317
+ gy = Chainer::Variable.new(gy_data, requires_grad: false)
318
+ end
319
+
320
+ grads[y.node] = gy
321
+ end
322
+
323
+ unless grad_inputs.nil?
324
+ inputs.zip(grad_inputs).each do |x, gx|
325
+ grads[x.node] = gx unless gx.nil?
326
+ end
327
+ end
328
+
329
+ # Backprop implementation. It edits grads which will only contain the
330
+ # gradients w.r.t. the inputs.
331
+ old_enable_backprop = Chainer.configuration.enable_backprop
332
+ Chainer.configuration.enable_backprop = enable_double_backprop
333
+ backprop(outputs, inputs, grad_required, retain_grad, grads)
334
+ Chainer.configuration.enable_backprop = old_enable_backprop
335
+
336
+ # Extract the gradients w.r.t. the inputs and return them.
337
+ ret = inputs.map { |x| grads[x.node] }
338
+ if set_grad
339
+ inputs.zip(ret).each do |x, gx|
340
+ x.grad_var = gx
341
+ end
342
+ end
343
+
344
+ ret
345
+ end
346
+
347
+ def self.backprop(outputs, inputs, grad_required, retain_grad, grads)
348
+ candidate_funcs = []
349
+ visited_funcs = Set.new
350
+
351
+ push_candidate = -> (func) do
352
+ return if visited_funcs.include?(func)
353
+
354
+ # Negate since heapq is min-heap
355
+ # The second element is used to make each item unique
356
+ visited_funcs.add(func)
357
+ candidate_funcs.unshift(func)
358
+ candidate_funcs.sort_by! { |f| f.rank }
359
+ end
360
+
361
+ pop_candidate = -> () do
362
+ candidate_funcs.pop
363
+ end
364
+
365
+ outputs.each do |y|
366
+ creator = y.creator_node
367
+ next if creator.nil?
368
+ push_candidate.(creator)
369
+ end
370
+
371
+ input_nodes = Set.new(inputs.map(&:node))
372
+
373
+ while func = pop_candidate.()
374
+ # Collect the gradients w.r.t. the outputs
375
+ gys = []
376
+
377
+ func.outputs.each do |y_ref|
378
+ y = y_ref.__getobj__
379
+ if y.nil?
380
+ gys << nil
381
+ next
382
+ end
383
+ gys << grads[y]
384
+ end
385
+
386
+ # Collect the gradients w.r.t. the inputs
387
+ #
388
+ # Note (Tokui): when the same variable is passed multiple times as
389
+ # inputs in the same function (e.g. an expression like f(x, x)), the
390
+ # current implementation passes None as the current gradient w.r.t.
391
+ # such an input except for the first one (i.e., it builds gxs like
392
+ # (gx, None) where gx is the current gradient w.r.t. x).
393
+ gxs = []
394
+ input_indexes = []
395
+ selected_inputs = Set.new
396
+ func.inputs.each_with_index do |x, i|
397
+ next unless grad_required.include?(x)
398
+
399
+ input_indexes << i
400
+ if selected_inputs.include?(x)
401
+ gxs << nil
402
+ else
403
+ gxs << grads[x]
404
+ selected_inputs.add(x)
405
+ end
406
+ end
407
+
408
+ next if input_indexes.empty?
409
+
410
+ # Do backward
411
+ new_gxs = func.backward_accumulate(input_indexes, gys, gxs)
412
+
413
+ # Delete output gradients that are not required to return
414
+ func.outputs.each do |y_ref|
415
+ y = y_ref.__getobj__
416
+ if y && grads[y] && !input_nodes.include?(y)
417
+ grads.delete(y)
418
+ end
419
+ end
420
+
421
+ # Update grads
422
+ selected_inputs = Set.new
423
+ input_indexes.zip(new_gxs).each do |i, g|
424
+ next if g.nil?
425
+
426
+ node = func.inputs[i]
427
+ if selected_inputs.include?(node)
428
+ # Accumulate the duplicated gradients here
429
+ cur_gx = grads[node]
430
+ if cur_gx
431
+ g = g + cur_gx
432
+ end
433
+ else
434
+ selected_inputs.add(node)
435
+ end
436
+
437
+ grads[node] = g
438
+
439
+ if retain_grad
440
+ v = node.get_variable
441
+ if v
442
+ v.grad_var = g
443
+ end
444
+ end
445
+
446
+ creator = node.creator_node
447
+ if creator
448
+ push_candidate.(creator)
449
+ end
450
+ end
451
+ end
452
+ end
453
+ private_class_method :backprop
454
+ end