red-chainer 0.1.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +5 -5
- data/.gitignore +1 -0
- data/.travis.yml +3 -1
- data/README.md +8 -12
- data/examples/mnist.rb +32 -4
- data/lib/chainer/functions/activation/leaky_relu.rb +64 -0
- data/lib/chainer/functions/activation/log_softmax.rb +50 -8
- data/lib/chainer/functions/activation/relu.rb +21 -1
- data/lib/chainer/functions/activation/sigmoid.rb +43 -0
- data/lib/chainer/functions/activation/tanh.rb +42 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +2 -2
- data/lib/chainer/iterators/serial_iterator.rb +26 -1
- data/lib/chainer/link.rb +29 -0
- data/lib/chainer/optimizer.rb +38 -0
- data/lib/chainer/serializer.rb +50 -0
- data/lib/chainer/serializers/marshal.rb +83 -0
- data/lib/chainer/training/extensions/log_report.rb +15 -0
- data/lib/chainer/training/extensions/print_report.rb +6 -0
- data/lib/chainer/training/extensions/snapshot.rb +33 -0
- data/lib/chainer/training/standard_updater.rb +12 -0
- data/lib/chainer/training/trainer.rb +26 -3
- data/lib/chainer/training/triggers/interval.rb +29 -4
- data/lib/chainer/version.rb +1 -1
- data/lib/chainer.rb +6 -0
- data/red-chainer.gemspec +1 -1
- metadata +11 -5
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
|
-
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 01f7ae8c937f4fb6384d9d6a0bf975d236da004d9165a734f9f5ab78c48d1197
|
4
|
+
data.tar.gz: 218ae1d0ebe3f90db0861d98d88a0be2a52e160f8001952d223df64ee8d22838
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: aba81905dd93397b368b64fdd517b3b99487cc47dfb2ca7477cadf5e8357038be3f920f950a4968394040bd34e04e6537f24965a52e93f7d4855f174ff305d60
|
7
|
+
data.tar.gz: 96752ee9d7b151eb63e36ddda56d82910a6eee07109376de47cd977a77d1bb2bf0ecf78353835a049b795b35441197ede4cfb1dbff78ad4bd4a220da1c6f5009
|
data/.gitignore
CHANGED
data/.travis.yml
CHANGED
data/README.md
CHANGED
@@ -5,39 +5,35 @@
|
|
5
5
|
Red Cahiner
|
6
6
|
|
7
7
|
## Description
|
8
|
-
|
9
8
|
Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/chainer`. To experiment with that code, run `bin/console` for an interactive prompt.
|
10
9
|
|
11
|
-
TODO: Delete this and the text above, and describe your gem
|
12
|
-
|
13
10
|
## Installation
|
14
11
|
|
15
12
|
Add this line to your application's Gemfile:
|
16
13
|
|
17
|
-
```
|
18
|
-
gem 'red-chainer'
|
14
|
+
```bash
|
15
|
+
gem 'red-chainer'
|
19
16
|
```
|
20
17
|
|
21
18
|
And then execute:
|
22
19
|
|
23
|
-
```
|
20
|
+
```bash
|
24
21
|
$ bundle
|
25
22
|
```
|
26
23
|
|
27
24
|
Or install it yourself as:
|
28
25
|
|
29
|
-
```
|
30
|
-
gem install
|
31
|
-
gem specific_install -l 'https://github.com/red-data-tools/red-chainer'
|
26
|
+
```bash
|
27
|
+
$ gem install red-chainer
|
32
28
|
```
|
33
29
|
|
34
30
|
## Usage
|
35
31
|
mnist sample program is [here](./examples/mnist.rb)
|
36
32
|
|
37
|
-
```
|
38
|
-
# install Gemfile
|
33
|
+
```bash
|
34
|
+
# when install Gemfile
|
39
35
|
$ bundle exec ruby examples/mnist.rb
|
40
|
-
# install yourself
|
36
|
+
# when install yourself
|
41
37
|
$ ruby examples/mnist.rb
|
42
38
|
```
|
43
39
|
|
data/examples/mnist.rb
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
require 'chainer'
|
2
2
|
require 'fileutils'
|
3
|
+
require 'optparse'
|
3
4
|
require 'tmpdir'
|
4
5
|
|
5
6
|
class MLP < Chainer::Chain
|
@@ -22,21 +23,48 @@ class MLP < Chainer::Chain
|
|
22
23
|
end
|
23
24
|
end
|
24
25
|
|
25
|
-
|
26
|
+
args = {
|
27
|
+
batchsize: 100,
|
28
|
+
frequency: -1,
|
29
|
+
epoch: 20,
|
30
|
+
resume: nil,
|
31
|
+
unit: 1000,
|
32
|
+
out: 'result'
|
33
|
+
}
|
34
|
+
|
35
|
+
opt = OptionParser.new
|
36
|
+
opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
|
37
|
+
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
|
38
|
+
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
|
39
|
+
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
|
40
|
+
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
|
41
|
+
opt.on('-u', '--unit VALUE', "Number of units (default: #{args[:unit]})") { |v| args[:unit] = v.to_i }
|
42
|
+
opt.parse!(ARGV)
|
43
|
+
|
44
|
+
model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10))
|
26
45
|
|
27
46
|
optimizer = Chainer::Optimizers::Adam.new
|
28
47
|
optimizer.setup(model)
|
29
48
|
train, test = Chainer::Datasets::Mnist.get_mnist
|
30
49
|
|
31
|
-
train_iter = Chainer::Iterators::SerialIterator.new(train,
|
32
|
-
test_iter = Chainer::Iterators::SerialIterator.new(test,
|
50
|
+
train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
|
51
|
+
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
|
33
52
|
|
34
53
|
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
|
35
|
-
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [
|
54
|
+
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
|
36
55
|
|
37
56
|
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
|
57
|
+
|
58
|
+
# Take a snapshot for each specified epoch
|
59
|
+
frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
|
60
|
+
trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'], priority: -100)
|
61
|
+
|
38
62
|
trainer.extend(Chainer::Training::Extensions::LogReport.new)
|
39
63
|
trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
|
40
64
|
trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
|
41
65
|
|
66
|
+
if args[:resume]
|
67
|
+
Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
|
68
|
+
end
|
69
|
+
|
42
70
|
trainer.run
|
@@ -0,0 +1,64 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Leaky rectifier unit.
|
5
|
+
class LeakyReLU < Function
|
6
|
+
# Leaky Rectified Linear Unit function.
|
7
|
+
#
|
8
|
+
# This function is expressed as
|
9
|
+
#
|
10
|
+
# $$
|
11
|
+
# f(x)=\\max(x, ax),
|
12
|
+
# $$
|
13
|
+
#
|
14
|
+
# where $a$ is a configurable slope value.
|
15
|
+
#
|
16
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
17
|
+
# @param [float] slope Slope value $a$.
|
18
|
+
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
19
|
+
# @example
|
20
|
+
# > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
|
21
|
+
# > x
|
22
|
+
# => Numo::DFloat#shape=[3,2]
|
23
|
+
# [[-1, 0],
|
24
|
+
# [2, -3],
|
25
|
+
# [-2, 1]]
|
26
|
+
# > F = Chainer::Functions::Activation::LeakyReLU
|
27
|
+
# > F.leaky_relu(x, slope:0.2).data
|
28
|
+
# => Numo::DFloat#shape=[3,2]
|
29
|
+
# [[-0.2, 0],
|
30
|
+
# [2, -0.6],
|
31
|
+
# [-0.4, 1]]
|
32
|
+
#
|
33
|
+
def self.leaky_relu(x, slope: 0.2)
|
34
|
+
self.new(slope: slope).(x)
|
35
|
+
end
|
36
|
+
|
37
|
+
def initialize(slope:0.2)
|
38
|
+
@slope = slope
|
39
|
+
end
|
40
|
+
|
41
|
+
def forward_cpu(x)
|
42
|
+
y = x[0].dup()
|
43
|
+
y[x[0] < 0] *= @slope
|
44
|
+
if @slope >= 0
|
45
|
+
retain_inputs([])
|
46
|
+
retain_outputs([0])
|
47
|
+
end
|
48
|
+
[y]
|
49
|
+
end
|
50
|
+
|
51
|
+
def backward_cpu(x, gy)
|
52
|
+
gx = gy[0].dup()
|
53
|
+
if @slope >= 0
|
54
|
+
y = @output_data
|
55
|
+
gx[y[0] < 0] *= @slope
|
56
|
+
else
|
57
|
+
gx[x[0] < 0] *= @slope
|
58
|
+
end
|
59
|
+
[gx]
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
@@ -10,26 +10,68 @@ module Chainer
|
|
10
10
|
m + s
|
11
11
|
end
|
12
12
|
|
13
|
-
def self.
|
13
|
+
def self._log_softmax(x)
|
14
14
|
log_z = logsumexp(x)
|
15
15
|
x - log_z
|
16
16
|
end
|
17
17
|
|
18
|
+
# Log-softmax activation function.
|
18
19
|
class LogSoftmax < Function
|
19
|
-
|
20
|
+
# Channel-wise log-softmax function.
|
21
|
+
#
|
22
|
+
# This function computes its logarithm of softmax along the second axis.
|
23
|
+
# Let $c = (c_1, c_2, \\dots, c_D)$ be the slice of +x+ along with
|
24
|
+
# the second axis. For each slice $c$, it computes the logarithm of
|
25
|
+
# the function $f(\c)$ defined as
|
26
|
+
#
|
27
|
+
# $$
|
28
|
+
# f(\c) = { \\exp(\c) \\over \\sum_{ d } \\exp(c_d) }.
|
29
|
+
# $$
|
30
|
+
#
|
31
|
+
# This method is theoretically equivalent to +log(softmax(x))+ but is more
|
32
|
+
# stable.
|
33
|
+
#
|
34
|
+
# @note
|
35
|
+
# +log(softmax(x))+ may cause underflow when +x+ is too small,
|
36
|
+
# because +softmax(x)+ may returns +0+.
|
37
|
+
# +log_softmax+ method is more stable.
|
38
|
+
#
|
39
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
|
40
|
+
# @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
|
41
|
+
#
|
42
|
+
# @see Chainer::Functions::Softmax
|
43
|
+
#
|
44
|
+
# @example
|
45
|
+
# > x = Numo::DFloat[[0, 1, 2], [0, 2, 4]]
|
46
|
+
# => Numo::DFloat#shape=[2,3]
|
47
|
+
# [[0, 1, 2],
|
48
|
+
# [0, 2, 4]]
|
49
|
+
# > F = Chainer::Functions::Activation::LogSoftmax
|
50
|
+
# > F.log_softmax(x).data
|
51
|
+
# => Numo::DFloat#shape=[2,3]
|
52
|
+
# [[-2.40761, -1.40761, -0.407606],
|
53
|
+
# [-4.14293, -2.14293, -0.142932]]
|
54
|
+
# @example (T.B.I : F.log, F.softmax)
|
55
|
+
# > F.log_softmax(x).data.nearly_eq(F.log(F.softmax(x)).data).all?)
|
56
|
+
# => true
|
57
|
+
#
|
58
|
+
def self.log_softmax(x)
|
20
59
|
self.new.(x)
|
21
60
|
end
|
22
61
|
|
23
|
-
def
|
62
|
+
def forward(xs)
|
63
|
+
y = Chainer::Functions::Activation._log_softmax(xs[0])
|
64
|
+
@x_shape = xs[0].shape
|
65
|
+
@x_dtype = xs[0].class
|
24
66
|
retain_inputs([])
|
25
67
|
retain_outputs([0])
|
26
|
-
|
27
|
-
[Utils::Array.force_array(x[0])]
|
68
|
+
[y]
|
28
69
|
end
|
29
70
|
|
30
|
-
def
|
31
|
-
y = output_data[0]
|
32
|
-
[
|
71
|
+
def backward(x, gy)
|
72
|
+
y = @output_data[0]
|
73
|
+
gx = gy[0] - Numo::NMath.exp(y) * gy[0].sum(axis: 1, keepdims: true)
|
74
|
+
[gx]
|
33
75
|
end
|
34
76
|
end
|
35
77
|
end
|
@@ -1,7 +1,27 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Activation
|
4
|
+
# Rectified Linear Unit.
|
4
5
|
class Relu < Function
|
6
|
+
# Rectified Linear Unit function.
|
7
|
+
#
|
8
|
+
# $$
|
9
|
+
# f(x)=\\max(0, x).
|
10
|
+
# $$
|
11
|
+
#
|
12
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
|
+
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
|
+
# @example
|
15
|
+
# > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
|
16
|
+
# > (x < 0).any?
|
17
|
+
# => true
|
18
|
+
# > F = Chainer::Functions::Activation::Relu
|
19
|
+
# > y = F.relu(x)
|
20
|
+
# > (y.data < 0).any?
|
21
|
+
# => false
|
22
|
+
# > y.shape
|
23
|
+
# => [3, 2]
|
24
|
+
#
|
5
25
|
def self.relu(x)
|
6
26
|
self.new.(x)
|
7
27
|
end
|
@@ -14,7 +34,7 @@ module Chainer
|
|
14
34
|
end
|
15
35
|
|
16
36
|
def backward_cpu(x, gy)
|
17
|
-
y = output_data[0]
|
37
|
+
y = @output_data[0]
|
18
38
|
[Utils::Array.force_array(gy[0] * (y > 0))]
|
19
39
|
end
|
20
40
|
end
|
@@ -0,0 +1,43 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Logistic sigmoid function.
|
5
|
+
class Sigmoid < Function
|
6
|
+
# Element-wise sigmoid logistic function.
|
7
|
+
#
|
8
|
+
# $$
|
9
|
+
# f(x)=(1 + \\exp(-x))^ { -1 }.
|
10
|
+
# $$
|
11
|
+
#
|
12
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
|
+
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
|
+
# @example It maps the input values into the range of $`[0, 1]`$.
|
15
|
+
# > x = Numo::DFloat.new(3).seq(-2, 2)
|
16
|
+
# => Numo::DFloat#shape=[3]
|
17
|
+
# [-2, 0, 2]
|
18
|
+
# > F = Chainer::Functions::Activation::Sigmoid
|
19
|
+
# > F.sigmoid(x).data
|
20
|
+
# => Numo::DFloat#shape=[3]
|
21
|
+
# [0.119203, 0.5, 0.880797]
|
22
|
+
#
|
23
|
+
def self.sigmoid(x)
|
24
|
+
self.new.(x)
|
25
|
+
end
|
26
|
+
|
27
|
+
def forward_cpu(x)
|
28
|
+
half = 0.5
|
29
|
+
y = Utils::Array.force_array((Numo::NMath.tanh(x[0] * half) * half)+ half)
|
30
|
+
retain_inputs([])
|
31
|
+
retain_outputs([0])
|
32
|
+
return [y]
|
33
|
+
end
|
34
|
+
|
35
|
+
def backward_cpu(x, gy)
|
36
|
+
one = 1
|
37
|
+
y = @output_data[0]
|
38
|
+
[Utils::Array.force_array((gy[0] * y) * (one - y))]
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
@@ -0,0 +1,42 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Activation
|
4
|
+
# Hyperbolic tangent function.
|
5
|
+
class Tanh < Function
|
6
|
+
# Elementwise hyperbolic tangent function.
|
7
|
+
#
|
8
|
+
# $$
|
9
|
+
# f(x)=\\tanh(x).
|
10
|
+
# $$
|
11
|
+
#
|
12
|
+
# @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
|
+
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
|
+
# @example
|
15
|
+
# > x = Numo::DFloat.new(3).seq(-1, 2)
|
16
|
+
# => Numo::DFloat#shape=[3]
|
17
|
+
# [-1, 1, 3]
|
18
|
+
# > F = Chainer::Functions::Activation::Tanh
|
19
|
+
# > F.tanh(x).data
|
20
|
+
# => Numo::DFloat#shape=[3]
|
21
|
+
# [-0.761594, 0.761594, 0.995055]
|
22
|
+
#
|
23
|
+
def self.tanh(x)
|
24
|
+
self.new.(x)
|
25
|
+
end
|
26
|
+
|
27
|
+
def forward_cpu(x)
|
28
|
+
y = Utils::Array.force_array(Numo::NMath.tanh(x[0]))
|
29
|
+
retain_inputs([])
|
30
|
+
retain_outputs([0])
|
31
|
+
return [y]
|
32
|
+
end
|
33
|
+
|
34
|
+
def backward_cpu(x, gy)
|
35
|
+
y = @output_data[0]
|
36
|
+
one = y.dtype.type(1)
|
37
|
+
[Utils::Array.force_array(gy[0] * (one - y * y))]
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
@@ -31,7 +31,7 @@ module Chainer
|
|
31
31
|
|
32
32
|
def forward_cpu(inputs)
|
33
33
|
x, t = inputs
|
34
|
-
log_y = Activation.
|
34
|
+
log_y = Activation._log_softmax(x)
|
35
35
|
|
36
36
|
if @cache_score
|
37
37
|
@y = Numo::NMath.exp(log_y)
|
@@ -81,7 +81,7 @@ module Chainer
|
|
81
81
|
if self.instance_variable_defined?(:'@y')
|
82
82
|
y = @y.dup
|
83
83
|
else
|
84
|
-
y = Activation.
|
84
|
+
y = Activation._log_softmax(x)
|
85
85
|
y = Numo::NMath.exp(y)
|
86
86
|
end
|
87
87
|
|
@@ -15,7 +15,7 @@ module Chainer
|
|
15
15
|
def next
|
16
16
|
raise StopIteration if !@repeat && @epoch > 0
|
17
17
|
|
18
|
-
@previous_epoch_detail =
|
18
|
+
@previous_epoch_detail = epoch_detail
|
19
19
|
|
20
20
|
i = @current_position
|
21
21
|
i_end = i + @batch_size
|
@@ -55,6 +55,31 @@ module Chainer
|
|
55
55
|
@epoch + @current_position.to_f / @dataset.size
|
56
56
|
end
|
57
57
|
|
58
|
+
def serialize(serializer)
|
59
|
+
@current_position = serializer.('current_position', @current_position)
|
60
|
+
@epoch = serializer.('epoch', @epoch)
|
61
|
+
@is_new_epoch = serializer.('is_new_epoch', @is_new_epoch)
|
62
|
+
unless @order.nil?
|
63
|
+
begin
|
64
|
+
serializer.('order', @order)
|
65
|
+
rescue KeyError
|
66
|
+
serializer('_order', @order)
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
begin
|
71
|
+
@previous_epoch_detail = serializer.( 'previous_epoch_detail', @previous_epoch_detail)
|
72
|
+
rescue KeyError
|
73
|
+
# guess previous_epoch_detail for older version
|
74
|
+
@previous_epoch_detail = @epoch + (@current_position - @batch_size) / @dataset.size
|
75
|
+
if epoch_detail > 0
|
76
|
+
@previous_epoch_detail = [@previous_epoch_detail, 0.0].max
|
77
|
+
else
|
78
|
+
@previous_epoch_detail = -1.0
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
58
83
|
def reset
|
59
84
|
if @shuffle
|
60
85
|
order = @dataset.size.times.map(&:to_i).shuffle
|
data/lib/chainer/link.rb
CHANGED
@@ -59,6 +59,27 @@ module Chainer
|
|
59
59
|
def namedlinks(skipself: false)
|
60
60
|
yield('/', self) unless skipself
|
61
61
|
end
|
62
|
+
|
63
|
+
def serialize(serializer)
|
64
|
+
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
|
65
|
+
@params.each do |name|
|
66
|
+
param = d[name]
|
67
|
+
data = serializer.(name.to_s, param.data)
|
68
|
+
if param.data.nil? && !data.nil?
|
69
|
+
# Initialize the parameter here
|
70
|
+
param.init(data.shape)
|
71
|
+
if param.data.is_a?(Numo::NArray)
|
72
|
+
param.data.store(data)
|
73
|
+
else
|
74
|
+
param.data.set(Numo::NArray.cast(data))
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
@persistent.each do |name|
|
80
|
+
d[name] = serializer.(name.to_s, d[name])
|
81
|
+
end
|
82
|
+
end
|
62
83
|
end
|
63
84
|
|
64
85
|
class Chain < Link
|
@@ -114,5 +135,13 @@ module Chainer
|
|
114
135
|
end
|
115
136
|
end
|
116
137
|
end
|
138
|
+
|
139
|
+
def serialize(serializer)
|
140
|
+
super(serializer)
|
141
|
+
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
|
142
|
+
@children.each do |name|
|
143
|
+
d[name].serialize(serializer[name.to_s])
|
144
|
+
end
|
145
|
+
end
|
117
146
|
end
|
118
147
|
end
|
data/lib/chainer/optimizer.rb
CHANGED
@@ -19,6 +19,17 @@ module Chainer
|
|
19
19
|
hook(self)
|
20
20
|
end
|
21
21
|
end
|
22
|
+
|
23
|
+
def serialize(serializer)
|
24
|
+
@t = serializer.('t', @t)
|
25
|
+
@epoch = serializer.('epoch', @epoch)
|
26
|
+
|
27
|
+
@target.namedparams() do |(name, param)|
|
28
|
+
if param.respond_to?(:update_rule)
|
29
|
+
param.update_rule.serialize(serializer[name.to_s])
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
22
33
|
end
|
23
34
|
|
24
35
|
class UpdateRule
|
@@ -56,6 +67,33 @@ module Chainer
|
|
56
67
|
raise NotImplementedError
|
57
68
|
end
|
58
69
|
|
70
|
+
|
71
|
+
# Serializes the update rule state.
|
72
|
+
# Be careful that this method only saves/loads the state of the update rule.
|
73
|
+
# The parameters of the target link is not saved/loaded by this
|
74
|
+
# method, and so you need to serialize the target link separately if you
|
75
|
+
# want to fully recover the training state including parameters.
|
76
|
+
#
|
77
|
+
# @param [Chainer::AbstractSerializer] serializer: Serializer object.
|
78
|
+
def serialize(serializer)
|
79
|
+
if @state.nil?
|
80
|
+
if serializer.is_a?(Chainer::Deserializer)
|
81
|
+
# try to initialize the state to retrieve state entries
|
82
|
+
@state = {}
|
83
|
+
self_copy = self.dup
|
84
|
+
arr = Numo::DFloat.new(1)
|
85
|
+
self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
|
86
|
+
@state.keys.each do |key|
|
87
|
+
@state[key] = serializer.(key.to_s, nil)
|
88
|
+
end
|
89
|
+
end
|
90
|
+
else
|
91
|
+
@state.each do |key, val|
|
92
|
+
@state[key] = serializer.(key.to_s, val)
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
59
97
|
private
|
60
98
|
|
61
99
|
def prepare(param)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
module Chainer
|
2
|
+
# Abstract base class of all serializers and deserializers.
|
3
|
+
class AbstractSerializer
|
4
|
+
# Gets a child serializer.
|
5
|
+
# This operator creates a child serializer represented by the given key.
|
6
|
+
#
|
7
|
+
# @param [string] key: Name of the child serializer.
|
8
|
+
def [](key)
|
9
|
+
raise NotImplementedError
|
10
|
+
end
|
11
|
+
|
12
|
+
# Serializes or deserializes a value by given name.
|
13
|
+
# This operator saves or loads a value by given name.
|
14
|
+
# If this is a serializer, then the value is simply saved at the key.
|
15
|
+
# Note that some type information might be missed depending on the
|
16
|
+
# implementation (and the target file format).
|
17
|
+
# If this is a deserializer, then the value is loaded by the key.
|
18
|
+
# The deserialization differently works on scalars and arrays.
|
19
|
+
# For scalars, the ``value`` argument is used just for determining the type of
|
20
|
+
# restored value to be converted, and the converted value is returned.
|
21
|
+
# For arrays, the restored elements are directly copied into the
|
22
|
+
# ``value`` argument. String values are treated like scalars.
|
23
|
+
#
|
24
|
+
# @param [string] key: Name of the serialization entry.
|
25
|
+
# @param [any] value: Object to be (de)serialized.
|
26
|
+
# ``None`` is only supported by deserializers.
|
27
|
+
# @return Serialized or deserialized value.
|
28
|
+
def call(key, value)
|
29
|
+
raise NotImplementedError
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
# Base class of all serializers.
|
34
|
+
class Serializer < AbstractSerializer
|
35
|
+
# Saves an object by this serializer.
|
36
|
+
# This is equivalent to ``obj.serialize(self)``.
|
37
|
+
#
|
38
|
+
# @param [any] obj: Target object to be serialized.
|
39
|
+
def save(obj)
|
40
|
+
obj.serialize(self)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
# Base class of all deserializers.
|
45
|
+
class Deserializer < AbstractSerializer
|
46
|
+
def load(obj)
|
47
|
+
obj.serialize(self)
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -0,0 +1,83 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Serializers
|
3
|
+
class MarshalSerializer < Chainer::Serializer
|
4
|
+
attr_accessor :target, :path
|
5
|
+
|
6
|
+
def self.save_file(filename, obj)
|
7
|
+
s = self.new
|
8
|
+
s.save(obj)
|
9
|
+
Marshal.dump(s.target, filename)
|
10
|
+
end
|
11
|
+
|
12
|
+
def initialize(target: nil, path: "")
|
13
|
+
@target = target.nil? ? {} : target
|
14
|
+
@path = path
|
15
|
+
end
|
16
|
+
|
17
|
+
def [](key)
|
18
|
+
self.class.new(target: @target, path: File.join(@path, key, '/'))
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(key, value)
|
22
|
+
ret = value
|
23
|
+
if value.is_a?(TrueClass)
|
24
|
+
arr = Numo::Bit[1]
|
25
|
+
elsif value.is_a?(FalseClass)
|
26
|
+
arr = Numo::Bit[0]
|
27
|
+
elsif value.instance_of?(String)
|
28
|
+
arr = value
|
29
|
+
else
|
30
|
+
arr = Numo::NArray.cast(value)
|
31
|
+
end
|
32
|
+
@target[File.join(@path, key)] = arr
|
33
|
+
ret
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
class MarshalDeserializer < Chainer::Deserializer
|
38
|
+
# Loads an object from the file in Marshal format.
|
39
|
+
# This is a short-cut function to load from an Marshal file that contains only one object.
|
40
|
+
#
|
41
|
+
# @param [string ]filename: Name of the file to be loaded.
|
42
|
+
# @param [object] obj: Object to be deserialized. It must support serialization protocol.
|
43
|
+
def self.load_file(filename, obj)
|
44
|
+
File.open(filename) do |f|
|
45
|
+
d = self.new(Marshal.load(f))
|
46
|
+
d.load(obj)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
def initialize(marshalData, path: '', strict: true)
|
51
|
+
@marshal_data = marshalData
|
52
|
+
@path = path
|
53
|
+
@strict = strict
|
54
|
+
end
|
55
|
+
|
56
|
+
def [](key)
|
57
|
+
self.class.new(@marshal_data, path: File.join(@path, key, '/'), strict: @strict)
|
58
|
+
end
|
59
|
+
|
60
|
+
def call(key, value)
|
61
|
+
key = File.join(@path, key)
|
62
|
+
if !@strict && !@marshal_data.keys.include?(key)
|
63
|
+
return value
|
64
|
+
end
|
65
|
+
|
66
|
+
dataset = @marshal_data[key]
|
67
|
+
if value.nil?
|
68
|
+
return dataset
|
69
|
+
elsif value.instance_of?(String)
|
70
|
+
return dataset
|
71
|
+
elsif value.is_a?(Numo::NArray)
|
72
|
+
value.store(dataset)
|
73
|
+
return value
|
74
|
+
elsif value.is_a?(TrueClass) || value.is_a?(FalseClass)
|
75
|
+
return dataset[0] == 1
|
76
|
+
else
|
77
|
+
return dataset[0]
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
@@ -61,6 +61,21 @@ module Chainer
|
|
61
61
|
init_summary
|
62
62
|
end
|
63
63
|
|
64
|
+
def serialize(serializer)
|
65
|
+
if @trigger.respond_to?(:serialize)
|
66
|
+
@trigger.serialize(serializer['_trigger'])
|
67
|
+
end
|
68
|
+
# Note that this serialization may lose some information of small
|
69
|
+
# numerical differences.
|
70
|
+
if serializer.is_a?(Chainer::Serializer)
|
71
|
+
log = JSON.generate(@log)
|
72
|
+
serializer.('_log', log)
|
73
|
+
else
|
74
|
+
log = serializer.('_log', '')
|
75
|
+
@log = JSON.parse(log)
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
64
79
|
private
|
65
80
|
|
66
81
|
def init_summary
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Extensions
|
4
|
+
class Snapshot < Extension
|
5
|
+
attr_accessor :save_class, :filename_proc, :target
|
6
|
+
|
7
|
+
def self.snapshot_object(target:, save_class:, &block)
|
8
|
+
self.new(save_class: save_class, filename_proc: block, target: target)
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.snapshot(save_class: nil, &block)
|
12
|
+
self.new(save_class: save_class, filename_proc: block)
|
13
|
+
end
|
14
|
+
|
15
|
+
def initialize(save_class: nil, filename_proc: nil, target: nil)
|
16
|
+
@save_class = save_class || Chainer::Serializers::MarshalSerializer
|
17
|
+
@filename_proc = filename_proc || Proc.new { |trainer| "snapshot_iter_#{trainer.updater.iteration}" }
|
18
|
+
@target = target
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(trainer)
|
22
|
+
target = @target || trainer
|
23
|
+
filename = filename_proc.call(trainer)
|
24
|
+
prefix = "tmp#{filename}"
|
25
|
+
temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
|
26
|
+
save_class.save_file(temp_file, trainer)
|
27
|
+
FileUtils.move(temp_file.path, File.join(trainer.out, filename))
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
@@ -58,6 +58,18 @@ module Chainer
|
|
58
58
|
iterator.finalize
|
59
59
|
end
|
60
60
|
end
|
61
|
+
|
62
|
+
def serialize(serializer)
|
63
|
+
@iterators.each do |name, iterator|
|
64
|
+
iterator.serialize(serializer["iterator:#{name}"])
|
65
|
+
end
|
66
|
+
@optimizers.each do |name, optimizer|
|
67
|
+
optimizer.serialize(serializer["optimizer:#{name}"])
|
68
|
+
optimizer.target.serialize(serializer["model:#{name}"])
|
69
|
+
end
|
70
|
+
|
71
|
+
@iteration = serializer.('iteration', @iteration)
|
72
|
+
end
|
61
73
|
end
|
62
74
|
end
|
63
75
|
end
|
@@ -44,7 +44,7 @@ module Chainer
|
|
44
44
|
return @final_elapsed_time if @done
|
45
45
|
raise "training has not been started yet" if @start_at.nil?
|
46
46
|
|
47
|
-
Time.now.to_f - @start_at
|
47
|
+
Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f
|
48
48
|
end
|
49
49
|
|
50
50
|
def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
|
@@ -97,7 +97,7 @@ module Chainer
|
|
97
97
|
raise 'cannot run training loop multiple times' if @done
|
98
98
|
FileUtils.mkdir_p(@out)
|
99
99
|
|
100
|
-
extensions = @extensions.sort_by { |(_, e)| e.priority }.map { |(name, extension)| [name, extension] }
|
100
|
+
extensions = @extensions.sort_by { |(_, e)| -e.priority }.map { |(name, extension)| [name, extension] }
|
101
101
|
|
102
102
|
@start_at = Time.now.to_f
|
103
103
|
|
@@ -115,7 +115,7 @@ module Chainer
|
|
115
115
|
@observation = {}
|
116
116
|
reporter.scope(@observation) do
|
117
117
|
update.call
|
118
|
-
extensions.each do |(
|
118
|
+
extensions.each do |(name, entry)|
|
119
119
|
entry.extension.(self) if entry.trigger.(self)
|
120
120
|
end
|
121
121
|
end
|
@@ -131,6 +131,29 @@ module Chainer
|
|
131
131
|
@final_elapsed_time = @elapsed_time
|
132
132
|
@done = true
|
133
133
|
end
|
134
|
+
|
135
|
+
def serialize(serializer)
|
136
|
+
updater.serialize(serializer['updater'])
|
137
|
+
if @stop_trigger.respond_to?(:serialize)
|
138
|
+
@stop_trigger.serialize(serializer['stop_trigger'])
|
139
|
+
end
|
140
|
+
|
141
|
+
s = serializer['extensions']
|
142
|
+
t = serializer['extension_triggers']
|
143
|
+
@extensions.each do |name, entry|
|
144
|
+
if entry.extension.respond_to?(:serialize)
|
145
|
+
entry.extension.serialize(s[name])
|
146
|
+
end
|
147
|
+
if entry.trigger.respond_to?(:serialize)
|
148
|
+
entry.trigger.serialize(t[name])
|
149
|
+
end
|
150
|
+
end
|
151
|
+
if serializer.is_a?(Chainer::Serializer)
|
152
|
+
serializer.('_snapshot_elapsed_time', elapsed_time)
|
153
|
+
else
|
154
|
+
@snapshot_elapsed_time = serializer.('_snapshot_elapsed_time', 0.0)
|
155
|
+
end
|
156
|
+
end
|
134
157
|
end
|
135
158
|
end
|
136
159
|
end
|
@@ -8,18 +8,43 @@ module Chainer
|
|
8
8
|
@period = period
|
9
9
|
@unit = unit
|
10
10
|
@count = 0
|
11
|
+
|
12
|
+
@previous_iteration = 0
|
13
|
+
@previous_epoch_detail = 0.0
|
11
14
|
end
|
12
15
|
|
13
16
|
def call(trainer)
|
14
17
|
updater = trainer.updater
|
15
18
|
if @unit == 'epoch'
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
+
epoch_detail = updater.epoch_detail
|
20
|
+
previous_epoch_detail = @previous_epoch_detail
|
21
|
+
|
22
|
+
if previous_epoch_detail < 0
|
23
|
+
previous_epoch_detail = updater.previous_epoch_detail
|
24
|
+
end
|
25
|
+
|
26
|
+
@count = epoch_detail.div(@period).floor
|
27
|
+
|
28
|
+
fire = previous_epoch_detail.div(@period).floor != epoch_detail.div(@period).floor
|
19
29
|
else
|
20
30
|
iteration = updater.iteration
|
21
|
-
|
31
|
+
previous_iteration = @previous_iteration
|
32
|
+
if previous_iteration < 0
|
33
|
+
previous_iteration = iteration - 1
|
34
|
+
end
|
35
|
+
fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
|
22
36
|
end
|
37
|
+
|
38
|
+
# save current values
|
39
|
+
@previous_iteration = updater.iteration
|
40
|
+
@previous_epoch_detail = updater.epoch_detail
|
41
|
+
|
42
|
+
fire
|
43
|
+
end
|
44
|
+
|
45
|
+
def serialize(serializer)
|
46
|
+
@previous_iteration = serializer.('previous_iteration', @previous_iteration)
|
47
|
+
@previous_epoch_detail = serializer.('previous_epoch_detail', @previous_epoch_detail)
|
23
48
|
end
|
24
49
|
end
|
25
50
|
end
|
data/lib/chainer/version.rb
CHANGED
data/lib/chainer.rb
CHANGED
@@ -22,7 +22,10 @@ require 'chainer/variable_node'
|
|
22
22
|
require 'chainer/utils/initializer'
|
23
23
|
require 'chainer/utils/variable'
|
24
24
|
require 'chainer/utils/array'
|
25
|
+
require 'chainer/functions/activation/leaky_relu'
|
25
26
|
require 'chainer/functions/activation/relu'
|
27
|
+
require 'chainer/functions/activation/sigmoid'
|
28
|
+
require 'chainer/functions/activation/tanh'
|
26
29
|
require 'chainer/functions/activation/log_softmax'
|
27
30
|
require 'chainer/functions/evaluation/accuracy'
|
28
31
|
require 'chainer/functions/math/basic_math'
|
@@ -33,6 +36,7 @@ require 'chainer/training/extensions/evaluator'
|
|
33
36
|
require 'chainer/training/extensions/log_report'
|
34
37
|
require 'chainer/training/extensions/print_report'
|
35
38
|
require 'chainer/training/extensions/progress_bar'
|
39
|
+
require 'chainer/training/extensions/snapshot'
|
36
40
|
require 'chainer/training/trainer'
|
37
41
|
require 'chainer/training/updater'
|
38
42
|
require 'chainer/training/util'
|
@@ -44,6 +48,8 @@ require 'chainer/dataset/download'
|
|
44
48
|
require 'chainer/datasets/mnist'
|
45
49
|
require 'chainer/datasets/tuple_dataset'
|
46
50
|
require 'chainer/reporter'
|
51
|
+
require 'chainer/serializer'
|
52
|
+
require 'chainer/serializers/marshal'
|
47
53
|
|
48
54
|
require 'numo/narray'
|
49
55
|
|
data/red-chainer.gemspec
CHANGED
@@ -19,7 +19,7 @@ Gem::Specification.new do |spec|
|
|
19
19
|
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
20
20
|
spec.require_paths = ["lib"]
|
21
21
|
|
22
|
-
spec.add_runtime_dependency "numo-narray", ">= 0.9.
|
22
|
+
spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
|
23
23
|
|
24
24
|
spec.add_development_dependency "bundler", "~> 1.15"
|
25
25
|
spec.add_development_dependency "rake", "~> 10.0"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-chainer
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Yusaku Hatanaka
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2018-02-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 0.9.
|
19
|
+
version: 0.9.1.1
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 0.9.
|
26
|
+
version: 0.9.1.1
|
27
27
|
- !ruby/object:Gem::Dependency
|
28
28
|
name: bundler
|
29
29
|
requirement: !ruby/object:Gem::Requirement
|
@@ -92,8 +92,11 @@ files:
|
|
92
92
|
- lib/chainer/datasets/mnist.rb
|
93
93
|
- lib/chainer/datasets/tuple_dataset.rb
|
94
94
|
- lib/chainer/function.rb
|
95
|
+
- lib/chainer/functions/activation/leaky_relu.rb
|
95
96
|
- lib/chainer/functions/activation/log_softmax.rb
|
96
97
|
- lib/chainer/functions/activation/relu.rb
|
98
|
+
- lib/chainer/functions/activation/sigmoid.rb
|
99
|
+
- lib/chainer/functions/activation/tanh.rb
|
97
100
|
- lib/chainer/functions/connection/linear.rb
|
98
101
|
- lib/chainer/functions/evaluation/accuracy.rb
|
99
102
|
- lib/chainer/functions/loss/softmax_cross_entropy.rb
|
@@ -112,11 +115,14 @@ files:
|
|
112
115
|
- lib/chainer/optimizers/adam.rb
|
113
116
|
- lib/chainer/parameter.rb
|
114
117
|
- lib/chainer/reporter.rb
|
118
|
+
- lib/chainer/serializer.rb
|
119
|
+
- lib/chainer/serializers/marshal.rb
|
115
120
|
- lib/chainer/training/extension.rb
|
116
121
|
- lib/chainer/training/extensions/evaluator.rb
|
117
122
|
- lib/chainer/training/extensions/log_report.rb
|
118
123
|
- lib/chainer/training/extensions/print_report.rb
|
119
124
|
- lib/chainer/training/extensions/progress_bar.rb
|
125
|
+
- lib/chainer/training/extensions/snapshot.rb
|
120
126
|
- lib/chainer/training/standard_updater.rb
|
121
127
|
- lib/chainer/training/trainer.rb
|
122
128
|
- lib/chainer/training/triggers/interval.rb
|
@@ -149,7 +155,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
149
155
|
version: '0'
|
150
156
|
requirements: []
|
151
157
|
rubyforge_project:
|
152
|
-
rubygems_version: 2.
|
158
|
+
rubygems_version: 2.7.3
|
153
159
|
signing_key:
|
154
160
|
specification_version: 4
|
155
161
|
summary: A flexible framework for neural network for Ruby
|