red-chainer 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +23 -0
- data/README.md +60 -0
- data/Rakefile +8 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/examples/mnist.rb +42 -0
- data/lib/chainer.rb +59 -0
- data/lib/chainer/configuration.rb +10 -0
- data/lib/chainer/dataset/convert.rb +62 -0
- data/lib/chainer/dataset/download.rb +56 -0
- data/lib/chainer/dataset/iterator.rb +15 -0
- data/lib/chainer/datasets/mnist.rb +89 -0
- data/lib/chainer/datasets/tuple_dataset.rb +33 -0
- data/lib/chainer/function.rb +80 -0
- data/lib/chainer/functions/activation/log_softmax.rb +37 -0
- data/lib/chainer/functions/activation/relu.rb +23 -0
- data/lib/chainer/functions/connection/linear.rb +48 -0
- data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
- data/lib/chainer/functions/math/basic_math.rb +119 -0
- data/lib/chainer/gradient_method.rb +63 -0
- data/lib/chainer/hyperparameter.rb +23 -0
- data/lib/chainer/initializer.rb +12 -0
- data/lib/chainer/initializers/constant.rb +18 -0
- data/lib/chainer/initializers/init.rb +24 -0
- data/lib/chainer/initializers/normal.rb +28 -0
- data/lib/chainer/iterators/serial_iterator.rb +74 -0
- data/lib/chainer/link.rb +118 -0
- data/lib/chainer/links/connection/linear.rb +43 -0
- data/lib/chainer/links/model/classifier.rb +39 -0
- data/lib/chainer/optimizer.rb +69 -0
- data/lib/chainer/optimizers/adam.rb +62 -0
- data/lib/chainer/parameter.rb +53 -0
- data/lib/chainer/reporter.rb +130 -0
- data/lib/chainer/training/extension.rb +25 -0
- data/lib/chainer/training/extensions/evaluator.rb +26 -0
- data/lib/chainer/training/extensions/log_report.rb +72 -0
- data/lib/chainer/training/extensions/print_report.rb +62 -0
- data/lib/chainer/training/extensions/progress_bar.rb +89 -0
- data/lib/chainer/training/standard_updater.rb +63 -0
- data/lib/chainer/training/trainer.rb +136 -0
- data/lib/chainer/training/triggers/interval.rb +27 -0
- data/lib/chainer/training/updater.rb +33 -0
- data/lib/chainer/training/util.rb +13 -0
- data/lib/chainer/utils/array.rb +10 -0
- data/lib/chainer/utils/initializer.rb +14 -0
- data/lib/chainer/utils/variable.rb +20 -0
- data/lib/chainer/variable.rb +204 -0
- data/lib/chainer/variable_node.rb +71 -0
- data/lib/chainer/version.rb +4 -0
- data/red-chainer.gemspec +27 -0
- metadata +156 -0
@@ -0,0 +1,72 @@
|
|
1
|
+
require 'tempfile'
|
2
|
+
require 'json'
|
3
|
+
|
4
|
+
module Chainer
|
5
|
+
module Training
|
6
|
+
module Extensions
|
7
|
+
class LogReport < Extension
|
8
|
+
attr_reader :log
|
9
|
+
|
10
|
+
def initialize(keys: nil, trigger: [1, 'epoch'], postprocess: nil, log_name: 'log')
|
11
|
+
@keys = keys
|
12
|
+
@trigger = Chainer::Training::Util.get_trigger(trigger)
|
13
|
+
@postprocess = postprocess
|
14
|
+
@log_name = log_name
|
15
|
+
@log = []
|
16
|
+
|
17
|
+
init_summary
|
18
|
+
end
|
19
|
+
|
20
|
+
def call(trainer)
|
21
|
+
observation = trainer.observation
|
22
|
+
|
23
|
+
if @keys.nil?
|
24
|
+
@summary.add(observation)
|
25
|
+
else
|
26
|
+
symbolized_observation = Hash[observation.map{|(k,v)| [k.to_sym,v]}]
|
27
|
+
filterd_keys = @keys.select {|k| observation.keys.include?(k.to_sym) }
|
28
|
+
@summary.add(filterd_keys.each_with_object({}) {|k, hash| hash[k.to_s] = observation[k.to_sym] })
|
29
|
+
end
|
30
|
+
|
31
|
+
# if trigger is true, output the result
|
32
|
+
return unless @trigger.(trainer)
|
33
|
+
|
34
|
+
stats = @summary.compute_mean
|
35
|
+
stats_cpu = {}
|
36
|
+
stats.each do |name, value|
|
37
|
+
stats_cpu[name] = value.to_f # copy to CPU
|
38
|
+
end
|
39
|
+
|
40
|
+
updater = trainer.updater
|
41
|
+
stats_cpu['epoch'] = updater.epoch
|
42
|
+
stats_cpu['iteration'] = updater.iteration
|
43
|
+
stats_cpu['elapsed_time'] = trainer.elapsed_time
|
44
|
+
|
45
|
+
@postprocess.(stats_cpu) unless @postprocess.nil?
|
46
|
+
|
47
|
+
@log << stats_cpu
|
48
|
+
|
49
|
+
unless @log_name.nil?
|
50
|
+
# example: sprintf("%{a}, %{b}", {a: "1", b: "2"})
|
51
|
+
# => "1, 2"
|
52
|
+
log_name = sprintf(@log_name, stats_cpu)
|
53
|
+
temp_file = Tempfile.create(basename: log_name, tmpdir: trainer.out)
|
54
|
+
|
55
|
+
JSON.dump(@log, temp_file)
|
56
|
+
|
57
|
+
new_path = File.join(trainer.out, log_name)
|
58
|
+
FileUtils.mv(temp_file.path, new_path)
|
59
|
+
end
|
60
|
+
|
61
|
+
init_summary
|
62
|
+
end
|
63
|
+
|
64
|
+
private
|
65
|
+
|
66
|
+
def init_summary
|
67
|
+
@summary = DictSummary.new
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Extensions
|
4
|
+
class PrintReport < Extension
|
5
|
+
def initialize(entries, log_report: 'Chainer::Training::Extensions::LogReport', out: STDOUT)
|
6
|
+
@entries = entries
|
7
|
+
@log_report = log_report
|
8
|
+
@out = out
|
9
|
+
|
10
|
+
@log_len = 0 # number of observations already printed
|
11
|
+
|
12
|
+
# format information
|
13
|
+
entry_widths = entries.map { |s| [10, s.size].max }
|
14
|
+
|
15
|
+
templates = []
|
16
|
+
header = []
|
17
|
+
entries.zip(entry_widths).each do |entry, w|
|
18
|
+
header << sprintf("%-#{w}s", entry)
|
19
|
+
templates << [entry, "%-#{w}g ", ' ' * (w + 2)]
|
20
|
+
end
|
21
|
+
@header = header.join(' ') + "\n"
|
22
|
+
@templates = templates
|
23
|
+
end
|
24
|
+
|
25
|
+
def call(trainer)
|
26
|
+
if @header
|
27
|
+
@out.write(@header)
|
28
|
+
@header = nil
|
29
|
+
end
|
30
|
+
|
31
|
+
if @log_report.is_a?(String)
|
32
|
+
log_report = trainer.get_extension(@log_report)
|
33
|
+
elsif @log_report.is_a?(LogReport)
|
34
|
+
log_report.(trainer)
|
35
|
+
else
|
36
|
+
raise TypeError, "log report has a wrong type #{log_report.class}"
|
37
|
+
end
|
38
|
+
|
39
|
+
log = log_report.log
|
40
|
+
while log.size > @log_len
|
41
|
+
@out.write("\033[J")
|
42
|
+
print(log[@log_len])
|
43
|
+
@log_len += 1
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
private
|
48
|
+
|
49
|
+
def print(observation)
|
50
|
+
@templates.each do |entry, template, empty|
|
51
|
+
if observation.keys.include?(entry)
|
52
|
+
@out.write(sprintf(template, observation[entry]))
|
53
|
+
else
|
54
|
+
@out.write(empty)
|
55
|
+
end
|
56
|
+
end
|
57
|
+
@out.write("\n")
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
require 'erb'
|
2
|
+
|
3
|
+
module Chainer
|
4
|
+
module Training
|
5
|
+
module Extensions
|
6
|
+
class ProgressBar < Extension
|
7
|
+
def initialize(training_length: nil, update_interval: 100, bar_length: 50, out: STDOUT)
|
8
|
+
@training_length = training_length
|
9
|
+
@status_template = nil
|
10
|
+
@update_interval = update_interval
|
11
|
+
@bar_length = bar_length
|
12
|
+
@out = out
|
13
|
+
@out.sync = true
|
14
|
+
@recent_timing = []
|
15
|
+
end
|
16
|
+
|
17
|
+
def call(trainer)
|
18
|
+
if @training_length.nil?
|
19
|
+
t = trainer.stop_trigger
|
20
|
+
raise TypeError, "cannot retrieve the training length #{t.class}" unless t.is_a?(Chainer::Training::Triggers::IntervalTrigger)
|
21
|
+
@training_length = [t.period, t.unit]
|
22
|
+
end
|
23
|
+
|
24
|
+
if @status_template.nil?
|
25
|
+
@status_template = ERB.new("<%= sprintf('%10d', self.iteration) %> iter, <%= self.epoch %> epoch / #{@training_length[0]} #{@training_length[1]}s\n")
|
26
|
+
end
|
27
|
+
|
28
|
+
length, unit = @training_length
|
29
|
+
iteration = trainer.updater.iteration
|
30
|
+
|
31
|
+
# print the progress bar according to interval
|
32
|
+
return unless iteration % @update_interval == 0
|
33
|
+
|
34
|
+
epoch = trainer.updater.epoch_detail
|
35
|
+
now = Time.now.to_f
|
36
|
+
|
37
|
+
@recent_timing << [iteration, epoch, now]
|
38
|
+
@out.write("\033[J")
|
39
|
+
|
40
|
+
if unit == 'iteration'
|
41
|
+
rate = iteration.to_f / length
|
42
|
+
else
|
43
|
+
rate = epoch.to_f / length
|
44
|
+
end
|
45
|
+
|
46
|
+
marks = '#' * (rate * @bar_length).to_i
|
47
|
+
@out.write(sprintf(" total [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), rate * 100))
|
48
|
+
|
49
|
+
epoch_rate = epoch - epoch.to_i
|
50
|
+
marks = '#' * (epoch_rate * @bar_length).to_i
|
51
|
+
@out.write(sprintf("this epoch [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), epoch_rate * 100))
|
52
|
+
|
53
|
+
status = @status_template.result(trainer.updater.bind)
|
54
|
+
@out.write(status)
|
55
|
+
|
56
|
+
old_t, old_e, old_sec = @recent_timing[0]
|
57
|
+
span = now - old_sec
|
58
|
+
|
59
|
+
if span.zero?
|
60
|
+
speed_t = Float::INFINITY
|
61
|
+
speed_e = Float::INFINITY
|
62
|
+
else
|
63
|
+
speed_t = (iteration - old_t) / span
|
64
|
+
speed_e = (epoch - old_e) / span
|
65
|
+
end
|
66
|
+
|
67
|
+
if unit == 'iteration'
|
68
|
+
estimated_time = (length - iteration) / speed_t
|
69
|
+
else
|
70
|
+
estimated_time = (length - epoch) / speed_e
|
71
|
+
end
|
72
|
+
|
73
|
+
@out.write(sprintf("%10.5g iters/sec. Estimated time to finish: %s.\n", speed_t, (Time.parse("1991/01/01") + (estimated_time)).strftime("%H:%m:%S")))
|
74
|
+
|
75
|
+
# move the cursor to the head of the progress bar
|
76
|
+
@out.write("\033[4A") # TODO: Support Windows
|
77
|
+
@out.flush
|
78
|
+
|
79
|
+
@recent_timing.delete_at(0) if @recent_timing.size > 100
|
80
|
+
end
|
81
|
+
|
82
|
+
def finalize
|
83
|
+
@out.write("\033[J") # TODO: Support Windows
|
84
|
+
@out.flush
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
@@ -0,0 +1,63 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
class StandardUpdater < Updater
|
4
|
+
attr_accessor :iteration
|
5
|
+
|
6
|
+
def initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil)
|
7
|
+
if iterator.kind_of?(Dataset::Iterator)
|
8
|
+
iterator = { main: iterator }
|
9
|
+
end
|
10
|
+
@iterators = iterator
|
11
|
+
|
12
|
+
unless optimizer.kind_of?(Hash)
|
13
|
+
optimizer = { main: optimizer }
|
14
|
+
end
|
15
|
+
@optimizers = optimizer
|
16
|
+
|
17
|
+
@converter = converter || Dataset::Convert.method(:concat_examples)
|
18
|
+
@loss_func = loss_func
|
19
|
+
@device = device
|
20
|
+
@iteration = 0
|
21
|
+
end
|
22
|
+
|
23
|
+
def get_all_optimizers
|
24
|
+
@optimizers.to_h
|
25
|
+
end
|
26
|
+
|
27
|
+
def update
|
28
|
+
update_core
|
29
|
+
@iteration += 1
|
30
|
+
end
|
31
|
+
|
32
|
+
def epoch
|
33
|
+
@iterators[:main].epoch
|
34
|
+
end
|
35
|
+
|
36
|
+
def epoch_detail
|
37
|
+
@iterators[:main].epoch_detail
|
38
|
+
end
|
39
|
+
|
40
|
+
def update_core
|
41
|
+
batch = @iterators[:main].next
|
42
|
+
in_arrays = @converter.call(batch, device: @device)
|
43
|
+
|
44
|
+
optimizer = @optimizers[:main]
|
45
|
+
loss_func = @loss_func || optimizer.target
|
46
|
+
|
47
|
+
if in_arrays.kind_of?(Array)
|
48
|
+
optimizer.update(loss_func, *in_arrays)
|
49
|
+
elsif in_arrays.kind_of?(Hash)
|
50
|
+
optimizer.update(loss_func, **in_arrays)
|
51
|
+
else
|
52
|
+
optimizer.update(loss_func, in_arrays)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
def finalize
|
57
|
+
@iterators.each do |(_, iterator)|
|
58
|
+
iterator.finalize
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -0,0 +1,136 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
class ExtensionEntry
|
4
|
+
attr_accessor :extension, :trigger, :invoke_before_training, :priority
|
5
|
+
|
6
|
+
def initialize(extension, priority, trigger, invoke_before_training)
|
7
|
+
@extension = extension
|
8
|
+
@trigger = trigger
|
9
|
+
@invoke_before_training = invoke_before_training
|
10
|
+
@priority = priority
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
class Trainer
|
15
|
+
attr_accessor :updater, :stop_trigger, :observation, :out
|
16
|
+
|
17
|
+
def initialize(updater, stop_trigger: nil, out: 'result')
|
18
|
+
@updater = updater
|
19
|
+
@stop_trigger = Chainer::Training::Util.get_trigger(stop_trigger)
|
20
|
+
@observation = {}
|
21
|
+
@out = out
|
22
|
+
|
23
|
+
reporter = Reporter.new
|
24
|
+
updater.get_all_optimizers().each do |(name, optimizer)|
|
25
|
+
reporter.add_observer(name, optimizer.target)
|
26
|
+
optimizer.target.namedlinks(skipself: true) do |suffix, observer|
|
27
|
+
observer_name = name.to_s + suffix
|
28
|
+
reporter.add_observer(observer_name, observer)
|
29
|
+
end
|
30
|
+
end
|
31
|
+
@reporter = reporter
|
32
|
+
|
33
|
+
@done = false
|
34
|
+
@extensions = {}
|
35
|
+
|
36
|
+
@start_at = nil
|
37
|
+
@snapshot_elapsed_time = 0.0
|
38
|
+
@final_elapsed_time = nil
|
39
|
+
|
40
|
+
updater.connect_trainer(self)
|
41
|
+
end
|
42
|
+
|
43
|
+
def elapsed_time
|
44
|
+
return @final_elapsed_time if @done
|
45
|
+
raise "training has not been started yet" if @start_at.nil?
|
46
|
+
|
47
|
+
Time.now.to_f - @start_at - @snapshot_elapsed_time.to_f
|
48
|
+
end
|
49
|
+
|
50
|
+
def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
|
51
|
+
if name.nil?
|
52
|
+
name = if extension.name
|
53
|
+
extension.name
|
54
|
+
elsif extension.default_name
|
55
|
+
extension.default_name
|
56
|
+
else
|
57
|
+
raise ArgumentError 'name is not given for the extension'
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
raise 'the name "training" is prohibited as an extension name' if name == 'training'
|
62
|
+
|
63
|
+
if trigger.nil?
|
64
|
+
trigger = extension.methods.include?(:trigger) ? extension.trigger : [1, 'iteration']
|
65
|
+
end
|
66
|
+
trigger = Chainer::Training::Util.get_trigger(trigger)
|
67
|
+
|
68
|
+
if priority.nil?
|
69
|
+
priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER
|
70
|
+
end
|
71
|
+
|
72
|
+
if invoke_before_training.nil?
|
73
|
+
invoke_before_training = extension.methods.include?(:invoke_before_training) ? extension.invoke_before_training : false
|
74
|
+
end
|
75
|
+
|
76
|
+
modified_name = name
|
77
|
+
ordinal = 0
|
78
|
+
|
79
|
+
@extensions.each do |modified_name|
|
80
|
+
ordinal += 1
|
81
|
+
modified_name = "#{name}_#{ordinal}"
|
82
|
+
end
|
83
|
+
|
84
|
+
extension.name = modified_name
|
85
|
+
@extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger, invoke_before_training)
|
86
|
+
end
|
87
|
+
|
88
|
+
def get_extension(name)
|
89
|
+
if @extensions.keys.include?(name)
|
90
|
+
@extensions[name].extension
|
91
|
+
else
|
92
|
+
raise "extension #{name} not found"
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def run
|
97
|
+
raise 'cannot run training loop multiple times' if @done
|
98
|
+
FileUtils.mkdir_p(@out)
|
99
|
+
|
100
|
+
extensions = @extensions.sort_by { |(_, e)| e.priority }.map { |(name, extension)| [name, extension] }
|
101
|
+
|
102
|
+
@start_at = Time.now.to_f
|
103
|
+
|
104
|
+
extensions.each do |(_, entry)|
|
105
|
+
initializer = entry.extension.methods.include?(:init) ? entry.extension.method(:init) : nil
|
106
|
+
initializer.call(self) if initializer
|
107
|
+
end
|
108
|
+
|
109
|
+
update = @updater.method(:update)
|
110
|
+
reporter = @reporter
|
111
|
+
stop_trigger = @stop_trigger
|
112
|
+
|
113
|
+
begin
|
114
|
+
until stop_trigger.(self) do
|
115
|
+
@observation = {}
|
116
|
+
reporter.scope(@observation) do
|
117
|
+
update.call
|
118
|
+
extensions.each do |(_, entry)|
|
119
|
+
entry.extension.(self) if entry.trigger.(self)
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
123
|
+
ensure
|
124
|
+
extensions.each do |(_, entry)|
|
125
|
+
finalize = entry.extension.methods.include?(:finalize) ? entry.extension.method(:finalize) : nil
|
126
|
+
finalize.() if finalize
|
127
|
+
end
|
128
|
+
@updater.finalize()
|
129
|
+
end
|
130
|
+
|
131
|
+
@final_elapsed_time = @elapsed_time
|
132
|
+
@done = true
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Triggers
|
4
|
+
class IntervalTrigger
|
5
|
+
attr_reader :period, :unit, :count
|
6
|
+
|
7
|
+
def initialize(period, unit)
|
8
|
+
@period = period
|
9
|
+
@unit = unit
|
10
|
+
@count = 0
|
11
|
+
end
|
12
|
+
|
13
|
+
def call(trainer)
|
14
|
+
updater = trainer.updater
|
15
|
+
if @unit == 'epoch'
|
16
|
+
prev = @count
|
17
|
+
@count = updater.epoch_detail.div(@period)
|
18
|
+
prev != @count
|
19
|
+
else
|
20
|
+
iteration = updater.iteration
|
21
|
+
iteration > 0 && iteration % @period == 0
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|