red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -11,7 +11,7 @@ module Chainer
|
|
11
11
|
get_cifar(100, with_label, ndim, scale)
|
12
12
|
end
|
13
13
|
|
14
|
-
def self.get_cifar(n_classes, with_label, ndim, scale)
|
14
|
+
def self.get_cifar(n_classes, with_label, ndim, scale, device: Chainer::Device.default)
|
15
15
|
train_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).to_table
|
16
16
|
test_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).to_table
|
17
17
|
|
@@ -25,13 +25,14 @@ module Chainer
|
|
25
25
|
test_labels = test_table[:fine_label]
|
26
26
|
end
|
27
27
|
|
28
|
+
xm = device.xm
|
28
29
|
[
|
29
|
-
preprocess_cifar(
|
30
|
-
preprocess_cifar(
|
30
|
+
preprocess_cifar(xm::UInt8[*train_data], xm::UInt8[*train_labels], with_label, ndim, scale),
|
31
|
+
preprocess_cifar(xm::UInt8[*test_data], xm::UInt8[*test_labels], with_label, ndim, scale)
|
31
32
|
]
|
32
33
|
end
|
33
34
|
|
34
|
-
def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
|
35
|
+
def self.preprocess_cifar(images, labels, withlabel, ndim, scale, device: Chainer::Device.default)
|
35
36
|
if ndim == 1
|
36
37
|
images = images.reshape(images.shape[0], 3072)
|
37
38
|
elsif ndim == 3
|
@@ -39,11 +40,12 @@ module Chainer
|
|
39
40
|
else
|
40
41
|
raise 'invalid ndim for CIFAR dataset'
|
41
42
|
end
|
42
|
-
|
43
|
+
xm = device.xm
|
44
|
+
images = images.cast_to(xm::SFloat)
|
43
45
|
images *= scale / 255.0
|
44
46
|
|
45
47
|
if withlabel
|
46
|
-
labels = labels.cast_to(
|
48
|
+
labels = labels.cast_to(xm::Int32)
|
47
49
|
TupleDataset.new(images, labels)
|
48
50
|
else
|
49
51
|
images
|
@@ -1,13 +1,17 @@
|
|
1
|
-
require '
|
1
|
+
require 'datasets'
|
2
2
|
|
3
3
|
module Chainer
|
4
4
|
module Datasets
|
5
|
-
module
|
6
|
-
def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype:
|
7
|
-
|
5
|
+
module MNIST
|
6
|
+
def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil)
|
7
|
+
xm = Chainer::Device.default.xm
|
8
|
+
dtype ||= xm::SFloat
|
9
|
+
label_dtype ||= xm::Int32
|
10
|
+
|
11
|
+
train_raw = retrieve_mnist(type: :train)
|
8
12
|
train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)
|
9
13
|
|
10
|
-
test_raw =
|
14
|
+
test_raw = retrieve_mnist(type: :test)
|
11
15
|
test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, label_dtype)
|
12
16
|
[train, test]
|
13
17
|
end
|
@@ -24,7 +28,7 @@ module Chainer
|
|
24
28
|
|
25
29
|
images = images.cast_to(image_dtype)
|
26
30
|
images *= scale / 255.0
|
27
|
-
|
31
|
+
|
28
32
|
if withlabel
|
29
33
|
labels = raw[:y].cast_to(label_dtype)
|
30
34
|
TupleDataset.new(images, labels)
|
@@ -33,56 +37,11 @@ module Chainer
|
|
33
37
|
end
|
34
38
|
end
|
35
39
|
|
36
|
-
def self.
|
37
|
-
|
38
|
-
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz']
|
39
|
-
retrieve_mnist('train.npz', urls)
|
40
|
-
end
|
41
|
-
|
42
|
-
def self.retrieve_mnist_test
|
43
|
-
urls = ['http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
44
|
-
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz']
|
45
|
-
retrieve_mnist('test.npz', urls)
|
46
|
-
end
|
47
|
-
|
48
|
-
def self.retrieve_mnist(name, urls)
|
49
|
-
root = Chainer::Dataset::Download.get_dataset_directory('pfnet/chainer/mnist')
|
50
|
-
path = File.expand_path(name, root)
|
51
|
-
Chainer::Dataset::Download.cache_or_load_file(path) do
|
52
|
-
make_npz(path, urls)
|
53
|
-
end
|
54
|
-
end
|
55
|
-
|
56
|
-
def self.make_npz(path, urls)
|
57
|
-
x_url, y_url = urls
|
58
|
-
x_path = Chainer::Dataset::Download.cached_download(x_url)
|
59
|
-
y_path = Chainer::Dataset::Download.cached_download(y_url)
|
60
|
-
|
61
|
-
x = nil
|
62
|
-
y = nil
|
63
|
-
|
64
|
-
Zlib::GzipReader.open(x_path) do |fx|
|
65
|
-
Zlib::GzipReader.open(y_path) do |fy|
|
66
|
-
fx.read(4)
|
67
|
-
fy.read(4)
|
68
|
-
|
69
|
-
n = fx.read(4).unpack('i>')[0]
|
70
|
-
fy.read(4)
|
71
|
-
fx.read(8)
|
72
|
-
|
73
|
-
x = Numo::UInt8.new(n, 784).rand(n)
|
74
|
-
y = Numo::UInt8.new(n).rand(n)
|
75
|
-
|
76
|
-
n.times do |i|
|
77
|
-
y[i] = fy.read(1).ord
|
78
|
-
784.times do |j|
|
79
|
-
x[i, j] = fx.read(1).ord
|
80
|
-
end
|
81
|
-
end
|
82
|
-
end
|
83
|
-
end
|
40
|
+
def self.retrieve_mnist(type:)
|
41
|
+
train_table = ::Datasets::MNIST.new(type: type).to_table
|
84
42
|
|
85
|
-
|
43
|
+
xm = Chainer::Device.default.xm
|
44
|
+
{ x: xm::UInt8[*train_table[:pixels]], y: xm::UInt8[*train_table[:label]] }
|
86
45
|
end
|
87
46
|
end
|
88
47
|
end
|
@@ -0,0 +1,88 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Device
|
3
|
+
# Creates device
|
4
|
+
#
|
5
|
+
# @param [Integer or Chainer::AbstractDevice] device_spec Device specifier.
|
6
|
+
# Negative integer indicates CPU. 0 or positive integer indicates GPU.
|
7
|
+
# If a device object is given, itself is returned.
|
8
|
+
# @return [Chainer::AbstractDevice] device object
|
9
|
+
def create(device_spec)
|
10
|
+
return device_spec if device_spec.kind_of?(AbstractDevice)
|
11
|
+
if device_spec.kind_of?(Integer)
|
12
|
+
return CpuDevice.new if device_spec < 0
|
13
|
+
return GpuDevice.new(device_spec)
|
14
|
+
end
|
15
|
+
raise "Invalid device_spec: #{device_spec}"
|
16
|
+
end
|
17
|
+
module_function :create
|
18
|
+
|
19
|
+
# Changes default device
|
20
|
+
#
|
21
|
+
# @param [Object] device_spec
|
22
|
+
# @see Chainer::Device.create
|
23
|
+
def change_default(device_spec)
|
24
|
+
@default = create(device_spec)
|
25
|
+
@default.use
|
26
|
+
end
|
27
|
+
module_function :change_default
|
28
|
+
|
29
|
+
# Gets default device
|
30
|
+
#
|
31
|
+
# @return [Chainer::AbstractDevice] the default device.
|
32
|
+
def default
|
33
|
+
@default ||= CpuDevice.new
|
34
|
+
end
|
35
|
+
module_function :default
|
36
|
+
|
37
|
+
# TODO(sonots): Add Device.from_array after Cumo provides an API
|
38
|
+
# to return GPU device ID from Cumo::NArray.
|
39
|
+
end
|
40
|
+
|
41
|
+
class AbstractDevice
|
42
|
+
def xm
|
43
|
+
raise NotImplementedError
|
44
|
+
end
|
45
|
+
|
46
|
+
def use
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
class CpuDevice < AbstractDevice
|
51
|
+
def xm
|
52
|
+
Numo
|
53
|
+
end
|
54
|
+
|
55
|
+
def ==(other)
|
56
|
+
return false unless other.is_a?(CpuDevice)
|
57
|
+
true
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
class GpuDevice < AbstractDevice
|
62
|
+
attr_reader :id
|
63
|
+
|
64
|
+
# @param [Integer] id GPU Device ID. If not given, CUDA current device id is used.
|
65
|
+
def initialize(id = nil)
|
66
|
+
Chainer::CUDA.check_available
|
67
|
+
id ||= Cumo::CUDA::Runtime.cudaGetDevice
|
68
|
+
if id < 0
|
69
|
+
raise 'GPU Device ID must not be negative'
|
70
|
+
end
|
71
|
+
@id = id
|
72
|
+
end
|
73
|
+
|
74
|
+
def xm
|
75
|
+
Cumo
|
76
|
+
end
|
77
|
+
|
78
|
+
def ==(other)
|
79
|
+
return false unless other.is_a?(GpuDevice)
|
80
|
+
id == other.id
|
81
|
+
end
|
82
|
+
|
83
|
+
# Sets CUDA current device with owned GPU Device ID
|
84
|
+
def use
|
85
|
+
Cumo::CUDA::Runtime.cudaSetDevice(@id)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
data/lib/chainer/function.rb
CHANGED
@@ -1,70 +1,89 @@
|
|
1
|
+
require 'chainer/function_node'
|
1
2
|
module Chainer
|
2
3
|
class Function
|
3
4
|
|
4
5
|
attr_reader :rank, :inputs, :outputs, :retain_after_backward
|
5
|
-
attr_accessor :output_data
|
6
|
+
attr_accessor :output_data, :owned_node
|
6
7
|
|
7
8
|
def initialize
|
8
9
|
@rank = 0
|
9
10
|
end
|
10
11
|
|
11
12
|
def call(*inputs)
|
12
|
-
|
13
|
-
if x.kind_of?(Chainer::Variable)
|
14
|
-
x
|
15
|
-
else
|
16
|
-
Variable.new(x, requires_grad: false)
|
17
|
-
end
|
18
|
-
end
|
13
|
+
node = self.node
|
19
14
|
|
20
|
-
|
21
|
-
|
15
|
+
node.function = self
|
16
|
+
node.weak_function = nil
|
17
|
+
@node = WeakRef.new(node)
|
18
|
+
@owned_node = nil
|
22
19
|
|
23
|
-
|
24
|
-
@output_indexes_to_retain = nil
|
25
|
-
outputs = forward(in_data)
|
26
|
-
raise if !outputs.is_a? Array
|
20
|
+
ret = node.apply(inputs)
|
27
21
|
|
28
|
-
ret
|
29
|
-
|
30
|
-
end
|
22
|
+
ret.size == 1 ? ret[0] : ret
|
23
|
+
end
|
31
24
|
|
32
|
-
|
33
|
-
|
25
|
+
def inputs
|
26
|
+
@node.inputs
|
27
|
+
end
|
34
28
|
|
35
|
-
|
29
|
+
def outputs
|
30
|
+
@node.outputs
|
31
|
+
end
|
36
32
|
|
37
|
-
|
38
|
-
|
33
|
+
def node
|
34
|
+
noderef = @node
|
35
|
+
nd = noderef ? noderef.__getobj__ : @owned_node
|
36
|
+
return nd if nd
|
39
37
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
remove_instance_variable(:@input_indexes_to_retain)
|
38
|
+
nd = FunctionAdapter.new(self)
|
39
|
+
@owned_node = nd
|
40
|
+
nd
|
41
|
+
end
|
45
42
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
end
|
50
|
-
remove_instance_variable(:@output_indexes_to_retain)
|
51
|
-
end
|
52
|
-
end
|
43
|
+
def output_data
|
44
|
+
node.output_data
|
45
|
+
end
|
53
46
|
|
54
|
-
|
47
|
+
def rank
|
48
|
+
@node.rank
|
49
|
+
end
|
50
|
+
|
51
|
+
def label
|
52
|
+
self.class.to_s
|
55
53
|
end
|
56
54
|
|
57
55
|
def forward(inputs)
|
58
|
-
|
59
|
-
|
56
|
+
xm = Chainer.get_array_module(*inputs)
|
57
|
+
if xm == Cumo
|
58
|
+
forward_gpu(inputs)
|
59
|
+
else
|
60
|
+
forward_cpu(inputs)
|
61
|
+
end
|
60
62
|
end
|
61
63
|
|
62
64
|
def forward_cpu(inputs)
|
63
65
|
raise NotImplementedError
|
64
66
|
end
|
65
67
|
|
68
|
+
def forward_gpu(inputs)
|
69
|
+
raise NotImplementedError
|
70
|
+
end
|
71
|
+
|
66
72
|
def backward(inputs, grad_outputs)
|
67
|
-
|
73
|
+
xm = Chainer.get_array_module(*(inputs + grad_outputs))
|
74
|
+
if xm == Cumo
|
75
|
+
backward_gpu(inputs, grad_outputs)
|
76
|
+
else
|
77
|
+
backward_cpu(inputs, grad_outputs)
|
78
|
+
end
|
79
|
+
end
|
80
|
+
|
81
|
+
def backward_cpu(inputs, grad_outputs)
|
82
|
+
return [nil] * inputs.size
|
83
|
+
end
|
84
|
+
|
85
|
+
def backward_gpu(inputs, grad_outputs)
|
86
|
+
return [nil] * inputs.size
|
68
87
|
end
|
69
88
|
|
70
89
|
def retain_inputs(indexes)
|
@@ -72,10 +91,53 @@ module Chainer
|
|
72
91
|
end
|
73
92
|
|
74
93
|
def retain_outputs(indexes, retain_after_backward: false)
|
75
|
-
|
76
|
-
|
77
|
-
|
94
|
+
node.retain_outputs(indexes)
|
95
|
+
end
|
96
|
+
end
|
97
|
+
|
98
|
+
class FunctionAdapter < ::Chainer::FunctionNode
|
99
|
+
attr_accessor :function, :weak_function
|
100
|
+
|
101
|
+
def initialize(function)
|
102
|
+
super()
|
103
|
+
@weak_function = WeakRef.new(function)
|
104
|
+
function.owned_node = self
|
105
|
+
end
|
106
|
+
|
107
|
+
def function
|
108
|
+
func = @function
|
109
|
+
return func if func
|
110
|
+
|
111
|
+
weak_func = @weak_function
|
112
|
+
weak_func.__getobj__
|
113
|
+
end
|
114
|
+
|
115
|
+
def label
|
116
|
+
@function.label
|
117
|
+
end
|
118
|
+
|
119
|
+
def forward(inputs)
|
120
|
+
retain_inputs(inputs.size.times.to_a)
|
121
|
+
@function.forward(inputs)
|
122
|
+
end
|
123
|
+
|
124
|
+
def backward(target_input_indexes, grad_outputs)
|
125
|
+
in_data = @inputs.map { |input| input.data }
|
126
|
+
grad_out_data = grad_outputs.map { |grad| grad.nil? ? nil : grad.data }
|
127
|
+
|
128
|
+
gxs = @function.backward(in_data, grad_out_data)
|
129
|
+
ret = []
|
130
|
+
target_input_indexes.each do |i|
|
131
|
+
if gxs[i].nil?
|
132
|
+
g = nil
|
133
|
+
else
|
134
|
+
g = Chainer::Variable.new(gxs[i])
|
135
|
+
g.node.old_style_grad_generator = @function.label
|
136
|
+
end
|
137
|
+
ret << g
|
78
138
|
end
|
139
|
+
|
140
|
+
ret
|
79
141
|
end
|
80
142
|
end
|
81
143
|
end
|
@@ -0,0 +1,454 @@
|
|
1
|
+
# Function node of the computational graph.
|
2
|
+
# FunctionNode is a class representing a node in a computational graph.
|
3
|
+
# The node corresponds to an application of a differentiable function to input variables.
|
4
|
+
# When a differentiable function is applied to `Chainer::Variable` objects,
|
5
|
+
# it creates an instance of FunctionNode implementation and calls its `apply` method.
|
6
|
+
# The `apply` method basically does the following three things.
|
7
|
+
# 1. Adding an edge from the function node to the variable node corresponding to each input.
|
8
|
+
# The node of each input is extracted by `Chainer::`Variable.node`.
|
9
|
+
# 2. Computing the output arrays of the function.
|
10
|
+
# 3. Creating a :class:`Variable` object for each output array and
|
11
|
+
# adding an edge from the node of the variable to the function node.
|
12
|
+
# The output variables are then returned.
|
13
|
+
module Chainer
|
14
|
+
class FunctionNode
|
15
|
+
attr_accessor :rank, :inputs, :outputs
|
16
|
+
|
17
|
+
def initialize
|
18
|
+
@rank = 0
|
19
|
+
@inputs = nil
|
20
|
+
@outputs = nil
|
21
|
+
|
22
|
+
@retained_output_data = nil
|
23
|
+
@input_indexes_to_retain = nil
|
24
|
+
@output_indexes_to_retain = nil
|
25
|
+
end
|
26
|
+
|
27
|
+
# Short text that represents the function.
|
28
|
+
#
|
29
|
+
# The default implementation returns its type name.
|
30
|
+
# Each function should override it to give more information.
|
31
|
+
def label
|
32
|
+
self.class.name
|
33
|
+
end
|
34
|
+
|
35
|
+
# A tuple of the retained output arrays.
|
36
|
+
# This property is mainly used by $Function$. Users basically do
|
37
|
+
# not have to use this property; use $get_retained_outputs$ instead.
|
38
|
+
def output_data
|
39
|
+
raise RuntimeError, 'retained output data is gone' if @retained_output_data.nil?
|
40
|
+
out_data = [nil] * @outputs.size
|
41
|
+
@output_indexes_to_retain.zip(@retained_output_data).each do |index, data|
|
42
|
+
out_data[index] = data
|
43
|
+
end
|
44
|
+
out_data
|
45
|
+
end
|
46
|
+
|
47
|
+
# Computes output variables and grows the computational graph.
|
48
|
+
#
|
49
|
+
# Basic behavior is expressed in the documentation of `FunctionNode`.
|
50
|
+
# @param [Chainer::Variable, Numo::NArray] inputs If the element is an Numo::NArray,
|
51
|
+
# it is automatically wrapped with `Chainer::Variable`.
|
52
|
+
# @return [Array<Chainer::Variable>] A tuple of output `Chainer::Variable` objectts.
|
53
|
+
def apply(inputs)
|
54
|
+
input_vars = inputs.map { |x| Chainer::Variable.as_variable(x) }
|
55
|
+
in_data = input_vars.map(&:data)
|
56
|
+
requires_grad = input_vars.map(&:requires_grad).any?
|
57
|
+
|
58
|
+
# Forward propagation
|
59
|
+
@input_indexes_to_retain = nil
|
60
|
+
@output_indexes_to_retain = nil
|
61
|
+
outputs = forward(in_data)
|
62
|
+
raise TypeError, "#{outputs.class} not Array" unless outputs.is_a?(Array)
|
63
|
+
|
64
|
+
ret = outputs.map { |y| Chainer::Variable.new(y, requires_grad: requires_grad) }
|
65
|
+
|
66
|
+
if Chainer.configuration.enable_backprop
|
67
|
+
# Topological ordering
|
68
|
+
@rank = input_vars.size > 0 ? input_vars.map(&:rank).max : 0
|
69
|
+
|
70
|
+
# Add backward edges
|
71
|
+
ret.each { |y| y.creator_node = self }
|
72
|
+
@inputs = input_vars.map(&:node)
|
73
|
+
# Add forward edges (must be weak references)
|
74
|
+
@outputs = ret.map { |y| WeakRef.new(y.node) }
|
75
|
+
|
76
|
+
unless @input_indexes_to_retain.nil?
|
77
|
+
@input_indexes_to_retain.each do |index|
|
78
|
+
input_vars[index].retain_data
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
unless @output_indexes_to_retain.nil?
|
83
|
+
retained_data = []
|
84
|
+
@output_indexes_to_retain.each do |index|
|
85
|
+
ret[index].retain_data
|
86
|
+
retained_data << outputs[index]
|
87
|
+
end
|
88
|
+
@retained_output_data = Array(retained_data)
|
89
|
+
end
|
90
|
+
end
|
91
|
+
|
92
|
+
ret
|
93
|
+
end
|
94
|
+
|
95
|
+
# Computes the output arrays from the input arrays.
|
96
|
+
#
|
97
|
+
# @param [Array] inputs input array(s)
|
98
|
+
# @return [Array] output array(s)
|
99
|
+
def forward(inputs)
|
100
|
+
raise TypeError, "mustt inputs > 0, inputs size is #{inputs.size}" if inputs.size.zero?
|
101
|
+
# TODO GPU
|
102
|
+
forward_cpu(inputs)
|
103
|
+
end
|
104
|
+
|
105
|
+
# Computes the output arrays from the input Numo::NArray.
|
106
|
+
#
|
107
|
+
# @param [Array<Numo::NArray>] inputs Numo::NArray objects.
|
108
|
+
# @return [Array<Numo::NArray>] Array of output arrays.
|
109
|
+
def forward_cpu(inputs)
|
110
|
+
raise NotImplementedError
|
111
|
+
end
|
112
|
+
|
113
|
+
# Lets specified input variable nodes keep data arrays.
|
114
|
+
#
|
115
|
+
# By calling this method from `forward`, the function node can specify which inputs are required for backprop.
|
116
|
+
# The input variables with retained arrays can be obtained by `get_retained_inputs` from `backward`.
|
117
|
+
#
|
118
|
+
# Note that **this method must not be called from the outside of forward method.**
|
119
|
+
# @param [Integer, Array] indexes Indexes of input variables that the function does not require for backprop.
|
120
|
+
def retain_inputs(indexes)
|
121
|
+
@input_indexes_to_retain = indexes
|
122
|
+
end
|
123
|
+
|
124
|
+
# Lets specified output variable nodes keep data arrays.
|
125
|
+
#
|
126
|
+
# By calling this method from `forward`, the function node can specify which outputs are required for backprop.
|
127
|
+
# If this method is not called, any output variables are not marked to keep the data array at the point of returning from `apply`.
|
128
|
+
# The output variables with retained arrays can be obtained by `get_retained_outputs` from `backward`.
|
129
|
+
# Note that **this method must not be called from the outside of forward method.**
|
130
|
+
# @param [Integer, Array] indexes Indexes of input variables that the function does not require for backprop.
|
131
|
+
def retain_outputs(indexes)
|
132
|
+
@output_indexes_to_retain = indexes
|
133
|
+
end
|
134
|
+
|
135
|
+
# Computes gradients w.r.t. specified inputs given output gradients.
|
136
|
+
#
|
137
|
+
# This method is used to compute one step of the backpropagation corresponding to the forward computation of this function node.
|
138
|
+
# Given the gradients w.r.t. output variables, this method computes the gradients w.r.t. specified input variables.
|
139
|
+
# Note that this method does not need to compute any input gradients not specified by `target_input_indexes`
|
140
|
+
# It enables the function node to return the input gradients with the full computational history,
|
141
|
+
# in which case it supports *differentiable backpropagation* or *higher-order differentiation*.
|
142
|
+
#
|
143
|
+
# @param [Array<Integer>] target_indexes Indices of the input variables w.r.t. which the gradients are required.
|
144
|
+
# It is guaranteed that this tuple contains at least one element.
|
145
|
+
# @param [Array<Chainer::Variable>] grad_outputs Gradients w.r.t. the output variables.
|
146
|
+
# If the gradient w.r.t. an output variable is not given, the corresponding element is `None`.
|
147
|
+
# @return [Array<Chainer::Variable>] Array of Chainer::Variable that represent the gradients.
|
148
|
+
def backward(target_indexes, grad_outputs)
|
149
|
+
[nil] * target_indexes.size
|
150
|
+
end
|
151
|
+
|
152
|
+
# Computes gradients w.r.t. specified inputs and accumulates them.
|
153
|
+
#
|
154
|
+
# This method provides a way to fuse the backward computation and the gradient accumulations
|
155
|
+
# in the case that the multiple functions are applied to the same variable.
|
156
|
+
# Users have to override either of this method or `backward`.
|
157
|
+
# It is often simpler to implement `backward` and is recommended if you do not need to provide efficient gradient accumulation.
|
158
|
+
#
|
159
|
+
# @param [Array<Integer>] target_input_indexes Indices of the input variables w.r.t. which the gradients are required.
|
160
|
+
# It is guaranteed that this tuple contains at least one element.
|
161
|
+
# @param [Array<Chainer::Variable>] grad_outputs Gradients w.r.t. the output variables.
|
162
|
+
# If the gradient w.r.t. an output variable is not given, the corresponding element is `None`.
|
163
|
+
# @param [Array<Chainer::Variable>] grad_inputs Gradients w.r.t. the input variables specified by `target_input_indexes`.
|
164
|
+
# These values are computed by other computation paths.
|
165
|
+
# If there is no gradient value existing for the variable, the corresponding element is ``None``.
|
166
|
+
# @return [Array<Chainer::Variable>] Array of variables that represent the gradients w.r.t. specified input variables.
|
167
|
+
def backward_accumulate(target_input_indexes, grad_outputs, grad_inputs)
|
168
|
+
gxs = backward(target_input_indexes, grad_outputs)
|
169
|
+
|
170
|
+
len_gxs = gxs.size
|
171
|
+
if len_gxs == @inputs.size
|
172
|
+
gxs = target_input_indexes.map { |i| gxs[i] }
|
173
|
+
elsif len_gxs != target_input_indexes.size
|
174
|
+
raise ArgumentError, "number of gradients returned by #{impl_name} (#{label}) is incorrect."
|
175
|
+
end
|
176
|
+
|
177
|
+
gxs.zip(grad_inputs).map do |gx, g_input|
|
178
|
+
if g_input.nil?
|
179
|
+
gx
|
180
|
+
elsif gx.nil?
|
181
|
+
g_input
|
182
|
+
else
|
183
|
+
gx + g_input
|
184
|
+
end
|
185
|
+
end
|
186
|
+
end
|
187
|
+
|
188
|
+
# Returns a Array of retained input variables.
|
189
|
+
#
|
190
|
+
# This method is used to retrieve the input variables retained in `forward`.
|
191
|
+
#
|
192
|
+
# @return [Array] a Array of retained input variables.
|
193
|
+
def get_retained_inputs
|
194
|
+
@input_indexes_to_retain.map { |index| @inputs[index].get_variable }
|
195
|
+
end
|
196
|
+
|
197
|
+
# Returns a Array of retained output variables.
|
198
|
+
#
|
199
|
+
# This method is used to retrieve the input variables retained in `forward`.
|
200
|
+
#
|
201
|
+
# @return [Array] a Array of retained input variables.
|
202
|
+
def get_retained_outputs
|
203
|
+
ret = []
|
204
|
+
outputs = @outputs
|
205
|
+
|
206
|
+
new_outputs = outputs.dup
|
207
|
+
outputs_modified = false
|
208
|
+
|
209
|
+
@output_indexes_to_retain.zip(@retained_output_data) do |index, data|
|
210
|
+
output = outputs[index].__getobj__
|
211
|
+
if output.nil?
|
212
|
+
output_var = Chainer::Variable.new(data)
|
213
|
+
output_var.creator_node = self
|
214
|
+
new_outputs[index] = WeakRef.new(output_var)
|
215
|
+
outputs_modified = true
|
216
|
+
else
|
217
|
+
output_var = output.get_variable
|
218
|
+
end
|
219
|
+
|
220
|
+
ret << output_var
|
221
|
+
end
|
222
|
+
|
223
|
+
if outputs_modified
|
224
|
+
@outputs = Array(new_outputs)
|
225
|
+
end
|
226
|
+
|
227
|
+
ret
|
228
|
+
end
|
229
|
+
|
230
|
+
# Purges in/out nodes and this function node itself from the graph.
|
231
|
+
def unchain
|
232
|
+
@outputs.each do |y|
|
233
|
+
y_ref = y.()
|
234
|
+
unless y_ref.nil?
|
235
|
+
y_ref.unchain
|
236
|
+
end
|
237
|
+
end
|
238
|
+
@inputs = nil
|
239
|
+
end
|
240
|
+
|
241
|
+
private
|
242
|
+
|
243
|
+
def impl_name
|
244
|
+
self.class.name
|
245
|
+
end
|
246
|
+
end
|
247
|
+
|
248
|
+
def self.grad(outputs, inputs, grad_outputs: nil, grad_inputs: nil, set_grad: false, retain_grad: false, enable_double_backprop: false)
|
249
|
+
# The implementation consists of three steps.
|
250
|
+
|
251
|
+
if !outputs.is_a?(Array)
|
252
|
+
raise TypeError, "outputs must be Array, not #{outputs.class}"
|
253
|
+
end
|
254
|
+
if !inputs.is_a?(Array)
|
255
|
+
raise TypeError, "inputs must be Array, not #{inputs.class}"
|
256
|
+
end
|
257
|
+
if !grad_outputs.nil? && !grad_outputs.is_a?(Array)
|
258
|
+
raise TypeError, "grad_outputs must be Array, not #{grad_outputs.class}"
|
259
|
+
end
|
260
|
+
if !grad_inputs.nil? && !grad_inputs.is_a?(Array)
|
261
|
+
raise TypeError, "grad_inputs must be Array, not #{grad_inputs.class}"
|
262
|
+
end
|
263
|
+
|
264
|
+
# 1. Backward enumeration: all the nodes reachable backward from the output
|
265
|
+
# nodes are enumerated. The forward direction links are collected in
|
266
|
+
# this step. Note that the variable nodes whose requires_grad is false
|
267
|
+
# are ignored and their creators are not searched.
|
268
|
+
candidate_funcs = outputs.map(&:creator_node).compact
|
269
|
+
visited_funcs = Set.new
|
270
|
+
forward_graph = {}
|
271
|
+
|
272
|
+
while func = candidate_funcs.pop
|
273
|
+
next if visited_funcs.include?(func)
|
274
|
+
visited_funcs.add(func)
|
275
|
+
|
276
|
+
func.inputs.each do |x|
|
277
|
+
next unless x.requires_grad
|
278
|
+
forward_graph[x] = [] if forward_graph[x].nil?
|
279
|
+
forward_graph[x] << func
|
280
|
+
creator = x.creator_node
|
281
|
+
if creator && !visited_funcs.include?(creator)
|
282
|
+
candidate_funcs << creator
|
283
|
+
end
|
284
|
+
end
|
285
|
+
end
|
286
|
+
|
287
|
+
# 2. Forward enumeration: all the nodes in the subgraph reachable from the
|
288
|
+
# input nodes are enumerated. The extracted (sub-)subgraph is the union
|
289
|
+
# of all paths that backpropagation will visit.
|
290
|
+
candidate_vars = inputs.map(&:node)
|
291
|
+
visited_funcs = Set.new
|
292
|
+
grad_required = Set.new
|
293
|
+
while x = candidate_vars.pop
|
294
|
+
grad_required.add(x)
|
295
|
+
forward_graph[x].each do |func|
|
296
|
+
next if visited_funcs.include?(func)
|
297
|
+
visited_funcs.add(func)
|
298
|
+
func.outputs.each do |y_ref|
|
299
|
+
y = y_ref.__getobj__
|
300
|
+
if y && forward_graph[y]
|
301
|
+
candidate_vars << y
|
302
|
+
end
|
303
|
+
end
|
304
|
+
end
|
305
|
+
end
|
306
|
+
|
307
|
+
# 3. Backpropagation: the backpropagation is executed along the
|
308
|
+
# (sub-)subgraph. It uses the topological order of the subgraph which is
|
309
|
+
# induced by the reversed order of function applications ("rank").
|
310
|
+
grads = {} # mapping from variable nodes to their gradients
|
311
|
+
|
312
|
+
# Initialize the gradient mapping.
|
313
|
+
grad_outputs = [nil] * outputs.size if grad_outputs.nil?
|
314
|
+
outputs.zip(grad_outputs).each do |y, gy|
|
315
|
+
if gy.nil?
|
316
|
+
gy_data = y.data.new_ones
|
317
|
+
gy = Chainer::Variable.new(gy_data, requires_grad: false)
|
318
|
+
end
|
319
|
+
|
320
|
+
grads[y.node] = gy
|
321
|
+
end
|
322
|
+
|
323
|
+
unless grad_inputs.nil?
|
324
|
+
inputs.zip(grad_inputs).each do |x, gx|
|
325
|
+
grads[x.node] = gx unless gx.nil?
|
326
|
+
end
|
327
|
+
end
|
328
|
+
|
329
|
+
# Backprop implementation. It edits grads which will only contain the
|
330
|
+
# gradients w.r.t. the inputs.
|
331
|
+
old_enable_backprop = Chainer.configuration.enable_backprop
|
332
|
+
Chainer.configuration.enable_backprop = enable_double_backprop
|
333
|
+
backprop(outputs, inputs, grad_required, retain_grad, grads)
|
334
|
+
Chainer.configuration.enable_backprop = old_enable_backprop
|
335
|
+
|
336
|
+
# Extract the gradients w.r.t. the inputs and return them.
|
337
|
+
ret = inputs.map { |x| grads[x.node] }
|
338
|
+
if set_grad
|
339
|
+
inputs.zip(ret).each do |x, gx|
|
340
|
+
x.grad_var = gx
|
341
|
+
end
|
342
|
+
end
|
343
|
+
|
344
|
+
ret
|
345
|
+
end
|
346
|
+
|
347
|
+
def self.backprop(outputs, inputs, grad_required, retain_grad, grads)
|
348
|
+
candidate_funcs = []
|
349
|
+
visited_funcs = Set.new
|
350
|
+
|
351
|
+
push_candidate = -> (func) do
|
352
|
+
return if visited_funcs.include?(func)
|
353
|
+
|
354
|
+
# Negate since heapq is min-heap
|
355
|
+
# The second element is used to make each item unique
|
356
|
+
visited_funcs.add(func)
|
357
|
+
candidate_funcs.unshift(func)
|
358
|
+
candidate_funcs.sort_by! { |f| f.rank }
|
359
|
+
end
|
360
|
+
|
361
|
+
pop_candidate = -> () do
|
362
|
+
candidate_funcs.pop
|
363
|
+
end
|
364
|
+
|
365
|
+
outputs.each do |y|
|
366
|
+
creator = y.creator_node
|
367
|
+
next if creator.nil?
|
368
|
+
push_candidate.(creator)
|
369
|
+
end
|
370
|
+
|
371
|
+
input_nodes = Set.new(inputs.map(&:node))
|
372
|
+
|
373
|
+
while func = pop_candidate.()
|
374
|
+
# Collect the gradients w.r.t. the outputs
|
375
|
+
gys = []
|
376
|
+
|
377
|
+
func.outputs.each do |y_ref|
|
378
|
+
y = y_ref.__getobj__
|
379
|
+
if y.nil?
|
380
|
+
gys << nil
|
381
|
+
next
|
382
|
+
end
|
383
|
+
gys << grads[y]
|
384
|
+
end
|
385
|
+
|
386
|
+
# Collect the gradients w.r.t. the inputs
|
387
|
+
#
|
388
|
+
# Note (Tokui): when the same variable is passed multiple times as
|
389
|
+
# inputs in the same function (e.g. an expression like f(x, x)), the
|
390
|
+
# current implementation passes None as the current gradient w.r.t.
|
391
|
+
# such an input except for the first one (i.e., it builds gxs like
|
392
|
+
# (gx, None) where gx is the current gradient w.r.t. x).
|
393
|
+
gxs = []
|
394
|
+
input_indexes = []
|
395
|
+
selected_inputs = Set.new
|
396
|
+
func.inputs.each_with_index do |x, i|
|
397
|
+
next unless grad_required.include?(x)
|
398
|
+
|
399
|
+
input_indexes << i
|
400
|
+
if selected_inputs.include?(x)
|
401
|
+
gxs << nil
|
402
|
+
else
|
403
|
+
gxs << grads[x]
|
404
|
+
selected_inputs.add(x)
|
405
|
+
end
|
406
|
+
end
|
407
|
+
|
408
|
+
next if input_indexes.empty?
|
409
|
+
|
410
|
+
# Do backward
|
411
|
+
new_gxs = func.backward_accumulate(input_indexes, gys, gxs)
|
412
|
+
|
413
|
+
# Delete output gradients that are not required to return
|
414
|
+
func.outputs.each do |y_ref|
|
415
|
+
y = y_ref.__getobj__
|
416
|
+
if y && grads[y] && !input_nodes.include?(y)
|
417
|
+
grads.delete(y)
|
418
|
+
end
|
419
|
+
end
|
420
|
+
|
421
|
+
# Update grads
|
422
|
+
selected_inputs = Set.new
|
423
|
+
input_indexes.zip(new_gxs).each do |i, g|
|
424
|
+
next if g.nil?
|
425
|
+
|
426
|
+
node = func.inputs[i]
|
427
|
+
if selected_inputs.include?(node)
|
428
|
+
# Accumulate the duplicated gradients here
|
429
|
+
cur_gx = grads[node]
|
430
|
+
if cur_gx
|
431
|
+
g = g + cur_gx
|
432
|
+
end
|
433
|
+
else
|
434
|
+
selected_inputs.add(node)
|
435
|
+
end
|
436
|
+
|
437
|
+
grads[node] = g
|
438
|
+
|
439
|
+
if retain_grad
|
440
|
+
v = node.get_variable
|
441
|
+
if v
|
442
|
+
v.grad_var = g
|
443
|
+
end
|
444
|
+
end
|
445
|
+
|
446
|
+
creator = node.creator_node
|
447
|
+
if creator
|
448
|
+
push_candidate.(creator)
|
449
|
+
end
|
450
|
+
end
|
451
|
+
end
|
452
|
+
end
|
453
|
+
private_class_method :backprop
|
454
|
+
end
|