red-chainer 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,43 @@
1
+ module Chainer
2
+ module Links
3
+ module Connection
4
+ class Linear < ::Chainer::Link
5
+ attr_reader :w, :b
6
+
7
+ def initialize(in_size, out_size: nil, nobias: false, initial_w: nil, initial_bias: nil)
8
+ super()
9
+ in_size, out_size = nil, in_size if out_size.nil?
10
+ @out_size = out_size
11
+
12
+ init_scope do
13
+ w_initializer = Chainer::Initializers.get_initializer(initial_w)
14
+ @w = Chainer::Parameter.new(initializer: w_initializer)
15
+
16
+ initialize_params(in_size) unless in_size.nil?
17
+
18
+ if nobias
19
+ @b = nil
20
+ else
21
+ initial_bias = 0 if initial_bias.nil?
22
+ bias_initializer = Chainer::Initializers.get_initializer(initial_bias)
23
+ @b = Chainer::Parameter.new(initializer: bias_initializer, shape: out_size)
24
+ end
25
+ end
26
+ end
27
+
28
+ def call(x)
29
+ if @w.data.nil?
30
+ initialize_params(x.size.div(x.shape[0]))
31
+ end
32
+ Chainer::Functions::Connection::LinearFunction.linear(x, @w, @b)
33
+ end
34
+
35
+ private
36
+
37
+ def initialize_params(in_size)
38
+ @w.init([@out_size, in_size])
39
+ end
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,39 @@
1
+ module Chainer
2
+ module Links
3
+ module Model
4
+ class Classifier < Chain
5
+ attr_accessor :compute_accuracy
6
+
7
+ def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy))
8
+ super()
9
+ @lossfun = lossfun
10
+ @accfun = accfun
11
+ @y = nil
12
+ @loss = nil
13
+ @accuracy = nil
14
+ @compute_accuracy = true
15
+
16
+ init_scope do
17
+ @predictor = predictor
18
+ end
19
+ end
20
+
21
+ def call(*args)
22
+ t = args.pop
23
+ x = args
24
+ @y = nil
25
+ @accuracy = nil
26
+ @y = @predictor.(*x)
27
+
28
+ @loss = @lossfun.call(@y, t)
29
+ Chainer::Reporter.save_report({loss: @loss}, self)
30
+ if @compute_accuracy
31
+ @accuracy = @accfun.call(@y, t)
32
+ Chainer::Reporter.save_report({accuracy: @accuracy}, self)
33
+ end
34
+ @loss
35
+ end
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,69 @@
1
+ module Chainer
2
+ class Optimizer
3
+ attr_accessor :target
4
+
5
+ def setup(link)
6
+ @target = link
7
+ @t = 0
8
+ @epoch = 0
9
+
10
+ @hooks = {}
11
+ end
12
+
13
+ def _call_hook(hook)
14
+ if hook.methods.include?(:call_for_each_param)
15
+ @target.params.each do |param|
16
+ hook.(param.update_rule, param)
17
+ end
18
+ else
19
+ hook(self)
20
+ end
21
+ end
22
+ end
23
+
24
+ class UpdateRule
25
+ attr_reader :state
26
+
27
+ def initialize(parent_hyperparam:)
28
+ @hooks = {}
29
+ @state = nil
30
+ @enabled = true
31
+ @hyperparam = Chainer::Hyperparameter.new(parent: parent_hyperparam)
32
+ @t = 0
33
+ end
34
+
35
+ def update(param)
36
+ return unless @enabled
37
+
38
+ @t += 1
39
+ prepare(param)
40
+ @hooks.values.each do |hook|
41
+ hook.call(param)
42
+ end
43
+ update_core(param)
44
+ end
45
+
46
+ def update_core(param)
47
+ # TODO: support GPU
48
+ update_core_cpu(param)
49
+ end
50
+
51
+ def update_core_cpu
52
+ raise NotImplementedError
53
+ end
54
+
55
+ def init_state(param)
56
+ raise NotImplementedError
57
+ end
58
+
59
+ private
60
+
61
+ def prepare(param)
62
+ if @state.nil?
63
+ @state = {}
64
+ init_state(param)
65
+ end
66
+ @state.select! { |_, v| v.kind_of?(Numo::NArray) }
67
+ end
68
+ end
69
+ end
@@ -0,0 +1,62 @@
1
+ module Chainer
2
+ module Optimizers
3
+ class AdamRule < UpdateRule
4
+ def initialize(parent_hyperparam: nil, alpha: nil, beta1: nil, beta2: nil, eps: nil)
5
+ hyperparam = Hyperparameter.new
6
+ hyperparam.instance_variable_set('@alpha', 0.001)
7
+ hyperparam.instance_variable_set('@beta1', 0.9)
8
+ hyperparam.instance_variable_set('@beta2', 0.999)
9
+ hyperparam.instance_variable_set('@eps', 1e-8)
10
+
11
+ super(parent_hyperparam: parent_hyperparam || hyperparam)
12
+
13
+ @hyperparam.instance_variable_set('@alpha', alpha) if alpha
14
+ @hyperparam.instance_variable_set('@beta1', beta1) if beta1
15
+ @hyperparam.instance_variable_set('@beta2', beta2) if beta2
16
+ @hyperparam.instance_variable_set('@eps', eps) if eps
17
+ end
18
+
19
+ def init_state(param)
20
+ @state[:m] = param.data.new_zeros
21
+ @state[:v] = param.data.new_zeros
22
+ end
23
+
24
+ def update_core_cpu(param)
25
+ grad = param.grad
26
+ return if grad.nil?
27
+
28
+ hp = @hyperparam
29
+
30
+ @state[:m] += (1 - hp.beta1) * (grad - @state[:m])
31
+ @state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
32
+ param.data -= lr * @state[:m] / (Numo::NMath.sqrt(@state[:v]) + hp.eps)
33
+ end
34
+
35
+ def lr
36
+ fix1 = 1.0 - @hyperparam.beta1 ** @t
37
+ fix2 = 1.0 - @hyperparam.beta2 ** @t
38
+ @hyperparam.alpha * Math.sqrt(fix2) / fix1
39
+ end
40
+ end
41
+
42
+ class Adam < GradientMethod
43
+ def initialize(alpha: nil, beta1: nil, beta2: nil, eps: nil)
44
+ super()
45
+ @hyperparam.instance_variable_set('@alpha', alpha || 0.001)
46
+ @hyperparam.instance_variable_set('@beta1', beta1 || 0.9)
47
+ @hyperparam.instance_variable_set('@beta2', beta2 || 0.999)
48
+ @hyperparam.instance_variable_set('@eps', eps || 1e-8)
49
+ end
50
+
51
+ def create_update_rule
52
+ AdamRule.new(parent_hyperparam: @hyperparam)
53
+ end
54
+
55
+ def lr
56
+ fix1 = 1.0 - (@hyperparam.beta1 ** @t)
57
+ fix2 = 1.0 - (@hyperparam.beta2 ** @t)
58
+ @hyperparam.alpha * Math.sqrt(fix2) / fix1
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,53 @@
1
+ module Chainer
2
+ class Parameter < Variable
3
+ attr_accessor :initializer, :grad_initializer, :update_rule
4
+
5
+ def initialize(initializer: nil, shape: nil, name: nil)
6
+ if initializer.nil?
7
+ initializer = Chainer::Initializers.nan()
8
+ elsif initializer.kind_of?(Numeric)
9
+ initializer = Initializers::Constant.new(initializer)
10
+ end
11
+
12
+ if shape.nil?
13
+ if @initializer.kind_of?(Numo::NArray)
14
+ super(initializer, name: name)
15
+ else
16
+ super(name: name)
17
+ @initializer = initializer
18
+ dtype = initializer.respond_to?(:dtype) ? initializer.dtype : 'DFloat'
19
+ @grad_initializer = Chainer::Initializers.nan()
20
+ end
21
+ else
22
+ if initializer.kind_of?(Numo::NArray)
23
+ initializer = Initializers::Constant.new(initializer)
24
+ end
25
+ data = Chainer::Initializers.generate_array(initializer, shape)
26
+ grad = Numo::NArray[*[1, 2]].new_fill(-922337203)
27
+ super(data, name: name, grad: grad)
28
+ end
29
+
30
+ @update_rule = nil
31
+ end
32
+
33
+ def cleargrad
34
+ super
35
+ @grad_initializer = nil if self.data.nil?
36
+ end
37
+
38
+ def init(shape)
39
+ data = Chainer::Initializers.generate_array(@initializer, shape)
40
+ ginit = @grad_initializer
41
+ grad = ginit.nil? ? nil : Chainer::Initializers.generate_array(ginit, shape)
42
+
43
+ @data[0] = data
44
+ @node.grad = grad
45
+ end
46
+
47
+ def update
48
+ if @update_rule
49
+ @update_rule.update(self)
50
+ end
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,130 @@
1
+ module Chainer
2
+ module ReportService
3
+ @@reporters = []
4
+ end
5
+
6
+ class Reporter
7
+ include ReportService
8
+
9
+ def initialize
10
+ @observer_names = {}
11
+ @observation = {}
12
+ end
13
+
14
+ def self.save_report(values, observer=nil)
15
+ reporter = @@reporters[-1]
16
+ reporter.report(values, observer)
17
+ end
18
+
19
+ def report(values, observer=nil)
20
+ # TODO: keep_graph_on_report option
21
+ if observer
22
+ observer_id = observer.object_id
23
+ unless @observer_names.keys.include?(observer_id)
24
+ raise "Given observer is not registered to the reporter."
25
+ end
26
+ observer_name = @observer_names[observer_id]
27
+ values.each do |key, value|
28
+ name = "#{observer_name}/#{key}"
29
+ @observation[name] = value
30
+ end
31
+ else
32
+ @observation.update(values)
33
+ end
34
+ end
35
+
36
+ def add_observer(name, observer)
37
+ @observer_names[observer.object_id] = name
38
+ end
39
+
40
+ def scope(observation)
41
+ @@reporters << self
42
+ old = @observation
43
+ @observation = observation
44
+ yield
45
+ @observation = old
46
+ @@reporters.pop
47
+ end
48
+ end
49
+
50
+ class Summary
51
+ def initialize
52
+ @x = 0
53
+ @x2 = 0
54
+ @n = 0
55
+ end
56
+
57
+ # Adds a scalar value.
58
+ # Args:
59
+ # value: Scalar value to accumulate.
60
+ def add(value)
61
+ @x += value
62
+ @x2 += value * value
63
+ @n += 1
64
+ end
65
+
66
+ # Computes the mean.
67
+ def compute_mean
68
+ @x.to_f / @n
69
+ end
70
+
71
+ # Computes and returns the mean and standard deviation values.
72
+ # Returns:
73
+ # array: Mean and standard deviation values.
74
+ def make_statistics
75
+ mean = @x / @n
76
+ var = @x2 / @n - mean * mean
77
+ std = Math.sqrt(var)
78
+ [mean, std]
79
+ end
80
+ end
81
+
82
+ # Online summarization of a sequence of dictionaries.
83
+ # ``DictSummary`` computes the statistics of a given set of scalars online.
84
+ # It only computes the statistics for scalar values and variables of scalar values in the dictionaries.
85
+ class DictSummary
86
+ def initialize
87
+ @summaries = Hash.new { |h,k| h[k] = Summary.new }
88
+ end
89
+
90
+ # Adds a dictionary of scalars.
91
+ # Args:
92
+ # d (dict): Dictionary of scalars to accumulate. Only elements of
93
+ # scalars, zero-dimensional arrays, and variables of
94
+ # zero-dimensional arrays are accumulated.
95
+ def add(d)
96
+ d.each do |k, v|
97
+ v = v.data if v.kind_of?(Chainer::Variable)
98
+ if v.class.method_defined?(:to_i) || (v.class.method_defined?(:ndim) && v.ndim == 0)
99
+ @summaries[k].add(v)
100
+ end
101
+ end
102
+ end
103
+
104
+ # Creates a dictionary of mean values.
105
+ # It returns a single dictionary that holds a mean value for each entry added to the summary.
106
+ #
107
+ # Returns:
108
+ # dict: Dictionary of mean values.
109
+ def compute_mean
110
+ @summaries.each_with_object({}) { |(name, summary), h| h[name] = summary.compute_mean }
111
+ end
112
+
113
+ # Creates a dictionary of statistics.
114
+ # It returns a single dictionary that holds mean and standard deviation
115
+ # values for every entry added to the summary. For an entry of name
116
+ # ``'key'``, these values are added to the dictionary by names ``'key'`` and ``'key.std'``, respectively.
117
+ #
118
+ # Returns:
119
+ # dict: Dictionary of statistics of all entries.
120
+ def make_statistics
121
+ stats = {}
122
+ @summaries.each do |name, summary|
123
+ mean, std = summary.make_statistics
124
+ stats[name] = mean
125
+ stats[name + '.std'] = std
126
+ end
127
+ stats
128
+ end
129
+ end
130
+ end
@@ -0,0 +1,25 @@
1
+ module Chainer
2
+ module Training
3
+ class Extension
4
+ PRIORITY_WRITER = 300
5
+ PRIORITY_EDITOR = 200
6
+ PRIORITY_READER = 100
7
+
8
+ attr_accessor :name, :priority
9
+
10
+ def initialize
11
+ end
12
+
13
+ def call(trainer)
14
+ end
15
+
16
+ def default_name
17
+ self.class.to_s
18
+ end
19
+
20
+ def priority
21
+ @priority || PRIORITY_READER
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,26 @@
1
+ module Chainer
2
+ module Training
3
+ module Extensions
4
+ class Evaluator < Extension
5
+ def initialize(iterator, target, converter: nil, device: nil, eval_hook: nil, eval_func: nil)
6
+ @priority = Extension::PRIORITY_WRITER
7
+
8
+ if iterator.kind_of?(Dataset::Iterator)
9
+ iterator = { main: iterator }
10
+ end
11
+ @iterators = iterator
12
+
13
+ if target.kind_of?(Link)
14
+ target = { main: target }
15
+ end
16
+ @targets = target
17
+
18
+ @converter = converter || Dataset::Convert.method(:concat_examples)
19
+ @device = device
20
+ @eval_hook = eval_hook
21
+ @eval_func = eval_func
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end