red-chainer 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/examples/cifar/models/resnet18.rb +90 -0
- data/examples/cifar/models/vgg.rb +2 -2
- data/examples/cifar/train_cifar.rb +13 -4
- data/lib/chainer.rb +1 -0
- data/lib/chainer/datasets/cifar.rb +1 -1
- data/lib/chainer/datasets/tuple_dataset.rb +1 -1
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +45 -0
- data/lib/chainer/initializers/init.rb +2 -2
- data/lib/chainer/link.rb +141 -13
- data/lib/chainer/links/normalization/batch_normalization.rb +1 -1
- data/lib/chainer/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 18e976d36f56e110baeae2d799471b4be2e4932778316ffcb3aafa51563bd3fd
|
4
|
+
data.tar.gz: f5429173d5b175738f4358774d9f058e31a86eec1418954cbad72d7875a2a7b8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: b1a617684703b92491e7876d919c2f6ef9c144039b9e12086d5b2ec7aed8644eca3ea8b2f4b32f2ef98fb6f31db4068b67467f95aa0a5868b3a903e073909b2e
|
7
|
+
data.tar.gz: '095870edbd1c68d35da023e9dd9a16096585a264f8e9cedf2c49dc465a62687a46634c58ee8a2d5fca4b2c41872f2f335455e076da66996a47f6c658ca6987e6'
|
@@ -0,0 +1,90 @@
|
|
1
|
+
module ResNet18
|
2
|
+
class Plain < Chainer::Chain
|
3
|
+
include Chainer::Functions::Activation
|
4
|
+
include Chainer::Initializers
|
5
|
+
include Chainer::Links::Connection
|
6
|
+
include Chainer::Links::Normalization
|
7
|
+
|
8
|
+
def initialize(ch, stride, use_conv: false)
|
9
|
+
super()
|
10
|
+
|
11
|
+
@use_conv = use_conv
|
12
|
+
w = HeNormal.new
|
13
|
+
|
14
|
+
init_scope do
|
15
|
+
@conv1 = Convolution2D.new(nil, ch, 3, stride: stride, pad: 1, nobias: true, initial_w: w)
|
16
|
+
@bn1 = BatchNormalization.new(ch)
|
17
|
+
@conv2 = Convolution2D.new(nil, ch, 3, stride: 1, pad: 1, nobias: true, initial_w: w)
|
18
|
+
@bn2 = BatchNormalization.new(ch)
|
19
|
+
if @use_conv
|
20
|
+
@conv3 = Convolution2D.new(nil, ch, 3, stride: stride, pad: 1, nobias: true, initial_w: w)
|
21
|
+
@bn3 = BatchNormalization.new(ch)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def call(x)
|
27
|
+
h = Relu.relu(@bn1.(@conv1.(x)))
|
28
|
+
h = @bn2.(@conv2.(h))
|
29
|
+
if @use_conv
|
30
|
+
h2 = @bn3.(@conv3.(x))
|
31
|
+
Relu.relu(h + h2)
|
32
|
+
else
|
33
|
+
Relu.relu(h + x)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
class Block < Chainer::ChainList
|
39
|
+
def initialize(layer, ch, stride=2)
|
40
|
+
super()
|
41
|
+
add_link(Plain.new(ch, stride, use_conv: true))
|
42
|
+
(layer-1).times do
|
43
|
+
add_link(Plain.new(ch, 1))
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
def call(x)
|
48
|
+
@children.each do |f|
|
49
|
+
x = f.(x)
|
50
|
+
end
|
51
|
+
x
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
class Model < Chainer::Chain
|
56
|
+
include Chainer::Functions::Activation
|
57
|
+
include Chainer::Functions::Evaluation
|
58
|
+
include Chainer::Functions::Loss
|
59
|
+
include Chainer::Functions::Pooling
|
60
|
+
include Chainer::Initializers
|
61
|
+
include Chainer::Links::Connection
|
62
|
+
include Chainer::Links::Normalization
|
63
|
+
|
64
|
+
def initialize(n_classes: 10)
|
65
|
+
super()
|
66
|
+
initial_w = HeNormal.new
|
67
|
+
|
68
|
+
init_scope do
|
69
|
+
@conv = Convolution2D.new(3, 64, 7, stride: 2, pad: 3, initial_w: initial_w)
|
70
|
+
@bn = BatchNormalization.new(64)
|
71
|
+
|
72
|
+
@res2 = Block.new(2, 64, 1)
|
73
|
+
@res3 = Block.new(2, 128)
|
74
|
+
@res4 = Block.new(2, 256)
|
75
|
+
@res5 = Block.new(2, 512)
|
76
|
+
@fc = Linear.new(nil, out_size: n_classes)
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def call(x)
|
81
|
+
h = Relu.relu(@bn.(@conv.(x)))
|
82
|
+
h = @res2.(h)
|
83
|
+
h = @res3.(h)
|
84
|
+
h = @res4.(h)
|
85
|
+
h = @res5.(h)
|
86
|
+
h = AveragePooling2D.average_pooling_2d(h, h.shape[2..-1])
|
87
|
+
@fc.(h)
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
@@ -15,7 +15,7 @@ class Block < Chainer::Chain
|
|
15
15
|
end
|
16
16
|
|
17
17
|
class VGG < Chainer::Chain
|
18
|
-
def initialize(
|
18
|
+
def initialize(n_classes: 10)
|
19
19
|
super()
|
20
20
|
init_scope do
|
21
21
|
@block1_1 = Block.new(64, 3)
|
@@ -33,7 +33,7 @@ class VGG < Chainer::Chain
|
|
33
33
|
@block5_3 = Block.new(512, 3)
|
34
34
|
@fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true)
|
35
35
|
@bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512)
|
36
|
-
@fc2 = Chainer::Links::Connection::Linear.new(nil, out_size:
|
36
|
+
@fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: n_classes, nobias: true)
|
37
37
|
end
|
38
38
|
end
|
39
39
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
require 'chainer'
|
2
2
|
require __dir__ + '/models/vgg'
|
3
|
+
require __dir__ + '/models/resnet18'
|
3
4
|
require 'optparse'
|
4
5
|
|
5
6
|
args = {
|
@@ -9,7 +10,8 @@ args = {
|
|
9
10
|
learnrate: 0.05,
|
10
11
|
epoch: 300,
|
11
12
|
out: 'result',
|
12
|
-
resume: nil
|
13
|
+
resume: nil,
|
14
|
+
model: 'vgg',
|
13
15
|
}
|
14
16
|
|
15
17
|
|
@@ -21,6 +23,7 @@ opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learn
|
|
21
23
|
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
|
22
24
|
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
|
23
25
|
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
|
26
|
+
opt.on('-m', '--model VALUE', "Use model") { |v| args[:model] = v }
|
24
27
|
opt.parse!(ARGV)
|
25
28
|
|
26
29
|
# Set up a neural network to train.
|
@@ -38,9 +41,15 @@ else
|
|
38
41
|
raise 'Invalid dataset choice.'
|
39
42
|
end
|
40
43
|
|
41
|
-
|
44
|
+
if args[:model] == 'vgg'
|
45
|
+
puts 'Using VGG model'
|
46
|
+
model_class = VGG
|
47
|
+
elsif args[:model] == 'resnet18'
|
48
|
+
puts 'Using ResNet-18 model'
|
49
|
+
model_class = ResNet18::Model
|
50
|
+
end
|
42
51
|
|
43
|
-
model = Chainer::Links::Model::Classifier.new(
|
52
|
+
model = Chainer::Links::Model::Classifier.new(model_class.new(n_classes: class_labels))
|
44
53
|
|
45
54
|
optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate])
|
46
55
|
optimizer.setup(model)
|
@@ -58,7 +67,7 @@ trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), t
|
|
58
67
|
frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
|
59
68
|
trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'])
|
60
69
|
|
61
|
-
trainer.extend(Chainer::Training::Extensions::LogReport.new)
|
70
|
+
trainer.extend(Chainer::Training::Extensions::LogReport.new)
|
62
71
|
trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
|
63
72
|
trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
|
64
73
|
|
data/lib/chainer.rb
CHANGED
@@ -43,6 +43,7 @@ require 'chainer/functions/connection/linear'
|
|
43
43
|
require 'chainer/functions/noise/dropout'
|
44
44
|
require 'chainer/functions/normalization/batch_normalization'
|
45
45
|
require 'chainer/functions/pooling/pooling_2d'
|
46
|
+
require 'chainer/functions/pooling/average_pooling_2d'
|
46
47
|
require 'chainer/functions/pooling/max_pooling_2d'
|
47
48
|
require 'chainer/testing/array'
|
48
49
|
require 'chainer/training/extension'
|
@@ -0,0 +1,45 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Pooling
|
4
|
+
class AveragePooling2D < Pooling2D
|
5
|
+
# Spatial average pooling function.
|
6
|
+
#
|
7
|
+
# This function acts similarly to :class:`Convolution2D`,
|
8
|
+
# but it computes the average of input spatial patch for each channel
|
9
|
+
# without any parameter instead of computing the inner products.
|
10
|
+
# @param [Chainer::Variable] x Input variable.
|
11
|
+
# @param [integer] ksize Size of pooling window. `ksize=k` and `ksize=[k, k]` are equivalent.
|
12
|
+
# @param [integer] stride Stride of pooling applications. `stride=s` and `stride=[s, s]` are equivalent.
|
13
|
+
# If `nil` is specified, then it uses same stride as the pooling window size.
|
14
|
+
# @param [integer] pad Spatial padding width for the input array. `pad=p` and `pad=[p, p]` are equivalent.
|
15
|
+
# @return [Chainer::Variable] Output variable
|
16
|
+
def self.average_pooling_2d(x, ksize, stride: nil, pad: 0)
|
17
|
+
self.new(ksize, stride: stride, pad: pad, cover_all: false).(x)
|
18
|
+
end
|
19
|
+
|
20
|
+
# Average pooling over a set of 2d planes.
|
21
|
+
def forward_cpu(x)
|
22
|
+
retain_inputs([])
|
23
|
+
@in_shape = x[0].shape
|
24
|
+
@in_dtype = x[0].class
|
25
|
+
|
26
|
+
col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw)
|
27
|
+
y = col.mean(axis: [2, 3])
|
28
|
+
|
29
|
+
[y]
|
30
|
+
end
|
31
|
+
|
32
|
+
def backward_cpu(x, gy)
|
33
|
+
h, w = @in_shape[2..-1]
|
34
|
+
shape = gy[0].shape
|
35
|
+
shape.insert(2, 1, 1)
|
36
|
+
gcol = gy[0].reshape(*shape).tile(1, 1, @kh, @kw, 1, 1)
|
37
|
+
|
38
|
+
gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
|
39
|
+
gx /= @kh * @kw
|
40
|
+
[gx]
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
@@ -13,8 +13,8 @@ module Chainer
|
|
13
13
|
return HeNormal.new(scale: 1 / Numo::NMath.sqrt(2)) if initializer.nil?
|
14
14
|
return Constant.new(initializer) if initializer.kind_of?(Numeric)
|
15
15
|
return Constant.new(initializer) if initializer.kind_of?(Numo::NArray)
|
16
|
-
|
17
|
-
unless initializer.
|
16
|
+
|
17
|
+
unless initializer.respond_to?(:call)
|
18
18
|
raise TypeError, "invalid type of initializer: #{initializer.class}"
|
19
19
|
end
|
20
20
|
|
data/lib/chainer/link.rb
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
module Chainer
|
2
2
|
class Link
|
3
|
+
attr_accessor :name
|
4
|
+
|
3
5
|
def initialize
|
4
6
|
@params = []
|
5
7
|
@persistent = []
|
@@ -17,22 +19,28 @@ module Chainer
|
|
17
19
|
|
18
20
|
begin
|
19
21
|
yield
|
20
|
-
|
22
|
+
self.instance_variables.each do |name|
|
23
|
+
set_attr(name, self.instance_variable_get(name))
|
24
|
+
end
|
21
25
|
ensure
|
22
26
|
@within_init_scope = old_flag
|
23
27
|
end
|
24
28
|
end
|
25
29
|
|
26
|
-
def set_attr
|
27
|
-
|
28
|
-
value =
|
29
|
-
|
30
|
-
|
31
|
-
@persistent.delete(name)
|
32
|
-
end
|
30
|
+
def set_attr(name, value)
|
31
|
+
if within_init_scope && value.kind_of?(Chainer::Parameter)
|
32
|
+
value.name = name
|
33
|
+
@params << name
|
34
|
+
@persistent.delete(name)
|
33
35
|
end
|
34
36
|
end
|
35
37
|
|
38
|
+
def del_attr(name)
|
39
|
+
@params.delete(name)
|
40
|
+
@persistent.delete(name)
|
41
|
+
self.remove_instance_variable(name)
|
42
|
+
end
|
43
|
+
|
36
44
|
def cleargrads
|
37
45
|
params do |param|
|
38
46
|
param.cleargrad
|
@@ -99,16 +107,22 @@ module Chainer
|
|
99
107
|
@children = []
|
100
108
|
end
|
101
109
|
|
102
|
-
def set_attr
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
@children << name
|
110
|
+
def set_attr(name, value)
|
111
|
+
if within_init_scope && value.kind_of?(Chainer::Link)
|
112
|
+
if self.respond_to?(name)
|
113
|
+
raise TypeError, "cannot register a new link #{name}: attribute exists"
|
107
114
|
end
|
115
|
+
value.name = name
|
116
|
+
@children << name
|
108
117
|
end
|
109
118
|
super
|
110
119
|
end
|
111
120
|
|
121
|
+
def del_attr(name)
|
122
|
+
@children.delete(name)
|
123
|
+
super
|
124
|
+
end
|
125
|
+
|
112
126
|
def params(include_uninit: true)
|
113
127
|
super(include_uninit: include_uninit) do |param|
|
114
128
|
yield param
|
@@ -155,4 +169,118 @@ module Chainer
|
|
155
169
|
end
|
156
170
|
end
|
157
171
|
end
|
172
|
+
|
173
|
+
|
174
|
+
# Composable link with list-like interface.
|
175
|
+
#
|
176
|
+
# This is another example of compositional link. Unlike :class:`Chainer::Chain`,
|
177
|
+
# this class can be used like a list of child links.
|
178
|
+
# Each child link is indexed by a non-negative integer,
|
179
|
+
# and it maintains the current number of registered child links.
|
180
|
+
# The :meth:`add_link` method inserts a new link at the end of the list.
|
181
|
+
# It is useful to write a chain with arbitrary number of child links,
|
182
|
+
# e.g. an arbitrarily deep multi-layer perceptron.
|
183
|
+
class ChainList < Link
|
184
|
+
attr_reader :children
|
185
|
+
|
186
|
+
def initialize(*links)
|
187
|
+
super()
|
188
|
+
@children = []
|
189
|
+
|
190
|
+
links.each do |link|
|
191
|
+
add_link(link)
|
192
|
+
end
|
193
|
+
end
|
194
|
+
|
195
|
+
def set_attr(name, value)
|
196
|
+
if within_init_scope && value.kind_of?(Chainer::Link)
|
197
|
+
raise TypeError, 'cannot register a new link within a "with chainlist.init_scope:" block.'
|
198
|
+
end
|
199
|
+
super
|
200
|
+
end
|
201
|
+
|
202
|
+
def [](index)
|
203
|
+
@children[index]
|
204
|
+
end
|
205
|
+
|
206
|
+
def each(&block)
|
207
|
+
@children.each(&block)
|
208
|
+
end
|
209
|
+
|
210
|
+
def size
|
211
|
+
@children.size
|
212
|
+
end
|
213
|
+
|
214
|
+
def <<(link)
|
215
|
+
add_link(link)
|
216
|
+
end
|
217
|
+
|
218
|
+
def add_link(link)
|
219
|
+
link.name = @children.size.to_s
|
220
|
+
@children << link
|
221
|
+
end
|
222
|
+
|
223
|
+
def params(include_uninit: true)
|
224
|
+
super(include_uninit: include_uninit) do |param|
|
225
|
+
yield param
|
226
|
+
end
|
227
|
+
|
228
|
+
@children.each do |link|
|
229
|
+
link.params(include_uninit: include_uninit) do |param|
|
230
|
+
yield param
|
231
|
+
end
|
232
|
+
end
|
233
|
+
end
|
234
|
+
|
235
|
+
def namedparams(include_uninit: true)
|
236
|
+
super(include_uninit: include_uninit) do |ret|
|
237
|
+
yield ret
|
238
|
+
end
|
239
|
+
@children.each_with_index do |link, idx|
|
240
|
+
prefix = "/#{idx}"
|
241
|
+
link.namedparams(include_uninit: include_uninit) do |path, param|
|
242
|
+
yield [prefix + path, param]
|
243
|
+
end
|
244
|
+
end
|
245
|
+
end
|
246
|
+
|
247
|
+
def links(skipself: false)
|
248
|
+
unless skipself
|
249
|
+
yield self
|
250
|
+
end
|
251
|
+
|
252
|
+
@children.each do |child|
|
253
|
+
child.links do |link|
|
254
|
+
yield link
|
255
|
+
end
|
256
|
+
end
|
257
|
+
end
|
258
|
+
|
259
|
+
def namedlinks(skipself: false)
|
260
|
+
unless skipself
|
261
|
+
yield '/', self
|
262
|
+
end
|
263
|
+
|
264
|
+
@children.each_with_index do |child, idx|
|
265
|
+
prefix = "/#{idx}"
|
266
|
+
yield prefix, child
|
267
|
+
child.namedlinks(skipself: true) do |path, link|
|
268
|
+
yield [prefix + path, link]
|
269
|
+
end
|
270
|
+
end
|
271
|
+
end
|
272
|
+
|
273
|
+
def children
|
274
|
+
@children.each do |child|
|
275
|
+
yield child
|
276
|
+
end
|
277
|
+
end
|
278
|
+
|
279
|
+
def serialize(serializer)
|
280
|
+
super
|
281
|
+
@children.each_with_index do |child, idx|
|
282
|
+
child.serialize(serializer[idx.to_s])
|
283
|
+
end
|
284
|
+
end
|
285
|
+
end
|
158
286
|
end
|
@@ -23,7 +23,7 @@ module Chainer
|
|
23
23
|
# @param [Numo::NArray.dtype] dtype Type to use in computing.
|
24
24
|
# @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect.
|
25
25
|
# @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect.
|
26
|
-
def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::
|
26
|
+
def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::SFloat, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
|
27
27
|
super()
|
28
28
|
@avg_mean = dtype.zeros(size)
|
29
29
|
register_persistent('avg_mean')
|
data/lib/chainer/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-chainer
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Yusaku Hatanaka
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-05-
|
11
|
+
date: 2018-05-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -96,6 +96,7 @@ files:
|
|
96
96
|
- Rakefile
|
97
97
|
- bin/console
|
98
98
|
- bin/setup
|
99
|
+
- examples/cifar/models/resnet18.rb
|
99
100
|
- examples/cifar/models/vgg.rb
|
100
101
|
- examples/cifar/train_cifar.rb
|
101
102
|
- examples/iris.rb
|
@@ -124,6 +125,7 @@ files:
|
|
124
125
|
- lib/chainer/functions/math/identity.rb
|
125
126
|
- lib/chainer/functions/noise/dropout.rb
|
126
127
|
- lib/chainer/functions/normalization/batch_normalization.rb
|
128
|
+
- lib/chainer/functions/pooling/average_pooling_2d.rb
|
127
129
|
- lib/chainer/functions/pooling/max_pooling_2d.rb
|
128
130
|
- lib/chainer/functions/pooling/pooling_2d.rb
|
129
131
|
- lib/chainer/gradient_check.rb
|