red-chainer 0.3.1 → 0.3.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +23 -8
- data/examples/{iris.rb → iris/iris.rb} +12 -9
- data/examples/{mnist.rb → mnist/mnist.rb} +0 -0
- data/lib/chainer/datasets/cifar.rb +11 -12
- data/lib/chainer/iterators/serial_iterator.rb +1 -1
- data/lib/chainer/reporter.rb +16 -0
- data/lib/chainer/training/extension.rb +7 -2
- data/lib/chainer/training/extensions/evaluator.rb +119 -0
- data/lib/chainer/training/extensions/exponential_shift.rb +3 -3
- data/lib/chainer/training/extensions/log_report.rb +8 -8
- data/lib/chainer/training/extensions/print_report.rb +2 -2
- data/lib/chainer/training/extensions/progress_bar.rb +3 -3
- data/lib/chainer/training/extensions/snapshot.rb +2 -0
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +1 -1
- metadata +6 -6
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 33a95bf098a08c334e6a9d29ff791350e6ac0bd9de843054f3ed14f2a005b79b
|
4
|
+
data.tar.gz: 6f4e53b84e93d01e26363b5d43dba73ec4fc515ecbebbd1574ac6a864f52fb83
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 40054365541bb8956c4a8211fbd489e75d1557a9e5c23bcb49f2cc2b393f0d96574d18259bad8e147525a3b720a5329f904be18c7c4f0891f1d886241b72a65d
|
7
|
+
data.tar.gz: fcb8e641d0efc1ffacc2014c2954d1c5a1e5f947fdcc18306f8946fb307c659e37bb3ccf08b9661d17e13184463b0282ae3594a6c252813d23a53fa0c6182a67
|
data/README.md
CHANGED
@@ -1,11 +1,14 @@
|
|
1
|
-
#
|
1
|
+
# Red Chainer : A deep learning framework
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
Red Chainer
|
3
|
+
A flexible framework for neural network for Ruby
|
6
4
|
|
7
5
|
## Description
|
8
|
-
|
6
|
+
|
7
|
+
It ported python's [Chainer](https://github.com/chainer/chainer) with Ruby.
|
8
|
+
|
9
|
+
## Requirements
|
10
|
+
|
11
|
+
* Ruby 2.3 or later
|
9
12
|
|
10
13
|
## Installation
|
11
14
|
|
@@ -28,15 +31,27 @@ $ gem install red-chainer
|
|
28
31
|
```
|
29
32
|
|
30
33
|
## Usage
|
31
|
-
mnist sample program is [here](./examples/mnist.rb)
|
34
|
+
mnist sample program is [here](./examples/mnist/mnist.rb)
|
32
35
|
|
33
36
|
```bash
|
34
37
|
# when install Gemfile
|
35
|
-
$ bundle exec ruby examples/mnist.rb
|
38
|
+
$ bundle exec ruby examples/mnist/mnist.rb
|
36
39
|
# when install yourself
|
37
|
-
$ ruby examples/mnist.rb
|
40
|
+
$ ruby examples/mnist/mnist.rb
|
38
41
|
```
|
39
42
|
|
40
43
|
## License
|
41
44
|
|
42
45
|
The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
|
46
|
+
|
47
|
+
## Red Chainer implementation status
|
48
|
+
|
49
|
+
| | Chainer 2.0<br>(Initial ported version) | Red Chainer (0.3.1) | example |
|
50
|
+
| ---- | ---- | ---- | ---- |
|
51
|
+
| [activation](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/activation) | 15 | 5 | LogSoftmax, ReLU, LeakyReLU, Sigmoid, Tanh |
|
52
|
+
| [loss](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/loss) | 17 | 2 | SoftMax, MeanSquaredError |
|
53
|
+
| [optimizer](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/optimizers) | 9 | 2 | Adam, MomentumSGDRule |
|
54
|
+
| [connection](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/connection) | 12 | 2 | Linear, Convolution2D |
|
55
|
+
| [pooling](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/pooling) | 14 | 3 | Pooling2D, MaxPooling2D, AveragePooling2D |
|
56
|
+
| [example](https://github.com/red-data-tools/red-chainer/tree/master/examples) | 31 | 3 | MNIST, Iris, CIFAR |
|
57
|
+
| GPU | use cupy | ToDo | want to support [Cumo](https://github.com/sonots/cumo) |
|
@@ -31,10 +31,11 @@ optimizer = Chainer::Optimizers::Adam.new
|
|
31
31
|
optimizer.setup(model)
|
32
32
|
|
33
33
|
iris = Datasets::Iris.new
|
34
|
-
|
34
|
+
iris_table = iris.to_table
|
35
|
+
x = iris_table.fetch_values(:sepal_length, :sepal_width, :petal_length, :petal_width).transpose
|
35
36
|
|
36
37
|
# target
|
37
|
-
y_class =
|
38
|
+
y_class = iris_table[:class]
|
38
39
|
|
39
40
|
# class index array
|
40
41
|
# ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
|
@@ -45,17 +46,12 @@ y = y_class.map{|s|
|
|
45
46
|
}
|
46
47
|
|
47
48
|
# y_onehot => One-hot [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0],,, [0.0, 1.0, 0.0], ,, [0.0, 0.0, 1.0]]
|
48
|
-
y_onehot =
|
49
|
-
i = class_name.index(s)
|
50
|
-
a = Array.new(class_name.size, 0.0)
|
51
|
-
a[i] = 1.0
|
52
|
-
a
|
53
|
-
}
|
49
|
+
y_onehot = Numo::SFloat.eye(class_name.size)[y,false]
|
54
50
|
|
55
51
|
puts "Iris Datasets"
|
56
52
|
puts "No. [sepal_length, sepal_width, petal_length, petal_width] one-hot #=> class"
|
57
53
|
x.each_with_index{|r, i|
|
58
|
-
puts "#{'%3d' % i} : [#{r.join(', ')}] #{y_onehot[i]} #=> #{y_class[i]}(#{y[i]})"
|
54
|
+
puts "#{'%3d' % i} : [#{r.join(', ')}] #{y_onehot[i, false].to_a} #=> #{y_class[i]}(#{y[i]})"
|
59
55
|
}
|
60
56
|
# [5.1, 3.5, 1.4, 0.2, "Iris-setosa"] => 50 data
|
61
57
|
# [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"] => 50 data
|
@@ -70,8 +66,13 @@ y_train = y_onehot[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-ve
|
|
70
66
|
x_test = x[(0..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
71
67
|
y_test = y[(0..-1).step(2)] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
72
68
|
|
69
|
+
puts
|
70
|
+
|
73
71
|
# Train
|
72
|
+
print("Training ")
|
73
|
+
|
74
74
|
10000.times{|i|
|
75
|
+
print(".") if i % 1000 == 0
|
75
76
|
x = Chainer::Variable.new(x_train)
|
76
77
|
y = Chainer::Variable.new(y_train)
|
77
78
|
model.cleargrads()
|
@@ -80,6 +81,8 @@ y_test = y[(0..-1).step(2)] #=> 75 data (Iris-setosa : 25, Iris-ve
|
|
80
81
|
optimizer.update()
|
81
82
|
}
|
82
83
|
|
84
|
+
puts
|
85
|
+
|
83
86
|
# Test
|
84
87
|
xt = Chainer::Variable.new(x_test)
|
85
88
|
yt = model.fwd(xt)
|
File without changes
|
@@ -12,18 +12,17 @@ module Chainer
|
|
12
12
|
end
|
13
13
|
|
14
14
|
def self.get_cifar(n_classes, with_label, ndim, scale)
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
test_labels << (n_classes == 10 ? record.label : record.fine_label)
|
15
|
+
train_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).to_table
|
16
|
+
test_table = ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).to_table
|
17
|
+
|
18
|
+
train_data = train_table[:pixels]
|
19
|
+
test_data = test_table[:pixels]
|
20
|
+
if n_classes == 10
|
21
|
+
train_labels = train_table[:label]
|
22
|
+
test_labels = test_table[:label]
|
23
|
+
else
|
24
|
+
train_labels = train_table[:fine_label]
|
25
|
+
test_labels = test_table[:fine_label]
|
27
26
|
end
|
28
27
|
|
29
28
|
[
|
data/lib/chainer/reporter.rb
CHANGED
@@ -6,6 +6,8 @@ module Chainer
|
|
6
6
|
class Reporter
|
7
7
|
include ReportService
|
8
8
|
|
9
|
+
attr_accessor :observer_names, :observation
|
10
|
+
|
9
11
|
def initialize
|
10
12
|
@observer_names = {}
|
11
13
|
@observation = {}
|
@@ -16,6 +18,14 @@ module Chainer
|
|
16
18
|
reporter.report(values, observer)
|
17
19
|
end
|
18
20
|
|
21
|
+
def self.report_scope(observation)
|
22
|
+
current = @@reporters[-1]
|
23
|
+
old = current.observation
|
24
|
+
current.observation = observation
|
25
|
+
yield
|
26
|
+
current.observation = old
|
27
|
+
end
|
28
|
+
|
19
29
|
def report(values, observer=nil)
|
20
30
|
# TODO: keep_graph_on_report option
|
21
31
|
if observer
|
@@ -37,6 +47,12 @@ module Chainer
|
|
37
47
|
@observer_names[observer.object_id] = name
|
38
48
|
end
|
39
49
|
|
50
|
+
def add_observers(prefix, observers, skipself: true)
|
51
|
+
observers.call(skipself: skipself) do |name, observer|
|
52
|
+
@observer_names[observer.object_id] = "#{prefix}#{name}"
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
40
56
|
def scope(observation)
|
41
57
|
@@reporters << self
|
42
58
|
old = @observation
|
@@ -5,7 +5,8 @@ module Chainer
|
|
5
5
|
PRIORITY_EDITOR = 200
|
6
6
|
PRIORITY_READER = 100
|
7
7
|
|
8
|
-
attr_accessor :name
|
8
|
+
attr_accessor :name
|
9
|
+
attr_writer :trigger, :priority
|
9
10
|
|
10
11
|
def initialize
|
11
12
|
end
|
@@ -14,7 +15,11 @@ module Chainer
|
|
14
15
|
end
|
15
16
|
|
16
17
|
def default_name
|
17
|
-
self.class.
|
18
|
+
self.class.name.split('::').last
|
19
|
+
end
|
20
|
+
|
21
|
+
def trigger
|
22
|
+
@trigger || [1, 'iteration']
|
18
23
|
end
|
19
24
|
|
20
25
|
def priority
|
@@ -1,9 +1,48 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Training
|
3
3
|
module Extensions
|
4
|
+
# Trainer extension to evaluate models on a validation set.
|
5
|
+
# This extension evaluates the current models by a given evaluation function.
|
6
|
+
#
|
7
|
+
# It creates a Chainer::Reporter object to store values observed in
|
8
|
+
# the evaluation function on each iteration. The report for all iterations
|
9
|
+
# are aggregated to Chainer::DictSummary. The collected mean values
|
10
|
+
# are further reported to the reporter object of the trainer, where the name
|
11
|
+
# of each observation is prefixed by the evaluator name. See
|
12
|
+
# Chainer::Reporter for details in naming rules of the reports.
|
13
|
+
#
|
14
|
+
# Evaluator has a structure to customize similar to that of Chainer::Training::StandardUpdater.
|
15
|
+
# The main differences are:
|
16
|
+
#
|
17
|
+
# - There are no optimizers in an evaluator. Instead, it holds links to evaluate.
|
18
|
+
# - An evaluation loop function is used instead of an update function.
|
19
|
+
# - Preparation routine can be customized, which is called before each evaluation.
|
20
|
+
# It can be used, e.g., to initialize the state of stateful recurrent networks.
|
21
|
+
#
|
22
|
+
# There are two ways to modify the evaluation behavior besides setting a custom evaluation function.
|
23
|
+
# One is by setting a custom evaluation loop via the `eval_func` argument.
|
24
|
+
# The other is by inheriting this class and overriding the `evaluate` method.
|
25
|
+
# In latter case, users have to create and handle a reporter object manually.
|
26
|
+
# Users also have to copy the iterators before using them, in order to reuse them at the next time of evaluation.
|
27
|
+
# In both cases, the functions are called in testing mode (i.e., `chainer.config.train` is set to `false`).
|
28
|
+
#
|
29
|
+
# This extension is called at the end of each epoch by default.
|
4
30
|
class Evaluator < Extension
|
31
|
+
# @param [Dataset::Iterator] iterator Dataset iterator for the validation dataset. It can also be a dictionary of iterators.
|
32
|
+
# If this is just an iterator, the iterator is registered by the name 'main'.
|
33
|
+
# @param [Chainer::Link] target Link object or a dictionary of links to evaluate.
|
34
|
+
# If this is just a link object, the link is registered by the name 'main'.
|
35
|
+
# @param [Dataset::Convert] converter Converter function to build input arrays.
|
36
|
+
# `Chainer::Dataset.concat_examples` is used by default.
|
37
|
+
# @param [integer] device Device to which the training data is sent. Negative value indicates the host memory (CPU).
|
38
|
+
# @param [Function] eval_hook Function to prepare for each evaluation process.
|
39
|
+
# It is called at the beginning of the evaluation.
|
40
|
+
# The evaluator extension object is passed at each call.
|
41
|
+
# @param [Function] eval_func Evaluation function called at each iteration.
|
42
|
+
# The target link to evaluate as a callable is used by default.
|
5
43
|
def initialize(iterator, target, converter: nil, device: nil, eval_hook: nil, eval_func: nil)
|
6
44
|
@priority = Extension::PRIORITY_WRITER
|
45
|
+
@trigger = [1, 'epoch']
|
7
46
|
|
8
47
|
if iterator.kind_of?(Dataset::Iterator)
|
9
48
|
iterator = { main: iterator }
|
@@ -20,6 +59,86 @@ module Chainer
|
|
20
59
|
@eval_hook = eval_hook
|
21
60
|
@eval_func = eval_func
|
22
61
|
end
|
62
|
+
|
63
|
+
# Executes the evaluator extension.
|
64
|
+
#
|
65
|
+
# Unlike usual extensions, this extension can be executed without passing a trainer object.
|
66
|
+
# This extension reports the performance on validation dataset using the `Chainer.report` function.
|
67
|
+
# Thus, users can use this extension independently from any trainer by manually configuring a `Chainer::Reporter` object.
|
68
|
+
#
|
69
|
+
# @param [Chainer::Training::Trainer] trainer Trainer object that invokes this extension.
|
70
|
+
# It can be omitted in case of calling this extension manually.
|
71
|
+
def call(trainer = nil)
|
72
|
+
reporter = Reporter.new
|
73
|
+
prefix = self.respond_to?(:name) ? "#{self.name}/" : ""
|
74
|
+
|
75
|
+
@targets.each do |name, target|
|
76
|
+
reporter.add_observer("#{prefix}#{name}", target)
|
77
|
+
reporter.add_observers("#{prefix}#{name}", target.method(:namedlinks), skipself: true)
|
78
|
+
end
|
79
|
+
|
80
|
+
result = nil
|
81
|
+
reporter.scope(reporter.observation) do
|
82
|
+
old_train = Chainer.configuration.train
|
83
|
+
Chainer.configuration.train = false
|
84
|
+
result = evaluate()
|
85
|
+
Chainer.configuration.train = old_train
|
86
|
+
end
|
87
|
+
|
88
|
+
Reporter.save_report(result)
|
89
|
+
return result
|
90
|
+
end
|
91
|
+
|
92
|
+
# Evaluates the model and returns a result dictionary.
|
93
|
+
# This method runs the evaluation loop over the validation dataset.
|
94
|
+
# It accumulates the reported values to `DictSummary` and returns a dictionary whose values are means computed by the summary.
|
95
|
+
#
|
96
|
+
# Users can override this method to customize the evaluation routine.
|
97
|
+
# @return dict Result dictionary. This dictionary is further reported via `Chainer.save_report` without specifying any observer.
|
98
|
+
def evaluate
|
99
|
+
iterator = @iterators[:main]
|
100
|
+
target = @targets[:main]
|
101
|
+
eval_func = @eval_func || target
|
102
|
+
|
103
|
+
@eval_hook.(self) if @eval_hook
|
104
|
+
|
105
|
+
if iterator.respond_to?(:reset)
|
106
|
+
iterator.reset
|
107
|
+
it = iterator
|
108
|
+
else
|
109
|
+
it = iterator.dup
|
110
|
+
end
|
111
|
+
|
112
|
+
summary = DictSummary.new
|
113
|
+
|
114
|
+
until it.is_new_epoch do
|
115
|
+
batch = it.next
|
116
|
+
observation = {}
|
117
|
+
Reporter.report_scope(observation) do
|
118
|
+
in_arrays = @converter.(batch, device: @device)
|
119
|
+
|
120
|
+
old_enable_backprop = Chainer.configuration.enable_backprop
|
121
|
+
Chainer.configuration.enable_backprop = false
|
122
|
+
|
123
|
+
if in_arrays.kind_of?(Array)
|
124
|
+
eval_func.(*in_arrays)
|
125
|
+
elsif in_arrays.kind_of?(Hash)
|
126
|
+
eval_func.(**in_arrays)
|
127
|
+
else
|
128
|
+
eval_func.(in_arrays)
|
129
|
+
end
|
130
|
+
|
131
|
+
Chainer.configuration.enable_backprop = old_enable_backprop
|
132
|
+
end
|
133
|
+
summary.add(observation)
|
134
|
+
end
|
135
|
+
|
136
|
+
summary.compute_mean()
|
137
|
+
end
|
138
|
+
|
139
|
+
def default_name
|
140
|
+
"validation"
|
141
|
+
end
|
23
142
|
end
|
24
143
|
end
|
25
144
|
end
|
@@ -2,12 +2,12 @@ module Chainer
|
|
2
2
|
module Training
|
3
3
|
module Extensions
|
4
4
|
# Trainer extension to exponentially shift an optimizer attribute.
|
5
|
-
#
|
5
|
+
#
|
6
6
|
# This extension exponentially increases or decreases the specified attribute of the optimizer.
|
7
7
|
# The typical use case is an exponential decay of the learning rate.
|
8
8
|
# This extension is also called before the training loop starts by default.
|
9
9
|
class ExponentialShift < Extension
|
10
|
-
attr_reader :last_value
|
10
|
+
attr_reader :last_value
|
11
11
|
|
12
12
|
# @param [string] attr Name of the attribute to shift
|
13
13
|
# @param [float] rate Rate of the exponential shift.
|
@@ -62,7 +62,7 @@ module Chainer
|
|
62
62
|
end
|
63
63
|
end
|
64
64
|
|
65
|
-
private
|
65
|
+
private
|
66
66
|
|
67
67
|
def get_optimizer(trainer)
|
68
68
|
@optimizer || trainer.updater.get_optimizer(:main)
|
@@ -9,7 +9,7 @@ module Chainer
|
|
9
9
|
|
10
10
|
def initialize(keys: nil, trigger: [1, 'epoch'], postprocess: nil, log_name: 'log')
|
11
11
|
@keys = keys
|
12
|
-
@
|
12
|
+
@_trigger = Chainer::Training::Util.get_trigger(trigger)
|
13
13
|
@postprocess = postprocess
|
14
14
|
@log_name = log_name
|
15
15
|
@log = []
|
@@ -25,11 +25,11 @@ module Chainer
|
|
25
25
|
else
|
26
26
|
symbolized_observation = Hash[observation.map{|(k,v)| [k.to_sym,v]}]
|
27
27
|
filterd_keys = @keys.select {|k| observation.keys.include?(k.to_sym) }
|
28
|
-
@summary.add(filterd_keys.each_with_object({}) {|k, hash| hash[k.to_s] = observation[k.to_sym] })
|
28
|
+
@summary.add(filterd_keys.each_with_object({}) {|k, hash| hash[k.to_s] = observation[k.to_sym] })
|
29
29
|
end
|
30
30
|
|
31
|
-
# if
|
32
|
-
return unless @
|
31
|
+
# if @_trigger is true, output the result
|
32
|
+
return unless @_trigger.(trainer)
|
33
33
|
|
34
34
|
stats = @summary.compute_mean
|
35
35
|
stats_cpu = {}
|
@@ -41,9 +41,9 @@ module Chainer
|
|
41
41
|
stats_cpu['epoch'] = updater.epoch
|
42
42
|
stats_cpu['iteration'] = updater.iteration
|
43
43
|
stats_cpu['elapsed_time'] = trainer.elapsed_time
|
44
|
-
|
44
|
+
|
45
45
|
@postprocess.(stats_cpu) unless @postprocess.nil?
|
46
|
-
|
46
|
+
|
47
47
|
@log << stats_cpu
|
48
48
|
|
49
49
|
unless @log_name.nil?
|
@@ -62,8 +62,8 @@ module Chainer
|
|
62
62
|
end
|
63
63
|
|
64
64
|
def serialize(serializer)
|
65
|
-
if @
|
66
|
-
@
|
65
|
+
if @_trigger.respond_to?(:serialize)
|
66
|
+
@_trigger.serialize(serializer['_trigger'])
|
67
67
|
end
|
68
68
|
# Note that this serialization may lose some information of small
|
69
69
|
# numerical differences.
|
@@ -2,7 +2,7 @@ module Chainer
|
|
2
2
|
module Training
|
3
3
|
module Extensions
|
4
4
|
class PrintReport < Extension
|
5
|
-
def initialize(entries, log_report: '
|
5
|
+
def initialize(entries, log_report: 'LogReport', out: STDOUT)
|
6
6
|
@entries = entries
|
7
7
|
@log_report = log_report
|
8
8
|
@out = out
|
@@ -27,7 +27,7 @@ module Chainer
|
|
27
27
|
@out.write(@header)
|
28
28
|
@header = nil
|
29
29
|
end
|
30
|
-
|
30
|
+
|
31
31
|
if @log_report.is_a?(String)
|
32
32
|
log_report = trainer.get_extension(@log_report)
|
33
33
|
elsif @log_report.is_a?(LogReport)
|
@@ -14,7 +14,7 @@ module Chainer
|
|
14
14
|
@recent_timing = []
|
15
15
|
end
|
16
16
|
|
17
|
-
def call(trainer)
|
17
|
+
def call(trainer)
|
18
18
|
if @training_length.nil?
|
19
19
|
t = trainer.stop_trigger
|
20
20
|
raise TypeError, "cannot retrieve the training length #{t.class}" unless t.is_a?(Chainer::Training::Triggers::IntervalTrigger)
|
@@ -27,7 +27,7 @@ module Chainer
|
|
27
27
|
|
28
28
|
length, unit = @training_length
|
29
29
|
iteration = trainer.updater.iteration
|
30
|
-
|
30
|
+
|
31
31
|
# print the progress bar according to interval
|
32
32
|
return unless iteration % @update_interval == 0
|
33
33
|
|
@@ -69,7 +69,7 @@ module Chainer
|
|
69
69
|
else
|
70
70
|
estimated_time = (length - epoch) / speed_e
|
71
71
|
end
|
72
|
-
|
72
|
+
|
73
73
|
@out.write(sprintf("%10.5g iters/sec. Estimated time to finish: %s.\n", speed_t, (Time.parse("1991/01/01") + (estimated_time)).strftime("%H:%m:%S")))
|
74
74
|
|
75
75
|
# move the cursor to the head of the progress bar
|
@@ -13,6 +13,8 @@ module Chainer
|
|
13
13
|
end
|
14
14
|
|
15
15
|
def initialize(save_class: nil, filename_proc: nil, target: nil)
|
16
|
+
@priority = -100
|
17
|
+
@trigger = [1, 'epoch']
|
16
18
|
@save_class = save_class || Chainer::Serializers::MarshalSerializer
|
17
19
|
@filename_proc = filename_proc || Proc.new { |trainer| "snapshot_iter_#{trainer.updater.iteration}" }
|
18
20
|
@target = target
|
data/lib/chainer/version.rb
CHANGED
data/red-chainer.gemspec
CHANGED
@@ -20,7 +20,7 @@ Gem::Specification.new do |spec|
|
|
20
20
|
spec.require_paths = ["lib"]
|
21
21
|
|
22
22
|
spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
|
23
|
-
spec.add_runtime_dependency "red-datasets"
|
23
|
+
spec.add_runtime_dependency "red-datasets", ">= 0.0.5"
|
24
24
|
|
25
25
|
spec.add_development_dependency "bundler", "~> 1.15"
|
26
26
|
spec.add_development_dependency "rake", "~> 10.0"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-chainer
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Yusaku Hatanaka
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-06-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -30,14 +30,14 @@ dependencies:
|
|
30
30
|
requirements:
|
31
31
|
- - ">="
|
32
32
|
- !ruby/object:Gem::Version
|
33
|
-
version:
|
33
|
+
version: 0.0.5
|
34
34
|
type: :runtime
|
35
35
|
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - ">="
|
39
39
|
- !ruby/object:Gem::Version
|
40
|
-
version:
|
40
|
+
version: 0.0.5
|
41
41
|
- !ruby/object:Gem::Dependency
|
42
42
|
name: bundler
|
43
43
|
requirement: !ruby/object:Gem::Requirement
|
@@ -99,8 +99,8 @@ files:
|
|
99
99
|
- examples/cifar/models/resnet18.rb
|
100
100
|
- examples/cifar/models/vgg.rb
|
101
101
|
- examples/cifar/train_cifar.rb
|
102
|
-
- examples/iris.rb
|
103
|
-
- examples/mnist.rb
|
102
|
+
- examples/iris/iris.rb
|
103
|
+
- examples/mnist/mnist.rb
|
104
104
|
- lib/chainer.rb
|
105
105
|
- lib/chainer/configuration.rb
|
106
106
|
- lib/chainer/cuda.rb
|