red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -11,7 +11,7 @@ module Chainer
|
|
11
11
|
get_cifar(100, with_label, ndim, scale)
|
12
12
|
end
|
13
13
|
|
14
|
-
def self.get_cifar(n_classes, with_label, ndim, scale)
|
14
|
+
def self.get_cifar(n_classes, with_label, ndim, scale, device: Chainer::Device.default)
|
15
15
|
train_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).to_table
|
16
16
|
test_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).to_table
|
17
17
|
|
@@ -25,13 +25,14 @@ module Chainer
|
|
25
25
|
test_labels = test_table[:fine_label]
|
26
26
|
end
|
27
27
|
|
28
|
+
xm = device.xm
|
28
29
|
[
|
29
|
-
preprocess_cifar(
|
30
|
-
preprocess_cifar(
|
30
|
+
preprocess_cifar(xm::UInt8[*train_data], xm::UInt8[*train_labels], with_label, ndim, scale),
|
31
|
+
preprocess_cifar(xm::UInt8[*test_data], xm::UInt8[*test_labels], with_label, ndim, scale)
|
31
32
|
]
|
32
33
|
end
|
33
34
|
|
34
|
-
def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
|
35
|
+
def self.preprocess_cifar(images, labels, withlabel, ndim, scale, device: Chainer::Device.default)
|
35
36
|
if ndim == 1
|
36
37
|
images = images.reshape(images.shape[0], 3072)
|
37
38
|
elsif ndim == 3
|
@@ -39,11 +40,12 @@ module Chainer
|
|
39
40
|
else
|
40
41
|
raise 'invalid ndim for CIFAR dataset'
|
41
42
|
end
|
42
|
-
|
43
|
+
xm = device.xm
|
44
|
+
images = images.cast_to(xm::SFloat)
|
43
45
|
images *= scale / 255.0
|
44
46
|
|
45
47
|
if withlabel
|
46
|
-
labels = labels.cast_to(
|
48
|
+
labels = labels.cast_to(xm::Int32)
|
47
49
|
TupleDataset.new(images, labels)
|
48
50
|
else
|
49
51
|
images
|
@@ -1,13 +1,17 @@
|
|
1
|
-
require '
|
1
|
+
require 'datasets'
|
2
2
|
|
3
3
|
module Chainer
|
4
4
|
module Datasets
|
5
|
-
module
|
6
|
-
def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype:
|
7
|
-
|
5
|
+
module MNIST
|
6
|
+
def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil)
|
7
|
+
xm = Chainer::Device.default.xm
|
8
|
+
dtype ||= xm::SFloat
|
9
|
+
label_dtype ||= xm::Int32
|
10
|
+
|
11
|
+
train_raw = retrieve_mnist(type: :train)
|
8
12
|
train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)
|
9
13
|
|
10
|
-
test_raw =
|
14
|
+
test_raw = retrieve_mnist(type: :test)
|
11
15
|
test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, label_dtype)
|
12
16
|
[train, test]
|
13
17
|
end
|
@@ -24,7 +28,7 @@ module Chainer
|
|
24
28
|
|
25
29
|
images = images.cast_to(image_dtype)
|
26
30
|
images *= scale / 255.0
|
27
|
-
|
31
|
+
|
28
32
|
if withlabel
|
29
33
|
labels = raw[:y].cast_to(label_dtype)
|
30
34
|
TupleDataset.new(images, labels)
|
@@ -33,56 +37,11 @@ module Chainer
|
|
33
37
|
end
|
34
38
|
end
|
35
39
|
|
36
|
-
def self.
|
37
|
-
|
38
|
-
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz']
|
39
|
-
retrieve_mnist('train.npz', urls)
|
40
|
-
end
|
41
|
-
|
42
|
-
def self.retrieve_mnist_test
|
43
|
-
urls = ['http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
44
|
-
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz']
|
45
|
-
retrieve_mnist('test.npz', urls)
|
46
|
-
end
|
47
|
-
|
48
|
-
def self.retrieve_mnist(name, urls)
|
49
|
-
root = Chainer::Dataset::Download.get_dataset_directory('pfnet/chainer/mnist')
|
50
|
-
path = File.expand_path(name, root)
|
51
|
-
Chainer::Dataset::Download.cache_or_load_file(path) do
|
52
|
-
make_npz(path, urls)
|
53
|
-
end
|
54
|
-
end
|
55
|
-
|
56
|
-
def self.make_npz(path, urls)
|
57
|
-
x_url, y_url = urls
|
58
|
-
x_path = Chainer::Dataset::Download.cached_download(x_url)
|
59
|
-
y_path = Chainer::Dataset::Download.cached_download(y_url)
|
60
|
-
|
61
|
-
x = nil
|
62
|
-
y = nil
|
63
|
-
|
64
|
-
Zlib::GzipReader.open(x_path) do |fx|
|
65
|
-
Zlib::GzipReader.open(y_path) do |fy|
|
66
|
-
fx.read(4)
|
67
|
-
fy.read(4)
|
68
|
-
|
69
|
-
n = fx.read(4).unpack('i>')[0]
|
70
|
-
fy.read(4)
|
71
|
-
fx.read(8)
|
72
|
-
|
73
|
-
x = Numo::UInt8.new(n, 784).rand(n)
|
74
|
-
y = Numo::UInt8.new(n).rand(n)
|
75
|
-
|
76
|
-
n.times do |i|
|
77
|
-
y[i] = fy.read(1).ord
|
78
|
-
784.times do |j|
|
79
|
-
x[i, j] = fx.read(1).ord
|
80
|
-
end
|
81
|
-
end
|
82
|
-
end
|
83
|
-
end
|
40
|
+
def self.retrieve_mnist(type:)
|
41
|
+
train_table = ::Datasets::MNIST.new(type: type).to_table
|
84
42
|
|
85
|
-
|
43
|
+
xm = Chainer::Device.default.xm
|
44
|
+
{ x: xm::UInt8[*train_table[:pixels]], y: xm::UInt8[*train_table[:label]] }
|
86
45
|
end
|
87
46
|
end
|
88
47
|
end
|
@@ -0,0 +1,88 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Device
|
3
|
+
# Creates device
|
4
|
+
#
|
5
|
+
# @param [Integer or Chainer::AbstractDevice] device_spec Device specifier.
|
6
|
+
# Negative integer indicates CPU. 0 or positive integer indicates GPU.
|
7
|
+
# If a device object is given, itself is returned.
|
8
|
+
# @return [Chainer::AbstractDevice] device object
|
9
|
+
def create(device_spec)
|
10
|
+
return device_spec if device_spec.kind_of?(AbstractDevice)
|
11
|
+
if device_spec.kind_of?(Integer)
|
12
|
+
return CpuDevice.new if device_spec < 0
|
13
|
+
return GpuDevice.new(device_spec)
|
14
|
+
end
|
15
|
+
raise "Invalid device_spec: #{device_spec}"
|
16
|
+
end
|
17
|
+
module_function :create
|
18
|
+
|
19
|
+
# Changes default device
|
20
|
+
#
|
21
|
+
# @param [Object] device_spec
|
22
|
+
# @see Chainer::Device.create
|
23
|
+
def change_default(device_spec)
|
24
|
+
@default = create(device_spec)
|
25
|
+
@default.use
|
26
|
+
end
|
27
|
+
module_function :change_default
|
28
|
+
|
29
|
+
# Gets default device
|
30
|
+
#
|
31
|
+
# @return [Chainer::AbstractDevice] the default device.
|
32
|
+
def default
|
33
|
+
@default ||= CpuDevice.new
|
34
|
+
end
|
35
|
+
module_function :default
|
36
|
+
|
37
|
+
# TODO(sonots): Add Device.from_array after Cumo provides an API
|
38
|
+
# to return GPU device ID from Cumo::NArray.
|
39
|
+
end
|
40
|
+
|
41
|
+
class AbstractDevice
|
42
|
+
def xm
|
43
|
+
raise NotImplementedError
|
44
|
+
end
|
45
|
+
|
46
|
+
def use
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
class CpuDevice < AbstractDevice
|
51
|
+
def xm
|
52
|
+
Numo
|
53
|
+
end
|
54
|
+
|
55
|
+
def ==(other)
|
56
|
+
return false unless other.is_a?(CpuDevice)
|
57
|
+
true
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
class GpuDevice < AbstractDevice
|
62
|
+
attr_reader :id
|
63
|
+
|
64
|
+
# @param [Integer] id GPU Device ID. If not given, CUDA current device id is used.
|
65
|
+
def initialize(id = nil)
|
66
|
+
Chainer::CUDA.check_available
|
67
|
+
id ||= Cumo::CUDA::Runtime.cudaGetDevice
|
68
|
+
if id < 0
|
69
|
+
raise 'GPU Device ID must not be negative'
|
70
|
+
end
|
71
|
+
@id = id
|
72
|
+
end
|
73
|
+
|
74
|
+
def xm
|
75
|
+
Cumo
|
76
|
+
end
|
77
|
+
|
78
|
+
def ==(other)
|
79
|
+
return false unless other.is_a?(GpuDevice)
|
80
|
+
id == other.id
|
81
|
+
end
|
82
|
+
|
83
|
+
# Sets CUDA current device with owned GPU Device ID
|
84
|
+
def use
|
85
|
+
Cumo::CUDA::Runtime.cudaSetDevice(@id)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
data/lib/chainer/function.rb
CHANGED
@@ -1,70 +1,89 @@
|
|
1
|
+
require 'chainer/function_node'
|
1
2
|
module Chainer
|
2
3
|
class Function
|
3
4
|
|
4
5
|
attr_reader :rank, :inputs, :outputs, :retain_after_backward
|
5
|
-
attr_accessor :output_data
|
6
|
+
attr_accessor :output_data, :owned_node
|
6
7
|
|
7
8
|
def initialize
|
8
9
|
@rank = 0
|
9
10
|
end
|
10
11
|
|
11
12
|
def call(*inputs)
|
12
|
-
|
13
|
-
if x.kind_of?(Chainer::Variable)
|
14
|
-
x
|
15
|
-
else
|
16
|
-
Variable.new(x, requires_grad: false)
|
17
|
-
end
|
18
|
-
end
|
13
|
+
node = self.node
|
19
14
|
|
20
|
-
|
21
|
-
|
15
|
+
node.function = self
|
16
|
+
node.weak_function = nil
|
17
|
+
@node = WeakRef.new(node)
|
18
|
+
@owned_node = nil
|
22
19
|
|
23
|
-
|
24
|
-
@output_indexes_to_retain = nil
|
25
|
-
outputs = forward(in_data)
|
26
|
-
raise if !outputs.is_a? Array
|
20
|
+
ret = node.apply(inputs)
|
27
21
|
|
28
|
-
ret
|
29
|
-
|
30
|
-
end
|
22
|
+
ret.size == 1 ? ret[0] : ret
|
23
|
+
end
|
31
24
|
|
32
|
-
|
33
|
-
|
25
|
+
def inputs
|
26
|
+
@node.inputs
|
27
|
+
end
|
34
28
|
|
35
|
-
|
29
|
+
def outputs
|
30
|
+
@node.outputs
|
31
|
+
end
|
36
32
|
|
37
|
-
|
38
|
-
|
33
|
+
def node
|
34
|
+
noderef = @node
|
35
|
+
nd = noderef ? noderef.__getobj__ : @owned_node
|
36
|
+
return nd if nd
|
39
37
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
remove_instance_variable(:@input_indexes_to_retain)
|
38
|
+
nd = FunctionAdapter.new(self)
|
39
|
+
@owned_node = nd
|
40
|
+
nd
|
41
|
+
end
|
45
42
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
end
|
50
|
-
remove_instance_variable(:@output_indexes_to_retain)
|
51
|
-
end
|
52
|
-
end
|
43
|
+
def output_data
|
44
|
+
node.output_data
|
45
|
+
end
|
53
46
|
|
54
|
-
|
47
|
+
def rank
|
48
|
+
@node.rank
|
49
|
+
end
|
50
|
+
|
51
|
+
def label
|
52
|
+
self.class.to_s
|
55
53
|
end
|
56
54
|
|
57
55
|
def forward(inputs)
|
58
|
-
|
59
|
-
|
56
|
+
xm = Chainer.get_array_module(*inputs)
|
57
|
+
if xm == Cumo
|
58
|
+
forward_gpu(inputs)
|
59
|
+
else
|
60
|
+
forward_cpu(inputs)
|
61
|
+
end
|
60
62
|
end
|
61
63
|
|
62
64
|
def forward_cpu(inputs)
|
63
65
|
raise NotImplementedError
|
64
66
|
end
|
65
67
|
|
68
|
+
def forward_gpu(inputs)
|
69
|
+
raise NotImplementedError
|
70
|
+
end
|
71
|
+
|
66
72
|
def backward(inputs, grad_outputs)
|
67
|
-
|
73
|
+
xm = Chainer.get_array_module(*(inputs + grad_outputs))
|
74
|
+
if xm == Cumo
|
75
|
+
backward_gpu(inputs, grad_outputs)
|
76
|
+
else
|
77
|
+
backward_cpu(inputs, grad_outputs)
|
78
|
+
end
|
79
|
+
end
|
80
|
+
|
81
|
+
def backward_cpu(inputs, grad_outputs)
|
82
|
+
return [nil] * inputs.size
|
83
|
+
end
|
84
|
+
|
85
|
+
def backward_gpu(inputs, grad_outputs)
|
86
|
+
return [nil] * inputs.size
|
68
87
|
end
|
69
88
|
|
70
89
|
def retain_inputs(indexes)
|
@@ -72,10 +91,53 @@ module Chainer
|
|
72
91
|
end
|
73
92
|
|
74
93
|
def retain_outputs(indexes, retain_after_backward: false)
|
75
|
-
|
76
|
-
|
77
|
-
|
94
|
+
node.retain_outputs(indexes)
|
95
|
+
end
|
96
|
+
end
|
97
|
+
|
98
|
+
class FunctionAdapter < ::Chainer::FunctionNode
|
99
|
+
attr_accessor :function, :weak_function
|
100
|
+
|
101
|
+
def initialize(function)
|
102
|
+
super()
|
103
|
+
@weak_function = WeakRef.new(function)
|
104
|
+
function.owned_node = self
|
105
|
+
end
|
106
|
+
|
107
|
+
def function
|
108
|
+
func = @function
|
109
|
+
return func if func
|
110
|
+
|
111
|
+
weak_func = @weak_function
|
112
|
+
weak_func.__getobj__
|
113
|
+
end
|
114
|
+
|
115
|
+
def label
|
116
|
+
@function.label
|
117
|
+
end
|
118
|
+
|
119
|
+
def forward(inputs)
|
120
|
+
retain_inputs(inputs.size.times.to_a)
|
121
|
+
@function.forward(inputs)
|
122
|
+
end
|
123
|
+
|
124
|
+
def backward(target_input_indexes, grad_outputs)
|
125
|
+
in_data = @inputs.map { |input| input.data }
|
126
|
+
grad_out_data = grad_outputs.map { |grad| grad.nil? ? nil : grad.data }
|
127
|
+
|
128
|
+
gxs = @function.backward(in_data, grad_out_data)
|
129
|
+
ret = []
|
130
|
+
target_input_indexes.each do |i|
|
131
|
+
if gxs[i].nil?
|
132
|
+
g = nil
|
133
|
+
else
|
134
|
+
g = Chainer::Variable.new(gxs[i])
|
135
|
+
g.node.old_style_grad_generator = @function.label
|
136
|
+
end
|
137
|
+
ret << g
|
78
138
|
end
|
139
|
+
|
140
|
+
ret
|
79
141
|
end
|
80
142
|
end
|
81
143
|
end
|
@@ -0,0 +1,454 @@
|
|
1
|
+
# Function node of the computational graph.
|
2
|
+
# FunctionNode is a class representing a node in a computational graph.
|
3
|
+
# The node corresponds to an application of a differentiable function to input variables.
|
4
|
+
# When a differentiable function is applied to `Chainer::Variable` objects,
|
5
|
+
# it creates an instance of FunctionNode implementation and calls its `apply` method.
|
6
|
+
# The `apply` method basically does the following three things.
|
7
|
+
# 1. Adding an edge from the function node to the variable node corresponding to each input.
|
8
|
+
# The node of each input is extracted by `Chainer::`Variable.node`.
|
9
|
+
# 2. Computing the output arrays of the function.
|
10
|
+
# 3. Creating a :class:`Variable` object for each output array and
|
11
|
+
# adding an edge from the node of the variable to the function node.
|
12
|
+
# The output variables are then returned.
|
13
|
+
module Chainer
|
14
|
+
class FunctionNode
|
15
|
+
attr_accessor :rank, :inputs, :outputs
|
16
|
+
|
17
|
+
def initialize
|
18
|
+
@rank = 0
|
19
|
+
@inputs = nil
|
20
|
+
@outputs = nil
|
21
|
+
|
22
|
+
@retained_output_data = nil
|
23
|
+
@input_indexes_to_retain = nil
|
24
|
+
@output_indexes_to_retain = nil
|
25
|
+
end
|
26
|
+
|
27
|
+
# Short text that represents the function.
|
28
|
+
#
|
29
|
+
# The default implementation returns its type name.
|
30
|
+
# Each function should override it to give more information.
|
31
|
+
def label
|
32
|
+
self.class.name
|
33
|
+
end
|
34
|
+
|
35
|
+
# A tuple of the retained output arrays.
|
36
|
+
# This property is mainly used by $Function$. Users basically do
|
37
|
+
# not have to use this property; use $get_retained_outputs$ instead.
|
38
|
+
def output_data
|
39
|
+
raise RuntimeError, 'retained output data is gone' if @retained_output_data.nil?
|
40
|
+
out_data = [nil] * @outputs.size
|
41
|
+
@output_indexes_to_retain.zip(@retained_output_data).each do |index, data|
|
42
|
+
out_data[index] = data
|
43
|
+
end
|
44
|
+
out_data
|
45
|
+
end
|
46
|
+
|
47
|
+
# Computes output variables and grows the computational graph.
|
48
|
+
#
|
49
|
+
# Basic behavior is expressed in the documentation of `FunctionNode`.
|
50
|
+
# @param [Chainer::Variable, Numo::NArray] inputs If the element is an Numo::NArray,
|
51
|
+
# it is automatically wrapped with `Chainer::Variable`.
|
52
|
+
# @return [Array<Chainer::Variable>] A tuple of output `Chainer::Variable` objectts.
|
53
|
+
def apply(inputs)
|
54
|
+
input_vars = inputs.map { |x| Chainer::Variable.as_variable(x) }
|
55
|
+
in_data = input_vars.map(&:data)
|
56
|
+
requires_grad = input_vars.map(&:requires_grad).any?
|
57
|
+
|
58
|
+
# Forward propagation
|
59
|
+
@input_indexes_to_retain = nil
|
60
|
+
@output_indexes_to_retain = nil
|
61
|
+
outputs = forward(in_data)
|
62
|
+
raise TypeError, "#{outputs.class} not Array" unless outputs.is_a?(Array)
|
63
|
+
|
64
|
+
ret = outputs.map { |y| Chainer::Variable.new(y, requires_grad: requires_grad) }
|
65
|
+
|
66
|
+
if Chainer.configuration.enable_backprop
|
67
|
+
# Topological ordering
|
68
|
+
@rank = input_vars.size > 0 ? input_vars.map(&:rank).max : 0
|
69
|
+
|
70
|
+
# Add backward edges
|
71
|
+
ret.each { |y| y.creator_node = self }
|
72
|
+
@inputs = input_vars.map(&:node)
|
73
|
+
# Add forward edges (must be weak references)
|
74
|
+
@outputs = ret.map { |y| WeakRef.new(y.node) }
|
75
|
+
|
76
|
+
unless @input_indexes_to_retain.nil?
|
77
|
+
@input_indexes_to_retain.each do |index|
|
78
|
+
input_vars[index].retain_data
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
unless @output_indexes_to_retain.nil?
|
83
|
+
retained_data = []
|
84
|
+
@output_indexes_to_retain.each do |index|
|
85
|
+
ret[index].retain_data
|
86
|
+
retained_data << outputs[index]
|
87
|
+
end
|
88
|
+
@retained_output_data = Array(retained_data)
|
89
|
+
end
|
90
|
+
end
|
91
|
+
|
92
|
+
ret
|
93
|
+
end
|
94
|
+
|
95
|
+
# Computes the output arrays from the input arrays.
|
96
|
+
#
|
97
|
+
# @param [Array] inputs input array(s)
|
98
|
+
# @return [Array] output array(s)
|
99
|
+
def forward(inputs)
|
100
|
+
raise TypeError, "mustt inputs > 0, inputs size is #{inputs.size}" if inputs.size.zero?
|
101
|
+
# TODO GPU
|
102
|
+
forward_cpu(inputs)
|
103
|
+
end
|
104
|
+
|
105
|
+
# Computes the output arrays from the input Numo::NArray.
|
106
|
+
#
|
107
|
+
# @param [Array<Numo::NArray>] inputs Numo::NArray objects.
|
108
|
+
# @return [Array<Numo::NArray>] Array of output arrays.
|
109
|
+
def forward_cpu(inputs)
|
110
|
+
raise NotImplementedError
|
111
|
+
end
|
112
|
+
|
113
|
+
# Lets specified input variable nodes keep data arrays.
|
114
|
+
#
|
115
|
+
# By calling this method from `forward`, the function node can specify which inputs are required for backprop.
|
116
|
+
# The input variables with retained arrays can be obtained by `get_retained_inputs` from `backward`.
|
117
|
+
#
|
118
|
+
# Note that **this method must not be called from the outside of forward method.**
|
119
|
+
# @param [Integer, Array] indexes Indexes of input variables that the function does not require for backprop.
|
120
|
+
def retain_inputs(indexes)
|
121
|
+
@input_indexes_to_retain = indexes
|
122
|
+
end
|
123
|
+
|
124
|
+
# Lets specified output variable nodes keep data arrays.
|
125
|
+
#
|
126
|
+
# By calling this method from `forward`, the function node can specify which outputs are required for backprop.
|
127
|
+
# If this method is not called, any output variables are not marked to keep the data array at the point of returning from `apply`.
|
128
|
+
# The output variables with retained arrays can be obtained by `get_retained_outputs` from `backward`.
|
129
|
+
# Note that **this method must not be called from the outside of forward method.**
|
130
|
+
# @param [Integer, Array] indexes Indexes of input variables that the function does not require for backprop.
|
131
|
+
def retain_outputs(indexes)
|
132
|
+
@output_indexes_to_retain = indexes
|
133
|
+
end
|
134
|
+
|
135
|
+
# Computes gradients w.r.t. specified inputs given output gradients.
|
136
|
+
#
|
137
|
+
# This method is used to compute one step of the backpropagation corresponding to the forward computation of this function node.
|
138
|
+
# Given the gradients w.r.t. output variables, this method computes the gradients w.r.t. specified input variables.
|
139
|
+
# Note that this method does not need to compute any input gradients not specified by `target_input_indexes`
|
140
|
+
# It enables the function node to return the input gradients with the full computational history,
|
141
|
+
# in which case it supports *differentiable backpropagation* or *higher-order differentiation*.
|
142
|
+
#
|
143
|
+
# @param [Array<Integer>] target_indexes Indices of the input variables w.r.t. which the gradients are required.
|
144
|
+
# It is guaranteed that this tuple contains at least one element.
|
145
|
+
# @param [Array<Chainer::Variable>] grad_outputs Gradients w.r.t. the output variables.
|
146
|
+
# If the gradient w.r.t. an output variable is not given, the corresponding element is `None`.
|
147
|
+
# @return [Array<Chainer::Variable>] Array of Chainer::Variable that represent the gradients.
|
148
|
+
def backward(target_indexes, grad_outputs)
|
149
|
+
[nil] * target_indexes.size
|
150
|
+
end
|
151
|
+
|
152
|
+
# Computes gradients w.r.t. specified inputs and accumulates them.
|
153
|
+
#
|
154
|
+
# This method provides a way to fuse the backward computation and the gradient accumulations
|
155
|
+
# in the case that the multiple functions are applied to the same variable.
|
156
|
+
# Users have to override either of this method or `backward`.
|
157
|
+
# It is often simpler to implement `backward` and is recommended if you do not need to provide efficient gradient accumulation.
|
158
|
+
#
|
159
|
+
# @param [Array<Integer>] target_input_indexes Indices of the input variables w.r.t. which the gradients are required.
|
160
|
+
# It is guaranteed that this tuple contains at least one element.
|
161
|
+
# @param [Array<Chainer::Variable>] grad_outputs Gradients w.r.t. the output variables.
|
162
|
+
# If the gradient w.r.t. an output variable is not given, the corresponding element is `None`.
|
163
|
+
# @param [Array<Chainer::Variable>] grad_inputs Gradients w.r.t. the input variables specified by `target_input_indexes`.
|
164
|
+
# These values are computed by other computation paths.
|
165
|
+
# If there is no gradient value existing for the variable, the corresponding element is ``None``.
|
166
|
+
# @return [Array<Chainer::Variable>] Array of variables that represent the gradients w.r.t. specified input variables.
|
167
|
+
def backward_accumulate(target_input_indexes, grad_outputs, grad_inputs)
|
168
|
+
gxs = backward(target_input_indexes, grad_outputs)
|
169
|
+
|
170
|
+
len_gxs = gxs.size
|
171
|
+
if len_gxs == @inputs.size
|
172
|
+
gxs = target_input_indexes.map { |i| gxs[i] }
|
173
|
+
elsif len_gxs != target_input_indexes.size
|
174
|
+
raise ArgumentError, "number of gradients returned by #{impl_name} (#{label}) is incorrect."
|
175
|
+
end
|
176
|
+
|
177
|
+
gxs.zip(grad_inputs).map do |gx, g_input|
|
178
|
+
if g_input.nil?
|
179
|
+
gx
|
180
|
+
elsif gx.nil?
|
181
|
+
g_input
|
182
|
+
else
|
183
|
+
gx + g_input
|
184
|
+
end
|
185
|
+
end
|
186
|
+
end
|
187
|
+
|
188
|
+
# Returns a Array of retained input variables.
|
189
|
+
#
|
190
|
+
# This method is used to retrieve the input variables retained in `forward`.
|
191
|
+
#
|
192
|
+
# @return [Array] a Array of retained input variables.
|
193
|
+
def get_retained_inputs
|
194
|
+
@input_indexes_to_retain.map { |index| @inputs[index].get_variable }
|
195
|
+
end
|
196
|
+
|
197
|
+
# Returns a Array of retained output variables.
|
198
|
+
#
|
199
|
+
# This method is used to retrieve the input variables retained in `forward`.
|
200
|
+
#
|
201
|
+
# @return [Array] a Array of retained input variables.
|
202
|
+
def get_retained_outputs
|
203
|
+
ret = []
|
204
|
+
outputs = @outputs
|
205
|
+
|
206
|
+
new_outputs = outputs.dup
|
207
|
+
outputs_modified = false
|
208
|
+
|
209
|
+
@output_indexes_to_retain.zip(@retained_output_data) do |index, data|
|
210
|
+
output = outputs[index].__getobj__
|
211
|
+
if output.nil?
|
212
|
+
output_var = Chainer::Variable.new(data)
|
213
|
+
output_var.creator_node = self
|
214
|
+
new_outputs[index] = WeakRef.new(output_var)
|
215
|
+
outputs_modified = true
|
216
|
+
else
|
217
|
+
output_var = output.get_variable
|
218
|
+
end
|
219
|
+
|
220
|
+
ret << output_var
|
221
|
+
end
|
222
|
+
|
223
|
+
if outputs_modified
|
224
|
+
@outputs = Array(new_outputs)
|
225
|
+
end
|
226
|
+
|
227
|
+
ret
|
228
|
+
end
|
229
|
+
|
230
|
+
# Purges in/out nodes and this function node itself from the graph.
|
231
|
+
def unchain
|
232
|
+
@outputs.each do |y|
|
233
|
+
y_ref = y.()
|
234
|
+
unless y_ref.nil?
|
235
|
+
y_ref.unchain
|
236
|
+
end
|
237
|
+
end
|
238
|
+
@inputs = nil
|
239
|
+
end
|
240
|
+
|
241
|
+
private
|
242
|
+
|
243
|
+
def impl_name
|
244
|
+
self.class.name
|
245
|
+
end
|
246
|
+
end
|
247
|
+
|
248
|
+
def self.grad(outputs, inputs, grad_outputs: nil, grad_inputs: nil, set_grad: false, retain_grad: false, enable_double_backprop: false)
|
249
|
+
# The implementation consists of three steps.
|
250
|
+
|
251
|
+
if !outputs.is_a?(Array)
|
252
|
+
raise TypeError, "outputs must be Array, not #{outputs.class}"
|
253
|
+
end
|
254
|
+
if !inputs.is_a?(Array)
|
255
|
+
raise TypeError, "inputs must be Array, not #{inputs.class}"
|
256
|
+
end
|
257
|
+
if !grad_outputs.nil? && !grad_outputs.is_a?(Array)
|
258
|
+
raise TypeError, "grad_outputs must be Array, not #{grad_outputs.class}"
|
259
|
+
end
|
260
|
+
if !grad_inputs.nil? && !grad_inputs.is_a?(Array)
|
261
|
+
raise TypeError, "grad_inputs must be Array, not #{grad_inputs.class}"
|
262
|
+
end
|
263
|
+
|
264
|
+
# 1. Backward enumeration: all the nodes reachable backward from the output
|
265
|
+
# nodes are enumerated. The forward direction links are collected in
|
266
|
+
# this step. Note that the variable nodes whose requires_grad is false
|
267
|
+
# are ignored and their creators are not searched.
|
268
|
+
candidate_funcs = outputs.map(&:creator_node).compact
|
269
|
+
visited_funcs = Set.new
|
270
|
+
forward_graph = {}
|
271
|
+
|
272
|
+
while func = candidate_funcs.pop
|
273
|
+
next if visited_funcs.include?(func)
|
274
|
+
visited_funcs.add(func)
|
275
|
+
|
276
|
+
func.inputs.each do |x|
|
277
|
+
next unless x.requires_grad
|
278
|
+
forward_graph[x] = [] if forward_graph[x].nil?
|
279
|
+
forward_graph[x] << func
|
280
|
+
creator = x.creator_node
|
281
|
+
if creator && !visited_funcs.include?(creator)
|
282
|
+
candidate_funcs << creator
|
283
|
+
end
|
284
|
+
end
|
285
|
+
end
|
286
|
+
|
287
|
+
# 2. Forward enumeration: all the nodes in the subgraph reachable from the
|
288
|
+
# input nodes are enumerated. The extracted (sub-)subgraph is the union
|
289
|
+
# of all paths that backpropagation will visit.
|
290
|
+
candidate_vars = inputs.map(&:node)
|
291
|
+
visited_funcs = Set.new
|
292
|
+
grad_required = Set.new
|
293
|
+
while x = candidate_vars.pop
|
294
|
+
grad_required.add(x)
|
295
|
+
forward_graph[x].each do |func|
|
296
|
+
next if visited_funcs.include?(func)
|
297
|
+
visited_funcs.add(func)
|
298
|
+
func.outputs.each do |y_ref|
|
299
|
+
y = y_ref.__getobj__
|
300
|
+
if y && forward_graph[y]
|
301
|
+
candidate_vars << y
|
302
|
+
end
|
303
|
+
end
|
304
|
+
end
|
305
|
+
end
|
306
|
+
|
307
|
+
# 3. Backpropagation: the backpropagation is executed along the
|
308
|
+
# (sub-)subgraph. It uses the topological order of the subgraph which is
|
309
|
+
# induced by the reversed order of function applications ("rank").
|
310
|
+
grads = {} # mapping from variable nodes to their gradients
|
311
|
+
|
312
|
+
# Initialize the gradient mapping.
|
313
|
+
grad_outputs = [nil] * outputs.size if grad_outputs.nil?
|
314
|
+
outputs.zip(grad_outputs).each do |y, gy|
|
315
|
+
if gy.nil?
|
316
|
+
gy_data = y.data.new_ones
|
317
|
+
gy = Chainer::Variable.new(gy_data, requires_grad: false)
|
318
|
+
end
|
319
|
+
|
320
|
+
grads[y.node] = gy
|
321
|
+
end
|
322
|
+
|
323
|
+
unless grad_inputs.nil?
|
324
|
+
inputs.zip(grad_inputs).each do |x, gx|
|
325
|
+
grads[x.node] = gx unless gx.nil?
|
326
|
+
end
|
327
|
+
end
|
328
|
+
|
329
|
+
# Backprop implementation. It edits grads which will only contain the
|
330
|
+
# gradients w.r.t. the inputs.
|
331
|
+
old_enable_backprop = Chainer.configuration.enable_backprop
|
332
|
+
Chainer.configuration.enable_backprop = enable_double_backprop
|
333
|
+
backprop(outputs, inputs, grad_required, retain_grad, grads)
|
334
|
+
Chainer.configuration.enable_backprop = old_enable_backprop
|
335
|
+
|
336
|
+
# Extract the gradients w.r.t. the inputs and return them.
|
337
|
+
ret = inputs.map { |x| grads[x.node] }
|
338
|
+
if set_grad
|
339
|
+
inputs.zip(ret).each do |x, gx|
|
340
|
+
x.grad_var = gx
|
341
|
+
end
|
342
|
+
end
|
343
|
+
|
344
|
+
ret
|
345
|
+
end
|
346
|
+
|
347
|
+
def self.backprop(outputs, inputs, grad_required, retain_grad, grads)
|
348
|
+
candidate_funcs = []
|
349
|
+
visited_funcs = Set.new
|
350
|
+
|
351
|
+
push_candidate = -> (func) do
|
352
|
+
return if visited_funcs.include?(func)
|
353
|
+
|
354
|
+
# Negate since heapq is min-heap
|
355
|
+
# The second element is used to make each item unique
|
356
|
+
visited_funcs.add(func)
|
357
|
+
candidate_funcs.unshift(func)
|
358
|
+
candidate_funcs.sort_by! { |f| f.rank }
|
359
|
+
end
|
360
|
+
|
361
|
+
pop_candidate = -> () do
|
362
|
+
candidate_funcs.pop
|
363
|
+
end
|
364
|
+
|
365
|
+
outputs.each do |y|
|
366
|
+
creator = y.creator_node
|
367
|
+
next if creator.nil?
|
368
|
+
push_candidate.(creator)
|
369
|
+
end
|
370
|
+
|
371
|
+
input_nodes = Set.new(inputs.map(&:node))
|
372
|
+
|
373
|
+
while func = pop_candidate.()
|
374
|
+
# Collect the gradients w.r.t. the outputs
|
375
|
+
gys = []
|
376
|
+
|
377
|
+
func.outputs.each do |y_ref|
|
378
|
+
y = y_ref.__getobj__
|
379
|
+
if y.nil?
|
380
|
+
gys << nil
|
381
|
+
next
|
382
|
+
end
|
383
|
+
gys << grads[y]
|
384
|
+
end
|
385
|
+
|
386
|
+
# Collect the gradients w.r.t. the inputs
|
387
|
+
#
|
388
|
+
# Note (Tokui): when the same variable is passed multiple times as
|
389
|
+
# inputs in the same function (e.g. an expression like f(x, x)), the
|
390
|
+
# current implementation passes None as the current gradient w.r.t.
|
391
|
+
# such an input except for the first one (i.e., it builds gxs like
|
392
|
+
# (gx, None) where gx is the current gradient w.r.t. x).
|
393
|
+
gxs = []
|
394
|
+
input_indexes = []
|
395
|
+
selected_inputs = Set.new
|
396
|
+
func.inputs.each_with_index do |x, i|
|
397
|
+
next unless grad_required.include?(x)
|
398
|
+
|
399
|
+
input_indexes << i
|
400
|
+
if selected_inputs.include?(x)
|
401
|
+
gxs << nil
|
402
|
+
else
|
403
|
+
gxs << grads[x]
|
404
|
+
selected_inputs.add(x)
|
405
|
+
end
|
406
|
+
end
|
407
|
+
|
408
|
+
next if input_indexes.empty?
|
409
|
+
|
410
|
+
# Do backward
|
411
|
+
new_gxs = func.backward_accumulate(input_indexes, gys, gxs)
|
412
|
+
|
413
|
+
# Delete output gradients that are not required to return
|
414
|
+
func.outputs.each do |y_ref|
|
415
|
+
y = y_ref.__getobj__
|
416
|
+
if y && grads[y] && !input_nodes.include?(y)
|
417
|
+
grads.delete(y)
|
418
|
+
end
|
419
|
+
end
|
420
|
+
|
421
|
+
# Update grads
|
422
|
+
selected_inputs = Set.new
|
423
|
+
input_indexes.zip(new_gxs).each do |i, g|
|
424
|
+
next if g.nil?
|
425
|
+
|
426
|
+
node = func.inputs[i]
|
427
|
+
if selected_inputs.include?(node)
|
428
|
+
# Accumulate the duplicated gradients here
|
429
|
+
cur_gx = grads[node]
|
430
|
+
if cur_gx
|
431
|
+
g = g + cur_gx
|
432
|
+
end
|
433
|
+
else
|
434
|
+
selected_inputs.add(node)
|
435
|
+
end
|
436
|
+
|
437
|
+
grads[node] = g
|
438
|
+
|
439
|
+
if retain_grad
|
440
|
+
v = node.get_variable
|
441
|
+
if v
|
442
|
+
v.grad_var = g
|
443
|
+
end
|
444
|
+
end
|
445
|
+
|
446
|
+
creator = node.creator_node
|
447
|
+
if creator
|
448
|
+
push_candidate.(creator)
|
449
|
+
end
|
450
|
+
end
|
451
|
+
end
|
452
|
+
end
|
453
|
+
private_class_method :backprop
|
454
|
+
end
|