red-chainer 0.2.1 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar/models/vgg.rb +84 -0
- data/examples/cifar/train_cifar.rb +70 -0
- data/examples/iris.rb +103 -0
- data/lib/chainer.rb +17 -0
- data/lib/chainer/configuration.rb +2 -1
- data/lib/chainer/cuda.rb +18 -0
- data/lib/chainer/dataset/convert.rb +30 -9
- data/lib/chainer/datasets/cifar.rb +56 -0
- data/lib/chainer/datasets/mnist.rb +3 -3
- data/lib/chainer/datasets/tuple_dataset.rb +3 -1
- data/lib/chainer/function.rb +1 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
- data/lib/chainer/functions/activation/log_softmax.rb +4 -4
- data/lib/chainer/functions/activation/relu.rb +3 -4
- data/lib/chainer/functions/activation/sigmoid.rb +4 -4
- data/lib/chainer/functions/activation/tanh.rb +5 -5
- data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
- data/lib/chainer/functions/connection/linear.rb +1 -1
- data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
- data/lib/chainer/functions/math/identity.rb +26 -0
- data/lib/chainer/functions/noise/dropout.rb +45 -0
- data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
- data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
- data/lib/chainer/gradient_check.rb +240 -0
- data/lib/chainer/initializer.rb +2 -0
- data/lib/chainer/initializers/constant.rb +1 -1
- data/lib/chainer/initializers/init.rb +5 -1
- data/lib/chainer/initializers/normal.rb +1 -1
- data/lib/chainer/iterators/serial_iterator.rb +1 -1
- data/lib/chainer/link.rb +11 -0
- data/lib/chainer/links/connection/convolution_2d.rb +98 -0
- data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
- data/lib/chainer/optimizer.rb +40 -1
- data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
- data/lib/chainer/parameter.rb +1 -1
- data/lib/chainer/serializers/marshal.rb +7 -3
- data/lib/chainer/testing/array.rb +32 -0
- data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
- data/lib/chainer/training/extensions/snapshot.rb +1 -1
- data/lib/chainer/training/standard_updater.rb +4 -0
- data/lib/chainer/training/trainer.rb +1 -1
- data/lib/chainer/utils/array.rb +13 -2
- data/lib/chainer/utils/conv.rb +59 -0
- data/lib/chainer/utils/math.rb +72 -0
- data/lib/chainer/utils/variable.rb +7 -3
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +1 -0
- metadata +37 -3
@@ -0,0 +1,78 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Extensions
|
4
|
+
# Trainer extension to exponentially shift an optimizer attribute.
|
5
|
+
#
|
6
|
+
# This extension exponentially increases or decreases the specified attribute of the optimizer.
|
7
|
+
# The typical use case is an exponential decay of the learning rate.
|
8
|
+
# This extension is also called before the training loop starts by default.
|
9
|
+
class ExponentialShift < Extension
|
10
|
+
attr_reader :last_value
|
11
|
+
|
12
|
+
# @param [string] attr Name of the attribute to shift
|
13
|
+
# @param [float] rate Rate of the exponential shift.
|
14
|
+
# @param [float] init Initial value of the attribute.
|
15
|
+
# @param [float] target Target value of the attribute.
|
16
|
+
# @param [Chainer::Optimizer] optimizer Target optimizer to adjust the attribute.
|
17
|
+
def initialize(attr, rate, init: nil, target: nil, optimizer: nil)
|
18
|
+
@attr = attr
|
19
|
+
raise 'ExponentialShift does not support negative rate' if rate < 0
|
20
|
+
@rate = rate
|
21
|
+
@init = init
|
22
|
+
@target = target
|
23
|
+
@optimizer = optimizer
|
24
|
+
@t = 0
|
25
|
+
@last_value = nil
|
26
|
+
end
|
27
|
+
|
28
|
+
def init(trainer)
|
29
|
+
optimizer = get_optimizer(trainer)
|
30
|
+
@init = optimizer.send(@attr) if @init.nil?
|
31
|
+
if @last_value.nil?
|
32
|
+
update_value(optimizer, @init)
|
33
|
+
else
|
34
|
+
update_value(optimizer, @last_value)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def call(trainer)
|
39
|
+
@t += 1
|
40
|
+
|
41
|
+
optimizer = get_optimizer(trainer)
|
42
|
+
value = @init * (@rate ** @t)
|
43
|
+
if @target
|
44
|
+
if @rate > 1
|
45
|
+
if value / @target > 1
|
46
|
+
value = @target
|
47
|
+
end
|
48
|
+
else
|
49
|
+
if value / @target < 1
|
50
|
+
value = @target
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
update_value(optimizer, value)
|
55
|
+
end
|
56
|
+
|
57
|
+
def serialize(serializer)
|
58
|
+
@t = serializer.('t', @t)
|
59
|
+
@last_value = serializer.('last_value', @last_value)
|
60
|
+
if @last_value.is_a?(Numo::NArray)
|
61
|
+
@last_value = @last_value[0]
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
private
|
66
|
+
|
67
|
+
def get_optimizer(trainer)
|
68
|
+
@optimizer || trainer.updater.get_optimizer(:main)
|
69
|
+
end
|
70
|
+
|
71
|
+
def update_value(optimizer, value)
|
72
|
+
optimizer.send("#{@attr}=", value)
|
73
|
+
@last_value = value
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -23,7 +23,7 @@ module Chainer
|
|
23
23
|
filename = filename_proc.call(trainer)
|
24
24
|
prefix = "tmp#{filename}"
|
25
25
|
temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
|
26
|
-
save_class.save_file(temp_file, trainer)
|
26
|
+
save_class.save_file(temp_file.path, trainer)
|
27
27
|
FileUtils.move(temp_file.path, File.join(trainer.out, filename))
|
28
28
|
end
|
29
29
|
end
|
data/lib/chainer/utils/array.rb
CHANGED
@@ -2,8 +2,19 @@ module Chainer
|
|
2
2
|
module Utils
|
3
3
|
module Array
|
4
4
|
def self.force_array(x, dtype=nil)
|
5
|
-
|
6
|
-
|
5
|
+
if x.is_a? Integer or x.is_a? Float
|
6
|
+
if dtype.nil?
|
7
|
+
Numo::NArray.cast(x)
|
8
|
+
else
|
9
|
+
dtype.cast(x.dup)
|
10
|
+
end
|
11
|
+
else
|
12
|
+
if dtype.nil?
|
13
|
+
x
|
14
|
+
else
|
15
|
+
dtype.cast(x)
|
16
|
+
end
|
17
|
+
end
|
7
18
|
end
|
8
19
|
end
|
9
20
|
end
|
@@ -0,0 +1,59 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Utils
|
3
|
+
module Conv
|
4
|
+
def self.get_conv_outsize(size, k, s, p, cover_all: false, d: 1)
|
5
|
+
dk = k + (k - 1) * (d - 1)
|
6
|
+
if cover_all
|
7
|
+
(size + p * 2 - dk + s - 1).div(s) + 1
|
8
|
+
else
|
9
|
+
(size + p * 2 - dk).div(s) + 1
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def self.im2col_cpu(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
|
14
|
+
n, c, h, w = img.shape
|
15
|
+
|
16
|
+
out_h = self.get_conv_outsize(h, kh, sy, ph, cover_all: cover_all, d: dy)
|
17
|
+
raise 'Height in the output should be positive.' if out_h <= 0
|
18
|
+
out_w = self.get_conv_outsize(w, kw, sx, pw, cover_all: cover_all, d: dx)
|
19
|
+
raise 'Width in the output should be positive.' if out_w <= 0
|
20
|
+
|
21
|
+
# padding
|
22
|
+
# TODO: ref: numpy.pad
|
23
|
+
pad_bottom = ph + sy - 1
|
24
|
+
pad_right = pw + sx - 1
|
25
|
+
pad_img = img.class.new(n, c, (h + ph + pad_bottom), (w + pw + pad_right)).fill(pval)
|
26
|
+
pad_img[nil, nil, ph...(ph+h), pw...(pw+w)] = img
|
27
|
+
|
28
|
+
col = pad_img.class.new(n, c, kh, kw, out_h, out_w).rand(1)
|
29
|
+
|
30
|
+
kh.times do |j|
|
31
|
+
jdy = j * dy
|
32
|
+
j_lim = [jdy + sy * out_h, pad_img.shape[2]].min
|
33
|
+
kw.times do |i|
|
34
|
+
idx = i * dx
|
35
|
+
i_lim = [idx + sx * out_w, pad_img.shape[3]].min
|
36
|
+
col[nil, nil, j, i, nil, nil] = pad_img[nil, nil, (jdy...j_lim).step(sy), (idx...i_lim).step(sx)]
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
col
|
41
|
+
end
|
42
|
+
|
43
|
+
def self.col2im_cpu(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
|
44
|
+
n, c, kh, kw, out_h, out_w = col.shape
|
45
|
+
img = col.class.zeros(n, c, h + 2 * ph + sy - 1, w + 2 * pw + sx - 1)
|
46
|
+
kh.times do |j|
|
47
|
+
jdy = j * dy
|
48
|
+
j_lim = [jdy + sy * out_h, img.shape[2]].min
|
49
|
+
kw.times do |i|
|
50
|
+
idx = i * dx
|
51
|
+
i_lim = [idx + sx * out_w, img.shape[3]].min
|
52
|
+
img[nil, nil, (jdy...j_lim).step(sy), (idx...i_lim).step(sx)] += col[nil, nil, j, i, nil, nil]
|
53
|
+
end
|
54
|
+
end
|
55
|
+
return img[nil, nil, (ph...(h + ph)), (pw...(w + pw))]
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
@@ -0,0 +1,72 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Utils
|
3
|
+
module Math
|
4
|
+
def self.tensordot(a, b, axes)
|
5
|
+
if axes.is_a?(Integer)
|
6
|
+
axes_a = (-axes...0).to_a
|
7
|
+
axes_b = (0...axes).to_a
|
8
|
+
else axes.is_a?(Array)
|
9
|
+
axes_a, axes_b = axes
|
10
|
+
end
|
11
|
+
|
12
|
+
axes_a = Array(axes_a)
|
13
|
+
axes_b = Array(axes_b)
|
14
|
+
na = axes_a.size
|
15
|
+
nb = axes_b.size
|
16
|
+
|
17
|
+
as = a.shape
|
18
|
+
nda = a.ndim
|
19
|
+
bs = b.shape
|
20
|
+
ndb = b.ndim
|
21
|
+
equal = true
|
22
|
+
if na != nb
|
23
|
+
equal = false
|
24
|
+
else
|
25
|
+
na.times do |k|
|
26
|
+
if as[axes_a[k]] != bs[axes_b[k]]
|
27
|
+
equal = false
|
28
|
+
break
|
29
|
+
end
|
30
|
+
|
31
|
+
if axes_a[k] < 0
|
32
|
+
axes_a[k] += nda
|
33
|
+
end
|
34
|
+
|
35
|
+
if axes_b[k] < 0
|
36
|
+
axes_b[k] += ndb
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
raise "shape-mismatch for sum" unless equal
|
42
|
+
|
43
|
+
notin = (0...nda).reject { |i| axes_a.include?(i) }
|
44
|
+
newaxes_a = notin + axes_a
|
45
|
+
n2 = 1
|
46
|
+
axes_a.each do |axis|
|
47
|
+
n2 *= as[axis]
|
48
|
+
end
|
49
|
+
tmp = a.shape.reduce(:*) / n2
|
50
|
+
newshape_a = [tmp, n2]
|
51
|
+
olda = notin.map { |axis| as[axis] }
|
52
|
+
|
53
|
+
notin = (0...ndb).reject { |i| axes_b.include?(i) }
|
54
|
+
newaxes_b = axes_b + notin
|
55
|
+
n2 = 1
|
56
|
+
axes_b.each do |axis|
|
57
|
+
n2 *= bs[axis]
|
58
|
+
end
|
59
|
+
tmp = b.shape.reduce(:*) / n2
|
60
|
+
newshape_b = [n2, tmp]
|
61
|
+
oldb = notin.map { |axis| bs[axis] }
|
62
|
+
|
63
|
+
at = a.transpose(*newaxes_a).reshape(*newshape_a)
|
64
|
+
bt = b.transpose(*newaxes_b).reshape(*newshape_b)
|
65
|
+
res = at.dot(bt)
|
66
|
+
|
67
|
+
return res.reshape(*(olda + oldb))
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
|
@@ -6,12 +6,16 @@ module Chainer
|
|
6
6
|
return
|
7
7
|
end
|
8
8
|
|
9
|
-
unless gx.
|
10
|
-
raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}"
|
9
|
+
unless gx.is_a?(x.data.class.superclass)
|
10
|
+
raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}"
|
11
|
+
end
|
12
|
+
|
13
|
+
unless gx.class == x.data.class
|
14
|
+
raise TypeError, "Dtype(Class) of data and grad mismatch\n#{x.data.class} != #{gx.class}"
|
11
15
|
end
|
12
16
|
|
13
17
|
unless gx.shape == x.data.shape
|
14
|
-
raise TypeError, "Shape of data and grad mismatch\n#{x.
|
18
|
+
raise TypeError, "Shape of data and grad mismatch\n#{x.data.shape} != #{gx.shape}"
|
15
19
|
end
|
16
20
|
end
|
17
21
|
end
|
data/lib/chainer/version.rb
CHANGED
data/red-chainer.gemspec
CHANGED
@@ -20,6 +20,7 @@ Gem::Specification.new do |spec|
|
|
20
20
|
spec.require_paths = ["lib"]
|
21
21
|
|
22
22
|
spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
|
23
|
+
spec.add_runtime_dependency "red-datasets"
|
23
24
|
|
24
25
|
spec.add_development_dependency "bundler", "~> 1.15"
|
25
26
|
spec.add_development_dependency "rake", "~> 10.0"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-chainer
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Yusaku Hatanaka
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-05-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -24,6 +24,20 @@ dependencies:
|
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
26
|
version: 0.9.1.1
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: red-datasets
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
27
41
|
- !ruby/object:Gem::Dependency
|
28
42
|
name: bundler
|
29
43
|
requirement: !ruby/object:Gem::Requirement
|
@@ -82,12 +96,17 @@ files:
|
|
82
96
|
- Rakefile
|
83
97
|
- bin/console
|
84
98
|
- bin/setup
|
99
|
+
- examples/cifar/models/vgg.rb
|
100
|
+
- examples/cifar/train_cifar.rb
|
101
|
+
- examples/iris.rb
|
85
102
|
- examples/mnist.rb
|
86
103
|
- lib/chainer.rb
|
87
104
|
- lib/chainer/configuration.rb
|
105
|
+
- lib/chainer/cuda.rb
|
88
106
|
- lib/chainer/dataset/convert.rb
|
89
107
|
- lib/chainer/dataset/download.rb
|
90
108
|
- lib/chainer/dataset/iterator.rb
|
109
|
+
- lib/chainer/datasets/cifar.rb
|
91
110
|
- lib/chainer/datasets/mnist.rb
|
92
111
|
- lib/chainer/datasets/tuple_dataset.rb
|
93
112
|
- lib/chainer/function.rb
|
@@ -96,10 +115,18 @@ files:
|
|
96
115
|
- lib/chainer/functions/activation/relu.rb
|
97
116
|
- lib/chainer/functions/activation/sigmoid.rb
|
98
117
|
- lib/chainer/functions/activation/tanh.rb
|
118
|
+
- lib/chainer/functions/connection/convolution_2d.rb
|
99
119
|
- lib/chainer/functions/connection/linear.rb
|
100
120
|
- lib/chainer/functions/evaluation/accuracy.rb
|
121
|
+
- lib/chainer/functions/loss/mean_squared_error.rb
|
101
122
|
- lib/chainer/functions/loss/softmax_cross_entropy.rb
|
102
123
|
- lib/chainer/functions/math/basic_math.rb
|
124
|
+
- lib/chainer/functions/math/identity.rb
|
125
|
+
- lib/chainer/functions/noise/dropout.rb
|
126
|
+
- lib/chainer/functions/normalization/batch_normalization.rb
|
127
|
+
- lib/chainer/functions/pooling/max_pooling_2d.rb
|
128
|
+
- lib/chainer/functions/pooling/pooling_2d.rb
|
129
|
+
- lib/chainer/gradient_check.rb
|
103
130
|
- lib/chainer/gradient_method.rb
|
104
131
|
- lib/chainer/hyperparameter.rb
|
105
132
|
- lib/chainer/initializer.rb
|
@@ -108,16 +135,21 @@ files:
|
|
108
135
|
- lib/chainer/initializers/normal.rb
|
109
136
|
- lib/chainer/iterators/serial_iterator.rb
|
110
137
|
- lib/chainer/link.rb
|
138
|
+
- lib/chainer/links/connection/convolution_2d.rb
|
111
139
|
- lib/chainer/links/connection/linear.rb
|
112
140
|
- lib/chainer/links/model/classifier.rb
|
141
|
+
- lib/chainer/links/normalization/batch_normalization.rb
|
113
142
|
- lib/chainer/optimizer.rb
|
114
143
|
- lib/chainer/optimizers/adam.rb
|
144
|
+
- lib/chainer/optimizers/momentum_sgd.rb
|
115
145
|
- lib/chainer/parameter.rb
|
116
146
|
- lib/chainer/reporter.rb
|
117
147
|
- lib/chainer/serializer.rb
|
118
148
|
- lib/chainer/serializers/marshal.rb
|
149
|
+
- lib/chainer/testing/array.rb
|
119
150
|
- lib/chainer/training/extension.rb
|
120
151
|
- lib/chainer/training/extensions/evaluator.rb
|
152
|
+
- lib/chainer/training/extensions/exponential_shift.rb
|
121
153
|
- lib/chainer/training/extensions/log_report.rb
|
122
154
|
- lib/chainer/training/extensions/print_report.rb
|
123
155
|
- lib/chainer/training/extensions/progress_bar.rb
|
@@ -128,7 +160,9 @@ files:
|
|
128
160
|
- lib/chainer/training/updater.rb
|
129
161
|
- lib/chainer/training/util.rb
|
130
162
|
- lib/chainer/utils/array.rb
|
163
|
+
- lib/chainer/utils/conv.rb
|
131
164
|
- lib/chainer/utils/initializer.rb
|
165
|
+
- lib/chainer/utils/math.rb
|
132
166
|
- lib/chainer/utils/variable.rb
|
133
167
|
- lib/chainer/variable.rb
|
134
168
|
- lib/chainer/variable_node.rb
|
@@ -154,7 +188,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
154
188
|
version: '0'
|
155
189
|
requirements: []
|
156
190
|
rubyforge_project:
|
157
|
-
rubygems_version: 2.7.
|
191
|
+
rubygems_version: 2.7.6
|
158
192
|
signing_key:
|
159
193
|
specification_version: 4
|
160
194
|
summary: A flexible framework for neural network for Ruby
|