red-chainer 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,72 @@
1
+ require 'tempfile'
2
+ require 'json'
3
+
4
+ module Chainer
5
+ module Training
6
+ module Extensions
7
+ class LogReport < Extension
8
+ attr_reader :log
9
+
10
+ def initialize(keys: nil, trigger: [1, 'epoch'], postprocess: nil, log_name: 'log')
11
+ @keys = keys
12
+ @trigger = Chainer::Training::Util.get_trigger(trigger)
13
+ @postprocess = postprocess
14
+ @log_name = log_name
15
+ @log = []
16
+
17
+ init_summary
18
+ end
19
+
20
+ def call(trainer)
21
+ observation = trainer.observation
22
+
23
+ if @keys.nil?
24
+ @summary.add(observation)
25
+ else
26
+ symbolized_observation = Hash[observation.map{|(k,v)| [k.to_sym,v]}]
27
+ filterd_keys = @keys.select {|k| observation.keys.include?(k.to_sym) }
28
+ @summary.add(filterd_keys.each_with_object({}) {|k, hash| hash[k.to_s] = observation[k.to_sym] })
29
+ end
30
+
31
+ # if trigger is true, output the result
32
+ return unless @trigger.(trainer)
33
+
34
+ stats = @summary.compute_mean
35
+ stats_cpu = {}
36
+ stats.each do |name, value|
37
+ stats_cpu[name] = value.to_f # copy to CPU
38
+ end
39
+
40
+ updater = trainer.updater
41
+ stats_cpu['epoch'] = updater.epoch
42
+ stats_cpu['iteration'] = updater.iteration
43
+ stats_cpu['elapsed_time'] = trainer.elapsed_time
44
+
45
+ @postprocess.(stats_cpu) unless @postprocess.nil?
46
+
47
+ @log << stats_cpu
48
+
49
+ unless @log_name.nil?
50
+ # example: sprintf("%{a}, %{b}", {a: "1", b: "2"})
51
+ # => "1, 2"
52
+ log_name = sprintf(@log_name, stats_cpu)
53
+ temp_file = Tempfile.create(basename: log_name, tmpdir: trainer.out)
54
+
55
+ JSON.dump(@log, temp_file)
56
+
57
+ new_path = File.join(trainer.out, log_name)
58
+ FileUtils.mv(temp_file.path, new_path)
59
+ end
60
+
61
+ init_summary
62
+ end
63
+
64
+ private
65
+
66
+ def init_summary
67
+ @summary = DictSummary.new
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
@@ -0,0 +1,62 @@
1
+ module Chainer
2
+ module Training
3
+ module Extensions
4
+ class PrintReport < Extension
5
+ def initialize(entries, log_report: 'Chainer::Training::Extensions::LogReport', out: STDOUT)
6
+ @entries = entries
7
+ @log_report = log_report
8
+ @out = out
9
+
10
+ @log_len = 0 # number of observations already printed
11
+
12
+ # format information
13
+ entry_widths = entries.map { |s| [10, s.size].max }
14
+
15
+ templates = []
16
+ header = []
17
+ entries.zip(entry_widths).each do |entry, w|
18
+ header << sprintf("%-#{w}s", entry)
19
+ templates << [entry, "%-#{w}g ", ' ' * (w + 2)]
20
+ end
21
+ @header = header.join(' ') + "\n"
22
+ @templates = templates
23
+ end
24
+
25
+ def call(trainer)
26
+ if @header
27
+ @out.write(@header)
28
+ @header = nil
29
+ end
30
+
31
+ if @log_report.is_a?(String)
32
+ log_report = trainer.get_extension(@log_report)
33
+ elsif @log_report.is_a?(LogReport)
34
+ log_report.(trainer)
35
+ else
36
+ raise TypeError, "log report has a wrong type #{log_report.class}"
37
+ end
38
+
39
+ log = log_report.log
40
+ while log.size > @log_len
41
+ @out.write("\033[J")
42
+ print(log[@log_len])
43
+ @log_len += 1
44
+ end
45
+ end
46
+
47
+ private
48
+
49
+ def print(observation)
50
+ @templates.each do |entry, template, empty|
51
+ if observation.keys.include?(entry)
52
+ @out.write(sprintf(template, observation[entry]))
53
+ else
54
+ @out.write(empty)
55
+ end
56
+ end
57
+ @out.write("\n")
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,89 @@
1
+ require 'erb'
2
+
3
+ module Chainer
4
+ module Training
5
+ module Extensions
6
+ class ProgressBar < Extension
7
+ def initialize(training_length: nil, update_interval: 100, bar_length: 50, out: STDOUT)
8
+ @training_length = training_length
9
+ @status_template = nil
10
+ @update_interval = update_interval
11
+ @bar_length = bar_length
12
+ @out = out
13
+ @out.sync = true
14
+ @recent_timing = []
15
+ end
16
+
17
+ def call(trainer)
18
+ if @training_length.nil?
19
+ t = trainer.stop_trigger
20
+ raise TypeError, "cannot retrieve the training length #{t.class}" unless t.is_a?(Chainer::Training::Triggers::IntervalTrigger)
21
+ @training_length = [t.period, t.unit]
22
+ end
23
+
24
+ if @status_template.nil?
25
+ @status_template = ERB.new("<%= sprintf('%10d', self.iteration) %> iter, <%= self.epoch %> epoch / #{@training_length[0]} #{@training_length[1]}s\n")
26
+ end
27
+
28
+ length, unit = @training_length
29
+ iteration = trainer.updater.iteration
30
+
31
+ # print the progress bar according to interval
32
+ return unless iteration % @update_interval == 0
33
+
34
+ epoch = trainer.updater.epoch_detail
35
+ now = Time.now.to_f
36
+
37
+ @recent_timing << [iteration, epoch, now]
38
+ @out.write("\033[J")
39
+
40
+ if unit == 'iteration'
41
+ rate = iteration.to_f / length
42
+ else
43
+ rate = epoch.to_f / length
44
+ end
45
+
46
+ marks = '#' * (rate * @bar_length).to_i
47
+ @out.write(sprintf(" total [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), rate * 100))
48
+
49
+ epoch_rate = epoch - epoch.to_i
50
+ marks = '#' * (epoch_rate * @bar_length).to_i
51
+ @out.write(sprintf("this epoch [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), epoch_rate * 100))
52
+
53
+ status = @status_template.result(trainer.updater.bind)
54
+ @out.write(status)
55
+
56
+ old_t, old_e, old_sec = @recent_timing[0]
57
+ span = now - old_sec
58
+
59
+ if span.zero?
60
+ speed_t = Float::INFINITY
61
+ speed_e = Float::INFINITY
62
+ else
63
+ speed_t = (iteration - old_t) / span
64
+ speed_e = (epoch - old_e) / span
65
+ end
66
+
67
+ if unit == 'iteration'
68
+ estimated_time = (length - iteration) / speed_t
69
+ else
70
+ estimated_time = (length - epoch) / speed_e
71
+ end
72
+
73
+ @out.write(sprintf("%10.5g iters/sec. Estimated time to finish: %s.\n", speed_t, (Time.parse("1991/01/01") + (estimated_time)).strftime("%H:%m:%S")))
74
+
75
+ # move the cursor to the head of the progress bar
76
+ @out.write("\033[4A") # TODO: Support Windows
77
+ @out.flush
78
+
79
+ @recent_timing.delete_at(0) if @recent_timing.size > 100
80
+ end
81
+
82
+ def finalize
83
+ @out.write("\033[J") # TODO: Support Windows
84
+ @out.flush
85
+ end
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,63 @@
1
+ module Chainer
2
+ module Training
3
+ class StandardUpdater < Updater
4
+ attr_accessor :iteration
5
+
6
+ def initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil)
7
+ if iterator.kind_of?(Dataset::Iterator)
8
+ iterator = { main: iterator }
9
+ end
10
+ @iterators = iterator
11
+
12
+ unless optimizer.kind_of?(Hash)
13
+ optimizer = { main: optimizer }
14
+ end
15
+ @optimizers = optimizer
16
+
17
+ @converter = converter || Dataset::Convert.method(:concat_examples)
18
+ @loss_func = loss_func
19
+ @device = device
20
+ @iteration = 0
21
+ end
22
+
23
+ def get_all_optimizers
24
+ @optimizers.to_h
25
+ end
26
+
27
+ def update
28
+ update_core
29
+ @iteration += 1
30
+ end
31
+
32
+ def epoch
33
+ @iterators[:main].epoch
34
+ end
35
+
36
+ def epoch_detail
37
+ @iterators[:main].epoch_detail
38
+ end
39
+
40
+ def update_core
41
+ batch = @iterators[:main].next
42
+ in_arrays = @converter.call(batch, device: @device)
43
+
44
+ optimizer = @optimizers[:main]
45
+ loss_func = @loss_func || optimizer.target
46
+
47
+ if in_arrays.kind_of?(Array)
48
+ optimizer.update(loss_func, *in_arrays)
49
+ elsif in_arrays.kind_of?(Hash)
50
+ optimizer.update(loss_func, **in_arrays)
51
+ else
52
+ optimizer.update(loss_func, in_arrays)
53
+ end
54
+ end
55
+
56
+ def finalize
57
+ @iterators.each do |(_, iterator)|
58
+ iterator.finalize
59
+ end
60
+ end
61
+ end
62
+ end
63
+ end
@@ -0,0 +1,136 @@
1
+ module Chainer
2
+ module Training
3
+ class ExtensionEntry
4
+ attr_accessor :extension, :trigger, :invoke_before_training, :priority
5
+
6
+ def initialize(extension, priority, trigger, invoke_before_training)
7
+ @extension = extension
8
+ @trigger = trigger
9
+ @invoke_before_training = invoke_before_training
10
+ @priority = priority
11
+ end
12
+ end
13
+
14
+ class Trainer
15
+ attr_accessor :updater, :stop_trigger, :observation, :out
16
+
17
+ def initialize(updater, stop_trigger: nil, out: 'result')
18
+ @updater = updater
19
+ @stop_trigger = Chainer::Training::Util.get_trigger(stop_trigger)
20
+ @observation = {}
21
+ @out = out
22
+
23
+ reporter = Reporter.new
24
+ updater.get_all_optimizers().each do |(name, optimizer)|
25
+ reporter.add_observer(name, optimizer.target)
26
+ optimizer.target.namedlinks(skipself: true) do |suffix, observer|
27
+ observer_name = name.to_s + suffix
28
+ reporter.add_observer(observer_name, observer)
29
+ end
30
+ end
31
+ @reporter = reporter
32
+
33
+ @done = false
34
+ @extensions = {}
35
+
36
+ @start_at = nil
37
+ @snapshot_elapsed_time = 0.0
38
+ @final_elapsed_time = nil
39
+
40
+ updater.connect_trainer(self)
41
+ end
42
+
43
+ def elapsed_time
44
+ return @final_elapsed_time if @done
45
+ raise "training has not been started yet" if @start_at.nil?
46
+
47
+ Time.now.to_f - @start_at - @snapshot_elapsed_time.to_f
48
+ end
49
+
50
+ def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
51
+ if name.nil?
52
+ name = if extension.name
53
+ extension.name
54
+ elsif extension.default_name
55
+ extension.default_name
56
+ else
57
+ raise ArgumentError 'name is not given for the extension'
58
+ end
59
+ end
60
+
61
+ raise 'the name "training" is prohibited as an extension name' if name == 'training'
62
+
63
+ if trigger.nil?
64
+ trigger = extension.methods.include?(:trigger) ? extension.trigger : [1, 'iteration']
65
+ end
66
+ trigger = Chainer::Training::Util.get_trigger(trigger)
67
+
68
+ if priority.nil?
69
+ priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER
70
+ end
71
+
72
+ if invoke_before_training.nil?
73
+ invoke_before_training = extension.methods.include?(:invoke_before_training) ? extension.invoke_before_training : false
74
+ end
75
+
76
+ modified_name = name
77
+ ordinal = 0
78
+
79
+ @extensions.each do |modified_name|
80
+ ordinal += 1
81
+ modified_name = "#{name}_#{ordinal}"
82
+ end
83
+
84
+ extension.name = modified_name
85
+ @extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger, invoke_before_training)
86
+ end
87
+
88
+ def get_extension(name)
89
+ if @extensions.keys.include?(name)
90
+ @extensions[name].extension
91
+ else
92
+ raise "extension #{name} not found"
93
+ end
94
+ end
95
+
96
+ def run
97
+ raise 'cannot run training loop multiple times' if @done
98
+ FileUtils.mkdir_p(@out)
99
+
100
+ extensions = @extensions.sort_by { |(_, e)| e.priority }.map { |(name, extension)| [name, extension] }
101
+
102
+ @start_at = Time.now.to_f
103
+
104
+ extensions.each do |(_, entry)|
105
+ initializer = entry.extension.methods.include?(:init) ? entry.extension.method(:init) : nil
106
+ initializer.call(self) if initializer
107
+ end
108
+
109
+ update = @updater.method(:update)
110
+ reporter = @reporter
111
+ stop_trigger = @stop_trigger
112
+
113
+ begin
114
+ until stop_trigger.(self) do
115
+ @observation = {}
116
+ reporter.scope(@observation) do
117
+ update.call
118
+ extensions.each do |(_, entry)|
119
+ entry.extension.(self) if entry.trigger.(self)
120
+ end
121
+ end
122
+ end
123
+ ensure
124
+ extensions.each do |(_, entry)|
125
+ finalize = entry.extension.methods.include?(:finalize) ? entry.extension.method(:finalize) : nil
126
+ finalize.() if finalize
127
+ end
128
+ @updater.finalize()
129
+ end
130
+
131
+ @final_elapsed_time = @elapsed_time
132
+ @done = true
133
+ end
134
+ end
135
+ end
136
+ end
@@ -0,0 +1,27 @@
1
+ module Chainer
2
+ module Training
3
+ module Triggers
4
+ class IntervalTrigger
5
+ attr_reader :period, :unit, :count
6
+
7
+ def initialize(period, unit)
8
+ @period = period
9
+ @unit = unit
10
+ @count = 0
11
+ end
12
+
13
+ def call(trainer)
14
+ updater = trainer.updater
15
+ if @unit == 'epoch'
16
+ prev = @count
17
+ @count = updater.epoch_detail.div(@period)
18
+ prev != @count
19
+ else
20
+ iteration = updater.iteration
21
+ iteration > 0 && iteration % @period == 0
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end