red-chainer 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +23 -0
- data/README.md +60 -0
- data/Rakefile +8 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/examples/mnist.rb +42 -0
- data/lib/chainer.rb +59 -0
- data/lib/chainer/configuration.rb +10 -0
- data/lib/chainer/dataset/convert.rb +62 -0
- data/lib/chainer/dataset/download.rb +56 -0
- data/lib/chainer/dataset/iterator.rb +15 -0
- data/lib/chainer/datasets/mnist.rb +89 -0
- data/lib/chainer/datasets/tuple_dataset.rb +33 -0
- data/lib/chainer/function.rb +80 -0
- data/lib/chainer/functions/activation/log_softmax.rb +37 -0
- data/lib/chainer/functions/activation/relu.rb +23 -0
- data/lib/chainer/functions/connection/linear.rb +48 -0
- data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
- data/lib/chainer/functions/math/basic_math.rb +119 -0
- data/lib/chainer/gradient_method.rb +63 -0
- data/lib/chainer/hyperparameter.rb +23 -0
- data/lib/chainer/initializer.rb +12 -0
- data/lib/chainer/initializers/constant.rb +18 -0
- data/lib/chainer/initializers/init.rb +24 -0
- data/lib/chainer/initializers/normal.rb +28 -0
- data/lib/chainer/iterators/serial_iterator.rb +74 -0
- data/lib/chainer/link.rb +118 -0
- data/lib/chainer/links/connection/linear.rb +43 -0
- data/lib/chainer/links/model/classifier.rb +39 -0
- data/lib/chainer/optimizer.rb +69 -0
- data/lib/chainer/optimizers/adam.rb +62 -0
- data/lib/chainer/parameter.rb +53 -0
- data/lib/chainer/reporter.rb +130 -0
- data/lib/chainer/training/extension.rb +25 -0
- data/lib/chainer/training/extensions/evaluator.rb +26 -0
- data/lib/chainer/training/extensions/log_report.rb +72 -0
- data/lib/chainer/training/extensions/print_report.rb +62 -0
- data/lib/chainer/training/extensions/progress_bar.rb +89 -0
- data/lib/chainer/training/standard_updater.rb +63 -0
- data/lib/chainer/training/trainer.rb +136 -0
- data/lib/chainer/training/triggers/interval.rb +27 -0
- data/lib/chainer/training/updater.rb +33 -0
- data/lib/chainer/training/util.rb +13 -0
- data/lib/chainer/utils/array.rb +10 -0
- data/lib/chainer/utils/initializer.rb +14 -0
- data/lib/chainer/utils/variable.rb +20 -0
- data/lib/chainer/variable.rb +204 -0
- data/lib/chainer/variable_node.rb +71 -0
- data/lib/chainer/version.rb +4 -0
- data/red-chainer.gemspec +27 -0
- metadata +156 -0
@@ -0,0 +1,72 @@
|
|
1
|
+
require 'tempfile'
|
2
|
+
require 'json'
|
3
|
+
|
4
|
+
module Chainer
|
5
|
+
module Training
|
6
|
+
module Extensions
|
7
|
+
class LogReport < Extension
|
8
|
+
attr_reader :log
|
9
|
+
|
10
|
+
def initialize(keys: nil, trigger: [1, 'epoch'], postprocess: nil, log_name: 'log')
|
11
|
+
@keys = keys
|
12
|
+
@trigger = Chainer::Training::Util.get_trigger(trigger)
|
13
|
+
@postprocess = postprocess
|
14
|
+
@log_name = log_name
|
15
|
+
@log = []
|
16
|
+
|
17
|
+
init_summary
|
18
|
+
end
|
19
|
+
|
20
|
+
def call(trainer)
|
21
|
+
observation = trainer.observation
|
22
|
+
|
23
|
+
if @keys.nil?
|
24
|
+
@summary.add(observation)
|
25
|
+
else
|
26
|
+
symbolized_observation = Hash[observation.map{|(k,v)| [k.to_sym,v]}]
|
27
|
+
filterd_keys = @keys.select {|k| observation.keys.include?(k.to_sym) }
|
28
|
+
@summary.add(filterd_keys.each_with_object({}) {|k, hash| hash[k.to_s] = observation[k.to_sym] })
|
29
|
+
end
|
30
|
+
|
31
|
+
# if trigger is true, output the result
|
32
|
+
return unless @trigger.(trainer)
|
33
|
+
|
34
|
+
stats = @summary.compute_mean
|
35
|
+
stats_cpu = {}
|
36
|
+
stats.each do |name, value|
|
37
|
+
stats_cpu[name] = value.to_f # copy to CPU
|
38
|
+
end
|
39
|
+
|
40
|
+
updater = trainer.updater
|
41
|
+
stats_cpu['epoch'] = updater.epoch
|
42
|
+
stats_cpu['iteration'] = updater.iteration
|
43
|
+
stats_cpu['elapsed_time'] = trainer.elapsed_time
|
44
|
+
|
45
|
+
@postprocess.(stats_cpu) unless @postprocess.nil?
|
46
|
+
|
47
|
+
@log << stats_cpu
|
48
|
+
|
49
|
+
unless @log_name.nil?
|
50
|
+
# example: sprintf("%{a}, %{b}", {a: "1", b: "2"})
|
51
|
+
# => "1, 2"
|
52
|
+
log_name = sprintf(@log_name, stats_cpu)
|
53
|
+
temp_file = Tempfile.create(basename: log_name, tmpdir: trainer.out)
|
54
|
+
|
55
|
+
JSON.dump(@log, temp_file)
|
56
|
+
|
57
|
+
new_path = File.join(trainer.out, log_name)
|
58
|
+
FileUtils.mv(temp_file.path, new_path)
|
59
|
+
end
|
60
|
+
|
61
|
+
init_summary
|
62
|
+
end
|
63
|
+
|
64
|
+
private
|
65
|
+
|
66
|
+
def init_summary
|
67
|
+
@summary = DictSummary.new
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Extensions
|
4
|
+
class PrintReport < Extension
|
5
|
+
def initialize(entries, log_report: 'Chainer::Training::Extensions::LogReport', out: STDOUT)
|
6
|
+
@entries = entries
|
7
|
+
@log_report = log_report
|
8
|
+
@out = out
|
9
|
+
|
10
|
+
@log_len = 0 # number of observations already printed
|
11
|
+
|
12
|
+
# format information
|
13
|
+
entry_widths = entries.map { |s| [10, s.size].max }
|
14
|
+
|
15
|
+
templates = []
|
16
|
+
header = []
|
17
|
+
entries.zip(entry_widths).each do |entry, w|
|
18
|
+
header << sprintf("%-#{w}s", entry)
|
19
|
+
templates << [entry, "%-#{w}g ", ' ' * (w + 2)]
|
20
|
+
end
|
21
|
+
@header = header.join(' ') + "\n"
|
22
|
+
@templates = templates
|
23
|
+
end
|
24
|
+
|
25
|
+
def call(trainer)
|
26
|
+
if @header
|
27
|
+
@out.write(@header)
|
28
|
+
@header = nil
|
29
|
+
end
|
30
|
+
|
31
|
+
if @log_report.is_a?(String)
|
32
|
+
log_report = trainer.get_extension(@log_report)
|
33
|
+
elsif @log_report.is_a?(LogReport)
|
34
|
+
log_report.(trainer)
|
35
|
+
else
|
36
|
+
raise TypeError, "log report has a wrong type #{log_report.class}"
|
37
|
+
end
|
38
|
+
|
39
|
+
log = log_report.log
|
40
|
+
while log.size > @log_len
|
41
|
+
@out.write("\033[J")
|
42
|
+
print(log[@log_len])
|
43
|
+
@log_len += 1
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
private
|
48
|
+
|
49
|
+
def print(observation)
|
50
|
+
@templates.each do |entry, template, empty|
|
51
|
+
if observation.keys.include?(entry)
|
52
|
+
@out.write(sprintf(template, observation[entry]))
|
53
|
+
else
|
54
|
+
@out.write(empty)
|
55
|
+
end
|
56
|
+
end
|
57
|
+
@out.write("\n")
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
require 'erb'
|
2
|
+
|
3
|
+
module Chainer
|
4
|
+
module Training
|
5
|
+
module Extensions
|
6
|
+
class ProgressBar < Extension
|
7
|
+
def initialize(training_length: nil, update_interval: 100, bar_length: 50, out: STDOUT)
|
8
|
+
@training_length = training_length
|
9
|
+
@status_template = nil
|
10
|
+
@update_interval = update_interval
|
11
|
+
@bar_length = bar_length
|
12
|
+
@out = out
|
13
|
+
@out.sync = true
|
14
|
+
@recent_timing = []
|
15
|
+
end
|
16
|
+
|
17
|
+
def call(trainer)
|
18
|
+
if @training_length.nil?
|
19
|
+
t = trainer.stop_trigger
|
20
|
+
raise TypeError, "cannot retrieve the training length #{t.class}" unless t.is_a?(Chainer::Training::Triggers::IntervalTrigger)
|
21
|
+
@training_length = [t.period, t.unit]
|
22
|
+
end
|
23
|
+
|
24
|
+
if @status_template.nil?
|
25
|
+
@status_template = ERB.new("<%= sprintf('%10d', self.iteration) %> iter, <%= self.epoch %> epoch / #{@training_length[0]} #{@training_length[1]}s\n")
|
26
|
+
end
|
27
|
+
|
28
|
+
length, unit = @training_length
|
29
|
+
iteration = trainer.updater.iteration
|
30
|
+
|
31
|
+
# print the progress bar according to interval
|
32
|
+
return unless iteration % @update_interval == 0
|
33
|
+
|
34
|
+
epoch = trainer.updater.epoch_detail
|
35
|
+
now = Time.now.to_f
|
36
|
+
|
37
|
+
@recent_timing << [iteration, epoch, now]
|
38
|
+
@out.write("\033[J")
|
39
|
+
|
40
|
+
if unit == 'iteration'
|
41
|
+
rate = iteration.to_f / length
|
42
|
+
else
|
43
|
+
rate = epoch.to_f / length
|
44
|
+
end
|
45
|
+
|
46
|
+
marks = '#' * (rate * @bar_length).to_i
|
47
|
+
@out.write(sprintf(" total [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), rate * 100))
|
48
|
+
|
49
|
+
epoch_rate = epoch - epoch.to_i
|
50
|
+
marks = '#' * (epoch_rate * @bar_length).to_i
|
51
|
+
@out.write(sprintf("this epoch [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), epoch_rate * 100))
|
52
|
+
|
53
|
+
status = @status_template.result(trainer.updater.bind)
|
54
|
+
@out.write(status)
|
55
|
+
|
56
|
+
old_t, old_e, old_sec = @recent_timing[0]
|
57
|
+
span = now - old_sec
|
58
|
+
|
59
|
+
if span.zero?
|
60
|
+
speed_t = Float::INFINITY
|
61
|
+
speed_e = Float::INFINITY
|
62
|
+
else
|
63
|
+
speed_t = (iteration - old_t) / span
|
64
|
+
speed_e = (epoch - old_e) / span
|
65
|
+
end
|
66
|
+
|
67
|
+
if unit == 'iteration'
|
68
|
+
estimated_time = (length - iteration) / speed_t
|
69
|
+
else
|
70
|
+
estimated_time = (length - epoch) / speed_e
|
71
|
+
end
|
72
|
+
|
73
|
+
@out.write(sprintf("%10.5g iters/sec. Estimated time to finish: %s.\n", speed_t, (Time.parse("1991/01/01") + (estimated_time)).strftime("%H:%m:%S")))
|
74
|
+
|
75
|
+
# move the cursor to the head of the progress bar
|
76
|
+
@out.write("\033[4A") # TODO: Support Windows
|
77
|
+
@out.flush
|
78
|
+
|
79
|
+
@recent_timing.delete_at(0) if @recent_timing.size > 100
|
80
|
+
end
|
81
|
+
|
82
|
+
def finalize
|
83
|
+
@out.write("\033[J") # TODO: Support Windows
|
84
|
+
@out.flush
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
@@ -0,0 +1,63 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
class StandardUpdater < Updater
|
4
|
+
attr_accessor :iteration
|
5
|
+
|
6
|
+
def initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil)
|
7
|
+
if iterator.kind_of?(Dataset::Iterator)
|
8
|
+
iterator = { main: iterator }
|
9
|
+
end
|
10
|
+
@iterators = iterator
|
11
|
+
|
12
|
+
unless optimizer.kind_of?(Hash)
|
13
|
+
optimizer = { main: optimizer }
|
14
|
+
end
|
15
|
+
@optimizers = optimizer
|
16
|
+
|
17
|
+
@converter = converter || Dataset::Convert.method(:concat_examples)
|
18
|
+
@loss_func = loss_func
|
19
|
+
@device = device
|
20
|
+
@iteration = 0
|
21
|
+
end
|
22
|
+
|
23
|
+
def get_all_optimizers
|
24
|
+
@optimizers.to_h
|
25
|
+
end
|
26
|
+
|
27
|
+
def update
|
28
|
+
update_core
|
29
|
+
@iteration += 1
|
30
|
+
end
|
31
|
+
|
32
|
+
def epoch
|
33
|
+
@iterators[:main].epoch
|
34
|
+
end
|
35
|
+
|
36
|
+
def epoch_detail
|
37
|
+
@iterators[:main].epoch_detail
|
38
|
+
end
|
39
|
+
|
40
|
+
def update_core
|
41
|
+
batch = @iterators[:main].next
|
42
|
+
in_arrays = @converter.call(batch, device: @device)
|
43
|
+
|
44
|
+
optimizer = @optimizers[:main]
|
45
|
+
loss_func = @loss_func || optimizer.target
|
46
|
+
|
47
|
+
if in_arrays.kind_of?(Array)
|
48
|
+
optimizer.update(loss_func, *in_arrays)
|
49
|
+
elsif in_arrays.kind_of?(Hash)
|
50
|
+
optimizer.update(loss_func, **in_arrays)
|
51
|
+
else
|
52
|
+
optimizer.update(loss_func, in_arrays)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
def finalize
|
57
|
+
@iterators.each do |(_, iterator)|
|
58
|
+
iterator.finalize
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -0,0 +1,136 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
class ExtensionEntry
|
4
|
+
attr_accessor :extension, :trigger, :invoke_before_training, :priority
|
5
|
+
|
6
|
+
def initialize(extension, priority, trigger, invoke_before_training)
|
7
|
+
@extension = extension
|
8
|
+
@trigger = trigger
|
9
|
+
@invoke_before_training = invoke_before_training
|
10
|
+
@priority = priority
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
class Trainer
|
15
|
+
attr_accessor :updater, :stop_trigger, :observation, :out
|
16
|
+
|
17
|
+
def initialize(updater, stop_trigger: nil, out: 'result')
|
18
|
+
@updater = updater
|
19
|
+
@stop_trigger = Chainer::Training::Util.get_trigger(stop_trigger)
|
20
|
+
@observation = {}
|
21
|
+
@out = out
|
22
|
+
|
23
|
+
reporter = Reporter.new
|
24
|
+
updater.get_all_optimizers().each do |(name, optimizer)|
|
25
|
+
reporter.add_observer(name, optimizer.target)
|
26
|
+
optimizer.target.namedlinks(skipself: true) do |suffix, observer|
|
27
|
+
observer_name = name.to_s + suffix
|
28
|
+
reporter.add_observer(observer_name, observer)
|
29
|
+
end
|
30
|
+
end
|
31
|
+
@reporter = reporter
|
32
|
+
|
33
|
+
@done = false
|
34
|
+
@extensions = {}
|
35
|
+
|
36
|
+
@start_at = nil
|
37
|
+
@snapshot_elapsed_time = 0.0
|
38
|
+
@final_elapsed_time = nil
|
39
|
+
|
40
|
+
updater.connect_trainer(self)
|
41
|
+
end
|
42
|
+
|
43
|
+
def elapsed_time
|
44
|
+
return @final_elapsed_time if @done
|
45
|
+
raise "training has not been started yet" if @start_at.nil?
|
46
|
+
|
47
|
+
Time.now.to_f - @start_at - @snapshot_elapsed_time.to_f
|
48
|
+
end
|
49
|
+
|
50
|
+
def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
|
51
|
+
if name.nil?
|
52
|
+
name = if extension.name
|
53
|
+
extension.name
|
54
|
+
elsif extension.default_name
|
55
|
+
extension.default_name
|
56
|
+
else
|
57
|
+
raise ArgumentError 'name is not given for the extension'
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
raise 'the name "training" is prohibited as an extension name' if name == 'training'
|
62
|
+
|
63
|
+
if trigger.nil?
|
64
|
+
trigger = extension.methods.include?(:trigger) ? extension.trigger : [1, 'iteration']
|
65
|
+
end
|
66
|
+
trigger = Chainer::Training::Util.get_trigger(trigger)
|
67
|
+
|
68
|
+
if priority.nil?
|
69
|
+
priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER
|
70
|
+
end
|
71
|
+
|
72
|
+
if invoke_before_training.nil?
|
73
|
+
invoke_before_training = extension.methods.include?(:invoke_before_training) ? extension.invoke_before_training : false
|
74
|
+
end
|
75
|
+
|
76
|
+
modified_name = name
|
77
|
+
ordinal = 0
|
78
|
+
|
79
|
+
@extensions.each do |modified_name|
|
80
|
+
ordinal += 1
|
81
|
+
modified_name = "#{name}_#{ordinal}"
|
82
|
+
end
|
83
|
+
|
84
|
+
extension.name = modified_name
|
85
|
+
@extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger, invoke_before_training)
|
86
|
+
end
|
87
|
+
|
88
|
+
def get_extension(name)
|
89
|
+
if @extensions.keys.include?(name)
|
90
|
+
@extensions[name].extension
|
91
|
+
else
|
92
|
+
raise "extension #{name} not found"
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def run
|
97
|
+
raise 'cannot run training loop multiple times' if @done
|
98
|
+
FileUtils.mkdir_p(@out)
|
99
|
+
|
100
|
+
extensions = @extensions.sort_by { |(_, e)| e.priority }.map { |(name, extension)| [name, extension] }
|
101
|
+
|
102
|
+
@start_at = Time.now.to_f
|
103
|
+
|
104
|
+
extensions.each do |(_, entry)|
|
105
|
+
initializer = entry.extension.methods.include?(:init) ? entry.extension.method(:init) : nil
|
106
|
+
initializer.call(self) if initializer
|
107
|
+
end
|
108
|
+
|
109
|
+
update = @updater.method(:update)
|
110
|
+
reporter = @reporter
|
111
|
+
stop_trigger = @stop_trigger
|
112
|
+
|
113
|
+
begin
|
114
|
+
until stop_trigger.(self) do
|
115
|
+
@observation = {}
|
116
|
+
reporter.scope(@observation) do
|
117
|
+
update.call
|
118
|
+
extensions.each do |(_, entry)|
|
119
|
+
entry.extension.(self) if entry.trigger.(self)
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
123
|
+
ensure
|
124
|
+
extensions.each do |(_, entry)|
|
125
|
+
finalize = entry.extension.methods.include?(:finalize) ? entry.extension.method(:finalize) : nil
|
126
|
+
finalize.() if finalize
|
127
|
+
end
|
128
|
+
@updater.finalize()
|
129
|
+
end
|
130
|
+
|
131
|
+
@final_elapsed_time = @elapsed_time
|
132
|
+
@done = true
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Triggers
|
4
|
+
class IntervalTrigger
|
5
|
+
attr_reader :period, :unit, :count
|
6
|
+
|
7
|
+
def initialize(period, unit)
|
8
|
+
@period = period
|
9
|
+
@unit = unit
|
10
|
+
@count = 0
|
11
|
+
end
|
12
|
+
|
13
|
+
def call(trainer)
|
14
|
+
updater = trainer.updater
|
15
|
+
if @unit == 'epoch'
|
16
|
+
prev = @count
|
17
|
+
@count = updater.epoch_detail.div(@period)
|
18
|
+
prev != @count
|
19
|
+
else
|
20
|
+
iteration = updater.iteration
|
21
|
+
iteration > 0 && iteration % @period == 0
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|