red-chainer 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,43 @@
1
+ module Chainer
2
+ module Links
3
+ module Connection
4
+ class Linear < ::Chainer::Link
5
+ attr_reader :w, :b
6
+
7
+ def initialize(in_size, out_size: nil, nobias: false, initial_w: nil, initial_bias: nil)
8
+ super()
9
+ in_size, out_size = nil, in_size if out_size.nil?
10
+ @out_size = out_size
11
+
12
+ init_scope do
13
+ w_initializer = Chainer::Initializers.get_initializer(initial_w)
14
+ @w = Chainer::Parameter.new(initializer: w_initializer)
15
+
16
+ initialize_params(in_size) unless in_size.nil?
17
+
18
+ if nobias
19
+ @b = nil
20
+ else
21
+ initial_bias = 0 if initial_bias.nil?
22
+ bias_initializer = Chainer::Initializers.get_initializer(initial_bias)
23
+ @b = Chainer::Parameter.new(initializer: bias_initializer, shape: out_size)
24
+ end
25
+ end
26
+ end
27
+
28
+ def call(x)
29
+ if @w.data.nil?
30
+ initialize_params(x.size.div(x.shape[0]))
31
+ end
32
+ Chainer::Functions::Connection::LinearFunction.linear(x, @w, @b)
33
+ end
34
+
35
+ private
36
+
37
+ def initialize_params(in_size)
38
+ @w.init([@out_size, in_size])
39
+ end
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,39 @@
1
+ module Chainer
2
+ module Links
3
+ module Model
4
+ class Classifier < Chain
5
+ attr_accessor :compute_accuracy
6
+
7
+ def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy))
8
+ super()
9
+ @lossfun = lossfun
10
+ @accfun = accfun
11
+ @y = nil
12
+ @loss = nil
13
+ @accuracy = nil
14
+ @compute_accuracy = true
15
+
16
+ init_scope do
17
+ @predictor = predictor
18
+ end
19
+ end
20
+
21
+ def call(*args)
22
+ t = args.pop
23
+ x = args
24
+ @y = nil
25
+ @accuracy = nil
26
+ @y = @predictor.(*x)
27
+
28
+ @loss = @lossfun.call(@y, t)
29
+ Chainer::Reporter.save_report({loss: @loss}, self)
30
+ if @compute_accuracy
31
+ @accuracy = @accfun.call(@y, t)
32
+ Chainer::Reporter.save_report({accuracy: @accuracy}, self)
33
+ end
34
+ @loss
35
+ end
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,69 @@
1
+ module Chainer
2
+ class Optimizer
3
+ attr_accessor :target
4
+
5
+ def setup(link)
6
+ @target = link
7
+ @t = 0
8
+ @epoch = 0
9
+
10
+ @hooks = {}
11
+ end
12
+
13
+ def _call_hook(hook)
14
+ if hook.methods.include?(:call_for_each_param)
15
+ @target.params.each do |param|
16
+ hook.(param.update_rule, param)
17
+ end
18
+ else
19
+ hook(self)
20
+ end
21
+ end
22
+ end
23
+
24
+ class UpdateRule
25
+ attr_reader :state
26
+
27
+ def initialize(parent_hyperparam:)
28
+ @hooks = {}
29
+ @state = nil
30
+ @enabled = true
31
+ @hyperparam = Chainer::Hyperparameter.new(parent: parent_hyperparam)
32
+ @t = 0
33
+ end
34
+
35
+ def update(param)
36
+ return unless @enabled
37
+
38
+ @t += 1
39
+ prepare(param)
40
+ @hooks.values.each do |hook|
41
+ hook.call(param)
42
+ end
43
+ update_core(param)
44
+ end
45
+
46
+ def update_core(param)
47
+ # TODO: support GPU
48
+ update_core_cpu(param)
49
+ end
50
+
51
+ def update_core_cpu
52
+ raise NotImplementedError
53
+ end
54
+
55
+ def init_state(param)
56
+ raise NotImplementedError
57
+ end
58
+
59
+ private
60
+
61
+ def prepare(param)
62
+ if @state.nil?
63
+ @state = {}
64
+ init_state(param)
65
+ end
66
+ @state.select! { |_, v| v.kind_of?(Numo::NArray) }
67
+ end
68
+ end
69
+ end
@@ -0,0 +1,62 @@
1
+ module Chainer
2
+ module Optimizers
3
+ class AdamRule < UpdateRule
4
+ def initialize(parent_hyperparam: nil, alpha: nil, beta1: nil, beta2: nil, eps: nil)
5
+ hyperparam = Hyperparameter.new
6
+ hyperparam.instance_variable_set('@alpha', 0.001)
7
+ hyperparam.instance_variable_set('@beta1', 0.9)
8
+ hyperparam.instance_variable_set('@beta2', 0.999)
9
+ hyperparam.instance_variable_set('@eps', 1e-8)
10
+
11
+ super(parent_hyperparam: parent_hyperparam || hyperparam)
12
+
13
+ @hyperparam.instance_variable_set('@alpha', alpha) if alpha
14
+ @hyperparam.instance_variable_set('@beta1', beta1) if beta1
15
+ @hyperparam.instance_variable_set('@beta2', beta2) if beta2
16
+ @hyperparam.instance_variable_set('@eps', eps) if eps
17
+ end
18
+
19
+ def init_state(param)
20
+ @state[:m] = param.data.new_zeros
21
+ @state[:v] = param.data.new_zeros
22
+ end
23
+
24
+ def update_core_cpu(param)
25
+ grad = param.grad
26
+ return if grad.nil?
27
+
28
+ hp = @hyperparam
29
+
30
+ @state[:m] += (1 - hp.beta1) * (grad - @state[:m])
31
+ @state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
32
+ param.data -= lr * @state[:m] / (Numo::NMath.sqrt(@state[:v]) + hp.eps)
33
+ end
34
+
35
+ def lr
36
+ fix1 = 1.0 - @hyperparam.beta1 ** @t
37
+ fix2 = 1.0 - @hyperparam.beta2 ** @t
38
+ @hyperparam.alpha * Math.sqrt(fix2) / fix1
39
+ end
40
+ end
41
+
42
+ class Adam < GradientMethod
43
+ def initialize(alpha: nil, beta1: nil, beta2: nil, eps: nil)
44
+ super()
45
+ @hyperparam.instance_variable_set('@alpha', alpha || 0.001)
46
+ @hyperparam.instance_variable_set('@beta1', beta1 || 0.9)
47
+ @hyperparam.instance_variable_set('@beta2', beta2 || 0.999)
48
+ @hyperparam.instance_variable_set('@eps', eps || 1e-8)
49
+ end
50
+
51
+ def create_update_rule
52
+ AdamRule.new(parent_hyperparam: @hyperparam)
53
+ end
54
+
55
+ def lr
56
+ fix1 = 1.0 - (@hyperparam.beta1 ** @t)
57
+ fix2 = 1.0 - (@hyperparam.beta2 ** @t)
58
+ @hyperparam.alpha * Math.sqrt(fix2) / fix1
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,53 @@
1
+ module Chainer
2
+ class Parameter < Variable
3
+ attr_accessor :initializer, :grad_initializer, :update_rule
4
+
5
+ def initialize(initializer: nil, shape: nil, name: nil)
6
+ if initializer.nil?
7
+ initializer = Chainer::Initializers.nan()
8
+ elsif initializer.kind_of?(Numeric)
9
+ initializer = Initializers::Constant.new(initializer)
10
+ end
11
+
12
+ if shape.nil?
13
+ if @initializer.kind_of?(Numo::NArray)
14
+ super(initializer, name: name)
15
+ else
16
+ super(name: name)
17
+ @initializer = initializer
18
+ dtype = initializer.respond_to?(:dtype) ? initializer.dtype : 'DFloat'
19
+ @grad_initializer = Chainer::Initializers.nan()
20
+ end
21
+ else
22
+ if initializer.kind_of?(Numo::NArray)
23
+ initializer = Initializers::Constant.new(initializer)
24
+ end
25
+ data = Chainer::Initializers.generate_array(initializer, shape)
26
+ grad = Numo::NArray[*[1, 2]].new_fill(-922337203)
27
+ super(data, name: name, grad: grad)
28
+ end
29
+
30
+ @update_rule = nil
31
+ end
32
+
33
+ def cleargrad
34
+ super
35
+ @grad_initializer = nil if self.data.nil?
36
+ end
37
+
38
+ def init(shape)
39
+ data = Chainer::Initializers.generate_array(@initializer, shape)
40
+ ginit = @grad_initializer
41
+ grad = ginit.nil? ? nil : Chainer::Initializers.generate_array(ginit, shape)
42
+
43
+ @data[0] = data
44
+ @node.grad = grad
45
+ end
46
+
47
+ def update
48
+ if @update_rule
49
+ @update_rule.update(self)
50
+ end
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,130 @@
1
+ module Chainer
2
+ module ReportService
3
+ @@reporters = []
4
+ end
5
+
6
+ class Reporter
7
+ include ReportService
8
+
9
+ def initialize
10
+ @observer_names = {}
11
+ @observation = {}
12
+ end
13
+
14
+ def self.save_report(values, observer=nil)
15
+ reporter = @@reporters[-1]
16
+ reporter.report(values, observer)
17
+ end
18
+
19
+ def report(values, observer=nil)
20
+ # TODO: keep_graph_on_report option
21
+ if observer
22
+ observer_id = observer.object_id
23
+ unless @observer_names.keys.include?(observer_id)
24
+ raise "Given observer is not registered to the reporter."
25
+ end
26
+ observer_name = @observer_names[observer_id]
27
+ values.each do |key, value|
28
+ name = "#{observer_name}/#{key}"
29
+ @observation[name] = value
30
+ end
31
+ else
32
+ @observation.update(values)
33
+ end
34
+ end
35
+
36
+ def add_observer(name, observer)
37
+ @observer_names[observer.object_id] = name
38
+ end
39
+
40
+ def scope(observation)
41
+ @@reporters << self
42
+ old = @observation
43
+ @observation = observation
44
+ yield
45
+ @observation = old
46
+ @@reporters.pop
47
+ end
48
+ end
49
+
50
+ class Summary
51
+ def initialize
52
+ @x = 0
53
+ @x2 = 0
54
+ @n = 0
55
+ end
56
+
57
+ # Adds a scalar value.
58
+ # Args:
59
+ # value: Scalar value to accumulate.
60
+ def add(value)
61
+ @x += value
62
+ @x2 += value * value
63
+ @n += 1
64
+ end
65
+
66
+ # Computes the mean.
67
+ def compute_mean
68
+ @x.to_f / @n
69
+ end
70
+
71
+ # Computes and returns the mean and standard deviation values.
72
+ # Returns:
73
+ # array: Mean and standard deviation values.
74
+ def make_statistics
75
+ mean = @x / @n
76
+ var = @x2 / @n - mean * mean
77
+ std = Math.sqrt(var)
78
+ [mean, std]
79
+ end
80
+ end
81
+
82
+ # Online summarization of a sequence of dictionaries.
83
+ # ``DictSummary`` computes the statistics of a given set of scalars online.
84
+ # It only computes the statistics for scalar values and variables of scalar values in the dictionaries.
85
+ class DictSummary
86
+ def initialize
87
+ @summaries = Hash.new { |h,k| h[k] = Summary.new }
88
+ end
89
+
90
+ # Adds a dictionary of scalars.
91
+ # Args:
92
+ # d (dict): Dictionary of scalars to accumulate. Only elements of
93
+ # scalars, zero-dimensional arrays, and variables of
94
+ # zero-dimensional arrays are accumulated.
95
+ def add(d)
96
+ d.each do |k, v|
97
+ v = v.data if v.kind_of?(Chainer::Variable)
98
+ if v.class.method_defined?(:to_i) || (v.class.method_defined?(:ndim) && v.ndim == 0)
99
+ @summaries[k].add(v)
100
+ end
101
+ end
102
+ end
103
+
104
+ # Creates a dictionary of mean values.
105
+ # It returns a single dictionary that holds a mean value for each entry added to the summary.
106
+ #
107
+ # Returns:
108
+ # dict: Dictionary of mean values.
109
+ def compute_mean
110
+ @summaries.each_with_object({}) { |(name, summary), h| h[name] = summary.compute_mean }
111
+ end
112
+
113
+ # Creates a dictionary of statistics.
114
+ # It returns a single dictionary that holds mean and standard deviation
115
+ # values for every entry added to the summary. For an entry of name
116
+ # ``'key'``, these values are added to the dictionary by names ``'key'`` and ``'key.std'``, respectively.
117
+ #
118
+ # Returns:
119
+ # dict: Dictionary of statistics of all entries.
120
+ def make_statistics
121
+ stats = {}
122
+ @summaries.each do |name, summary|
123
+ mean, std = summary.make_statistics
124
+ stats[name] = mean
125
+ stats[name + '.std'] = std
126
+ end
127
+ stats
128
+ end
129
+ end
130
+ end
@@ -0,0 +1,25 @@
1
+ module Chainer
2
+ module Training
3
+ class Extension
4
+ PRIORITY_WRITER = 300
5
+ PRIORITY_EDITOR = 200
6
+ PRIORITY_READER = 100
7
+
8
+ attr_accessor :name, :priority
9
+
10
+ def initialize
11
+ end
12
+
13
+ def call(trainer)
14
+ end
15
+
16
+ def default_name
17
+ self.class.to_s
18
+ end
19
+
20
+ def priority
21
+ @priority || PRIORITY_READER
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,26 @@
1
+ module Chainer
2
+ module Training
3
+ module Extensions
4
+ class Evaluator < Extension
5
+ def initialize(iterator, target, converter: nil, device: nil, eval_hook: nil, eval_func: nil)
6
+ @priority = Extension::PRIORITY_WRITER
7
+
8
+ if iterator.kind_of?(Dataset::Iterator)
9
+ iterator = { main: iterator }
10
+ end
11
+ @iterators = iterator
12
+
13
+ if target.kind_of?(Link)
14
+ target = { main: target }
15
+ end
16
+ @targets = target
17
+
18
+ @converter = converter || Dataset::Convert.method(:concat_examples)
19
+ @device = device
20
+ @eval_hook = eval_hook
21
+ @eval_func = eval_func
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end