red-chainer 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,72 @@
1
+ require 'tempfile'
2
+ require 'json'
3
+
4
+ module Chainer
5
+ module Training
6
+ module Extensions
7
+ class LogReport < Extension
8
+ attr_reader :log
9
+
10
+ def initialize(keys: nil, trigger: [1, 'epoch'], postprocess: nil, log_name: 'log')
11
+ @keys = keys
12
+ @trigger = Chainer::Training::Util.get_trigger(trigger)
13
+ @postprocess = postprocess
14
+ @log_name = log_name
15
+ @log = []
16
+
17
+ init_summary
18
+ end
19
+
20
+ def call(trainer)
21
+ observation = trainer.observation
22
+
23
+ if @keys.nil?
24
+ @summary.add(observation)
25
+ else
26
+ symbolized_observation = Hash[observation.map{|(k,v)| [k.to_sym,v]}]
27
+ filterd_keys = @keys.select {|k| observation.keys.include?(k.to_sym) }
28
+ @summary.add(filterd_keys.each_with_object({}) {|k, hash| hash[k.to_s] = observation[k.to_sym] })
29
+ end
30
+
31
+ # if trigger is true, output the result
32
+ return unless @trigger.(trainer)
33
+
34
+ stats = @summary.compute_mean
35
+ stats_cpu = {}
36
+ stats.each do |name, value|
37
+ stats_cpu[name] = value.to_f # copy to CPU
38
+ end
39
+
40
+ updater = trainer.updater
41
+ stats_cpu['epoch'] = updater.epoch
42
+ stats_cpu['iteration'] = updater.iteration
43
+ stats_cpu['elapsed_time'] = trainer.elapsed_time
44
+
45
+ @postprocess.(stats_cpu) unless @postprocess.nil?
46
+
47
+ @log << stats_cpu
48
+
49
+ unless @log_name.nil?
50
+ # example: sprintf("%{a}, %{b}", {a: "1", b: "2"})
51
+ # => "1, 2"
52
+ log_name = sprintf(@log_name, stats_cpu)
53
+ temp_file = Tempfile.create(basename: log_name, tmpdir: trainer.out)
54
+
55
+ JSON.dump(@log, temp_file)
56
+
57
+ new_path = File.join(trainer.out, log_name)
58
+ FileUtils.mv(temp_file.path, new_path)
59
+ end
60
+
61
+ init_summary
62
+ end
63
+
64
+ private
65
+
66
+ def init_summary
67
+ @summary = DictSummary.new
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
@@ -0,0 +1,62 @@
1
+ module Chainer
2
+ module Training
3
+ module Extensions
4
+ class PrintReport < Extension
5
+ def initialize(entries, log_report: 'Chainer::Training::Extensions::LogReport', out: STDOUT)
6
+ @entries = entries
7
+ @log_report = log_report
8
+ @out = out
9
+
10
+ @log_len = 0 # number of observations already printed
11
+
12
+ # format information
13
+ entry_widths = entries.map { |s| [10, s.size].max }
14
+
15
+ templates = []
16
+ header = []
17
+ entries.zip(entry_widths).each do |entry, w|
18
+ header << sprintf("%-#{w}s", entry)
19
+ templates << [entry, "%-#{w}g ", ' ' * (w + 2)]
20
+ end
21
+ @header = header.join(' ') + "\n"
22
+ @templates = templates
23
+ end
24
+
25
+ def call(trainer)
26
+ if @header
27
+ @out.write(@header)
28
+ @header = nil
29
+ end
30
+
31
+ if @log_report.is_a?(String)
32
+ log_report = trainer.get_extension(@log_report)
33
+ elsif @log_report.is_a?(LogReport)
34
+ log_report.(trainer)
35
+ else
36
+ raise TypeError, "log report has a wrong type #{log_report.class}"
37
+ end
38
+
39
+ log = log_report.log
40
+ while log.size > @log_len
41
+ @out.write("\033[J")
42
+ print(log[@log_len])
43
+ @log_len += 1
44
+ end
45
+ end
46
+
47
+ private
48
+
49
+ def print(observation)
50
+ @templates.each do |entry, template, empty|
51
+ if observation.keys.include?(entry)
52
+ @out.write(sprintf(template, observation[entry]))
53
+ else
54
+ @out.write(empty)
55
+ end
56
+ end
57
+ @out.write("\n")
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,89 @@
1
+ require 'erb'
2
+
3
+ module Chainer
4
+ module Training
5
+ module Extensions
6
+ class ProgressBar < Extension
7
+ def initialize(training_length: nil, update_interval: 100, bar_length: 50, out: STDOUT)
8
+ @training_length = training_length
9
+ @status_template = nil
10
+ @update_interval = update_interval
11
+ @bar_length = bar_length
12
+ @out = out
13
+ @out.sync = true
14
+ @recent_timing = []
15
+ end
16
+
17
+ def call(trainer)
18
+ if @training_length.nil?
19
+ t = trainer.stop_trigger
20
+ raise TypeError, "cannot retrieve the training length #{t.class}" unless t.is_a?(Chainer::Training::Triggers::IntervalTrigger)
21
+ @training_length = [t.period, t.unit]
22
+ end
23
+
24
+ if @status_template.nil?
25
+ @status_template = ERB.new("<%= sprintf('%10d', self.iteration) %> iter, <%= self.epoch %> epoch / #{@training_length[0]} #{@training_length[1]}s\n")
26
+ end
27
+
28
+ length, unit = @training_length
29
+ iteration = trainer.updater.iteration
30
+
31
+ # print the progress bar according to interval
32
+ return unless iteration % @update_interval == 0
33
+
34
+ epoch = trainer.updater.epoch_detail
35
+ now = Time.now.to_f
36
+
37
+ @recent_timing << [iteration, epoch, now]
38
+ @out.write("\033[J")
39
+
40
+ if unit == 'iteration'
41
+ rate = iteration.to_f / length
42
+ else
43
+ rate = epoch.to_f / length
44
+ end
45
+
46
+ marks = '#' * (rate * @bar_length).to_i
47
+ @out.write(sprintf(" total [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), rate * 100))
48
+
49
+ epoch_rate = epoch - epoch.to_i
50
+ marks = '#' * (epoch_rate * @bar_length).to_i
51
+ @out.write(sprintf("this epoch [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), epoch_rate * 100))
52
+
53
+ status = @status_template.result(trainer.updater.bind)
54
+ @out.write(status)
55
+
56
+ old_t, old_e, old_sec = @recent_timing[0]
57
+ span = now - old_sec
58
+
59
+ if span.zero?
60
+ speed_t = Float::INFINITY
61
+ speed_e = Float::INFINITY
62
+ else
63
+ speed_t = (iteration - old_t) / span
64
+ speed_e = (epoch - old_e) / span
65
+ end
66
+
67
+ if unit == 'iteration'
68
+ estimated_time = (length - iteration) / speed_t
69
+ else
70
+ estimated_time = (length - epoch) / speed_e
71
+ end
72
+
73
+ @out.write(sprintf("%10.5g iters/sec. Estimated time to finish: %s.\n", speed_t, (Time.parse("1991/01/01") + (estimated_time)).strftime("%H:%m:%S")))
74
+
75
+ # move the cursor to the head of the progress bar
76
+ @out.write("\033[4A") # TODO: Support Windows
77
+ @out.flush
78
+
79
+ @recent_timing.delete_at(0) if @recent_timing.size > 100
80
+ end
81
+
82
+ def finalize
83
+ @out.write("\033[J") # TODO: Support Windows
84
+ @out.flush
85
+ end
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,63 @@
1
+ module Chainer
2
+ module Training
3
+ class StandardUpdater < Updater
4
+ attr_accessor :iteration
5
+
6
+ def initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil)
7
+ if iterator.kind_of?(Dataset::Iterator)
8
+ iterator = { main: iterator }
9
+ end
10
+ @iterators = iterator
11
+
12
+ unless optimizer.kind_of?(Hash)
13
+ optimizer = { main: optimizer }
14
+ end
15
+ @optimizers = optimizer
16
+
17
+ @converter = converter || Dataset::Convert.method(:concat_examples)
18
+ @loss_func = loss_func
19
+ @device = device
20
+ @iteration = 0
21
+ end
22
+
23
+ def get_all_optimizers
24
+ @optimizers.to_h
25
+ end
26
+
27
+ def update
28
+ update_core
29
+ @iteration += 1
30
+ end
31
+
32
+ def epoch
33
+ @iterators[:main].epoch
34
+ end
35
+
36
+ def epoch_detail
37
+ @iterators[:main].epoch_detail
38
+ end
39
+
40
+ def update_core
41
+ batch = @iterators[:main].next
42
+ in_arrays = @converter.call(batch, device: @device)
43
+
44
+ optimizer = @optimizers[:main]
45
+ loss_func = @loss_func || optimizer.target
46
+
47
+ if in_arrays.kind_of?(Array)
48
+ optimizer.update(loss_func, *in_arrays)
49
+ elsif in_arrays.kind_of?(Hash)
50
+ optimizer.update(loss_func, **in_arrays)
51
+ else
52
+ optimizer.update(loss_func, in_arrays)
53
+ end
54
+ end
55
+
56
+ def finalize
57
+ @iterators.each do |(_, iterator)|
58
+ iterator.finalize
59
+ end
60
+ end
61
+ end
62
+ end
63
+ end
@@ -0,0 +1,136 @@
1
+ module Chainer
2
+ module Training
3
+ class ExtensionEntry
4
+ attr_accessor :extension, :trigger, :invoke_before_training, :priority
5
+
6
+ def initialize(extension, priority, trigger, invoke_before_training)
7
+ @extension = extension
8
+ @trigger = trigger
9
+ @invoke_before_training = invoke_before_training
10
+ @priority = priority
11
+ end
12
+ end
13
+
14
+ class Trainer
15
+ attr_accessor :updater, :stop_trigger, :observation, :out
16
+
17
+ def initialize(updater, stop_trigger: nil, out: 'result')
18
+ @updater = updater
19
+ @stop_trigger = Chainer::Training::Util.get_trigger(stop_trigger)
20
+ @observation = {}
21
+ @out = out
22
+
23
+ reporter = Reporter.new
24
+ updater.get_all_optimizers().each do |(name, optimizer)|
25
+ reporter.add_observer(name, optimizer.target)
26
+ optimizer.target.namedlinks(skipself: true) do |suffix, observer|
27
+ observer_name = name.to_s + suffix
28
+ reporter.add_observer(observer_name, observer)
29
+ end
30
+ end
31
+ @reporter = reporter
32
+
33
+ @done = false
34
+ @extensions = {}
35
+
36
+ @start_at = nil
37
+ @snapshot_elapsed_time = 0.0
38
+ @final_elapsed_time = nil
39
+
40
+ updater.connect_trainer(self)
41
+ end
42
+
43
+ def elapsed_time
44
+ return @final_elapsed_time if @done
45
+ raise "training has not been started yet" if @start_at.nil?
46
+
47
+ Time.now.to_f - @start_at - @snapshot_elapsed_time.to_f
48
+ end
49
+
50
+ def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
51
+ if name.nil?
52
+ name = if extension.name
53
+ extension.name
54
+ elsif extension.default_name
55
+ extension.default_name
56
+ else
57
+ raise ArgumentError 'name is not given for the extension'
58
+ end
59
+ end
60
+
61
+ raise 'the name "training" is prohibited as an extension name' if name == 'training'
62
+
63
+ if trigger.nil?
64
+ trigger = extension.methods.include?(:trigger) ? extension.trigger : [1, 'iteration']
65
+ end
66
+ trigger = Chainer::Training::Util.get_trigger(trigger)
67
+
68
+ if priority.nil?
69
+ priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER
70
+ end
71
+
72
+ if invoke_before_training.nil?
73
+ invoke_before_training = extension.methods.include?(:invoke_before_training) ? extension.invoke_before_training : false
74
+ end
75
+
76
+ modified_name = name
77
+ ordinal = 0
78
+
79
+ @extensions.each do |modified_name|
80
+ ordinal += 1
81
+ modified_name = "#{name}_#{ordinal}"
82
+ end
83
+
84
+ extension.name = modified_name
85
+ @extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger, invoke_before_training)
86
+ end
87
+
88
+ def get_extension(name)
89
+ if @extensions.keys.include?(name)
90
+ @extensions[name].extension
91
+ else
92
+ raise "extension #{name} not found"
93
+ end
94
+ end
95
+
96
+ def run
97
+ raise 'cannot run training loop multiple times' if @done
98
+ FileUtils.mkdir_p(@out)
99
+
100
+ extensions = @extensions.sort_by { |(_, e)| e.priority }.map { |(name, extension)| [name, extension] }
101
+
102
+ @start_at = Time.now.to_f
103
+
104
+ extensions.each do |(_, entry)|
105
+ initializer = entry.extension.methods.include?(:init) ? entry.extension.method(:init) : nil
106
+ initializer.call(self) if initializer
107
+ end
108
+
109
+ update = @updater.method(:update)
110
+ reporter = @reporter
111
+ stop_trigger = @stop_trigger
112
+
113
+ begin
114
+ until stop_trigger.(self) do
115
+ @observation = {}
116
+ reporter.scope(@observation) do
117
+ update.call
118
+ extensions.each do |(_, entry)|
119
+ entry.extension.(self) if entry.trigger.(self)
120
+ end
121
+ end
122
+ end
123
+ ensure
124
+ extensions.each do |(_, entry)|
125
+ finalize = entry.extension.methods.include?(:finalize) ? entry.extension.method(:finalize) : nil
126
+ finalize.() if finalize
127
+ end
128
+ @updater.finalize()
129
+ end
130
+
131
+ @final_elapsed_time = @elapsed_time
132
+ @done = true
133
+ end
134
+ end
135
+ end
136
+ end
@@ -0,0 +1,27 @@
1
+ module Chainer
2
+ module Training
3
+ module Triggers
4
+ class IntervalTrigger
5
+ attr_reader :period, :unit, :count
6
+
7
+ def initialize(period, unit)
8
+ @period = period
9
+ @unit = unit
10
+ @count = 0
11
+ end
12
+
13
+ def call(trainer)
14
+ updater = trainer.updater
15
+ if @unit == 'epoch'
16
+ prev = @count
17
+ @count = updater.epoch_detail.div(@period)
18
+ prev != @count
19
+ else
20
+ iteration = updater.iteration
21
+ iteration > 0 && iteration % @period == 0
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end