red-chainer 0.2.1 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar/models/vgg.rb +84 -0
- data/examples/cifar/train_cifar.rb +70 -0
- data/examples/iris.rb +103 -0
- data/lib/chainer.rb +17 -0
- data/lib/chainer/configuration.rb +2 -1
- data/lib/chainer/cuda.rb +18 -0
- data/lib/chainer/dataset/convert.rb +30 -9
- data/lib/chainer/datasets/cifar.rb +56 -0
- data/lib/chainer/datasets/mnist.rb +3 -3
- data/lib/chainer/datasets/tuple_dataset.rb +3 -1
- data/lib/chainer/function.rb +1 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
- data/lib/chainer/functions/activation/log_softmax.rb +4 -4
- data/lib/chainer/functions/activation/relu.rb +3 -4
- data/lib/chainer/functions/activation/sigmoid.rb +4 -4
- data/lib/chainer/functions/activation/tanh.rb +5 -5
- data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
- data/lib/chainer/functions/connection/linear.rb +1 -1
- data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
- data/lib/chainer/functions/math/identity.rb +26 -0
- data/lib/chainer/functions/noise/dropout.rb +45 -0
- data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
- data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
- data/lib/chainer/gradient_check.rb +240 -0
- data/lib/chainer/initializer.rb +2 -0
- data/lib/chainer/initializers/constant.rb +1 -1
- data/lib/chainer/initializers/init.rb +5 -1
- data/lib/chainer/initializers/normal.rb +1 -1
- data/lib/chainer/iterators/serial_iterator.rb +1 -1
- data/lib/chainer/link.rb +11 -0
- data/lib/chainer/links/connection/convolution_2d.rb +98 -0
- data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
- data/lib/chainer/optimizer.rb +40 -1
- data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
- data/lib/chainer/parameter.rb +1 -1
- data/lib/chainer/serializers/marshal.rb +7 -3
- data/lib/chainer/testing/array.rb +32 -0
- data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
- data/lib/chainer/training/extensions/snapshot.rb +1 -1
- data/lib/chainer/training/standard_updater.rb +4 -0
- data/lib/chainer/training/trainer.rb +1 -1
- data/lib/chainer/utils/array.rb +13 -2
- data/lib/chainer/utils/conv.rb +59 -0
- data/lib/chainer/utils/math.rb +72 -0
- data/lib/chainer/utils/variable.rb +7 -3
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +1 -0
- metadata +37 -3
@@ -0,0 +1,78 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Extensions
|
4
|
+
# Trainer extension to exponentially shift an optimizer attribute.
|
5
|
+
#
|
6
|
+
# This extension exponentially increases or decreases the specified attribute of the optimizer.
|
7
|
+
# The typical use case is an exponential decay of the learning rate.
|
8
|
+
# This extension is also called before the training loop starts by default.
|
9
|
+
class ExponentialShift < Extension
|
10
|
+
attr_reader :last_value
|
11
|
+
|
12
|
+
# @param [string] attr Name of the attribute to shift
|
13
|
+
# @param [float] rate Rate of the exponential shift.
|
14
|
+
# @param [float] init Initial value of the attribute.
|
15
|
+
# @param [float] target Target value of the attribute.
|
16
|
+
# @param [Chainer::Optimizer] optimizer Target optimizer to adjust the attribute.
|
17
|
+
def initialize(attr, rate, init: nil, target: nil, optimizer: nil)
|
18
|
+
@attr = attr
|
19
|
+
raise 'ExponentialShift does not support negative rate' if rate < 0
|
20
|
+
@rate = rate
|
21
|
+
@init = init
|
22
|
+
@target = target
|
23
|
+
@optimizer = optimizer
|
24
|
+
@t = 0
|
25
|
+
@last_value = nil
|
26
|
+
end
|
27
|
+
|
28
|
+
def init(trainer)
|
29
|
+
optimizer = get_optimizer(trainer)
|
30
|
+
@init = optimizer.send(@attr) if @init.nil?
|
31
|
+
if @last_value.nil?
|
32
|
+
update_value(optimizer, @init)
|
33
|
+
else
|
34
|
+
update_value(optimizer, @last_value)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def call(trainer)
|
39
|
+
@t += 1
|
40
|
+
|
41
|
+
optimizer = get_optimizer(trainer)
|
42
|
+
value = @init * (@rate ** @t)
|
43
|
+
if @target
|
44
|
+
if @rate > 1
|
45
|
+
if value / @target > 1
|
46
|
+
value = @target
|
47
|
+
end
|
48
|
+
else
|
49
|
+
if value / @target < 1
|
50
|
+
value = @target
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
update_value(optimizer, value)
|
55
|
+
end
|
56
|
+
|
57
|
+
def serialize(serializer)
|
58
|
+
@t = serializer.('t', @t)
|
59
|
+
@last_value = serializer.('last_value', @last_value)
|
60
|
+
if @last_value.is_a?(Numo::NArray)
|
61
|
+
@last_value = @last_value[0]
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
private
|
66
|
+
|
67
|
+
def get_optimizer(trainer)
|
68
|
+
@optimizer || trainer.updater.get_optimizer(:main)
|
69
|
+
end
|
70
|
+
|
71
|
+
def update_value(optimizer, value)
|
72
|
+
optimizer.send("#{@attr}=", value)
|
73
|
+
@last_value = value
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -23,7 +23,7 @@ module Chainer
|
|
23
23
|
filename = filename_proc.call(trainer)
|
24
24
|
prefix = "tmp#{filename}"
|
25
25
|
temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
|
26
|
-
save_class.save_file(temp_file, trainer)
|
26
|
+
save_class.save_file(temp_file.path, trainer)
|
27
27
|
FileUtils.move(temp_file.path, File.join(trainer.out, filename))
|
28
28
|
end
|
29
29
|
end
|
data/lib/chainer/utils/array.rb
CHANGED
@@ -2,8 +2,19 @@ module Chainer
|
|
2
2
|
module Utils
|
3
3
|
module Array
|
4
4
|
def self.force_array(x, dtype=nil)
|
5
|
-
|
6
|
-
|
5
|
+
if x.is_a? Integer or x.is_a? Float
|
6
|
+
if dtype.nil?
|
7
|
+
Numo::NArray.cast(x)
|
8
|
+
else
|
9
|
+
dtype.cast(x.dup)
|
10
|
+
end
|
11
|
+
else
|
12
|
+
if dtype.nil?
|
13
|
+
x
|
14
|
+
else
|
15
|
+
dtype.cast(x)
|
16
|
+
end
|
17
|
+
end
|
7
18
|
end
|
8
19
|
end
|
9
20
|
end
|
@@ -0,0 +1,59 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Utils
|
3
|
+
module Conv
|
4
|
+
def self.get_conv_outsize(size, k, s, p, cover_all: false, d: 1)
|
5
|
+
dk = k + (k - 1) * (d - 1)
|
6
|
+
if cover_all
|
7
|
+
(size + p * 2 - dk + s - 1).div(s) + 1
|
8
|
+
else
|
9
|
+
(size + p * 2 - dk).div(s) + 1
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def self.im2col_cpu(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
|
14
|
+
n, c, h, w = img.shape
|
15
|
+
|
16
|
+
out_h = self.get_conv_outsize(h, kh, sy, ph, cover_all: cover_all, d: dy)
|
17
|
+
raise 'Height in the output should be positive.' if out_h <= 0
|
18
|
+
out_w = self.get_conv_outsize(w, kw, sx, pw, cover_all: cover_all, d: dx)
|
19
|
+
raise 'Width in the output should be positive.' if out_w <= 0
|
20
|
+
|
21
|
+
# padding
|
22
|
+
# TODO: ref: numpy.pad
|
23
|
+
pad_bottom = ph + sy - 1
|
24
|
+
pad_right = pw + sx - 1
|
25
|
+
pad_img = img.class.new(n, c, (h + ph + pad_bottom), (w + pw + pad_right)).fill(pval)
|
26
|
+
pad_img[nil, nil, ph...(ph+h), pw...(pw+w)] = img
|
27
|
+
|
28
|
+
col = pad_img.class.new(n, c, kh, kw, out_h, out_w).rand(1)
|
29
|
+
|
30
|
+
kh.times do |j|
|
31
|
+
jdy = j * dy
|
32
|
+
j_lim = [jdy + sy * out_h, pad_img.shape[2]].min
|
33
|
+
kw.times do |i|
|
34
|
+
idx = i * dx
|
35
|
+
i_lim = [idx + sx * out_w, pad_img.shape[3]].min
|
36
|
+
col[nil, nil, j, i, nil, nil] = pad_img[nil, nil, (jdy...j_lim).step(sy), (idx...i_lim).step(sx)]
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
col
|
41
|
+
end
|
42
|
+
|
43
|
+
def self.col2im_cpu(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
|
44
|
+
n, c, kh, kw, out_h, out_w = col.shape
|
45
|
+
img = col.class.zeros(n, c, h + 2 * ph + sy - 1, w + 2 * pw + sx - 1)
|
46
|
+
kh.times do |j|
|
47
|
+
jdy = j * dy
|
48
|
+
j_lim = [jdy + sy * out_h, img.shape[2]].min
|
49
|
+
kw.times do |i|
|
50
|
+
idx = i * dx
|
51
|
+
i_lim = [idx + sx * out_w, img.shape[3]].min
|
52
|
+
img[nil, nil, (jdy...j_lim).step(sy), (idx...i_lim).step(sx)] += col[nil, nil, j, i, nil, nil]
|
53
|
+
end
|
54
|
+
end
|
55
|
+
return img[nil, nil, (ph...(h + ph)), (pw...(w + pw))]
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
@@ -0,0 +1,72 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Utils
|
3
|
+
module Math
|
4
|
+
def self.tensordot(a, b, axes)
|
5
|
+
if axes.is_a?(Integer)
|
6
|
+
axes_a = (-axes...0).to_a
|
7
|
+
axes_b = (0...axes).to_a
|
8
|
+
else axes.is_a?(Array)
|
9
|
+
axes_a, axes_b = axes
|
10
|
+
end
|
11
|
+
|
12
|
+
axes_a = Array(axes_a)
|
13
|
+
axes_b = Array(axes_b)
|
14
|
+
na = axes_a.size
|
15
|
+
nb = axes_b.size
|
16
|
+
|
17
|
+
as = a.shape
|
18
|
+
nda = a.ndim
|
19
|
+
bs = b.shape
|
20
|
+
ndb = b.ndim
|
21
|
+
equal = true
|
22
|
+
if na != nb
|
23
|
+
equal = false
|
24
|
+
else
|
25
|
+
na.times do |k|
|
26
|
+
if as[axes_a[k]] != bs[axes_b[k]]
|
27
|
+
equal = false
|
28
|
+
break
|
29
|
+
end
|
30
|
+
|
31
|
+
if axes_a[k] < 0
|
32
|
+
axes_a[k] += nda
|
33
|
+
end
|
34
|
+
|
35
|
+
if axes_b[k] < 0
|
36
|
+
axes_b[k] += ndb
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
raise "shape-mismatch for sum" unless equal
|
42
|
+
|
43
|
+
notin = (0...nda).reject { |i| axes_a.include?(i) }
|
44
|
+
newaxes_a = notin + axes_a
|
45
|
+
n2 = 1
|
46
|
+
axes_a.each do |axis|
|
47
|
+
n2 *= as[axis]
|
48
|
+
end
|
49
|
+
tmp = a.shape.reduce(:*) / n2
|
50
|
+
newshape_a = [tmp, n2]
|
51
|
+
olda = notin.map { |axis| as[axis] }
|
52
|
+
|
53
|
+
notin = (0...ndb).reject { |i| axes_b.include?(i) }
|
54
|
+
newaxes_b = axes_b + notin
|
55
|
+
n2 = 1
|
56
|
+
axes_b.each do |axis|
|
57
|
+
n2 *= bs[axis]
|
58
|
+
end
|
59
|
+
tmp = b.shape.reduce(:*) / n2
|
60
|
+
newshape_b = [n2, tmp]
|
61
|
+
oldb = notin.map { |axis| bs[axis] }
|
62
|
+
|
63
|
+
at = a.transpose(*newaxes_a).reshape(*newshape_a)
|
64
|
+
bt = b.transpose(*newaxes_b).reshape(*newshape_b)
|
65
|
+
res = at.dot(bt)
|
66
|
+
|
67
|
+
return res.reshape(*(olda + oldb))
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
|
@@ -6,12 +6,16 @@ module Chainer
|
|
6
6
|
return
|
7
7
|
end
|
8
8
|
|
9
|
-
unless gx.
|
10
|
-
raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}"
|
9
|
+
unless gx.is_a?(x.data.class.superclass)
|
10
|
+
raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}"
|
11
|
+
end
|
12
|
+
|
13
|
+
unless gx.class == x.data.class
|
14
|
+
raise TypeError, "Dtype(Class) of data and grad mismatch\n#{x.data.class} != #{gx.class}"
|
11
15
|
end
|
12
16
|
|
13
17
|
unless gx.shape == x.data.shape
|
14
|
-
raise TypeError, "Shape of data and grad mismatch\n#{x.
|
18
|
+
raise TypeError, "Shape of data and grad mismatch\n#{x.data.shape} != #{gx.shape}"
|
15
19
|
end
|
16
20
|
end
|
17
21
|
end
|
data/lib/chainer/version.rb
CHANGED
data/red-chainer.gemspec
CHANGED
@@ -20,6 +20,7 @@ Gem::Specification.new do |spec|
|
|
20
20
|
spec.require_paths = ["lib"]
|
21
21
|
|
22
22
|
spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
|
23
|
+
spec.add_runtime_dependency "red-datasets"
|
23
24
|
|
24
25
|
spec.add_development_dependency "bundler", "~> 1.15"
|
25
26
|
spec.add_development_dependency "rake", "~> 10.0"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-chainer
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Yusaku Hatanaka
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-05-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -24,6 +24,20 @@ dependencies:
|
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
26
|
version: 0.9.1.1
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: red-datasets
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
27
41
|
- !ruby/object:Gem::Dependency
|
28
42
|
name: bundler
|
29
43
|
requirement: !ruby/object:Gem::Requirement
|
@@ -82,12 +96,17 @@ files:
|
|
82
96
|
- Rakefile
|
83
97
|
- bin/console
|
84
98
|
- bin/setup
|
99
|
+
- examples/cifar/models/vgg.rb
|
100
|
+
- examples/cifar/train_cifar.rb
|
101
|
+
- examples/iris.rb
|
85
102
|
- examples/mnist.rb
|
86
103
|
- lib/chainer.rb
|
87
104
|
- lib/chainer/configuration.rb
|
105
|
+
- lib/chainer/cuda.rb
|
88
106
|
- lib/chainer/dataset/convert.rb
|
89
107
|
- lib/chainer/dataset/download.rb
|
90
108
|
- lib/chainer/dataset/iterator.rb
|
109
|
+
- lib/chainer/datasets/cifar.rb
|
91
110
|
- lib/chainer/datasets/mnist.rb
|
92
111
|
- lib/chainer/datasets/tuple_dataset.rb
|
93
112
|
- lib/chainer/function.rb
|
@@ -96,10 +115,18 @@ files:
|
|
96
115
|
- lib/chainer/functions/activation/relu.rb
|
97
116
|
- lib/chainer/functions/activation/sigmoid.rb
|
98
117
|
- lib/chainer/functions/activation/tanh.rb
|
118
|
+
- lib/chainer/functions/connection/convolution_2d.rb
|
99
119
|
- lib/chainer/functions/connection/linear.rb
|
100
120
|
- lib/chainer/functions/evaluation/accuracy.rb
|
121
|
+
- lib/chainer/functions/loss/mean_squared_error.rb
|
101
122
|
- lib/chainer/functions/loss/softmax_cross_entropy.rb
|
102
123
|
- lib/chainer/functions/math/basic_math.rb
|
124
|
+
- lib/chainer/functions/math/identity.rb
|
125
|
+
- lib/chainer/functions/noise/dropout.rb
|
126
|
+
- lib/chainer/functions/normalization/batch_normalization.rb
|
127
|
+
- lib/chainer/functions/pooling/max_pooling_2d.rb
|
128
|
+
- lib/chainer/functions/pooling/pooling_2d.rb
|
129
|
+
- lib/chainer/gradient_check.rb
|
103
130
|
- lib/chainer/gradient_method.rb
|
104
131
|
- lib/chainer/hyperparameter.rb
|
105
132
|
- lib/chainer/initializer.rb
|
@@ -108,16 +135,21 @@ files:
|
|
108
135
|
- lib/chainer/initializers/normal.rb
|
109
136
|
- lib/chainer/iterators/serial_iterator.rb
|
110
137
|
- lib/chainer/link.rb
|
138
|
+
- lib/chainer/links/connection/convolution_2d.rb
|
111
139
|
- lib/chainer/links/connection/linear.rb
|
112
140
|
- lib/chainer/links/model/classifier.rb
|
141
|
+
- lib/chainer/links/normalization/batch_normalization.rb
|
113
142
|
- lib/chainer/optimizer.rb
|
114
143
|
- lib/chainer/optimizers/adam.rb
|
144
|
+
- lib/chainer/optimizers/momentum_sgd.rb
|
115
145
|
- lib/chainer/parameter.rb
|
116
146
|
- lib/chainer/reporter.rb
|
117
147
|
- lib/chainer/serializer.rb
|
118
148
|
- lib/chainer/serializers/marshal.rb
|
149
|
+
- lib/chainer/testing/array.rb
|
119
150
|
- lib/chainer/training/extension.rb
|
120
151
|
- lib/chainer/training/extensions/evaluator.rb
|
152
|
+
- lib/chainer/training/extensions/exponential_shift.rb
|
121
153
|
- lib/chainer/training/extensions/log_report.rb
|
122
154
|
- lib/chainer/training/extensions/print_report.rb
|
123
155
|
- lib/chainer/training/extensions/progress_bar.rb
|
@@ -128,7 +160,9 @@ files:
|
|
128
160
|
- lib/chainer/training/updater.rb
|
129
161
|
- lib/chainer/training/util.rb
|
130
162
|
- lib/chainer/utils/array.rb
|
163
|
+
- lib/chainer/utils/conv.rb
|
131
164
|
- lib/chainer/utils/initializer.rb
|
165
|
+
- lib/chainer/utils/math.rb
|
132
166
|
- lib/chainer/utils/variable.rb
|
133
167
|
- lib/chainer/variable.rb
|
134
168
|
- lib/chainer/variable_node.rb
|
@@ -154,7 +188,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
154
188
|
version: '0'
|
155
189
|
requirements: []
|
156
190
|
rubyforge_project:
|
157
|
-
rubygems_version: 2.7.
|
191
|
+
rubygems_version: 2.7.6
|
158
192
|
signing_key:
|
159
193
|
specification_version: 4
|
160
194
|
summary: A flexible framework for neural network for Ruby
|