red-chainer 0.1.1 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/.gitignore +1 -0
- data/.travis.yml +3 -1
- data/README.md +8 -12
- data/examples/mnist.rb +32 -4
- data/lib/chainer/functions/activation/leaky_relu.rb +64 -0
- data/lib/chainer/functions/activation/log_softmax.rb +50 -8
- data/lib/chainer/functions/activation/relu.rb +21 -1
- data/lib/chainer/functions/activation/sigmoid.rb +43 -0
- data/lib/chainer/functions/activation/tanh.rb +42 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +2 -2
- data/lib/chainer/iterators/serial_iterator.rb +26 -1
- data/lib/chainer/link.rb +29 -0
- data/lib/chainer/optimizer.rb +38 -0
- data/lib/chainer/serializer.rb +50 -0
- data/lib/chainer/serializers/marshal.rb +83 -0
- data/lib/chainer/training/extensions/log_report.rb +15 -0
- data/lib/chainer/training/extensions/print_report.rb +6 -0
- data/lib/chainer/training/extensions/snapshot.rb +33 -0
- data/lib/chainer/training/standard_updater.rb +12 -0
- data/lib/chainer/training/trainer.rb +26 -3
- data/lib/chainer/training/triggers/interval.rb +29 -4
- data/lib/chainer/version.rb +1 -1
- data/lib/chainer.rb +6 -0
- data/red-chainer.gemspec +1 -1
- metadata +11 -5
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
|
-
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 01f7ae8c937f4fb6384d9d6a0bf975d236da004d9165a734f9f5ab78c48d1197
|
4
|
+
data.tar.gz: 218ae1d0ebe3f90db0861d98d88a0be2a52e160f8001952d223df64ee8d22838
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: aba81905dd93397b368b64fdd517b3b99487cc47dfb2ca7477cadf5e8357038be3f920f950a4968394040bd34e04e6537f24965a52e93f7d4855f174ff305d60
|
7
|
+
data.tar.gz: 96752ee9d7b151eb63e36ddda56d82910a6eee07109376de47cd977a77d1bb2bf0ecf78353835a049b795b35441197ede4cfb1dbff78ad4bd4a220da1c6f5009
|
data/.gitignore
CHANGED
data/.travis.yml
CHANGED
data/README.md
CHANGED
@@ -5,39 +5,35 @@
|
|
5
5
|
Red Cahiner
|
6
6
|
|
7
7
|
## Description
|
8
|
-
|
9
8
|
Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/chainer`. To experiment with that code, run `bin/console` for an interactive prompt.
|
10
9
|
|
11
|
-
TODO: Delete this and the text above, and describe your gem
|
12
|
-
|
13
10
|
## Installation
|
14
11
|
|
15
12
|
Add this line to your application's Gemfile:
|
16
13
|
|
17
|
-
```
|
18
|
-
gem 'red-chainer'
|
14
|
+
```bash
|
15
|
+
gem 'red-chainer'
|
19
16
|
```
|
20
17
|
|
21
18
|
And then execute:
|
22
19
|
|
23
|
-
```
|
20
|
+
```bash
|
24
21
|
$ bundle
|
25
22
|
```
|
26
23
|
|
27
24
|
Or install it yourself as:
|
28
25
|
|
29
|
-
```
|
30
|
-
gem install
|
31
|
-
gem specific_install -l 'https://github.com/red-data-tools/red-chainer'
|
26
|
+
```bash
|
27
|
+
$ gem install red-chainer
|
32
28
|
```
|
33
29
|
|
34
30
|
## Usage
|
35
31
|
mnist sample program is [here](./examples/mnist.rb)
|
36
32
|
|
37
|
-
```
|
38
|
-
# install Gemfile
|
33
|
+
```bash
|
34
|
+
# when install Gemfile
|
39
35
|
$ bundle exec ruby examples/mnist.rb
|
40
|
-
# install yourself
|
36
|
+
# when install yourself
|
41
37
|
$ ruby examples/mnist.rb
|
42
38
|
```
|
43
39
|
|
data/examples/mnist.rb
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
require 'chainer'
|
2
2
|
require 'fileutils'
|
3
|
+
require 'optparse'
|
3
4
|
require 'tmpdir'
|
4
5
|
|
5
6
|
class MLP < Chainer::Chain
|
@@ -22,21 +23,48 @@ class MLP < Chainer::Chain
|
|
22
23
|
end
|
23
24
|
end
|
24
25
|
|
25
|
-
|
26
|
+
args = {
|
27
|
+
batchsize: 100,
|
28
|
+
frequency: -1,
|
29
|
+
epoch: 20,
|
30
|
+
resume: nil,
|
31
|
+
unit: 1000,
|
32
|
+
out: 'result'
|
33
|
+
}
|
34
|
+
|
35
|
+
opt = OptionParser.new
|
36
|
+
opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
|
37
|
+
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
|
38
|
+
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
|
39
|
+
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
|
40
|
+
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
|
41
|
+
opt.on('-u', '--unit VALUE', "Number of units (default: #{args[:unit]})") { |v| args[:unit] = v.to_i }
|
42
|
+
opt.parse!(ARGV)
|
43
|
+
|
44
|
+
model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10))
|
26
45
|
|
27
46
|
optimizer = Chainer::Optimizers::Adam.new
|
28
47
|
optimizer.setup(model)
|
29
48
|
train, test = Chainer::Datasets::Mnist.get_mnist
|
30
49
|
|
31
|
-
train_iter = Chainer::Iterators::SerialIterator.new(train,
|
32
|
-
test_iter = Chainer::Iterators::SerialIterator.new(test,
|
50
|
+
train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
|
51
|
+
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
|
33
52
|
|
34
53
|
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
|
35
|
-
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [
|
54
|
+
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
|
36
55
|
|
37
56
|
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
|
57
|
+
|
58
|
+
# Take a snapshot for each specified epoch
|
59
|
+
frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
|
60
|
+
trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'], priority: -100)
|
61
|
+
|
38
62
|
trainer.extend(Chainer::Training::Extensions::LogReport.new)
|
39
63
|
trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
|
40
64
|
trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
|
41
65
|
|
66
|
+
if args[:resume]
|
67
|
+
Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
|
68
|
+
end
|
69
|
+
|
42
70
|
trainer.run
|
@@ -0,0 +1,64 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Leaky rectifier unit.
|
5
|
+
class LeakyReLU < Function
|
6
|
+
# Leaky Rectified Linear Unit function.
|
7
|
+
#
|
8
|
+
# This function is expressed as
|
9
|
+
#
|
10
|
+
# $$
|
11
|
+
# f(x)=\\max(x, ax),
|
12
|
+
# $$
|
13
|
+
#
|
14
|
+
# where $a$ is a configurable slope value.
|
15
|
+
#
|
16
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
17
|
+
# @param [float] slope Slope value $a$.
|
18
|
+
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
19
|
+
# @example
|
20
|
+
# > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
|
21
|
+
# > x
|
22
|
+
# => Numo::DFloat#shape=[3,2]
|
23
|
+
# [[-1, 0],
|
24
|
+
# [2, -3],
|
25
|
+
# [-2, 1]]
|
26
|
+
# > F = Chainer::Functions::Activation::LeakyReLU
|
27
|
+
# > F.leaky_relu(x, slope:0.2).data
|
28
|
+
# => Numo::DFloat#shape=[3,2]
|
29
|
+
# [[-0.2, 0],
|
30
|
+
# [2, -0.6],
|
31
|
+
# [-0.4, 1]]
|
32
|
+
#
|
33
|
+
def self.leaky_relu(x, slope: 0.2)
|
34
|
+
self.new(slope: slope).(x)
|
35
|
+
end
|
36
|
+
|
37
|
+
def initialize(slope:0.2)
|
38
|
+
@slope = slope
|
39
|
+
end
|
40
|
+
|
41
|
+
def forward_cpu(x)
|
42
|
+
y = x[0].dup()
|
43
|
+
y[x[0] < 0] *= @slope
|
44
|
+
if @slope >= 0
|
45
|
+
retain_inputs([])
|
46
|
+
retain_outputs([0])
|
47
|
+
end
|
48
|
+
[y]
|
49
|
+
end
|
50
|
+
|
51
|
+
def backward_cpu(x, gy)
|
52
|
+
gx = gy[0].dup()
|
53
|
+
if @slope >= 0
|
54
|
+
y = @output_data
|
55
|
+
gx[y[0] < 0] *= @slope
|
56
|
+
else
|
57
|
+
gx[x[0] < 0] *= @slope
|
58
|
+
end
|
59
|
+
[gx]
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
@@ -10,26 +10,68 @@ module Chainer
|
|
10
10
|
m + s
|
11
11
|
end
|
12
12
|
|
13
|
-
def self.
|
13
|
+
def self._log_softmax(x)
|
14
14
|
log_z = logsumexp(x)
|
15
15
|
x - log_z
|
16
16
|
end
|
17
17
|
|
18
|
+
# Log-softmax activation function.
|
18
19
|
class LogSoftmax < Function
|
19
|
-
|
20
|
+
# Channel-wise log-softmax function.
|
21
|
+
#
|
22
|
+
# This function computes its logarithm of softmax along the second axis.
|
23
|
+
# Let $c = (c_1, c_2, \\dots, c_D)$ be the slice of +x+ along with
|
24
|
+
# the second axis. For each slice $c$, it computes the logarithm of
|
25
|
+
# the function $f(\c)$ defined as
|
26
|
+
#
|
27
|
+
# $$
|
28
|
+
# f(\c) = { \\exp(\c) \\over \\sum_{ d } \\exp(c_d) }.
|
29
|
+
# $$
|
30
|
+
#
|
31
|
+
# This method is theoretically equivalent to +log(softmax(x))+ but is more
|
32
|
+
# stable.
|
33
|
+
#
|
34
|
+
# @note
|
35
|
+
# +log(softmax(x))+ may cause underflow when +x+ is too small,
|
36
|
+
# because +softmax(x)+ may returns +0+.
|
37
|
+
# +log_softmax+ method is more stable.
|
38
|
+
#
|
39
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
|
40
|
+
# @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
|
41
|
+
#
|
42
|
+
# @see Chainer::Functions::Softmax
|
43
|
+
#
|
44
|
+
# @example
|
45
|
+
# > x = Numo::DFloat[[0, 1, 2], [0, 2, 4]]
|
46
|
+
# => Numo::DFloat#shape=[2,3]
|
47
|
+
# [[0, 1, 2],
|
48
|
+
# [0, 2, 4]]
|
49
|
+
# > F = Chainer::Functions::Activation::LogSoftmax
|
50
|
+
# > F.log_softmax(x).data
|
51
|
+
# => Numo::DFloat#shape=[2,3]
|
52
|
+
# [[-2.40761, -1.40761, -0.407606],
|
53
|
+
# [-4.14293, -2.14293, -0.142932]]
|
54
|
+
# @example (T.B.I : F.log, F.softmax)
|
55
|
+
# > F.log_softmax(x).data.nearly_eq(F.log(F.softmax(x)).data).all?)
|
56
|
+
# => true
|
57
|
+
#
|
58
|
+
def self.log_softmax(x)
|
20
59
|
self.new.(x)
|
21
60
|
end
|
22
61
|
|
23
|
-
def
|
62
|
+
def forward(xs)
|
63
|
+
y = Chainer::Functions::Activation._log_softmax(xs[0])
|
64
|
+
@x_shape = xs[0].shape
|
65
|
+
@x_dtype = xs[0].class
|
24
66
|
retain_inputs([])
|
25
67
|
retain_outputs([0])
|
26
|
-
|
27
|
-
[Utils::Array.force_array(x[0])]
|
68
|
+
[y]
|
28
69
|
end
|
29
70
|
|
30
|
-
def
|
31
|
-
y = output_data[0]
|
32
|
-
[
|
71
|
+
def backward(x, gy)
|
72
|
+
y = @output_data[0]
|
73
|
+
gx = gy[0] - Numo::NMath.exp(y) * gy[0].sum(axis: 1, keepdims: true)
|
74
|
+
[gx]
|
33
75
|
end
|
34
76
|
end
|
35
77
|
end
|
@@ -1,7 +1,27 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
|
+
# Rectified Linear Unit.
|
4
5
|
class Relu < Function
|
6
|
+
# Rectified Linear Unit function.
|
7
|
+
#
|
8
|
+
# $$
|
9
|
+
# f(x)=\\max(0, x).
|
10
|
+
# $$
|
11
|
+
#
|
12
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
|
+
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
|
+
# @example
|
15
|
+
# > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
|
16
|
+
# > (x < 0).any?
|
17
|
+
# => true
|
18
|
+
# > F = Chainer::Functions::Activation::Relu
|
19
|
+
# > y = F.relu(x)
|
20
|
+
# > (y.data < 0).any?
|
21
|
+
# => false
|
22
|
+
# > y.shape
|
23
|
+
# => [3, 2]
|
24
|
+
#
|
5
25
|
def self.relu(x)
|
6
26
|
self.new.(x)
|
7
27
|
end
|
@@ -14,7 +34,7 @@ module Chainer
|
|
14
34
|
end
|
15
35
|
|
16
36
|
def backward_cpu(x, gy)
|
17
|
-
y = output_data[0]
|
37
|
+
y = @output_data[0]
|
18
38
|
[Utils::Array.force_array(gy[0] * (y > 0))]
|
19
39
|
end
|
20
40
|
end
|
@@ -0,0 +1,43 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Logistic sigmoid function.
|
5
|
+
class Sigmoid < Function
|
6
|
+
# Element-wise sigmoid logistic function.
|
7
|
+
#
|
8
|
+
# $$
|
9
|
+
# f(x)=(1 + \\exp(-x))^ { -1 }.
|
10
|
+
# $$
|
11
|
+
#
|
12
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
|
+
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
|
+
# @example It maps the input values into the range of $`[0, 1]`$.
|
15
|
+
# > x = Numo::DFloat.new(3).seq(-2, 2)
|
16
|
+
# => Numo::DFloat#shape=[3]
|
17
|
+
# [-2, 0, 2]
|
18
|
+
# > F = Chainer::Functions::Activation::Sigmoid
|
19
|
+
# > F.sigmoid(x).data
|
20
|
+
# => Numo::DFloat#shape=[3]
|
21
|
+
# [0.119203, 0.5, 0.880797]
|
22
|
+
#
|
23
|
+
def self.sigmoid(x)
|
24
|
+
self.new.(x)
|
25
|
+
end
|
26
|
+
|
27
|
+
def forward_cpu(x)
|
28
|
+
half = 0.5
|
29
|
+
y = Utils::Array.force_array((Numo::NMath.tanh(x[0] * half) * half)+ half)
|
30
|
+
retain_inputs([])
|
31
|
+
retain_outputs([0])
|
32
|
+
return [y]
|
33
|
+
end
|
34
|
+
|
35
|
+
def backward_cpu(x, gy)
|
36
|
+
one = 1
|
37
|
+
y = @output_data[0]
|
38
|
+
[Utils::Array.force_array((gy[0] * y) * (one - y))]
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
@@ -0,0 +1,42 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Hyperbolic tangent function.
|
5
|
+
class Tanh < Function
|
6
|
+
# Elementwise hyperbolic tangent function.
|
7
|
+
#
|
8
|
+
# $$
|
9
|
+
# f(x)=\\tanh(x).
|
10
|
+
# $$
|
11
|
+
#
|
12
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
|
+
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
|
+
# @example
|
15
|
+
# > x = Numo::DFloat.new(3).seq(-1, 2)
|
16
|
+
# => Numo::DFloat#shape=[3]
|
17
|
+
# [-1, 1, 3]
|
18
|
+
# > F = Chainer::Functions::Activation::Tanh
|
19
|
+
# > F.tanh(x).data
|
20
|
+
# => Numo::DFloat#shape=[3]
|
21
|
+
# [-0.761594, 0.761594, 0.995055]
|
22
|
+
#
|
23
|
+
def self.tanh(x)
|
24
|
+
self.new.(x)
|
25
|
+
end
|
26
|
+
|
27
|
+
def forward_cpu(x)
|
28
|
+
y = Utils::Array.force_array(Numo::NMath.tanh(x[0]))
|
29
|
+
retain_inputs([])
|
30
|
+
retain_outputs([0])
|
31
|
+
return [y]
|
32
|
+
end
|
33
|
+
|
34
|
+
def backward_cpu(x, gy)
|
35
|
+
y = @output_data[0]
|
36
|
+
one = y.dtype.type(1)
|
37
|
+
[Utils::Array.force_array(gy[0] * (one - y * y))]
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
@@ -31,7 +31,7 @@ module Chainer
|
|
31
31
|
|
32
32
|
def forward_cpu(inputs)
|
33
33
|
x, t = inputs
|
34
|
-
log_y = Activation.
|
34
|
+
log_y = Activation._log_softmax(x)
|
35
35
|
|
36
36
|
if @cache_score
|
37
37
|
@y = Numo::NMath.exp(log_y)
|
@@ -81,7 +81,7 @@ module Chainer
|
|
81
81
|
if self.instance_variable_defined?(:'@y')
|
82
82
|
y = @y.dup
|
83
83
|
else
|
84
|
-
y = Activation.
|
84
|
+
y = Activation._log_softmax(x)
|
85
85
|
y = Numo::NMath.exp(y)
|
86
86
|
end
|
87
87
|
|
@@ -15,7 +15,7 @@ module Chainer
|
|
15
15
|
def next
|
16
16
|
raise StopIteration if !@repeat && @epoch > 0
|
17
17
|
|
18
|
-
@previous_epoch_detail =
|
18
|
+
@previous_epoch_detail = epoch_detail
|
19
19
|
|
20
20
|
i = @current_position
|
21
21
|
i_end = i + @batch_size
|
@@ -55,6 +55,31 @@ module Chainer
|
|
55
55
|
@epoch + @current_position.to_f / @dataset.size
|
56
56
|
end
|
57
57
|
|
58
|
+
def serialize(serializer)
|
59
|
+
@current_position = serializer.('current_position', @current_position)
|
60
|
+
@epoch = serializer.('epoch', @epoch)
|
61
|
+
@is_new_epoch = serializer.('is_new_epoch', @is_new_epoch)
|
62
|
+
unless @order.nil?
|
63
|
+
begin
|
64
|
+
serializer.('order', @order)
|
65
|
+
rescue KeyError
|
66
|
+
serializer('_order', @order)
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
begin
|
71
|
+
@previous_epoch_detail = serializer.( 'previous_epoch_detail', @previous_epoch_detail)
|
72
|
+
rescue KeyError
|
73
|
+
# guess previous_epoch_detail for older version
|
74
|
+
@previous_epoch_detail = @epoch + (@current_position - @batch_size) / @dataset.size
|
75
|
+
if epoch_detail > 0
|
76
|
+
@previous_epoch_detail = [@previous_epoch_detail, 0.0].max
|
77
|
+
else
|
78
|
+
@previous_epoch_detail = -1.0
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
58
83
|
def reset
|
59
84
|
if @shuffle
|
60
85
|
order = @dataset.size.times.map(&:to_i).shuffle
|
data/lib/chainer/link.rb
CHANGED
@@ -59,6 +59,27 @@ module Chainer
|
|
59
59
|
def namedlinks(skipself: false)
|
60
60
|
yield('/', self) unless skipself
|
61
61
|
end
|
62
|
+
|
63
|
+
def serialize(serializer)
|
64
|
+
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
|
65
|
+
@params.each do |name|
|
66
|
+
param = d[name]
|
67
|
+
data = serializer.(name.to_s, param.data)
|
68
|
+
if param.data.nil? && !data.nil?
|
69
|
+
# Initialize the parameter here
|
70
|
+
param.init(data.shape)
|
71
|
+
if param.data.is_a?(Numo::NArray)
|
72
|
+
param.data.store(data)
|
73
|
+
else
|
74
|
+
param.data.set(Numo::NArray.cast(data))
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
@persistent.each do |name|
|
80
|
+
d[name] = serializer.(name.to_s, d[name])
|
81
|
+
end
|
82
|
+
end
|
62
83
|
end
|
63
84
|
|
64
85
|
class Chain < Link
|
@@ -114,5 +135,13 @@ module Chainer
|
|
114
135
|
end
|
115
136
|
end
|
116
137
|
end
|
138
|
+
|
139
|
+
def serialize(serializer)
|
140
|
+
super(serializer)
|
141
|
+
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
|
142
|
+
@children.each do |name|
|
143
|
+
d[name].serialize(serializer[name.to_s])
|
144
|
+
end
|
145
|
+
end
|
117
146
|
end
|
118
147
|
end
|
data/lib/chainer/optimizer.rb
CHANGED
@@ -19,6 +19,17 @@ module Chainer
|
|
19
19
|
hook(self)
|
20
20
|
end
|
21
21
|
end
|
22
|
+
|
23
|
+
def serialize(serializer)
|
24
|
+
@t = serializer.('t', @t)
|
25
|
+
@epoch = serializer.('epoch', @epoch)
|
26
|
+
|
27
|
+
@target.namedparams() do |(name, param)|
|
28
|
+
if param.respond_to?(:update_rule)
|
29
|
+
param.update_rule.serialize(serializer[name.to_s])
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
22
33
|
end
|
23
34
|
|
24
35
|
class UpdateRule
|
@@ -56,6 +67,33 @@ module Chainer
|
|
56
67
|
raise NotImplementedError
|
57
68
|
end
|
58
69
|
|
70
|
+
|
71
|
+
# Serializes the update rule state.
|
72
|
+
# Be careful that this method only saves/loads the state of the update rule.
|
73
|
+
# The parameters of the target link is not saved/loaded by this
|
74
|
+
# method, and so you need to serialize the target link separately if you
|
75
|
+
# want to fully recover the training state including parameters.
|
76
|
+
#
|
77
|
+
# @param [Chainer::AbstractSerializer] serializer: Serializer object.
|
78
|
+
def serialize(serializer)
|
79
|
+
if @state.nil?
|
80
|
+
if serializer.is_a?(Chainer::Deserializer)
|
81
|
+
# try to initialize the state to retrieve state entries
|
82
|
+
@state = {}
|
83
|
+
self_copy = self.dup
|
84
|
+
arr = Numo::DFloat.new(1)
|
85
|
+
self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
|
86
|
+
@state.keys.each do |key|
|
87
|
+
@state[key] = serializer.(key.to_s, nil)
|
88
|
+
end
|
89
|
+
end
|
90
|
+
else
|
91
|
+
@state.each do |key, val|
|
92
|
+
@state[key] = serializer.(key.to_s, val)
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
59
97
|
private
|
60
98
|
|
61
99
|
def prepare(param)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
module Chainer
|
2
|
+
# Abstract base class of all serializers and deserializers.
|
3
|
+
class AbstractSerializer
|
4
|
+
# Gets a child serializer.
|
5
|
+
# This operator creates a child serializer represented by the given key.
|
6
|
+
#
|
7
|
+
# @param [string] key: Name of the child serializer.
|
8
|
+
def [](key)
|
9
|
+
raise NotImplementedError
|
10
|
+
end
|
11
|
+
|
12
|
+
# Serializes or deserializes a value by given name.
|
13
|
+
# This operator saves or loads a value by given name.
|
14
|
+
# If this is a serializer, then the value is simply saved at the key.
|
15
|
+
# Note that some type information might be missed depending on the
|
16
|
+
# implementation (and the target file format).
|
17
|
+
# If this is a deserializer, then the value is loaded by the key.
|
18
|
+
# The deserialization differently works on scalars and arrays.
|
19
|
+
# For scalars, the ``value`` argument is used just for determining the type of
|
20
|
+
# restored value to be converted, and the converted value is returned.
|
21
|
+
# For arrays, the restored elements are directly copied into the
|
22
|
+
# ``value`` argument. String values are treated like scalars.
|
23
|
+
#
|
24
|
+
# @param [string] key: Name of the serialization entry.
|
25
|
+
# @param [any] value: Object to be (de)serialized.
|
26
|
+
# ``None`` is only supported by deserializers.
|
27
|
+
# @return Serialized or deserialized value.
|
28
|
+
def call(key, value)
|
29
|
+
raise NotImplementedError
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
# Base class of all serializers.
|
34
|
+
class Serializer < AbstractSerializer
|
35
|
+
# Saves an object by this serializer.
|
36
|
+
# This is equivalent to ``obj.serialize(self)``.
|
37
|
+
#
|
38
|
+
# @param [any] obj: Target object to be serialized.
|
39
|
+
def save(obj)
|
40
|
+
obj.serialize(self)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
# Base class of all deserializers.
|
45
|
+
class Deserializer < AbstractSerializer
|
46
|
+
def load(obj)
|
47
|
+
obj.serialize(self)
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -0,0 +1,83 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Serializers
|
3
|
+
class MarshalSerializer < Chainer::Serializer
|
4
|
+
attr_accessor :target, :path
|
5
|
+
|
6
|
+
def self.save_file(filename, obj)
|
7
|
+
s = self.new
|
8
|
+
s.save(obj)
|
9
|
+
Marshal.dump(s.target, filename)
|
10
|
+
end
|
11
|
+
|
12
|
+
def initialize(target: nil, path: "")
|
13
|
+
@target = target.nil? ? {} : target
|
14
|
+
@path = path
|
15
|
+
end
|
16
|
+
|
17
|
+
def [](key)
|
18
|
+
self.class.new(target: @target, path: File.join(@path, key, '/'))
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(key, value)
|
22
|
+
ret = value
|
23
|
+
if value.is_a?(TrueClass)
|
24
|
+
arr = Numo::Bit[1]
|
25
|
+
elsif value.is_a?(FalseClass)
|
26
|
+
arr = Numo::Bit[0]
|
27
|
+
elsif value.instance_of?(String)
|
28
|
+
arr = value
|
29
|
+
else
|
30
|
+
arr = Numo::NArray.cast(value)
|
31
|
+
end
|
32
|
+
@target[File.join(@path, key)] = arr
|
33
|
+
ret
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
class MarshalDeserializer < Chainer::Deserializer
|
38
|
+
# Loads an object from the file in Marshal format.
|
39
|
+
# This is a short-cut function to load from an Marshal file that contains only one object.
|
40
|
+
#
|
41
|
+
# @param [string ]filename: Name of the file to be loaded.
|
42
|
+
# @param [object] obj: Object to be deserialized. It must support serialization protocol.
|
43
|
+
def self.load_file(filename, obj)
|
44
|
+
File.open(filename) do |f|
|
45
|
+
d = self.new(Marshal.load(f))
|
46
|
+
d.load(obj)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
def initialize(marshalData, path: '', strict: true)
|
51
|
+
@marshal_data = marshalData
|
52
|
+
@path = path
|
53
|
+
@strict = strict
|
54
|
+
end
|
55
|
+
|
56
|
+
def [](key)
|
57
|
+
self.class.new(@marshal_data, path: File.join(@path, key, '/'), strict: @strict)
|
58
|
+
end
|
59
|
+
|
60
|
+
def call(key, value)
|
61
|
+
key = File.join(@path, key)
|
62
|
+
if !@strict && !@marshal_data.keys.include?(key)
|
63
|
+
return value
|
64
|
+
end
|
65
|
+
|
66
|
+
dataset = @marshal_data[key]
|
67
|
+
if value.nil?
|
68
|
+
return dataset
|
69
|
+
elsif value.instance_of?(String)
|
70
|
+
return dataset
|
71
|
+
elsif value.is_a?(Numo::NArray)
|
72
|
+
value.store(dataset)
|
73
|
+
return value
|
74
|
+
elsif value.is_a?(TrueClass) || value.is_a?(FalseClass)
|
75
|
+
return dataset[0] == 1
|
76
|
+
else
|
77
|
+
return dataset[0]
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
@@ -61,6 +61,21 @@ module Chainer
|
|
61
61
|
init_summary
|
62
62
|
end
|
63
63
|
|
64
|
+
def serialize(serializer)
|
65
|
+
if @trigger.respond_to?(:serialize)
|
66
|
+
@trigger.serialize(serializer['_trigger'])
|
67
|
+
end
|
68
|
+
# Note that this serialization may lose some information of small
|
69
|
+
# numerical differences.
|
70
|
+
if serializer.is_a?(Chainer::Serializer)
|
71
|
+
log = JSON.generate(@log)
|
72
|
+
serializer.('_log', log)
|
73
|
+
else
|
74
|
+
log = serializer.('_log', '')
|
75
|
+
@log = JSON.parse(log)
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
64
79
|
private
|
65
80
|
|
66
81
|
def init_summary
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Extensions
|
4
|
+
class Snapshot < Extension
|
5
|
+
attr_accessor :save_class, :filename_proc, :target
|
6
|
+
|
7
|
+
def self.snapshot_object(target:, save_class:, &block)
|
8
|
+
self.new(save_class: save_class, filename_proc: block, target: target)
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.snapshot(save_class: nil, &block)
|
12
|
+
self.new(save_class: save_class, filename_proc: block)
|
13
|
+
end
|
14
|
+
|
15
|
+
def initialize(save_class: nil, filename_proc: nil, target: nil)
|
16
|
+
@save_class = save_class || Chainer::Serializers::MarshalSerializer
|
17
|
+
@filename_proc = filename_proc || Proc.new { |trainer| "snapshot_iter_#{trainer.updater.iteration}" }
|
18
|
+
@target = target
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(trainer)
|
22
|
+
target = @target || trainer
|
23
|
+
filename = filename_proc.call(trainer)
|
24
|
+
prefix = "tmp#{filename}"
|
25
|
+
temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
|
26
|
+
save_class.save_file(temp_file, trainer)
|
27
|
+
FileUtils.move(temp_file.path, File.join(trainer.out, filename))
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
@@ -58,6 +58,18 @@ module Chainer
|
|
58
58
|
iterator.finalize
|
59
59
|
end
|
60
60
|
end
|
61
|
+
|
62
|
+
def serialize(serializer)
|
63
|
+
@iterators.each do |name, iterator|
|
64
|
+
iterator.serialize(serializer["iterator:#{name}"])
|
65
|
+
end
|
66
|
+
@optimizers.each do |name, optimizer|
|
67
|
+
optimizer.serialize(serializer["optimizer:#{name}"])
|
68
|
+
optimizer.target.serialize(serializer["model:#{name}"])
|
69
|
+
end
|
70
|
+
|
71
|
+
@iteration = serializer.('iteration', @iteration)
|
72
|
+
end
|
61
73
|
end
|
62
74
|
end
|
63
75
|
end
|
@@ -44,7 +44,7 @@ module Chainer
|
|
44
44
|
return @final_elapsed_time if @done
|
45
45
|
raise "training has not been started yet" if @start_at.nil?
|
46
46
|
|
47
|
-
Time.now.to_f - @start_at
|
47
|
+
Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f
|
48
48
|
end
|
49
49
|
|
50
50
|
def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
|
@@ -97,7 +97,7 @@ module Chainer
|
|
97
97
|
raise 'cannot run training loop multiple times' if @done
|
98
98
|
FileUtils.mkdir_p(@out)
|
99
99
|
|
100
|
-
extensions = @extensions.sort_by { |(_, e)| e.priority }.map { |(name, extension)| [name, extension] }
|
100
|
+
extensions = @extensions.sort_by { |(_, e)| -e.priority }.map { |(name, extension)| [name, extension] }
|
101
101
|
|
102
102
|
@start_at = Time.now.to_f
|
103
103
|
|
@@ -115,7 +115,7 @@ module Chainer
|
|
115
115
|
@observation = {}
|
116
116
|
reporter.scope(@observation) do
|
117
117
|
update.call
|
118
|
-
extensions.each do |(
|
118
|
+
extensions.each do |(name, entry)|
|
119
119
|
entry.extension.(self) if entry.trigger.(self)
|
120
120
|
end
|
121
121
|
end
|
@@ -131,6 +131,29 @@ module Chainer
|
|
131
131
|
@final_elapsed_time = @elapsed_time
|
132
132
|
@done = true
|
133
133
|
end
|
134
|
+
|
135
|
+
def serialize(serializer)
|
136
|
+
updater.serialize(serializer['updater'])
|
137
|
+
if @stop_trigger.respond_to?(:serialize)
|
138
|
+
@stop_trigger.serialize(serializer['stop_trigger'])
|
139
|
+
end
|
140
|
+
|
141
|
+
s = serializer['extensions']
|
142
|
+
t = serializer['extension_triggers']
|
143
|
+
@extensions.each do |name, entry|
|
144
|
+
if entry.extension.respond_to?(:serialize)
|
145
|
+
entry.extension.serialize(s[name])
|
146
|
+
end
|
147
|
+
if entry.trigger.respond_to?(:serialize)
|
148
|
+
entry.trigger.serialize(t[name])
|
149
|
+
end
|
150
|
+
end
|
151
|
+
if serializer.is_a?(Chainer::Serializer)
|
152
|
+
serializer.('_snapshot_elapsed_time', elapsed_time)
|
153
|
+
else
|
154
|
+
@snapshot_elapsed_time = serializer.('_snapshot_elapsed_time', 0.0)
|
155
|
+
end
|
156
|
+
end
|
134
157
|
end
|
135
158
|
end
|
136
159
|
end
|
@@ -8,18 +8,43 @@ module Chainer
|
|
8
8
|
@period = period
|
9
9
|
@unit = unit
|
10
10
|
@count = 0
|
11
|
+
|
12
|
+
@previous_iteration = 0
|
13
|
+
@previous_epoch_detail = 0.0
|
11
14
|
end
|
12
15
|
|
13
16
|
def call(trainer)
|
14
17
|
updater = trainer.updater
|
15
18
|
if @unit == 'epoch'
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
+
epoch_detail = updater.epoch_detail
|
20
|
+
previous_epoch_detail = @previous_epoch_detail
|
21
|
+
|
22
|
+
if previous_epoch_detail < 0
|
23
|
+
previous_epoch_detail = updater.previous_epoch_detail
|
24
|
+
end
|
25
|
+
|
26
|
+
@count = epoch_detail.div(@period).floor
|
27
|
+
|
28
|
+
fire = previous_epoch_detail.div(@period).floor != epoch_detail.div(@period).floor
|
19
29
|
else
|
20
30
|
iteration = updater.iteration
|
21
|
-
|
31
|
+
previous_iteration = @previous_iteration
|
32
|
+
if previous_iteration < 0
|
33
|
+
previous_iteration = iteration - 1
|
34
|
+
end
|
35
|
+
fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
|
22
36
|
end
|
37
|
+
|
38
|
+
# save current values
|
39
|
+
@previous_iteration = updater.iteration
|
40
|
+
@previous_epoch_detail = updater.epoch_detail
|
41
|
+
|
42
|
+
fire
|
43
|
+
end
|
44
|
+
|
45
|
+
def serialize(serializer)
|
46
|
+
@previous_iteration = serializer.('previous_iteration', @previous_iteration)
|
47
|
+
@previous_epoch_detail = serializer.('previous_epoch_detail', @previous_epoch_detail)
|
23
48
|
end
|
24
49
|
end
|
25
50
|
end
|
data/lib/chainer/version.rb
CHANGED
data/lib/chainer.rb
CHANGED
@@ -22,7 +22,10 @@ require 'chainer/variable_node'
|
|
22
22
|
require 'chainer/utils/initializer'
|
23
23
|
require 'chainer/utils/variable'
|
24
24
|
require 'chainer/utils/array'
|
25
|
+
require 'chainer/functions/activation/leaky_relu'
|
25
26
|
require 'chainer/functions/activation/relu'
|
27
|
+
require 'chainer/functions/activation/sigmoid'
|
28
|
+
require 'chainer/functions/activation/tanh'
|
26
29
|
require 'chainer/functions/activation/log_softmax'
|
27
30
|
require 'chainer/functions/evaluation/accuracy'
|
28
31
|
require 'chainer/functions/math/basic_math'
|
@@ -33,6 +36,7 @@ require 'chainer/training/extensions/evaluator'
|
|
33
36
|
require 'chainer/training/extensions/log_report'
|
34
37
|
require 'chainer/training/extensions/print_report'
|
35
38
|
require 'chainer/training/extensions/progress_bar'
|
39
|
+
require 'chainer/training/extensions/snapshot'
|
36
40
|
require 'chainer/training/trainer'
|
37
41
|
require 'chainer/training/updater'
|
38
42
|
require 'chainer/training/util'
|
@@ -44,6 +48,8 @@ require 'chainer/dataset/download'
|
|
44
48
|
require 'chainer/datasets/mnist'
|
45
49
|
require 'chainer/datasets/tuple_dataset'
|
46
50
|
require 'chainer/reporter'
|
51
|
+
require 'chainer/serializer'
|
52
|
+
require 'chainer/serializers/marshal'
|
47
53
|
|
48
54
|
require 'numo/narray'
|
49
55
|
|
data/red-chainer.gemspec
CHANGED
@@ -19,7 +19,7 @@ Gem::Specification.new do |spec|
|
|
19
19
|
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
20
20
|
spec.require_paths = ["lib"]
|
21
21
|
|
22
|
-
spec.add_runtime_dependency "numo-narray", ">= 0.9.
|
22
|
+
spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
|
23
23
|
|
24
24
|
spec.add_development_dependency "bundler", "~> 1.15"
|
25
25
|
spec.add_development_dependency "rake", "~> 10.0"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-chainer
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Yusaku Hatanaka
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2018-02-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 0.9.
|
19
|
+
version: 0.9.1.1
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 0.9.
|
26
|
+
version: 0.9.1.1
|
27
27
|
- !ruby/object:Gem::Dependency
|
28
28
|
name: bundler
|
29
29
|
requirement: !ruby/object:Gem::Requirement
|
@@ -92,8 +92,11 @@ files:
|
|
92
92
|
- lib/chainer/datasets/mnist.rb
|
93
93
|
- lib/chainer/datasets/tuple_dataset.rb
|
94
94
|
- lib/chainer/function.rb
|
95
|
+
- lib/chainer/functions/activation/leaky_relu.rb
|
95
96
|
- lib/chainer/functions/activation/log_softmax.rb
|
96
97
|
- lib/chainer/functions/activation/relu.rb
|
98
|
+
- lib/chainer/functions/activation/sigmoid.rb
|
99
|
+
- lib/chainer/functions/activation/tanh.rb
|
97
100
|
- lib/chainer/functions/connection/linear.rb
|
98
101
|
- lib/chainer/functions/evaluation/accuracy.rb
|
99
102
|
- lib/chainer/functions/loss/softmax_cross_entropy.rb
|
@@ -112,11 +115,14 @@ files:
|
|
112
115
|
- lib/chainer/optimizers/adam.rb
|
113
116
|
- lib/chainer/parameter.rb
|
114
117
|
- lib/chainer/reporter.rb
|
118
|
+
- lib/chainer/serializer.rb
|
119
|
+
- lib/chainer/serializers/marshal.rb
|
115
120
|
- lib/chainer/training/extension.rb
|
116
121
|
- lib/chainer/training/extensions/evaluator.rb
|
117
122
|
- lib/chainer/training/extensions/log_report.rb
|
118
123
|
- lib/chainer/training/extensions/print_report.rb
|
119
124
|
- lib/chainer/training/extensions/progress_bar.rb
|
125
|
+
- lib/chainer/training/extensions/snapshot.rb
|
120
126
|
- lib/chainer/training/standard_updater.rb
|
121
127
|
- lib/chainer/training/trainer.rb
|
122
128
|
- lib/chainer/training/triggers/interval.rb
|
@@ -149,7 +155,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
149
155
|
version: '0'
|
150
156
|
requirements: []
|
151
157
|
rubyforge_project:
|
152
|
-
rubygems_version: 2.
|
158
|
+
rubygems_version: 2.7.3
|
153
159
|
signing_key:
|
154
160
|
specification_version: 4
|
155
161
|
summary: A flexible framework for neural network for Ruby
|