red-chainer 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +23 -0
- data/README.md +60 -0
- data/Rakefile +8 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/examples/mnist.rb +42 -0
- data/lib/chainer.rb +59 -0
- data/lib/chainer/configuration.rb +10 -0
- data/lib/chainer/dataset/convert.rb +62 -0
- data/lib/chainer/dataset/download.rb +56 -0
- data/lib/chainer/dataset/iterator.rb +15 -0
- data/lib/chainer/datasets/mnist.rb +89 -0
- data/lib/chainer/datasets/tuple_dataset.rb +33 -0
- data/lib/chainer/function.rb +80 -0
- data/lib/chainer/functions/activation/log_softmax.rb +37 -0
- data/lib/chainer/functions/activation/relu.rb +23 -0
- data/lib/chainer/functions/connection/linear.rb +48 -0
- data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
- data/lib/chainer/functions/math/basic_math.rb +119 -0
- data/lib/chainer/gradient_method.rb +63 -0
- data/lib/chainer/hyperparameter.rb +23 -0
- data/lib/chainer/initializer.rb +12 -0
- data/lib/chainer/initializers/constant.rb +18 -0
- data/lib/chainer/initializers/init.rb +24 -0
- data/lib/chainer/initializers/normal.rb +28 -0
- data/lib/chainer/iterators/serial_iterator.rb +74 -0
- data/lib/chainer/link.rb +118 -0
- data/lib/chainer/links/connection/linear.rb +43 -0
- data/lib/chainer/links/model/classifier.rb +39 -0
- data/lib/chainer/optimizer.rb +69 -0
- data/lib/chainer/optimizers/adam.rb +62 -0
- data/lib/chainer/parameter.rb +53 -0
- data/lib/chainer/reporter.rb +130 -0
- data/lib/chainer/training/extension.rb +25 -0
- data/lib/chainer/training/extensions/evaluator.rb +26 -0
- data/lib/chainer/training/extensions/log_report.rb +72 -0
- data/lib/chainer/training/extensions/print_report.rb +62 -0
- data/lib/chainer/training/extensions/progress_bar.rb +89 -0
- data/lib/chainer/training/standard_updater.rb +63 -0
- data/lib/chainer/training/trainer.rb +136 -0
- data/lib/chainer/training/triggers/interval.rb +27 -0
- data/lib/chainer/training/updater.rb +33 -0
- data/lib/chainer/training/util.rb +13 -0
- data/lib/chainer/utils/array.rb +10 -0
- data/lib/chainer/utils/initializer.rb +14 -0
- data/lib/chainer/utils/variable.rb +20 -0
- data/lib/chainer/variable.rb +204 -0
- data/lib/chainer/variable_node.rb +71 -0
- data/lib/chainer/version.rb +4 -0
- data/red-chainer.gemspec +27 -0
- metadata +156 -0
@@ -0,0 +1,43 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Links
|
3
|
+
module Connection
|
4
|
+
class Linear < ::Chainer::Link
|
5
|
+
attr_reader :w, :b
|
6
|
+
|
7
|
+
def initialize(in_size, out_size: nil, nobias: false, initial_w: nil, initial_bias: nil)
|
8
|
+
super()
|
9
|
+
in_size, out_size = nil, in_size if out_size.nil?
|
10
|
+
@out_size = out_size
|
11
|
+
|
12
|
+
init_scope do
|
13
|
+
w_initializer = Chainer::Initializers.get_initializer(initial_w)
|
14
|
+
@w = Chainer::Parameter.new(initializer: w_initializer)
|
15
|
+
|
16
|
+
initialize_params(in_size) unless in_size.nil?
|
17
|
+
|
18
|
+
if nobias
|
19
|
+
@b = nil
|
20
|
+
else
|
21
|
+
initial_bias = 0 if initial_bias.nil?
|
22
|
+
bias_initializer = Chainer::Initializers.get_initializer(initial_bias)
|
23
|
+
@b = Chainer::Parameter.new(initializer: bias_initializer, shape: out_size)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
def call(x)
|
29
|
+
if @w.data.nil?
|
30
|
+
initialize_params(x.size.div(x.shape[0]))
|
31
|
+
end
|
32
|
+
Chainer::Functions::Connection::LinearFunction.linear(x, @w, @b)
|
33
|
+
end
|
34
|
+
|
35
|
+
private
|
36
|
+
|
37
|
+
def initialize_params(in_size)
|
38
|
+
@w.init([@out_size, in_size])
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Links
|
3
|
+
module Model
|
4
|
+
class Classifier < Chain
|
5
|
+
attr_accessor :compute_accuracy
|
6
|
+
|
7
|
+
def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy))
|
8
|
+
super()
|
9
|
+
@lossfun = lossfun
|
10
|
+
@accfun = accfun
|
11
|
+
@y = nil
|
12
|
+
@loss = nil
|
13
|
+
@accuracy = nil
|
14
|
+
@compute_accuracy = true
|
15
|
+
|
16
|
+
init_scope do
|
17
|
+
@predictor = predictor
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(*args)
|
22
|
+
t = args.pop
|
23
|
+
x = args
|
24
|
+
@y = nil
|
25
|
+
@accuracy = nil
|
26
|
+
@y = @predictor.(*x)
|
27
|
+
|
28
|
+
@loss = @lossfun.call(@y, t)
|
29
|
+
Chainer::Reporter.save_report({loss: @loss}, self)
|
30
|
+
if @compute_accuracy
|
31
|
+
@accuracy = @accfun.call(@y, t)
|
32
|
+
Chainer::Reporter.save_report({accuracy: @accuracy}, self)
|
33
|
+
end
|
34
|
+
@loss
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
@@ -0,0 +1,69 @@
|
|
1
|
+
module Chainer
|
2
|
+
class Optimizer
|
3
|
+
attr_accessor :target
|
4
|
+
|
5
|
+
def setup(link)
|
6
|
+
@target = link
|
7
|
+
@t = 0
|
8
|
+
@epoch = 0
|
9
|
+
|
10
|
+
@hooks = {}
|
11
|
+
end
|
12
|
+
|
13
|
+
def _call_hook(hook)
|
14
|
+
if hook.methods.include?(:call_for_each_param)
|
15
|
+
@target.params.each do |param|
|
16
|
+
hook.(param.update_rule, param)
|
17
|
+
end
|
18
|
+
else
|
19
|
+
hook(self)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
class UpdateRule
|
25
|
+
attr_reader :state
|
26
|
+
|
27
|
+
def initialize(parent_hyperparam:)
|
28
|
+
@hooks = {}
|
29
|
+
@state = nil
|
30
|
+
@enabled = true
|
31
|
+
@hyperparam = Chainer::Hyperparameter.new(parent: parent_hyperparam)
|
32
|
+
@t = 0
|
33
|
+
end
|
34
|
+
|
35
|
+
def update(param)
|
36
|
+
return unless @enabled
|
37
|
+
|
38
|
+
@t += 1
|
39
|
+
prepare(param)
|
40
|
+
@hooks.values.each do |hook|
|
41
|
+
hook.call(param)
|
42
|
+
end
|
43
|
+
update_core(param)
|
44
|
+
end
|
45
|
+
|
46
|
+
def update_core(param)
|
47
|
+
# TODO: support GPU
|
48
|
+
update_core_cpu(param)
|
49
|
+
end
|
50
|
+
|
51
|
+
def update_core_cpu
|
52
|
+
raise NotImplementedError
|
53
|
+
end
|
54
|
+
|
55
|
+
def init_state(param)
|
56
|
+
raise NotImplementedError
|
57
|
+
end
|
58
|
+
|
59
|
+
private
|
60
|
+
|
61
|
+
def prepare(param)
|
62
|
+
if @state.nil?
|
63
|
+
@state = {}
|
64
|
+
init_state(param)
|
65
|
+
end
|
66
|
+
@state.select! { |_, v| v.kind_of?(Numo::NArray) }
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Optimizers
|
3
|
+
class AdamRule < UpdateRule
|
4
|
+
def initialize(parent_hyperparam: nil, alpha: nil, beta1: nil, beta2: nil, eps: nil)
|
5
|
+
hyperparam = Hyperparameter.new
|
6
|
+
hyperparam.instance_variable_set('@alpha', 0.001)
|
7
|
+
hyperparam.instance_variable_set('@beta1', 0.9)
|
8
|
+
hyperparam.instance_variable_set('@beta2', 0.999)
|
9
|
+
hyperparam.instance_variable_set('@eps', 1e-8)
|
10
|
+
|
11
|
+
super(parent_hyperparam: parent_hyperparam || hyperparam)
|
12
|
+
|
13
|
+
@hyperparam.instance_variable_set('@alpha', alpha) if alpha
|
14
|
+
@hyperparam.instance_variable_set('@beta1', beta1) if beta1
|
15
|
+
@hyperparam.instance_variable_set('@beta2', beta2) if beta2
|
16
|
+
@hyperparam.instance_variable_set('@eps', eps) if eps
|
17
|
+
end
|
18
|
+
|
19
|
+
def init_state(param)
|
20
|
+
@state[:m] = param.data.new_zeros
|
21
|
+
@state[:v] = param.data.new_zeros
|
22
|
+
end
|
23
|
+
|
24
|
+
def update_core_cpu(param)
|
25
|
+
grad = param.grad
|
26
|
+
return if grad.nil?
|
27
|
+
|
28
|
+
hp = @hyperparam
|
29
|
+
|
30
|
+
@state[:m] += (1 - hp.beta1) * (grad - @state[:m])
|
31
|
+
@state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
|
32
|
+
param.data -= lr * @state[:m] / (Numo::NMath.sqrt(@state[:v]) + hp.eps)
|
33
|
+
end
|
34
|
+
|
35
|
+
def lr
|
36
|
+
fix1 = 1.0 - @hyperparam.beta1 ** @t
|
37
|
+
fix2 = 1.0 - @hyperparam.beta2 ** @t
|
38
|
+
@hyperparam.alpha * Math.sqrt(fix2) / fix1
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
class Adam < GradientMethod
|
43
|
+
def initialize(alpha: nil, beta1: nil, beta2: nil, eps: nil)
|
44
|
+
super()
|
45
|
+
@hyperparam.instance_variable_set('@alpha', alpha || 0.001)
|
46
|
+
@hyperparam.instance_variable_set('@beta1', beta1 || 0.9)
|
47
|
+
@hyperparam.instance_variable_set('@beta2', beta2 || 0.999)
|
48
|
+
@hyperparam.instance_variable_set('@eps', eps || 1e-8)
|
49
|
+
end
|
50
|
+
|
51
|
+
def create_update_rule
|
52
|
+
AdamRule.new(parent_hyperparam: @hyperparam)
|
53
|
+
end
|
54
|
+
|
55
|
+
def lr
|
56
|
+
fix1 = 1.0 - (@hyperparam.beta1 ** @t)
|
57
|
+
fix2 = 1.0 - (@hyperparam.beta2 ** @t)
|
58
|
+
@hyperparam.alpha * Math.sqrt(fix2) / fix1
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,53 @@
|
|
1
|
+
module Chainer
|
2
|
+
class Parameter < Variable
|
3
|
+
attr_accessor :initializer, :grad_initializer, :update_rule
|
4
|
+
|
5
|
+
def initialize(initializer: nil, shape: nil, name: nil)
|
6
|
+
if initializer.nil?
|
7
|
+
initializer = Chainer::Initializers.nan()
|
8
|
+
elsif initializer.kind_of?(Numeric)
|
9
|
+
initializer = Initializers::Constant.new(initializer)
|
10
|
+
end
|
11
|
+
|
12
|
+
if shape.nil?
|
13
|
+
if @initializer.kind_of?(Numo::NArray)
|
14
|
+
super(initializer, name: name)
|
15
|
+
else
|
16
|
+
super(name: name)
|
17
|
+
@initializer = initializer
|
18
|
+
dtype = initializer.respond_to?(:dtype) ? initializer.dtype : 'DFloat'
|
19
|
+
@grad_initializer = Chainer::Initializers.nan()
|
20
|
+
end
|
21
|
+
else
|
22
|
+
if initializer.kind_of?(Numo::NArray)
|
23
|
+
initializer = Initializers::Constant.new(initializer)
|
24
|
+
end
|
25
|
+
data = Chainer::Initializers.generate_array(initializer, shape)
|
26
|
+
grad = Numo::NArray[*[1, 2]].new_fill(-922337203)
|
27
|
+
super(data, name: name, grad: grad)
|
28
|
+
end
|
29
|
+
|
30
|
+
@update_rule = nil
|
31
|
+
end
|
32
|
+
|
33
|
+
def cleargrad
|
34
|
+
super
|
35
|
+
@grad_initializer = nil if self.data.nil?
|
36
|
+
end
|
37
|
+
|
38
|
+
def init(shape)
|
39
|
+
data = Chainer::Initializers.generate_array(@initializer, shape)
|
40
|
+
ginit = @grad_initializer
|
41
|
+
grad = ginit.nil? ? nil : Chainer::Initializers.generate_array(ginit, shape)
|
42
|
+
|
43
|
+
@data[0] = data
|
44
|
+
@node.grad = grad
|
45
|
+
end
|
46
|
+
|
47
|
+
def update
|
48
|
+
if @update_rule
|
49
|
+
@update_rule.update(self)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
@@ -0,0 +1,130 @@
|
|
1
|
+
module Chainer
|
2
|
+
module ReportService
|
3
|
+
@@reporters = []
|
4
|
+
end
|
5
|
+
|
6
|
+
class Reporter
|
7
|
+
include ReportService
|
8
|
+
|
9
|
+
def initialize
|
10
|
+
@observer_names = {}
|
11
|
+
@observation = {}
|
12
|
+
end
|
13
|
+
|
14
|
+
def self.save_report(values, observer=nil)
|
15
|
+
reporter = @@reporters[-1]
|
16
|
+
reporter.report(values, observer)
|
17
|
+
end
|
18
|
+
|
19
|
+
def report(values, observer=nil)
|
20
|
+
# TODO: keep_graph_on_report option
|
21
|
+
if observer
|
22
|
+
observer_id = observer.object_id
|
23
|
+
unless @observer_names.keys.include?(observer_id)
|
24
|
+
raise "Given observer is not registered to the reporter."
|
25
|
+
end
|
26
|
+
observer_name = @observer_names[observer_id]
|
27
|
+
values.each do |key, value|
|
28
|
+
name = "#{observer_name}/#{key}"
|
29
|
+
@observation[name] = value
|
30
|
+
end
|
31
|
+
else
|
32
|
+
@observation.update(values)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
def add_observer(name, observer)
|
37
|
+
@observer_names[observer.object_id] = name
|
38
|
+
end
|
39
|
+
|
40
|
+
def scope(observation)
|
41
|
+
@@reporters << self
|
42
|
+
old = @observation
|
43
|
+
@observation = observation
|
44
|
+
yield
|
45
|
+
@observation = old
|
46
|
+
@@reporters.pop
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
class Summary
|
51
|
+
def initialize
|
52
|
+
@x = 0
|
53
|
+
@x2 = 0
|
54
|
+
@n = 0
|
55
|
+
end
|
56
|
+
|
57
|
+
# Adds a scalar value.
|
58
|
+
# Args:
|
59
|
+
# value: Scalar value to accumulate.
|
60
|
+
def add(value)
|
61
|
+
@x += value
|
62
|
+
@x2 += value * value
|
63
|
+
@n += 1
|
64
|
+
end
|
65
|
+
|
66
|
+
# Computes the mean.
|
67
|
+
def compute_mean
|
68
|
+
@x.to_f / @n
|
69
|
+
end
|
70
|
+
|
71
|
+
# Computes and returns the mean and standard deviation values.
|
72
|
+
# Returns:
|
73
|
+
# array: Mean and standard deviation values.
|
74
|
+
def make_statistics
|
75
|
+
mean = @x / @n
|
76
|
+
var = @x2 / @n - mean * mean
|
77
|
+
std = Math.sqrt(var)
|
78
|
+
[mean, std]
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
# Online summarization of a sequence of dictionaries.
|
83
|
+
# ``DictSummary`` computes the statistics of a given set of scalars online.
|
84
|
+
# It only computes the statistics for scalar values and variables of scalar values in the dictionaries.
|
85
|
+
class DictSummary
|
86
|
+
def initialize
|
87
|
+
@summaries = Hash.new { |h,k| h[k] = Summary.new }
|
88
|
+
end
|
89
|
+
|
90
|
+
# Adds a dictionary of scalars.
|
91
|
+
# Args:
|
92
|
+
# d (dict): Dictionary of scalars to accumulate. Only elements of
|
93
|
+
# scalars, zero-dimensional arrays, and variables of
|
94
|
+
# zero-dimensional arrays are accumulated.
|
95
|
+
def add(d)
|
96
|
+
d.each do |k, v|
|
97
|
+
v = v.data if v.kind_of?(Chainer::Variable)
|
98
|
+
if v.class.method_defined?(:to_i) || (v.class.method_defined?(:ndim) && v.ndim == 0)
|
99
|
+
@summaries[k].add(v)
|
100
|
+
end
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
# Creates a dictionary of mean values.
|
105
|
+
# It returns a single dictionary that holds a mean value for each entry added to the summary.
|
106
|
+
#
|
107
|
+
# Returns:
|
108
|
+
# dict: Dictionary of mean values.
|
109
|
+
def compute_mean
|
110
|
+
@summaries.each_with_object({}) { |(name, summary), h| h[name] = summary.compute_mean }
|
111
|
+
end
|
112
|
+
|
113
|
+
# Creates a dictionary of statistics.
|
114
|
+
# It returns a single dictionary that holds mean and standard deviation
|
115
|
+
# values for every entry added to the summary. For an entry of name
|
116
|
+
# ``'key'``, these values are added to the dictionary by names ``'key'`` and ``'key.std'``, respectively.
|
117
|
+
#
|
118
|
+
# Returns:
|
119
|
+
# dict: Dictionary of statistics of all entries.
|
120
|
+
def make_statistics
|
121
|
+
stats = {}
|
122
|
+
@summaries.each do |name, summary|
|
123
|
+
mean, std = summary.make_statistics
|
124
|
+
stats[name] = mean
|
125
|
+
stats[name + '.std'] = std
|
126
|
+
end
|
127
|
+
stats
|
128
|
+
end
|
129
|
+
end
|
130
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
class Extension
|
4
|
+
PRIORITY_WRITER = 300
|
5
|
+
PRIORITY_EDITOR = 200
|
6
|
+
PRIORITY_READER = 100
|
7
|
+
|
8
|
+
attr_accessor :name, :priority
|
9
|
+
|
10
|
+
def initialize
|
11
|
+
end
|
12
|
+
|
13
|
+
def call(trainer)
|
14
|
+
end
|
15
|
+
|
16
|
+
def default_name
|
17
|
+
self.class.to_s
|
18
|
+
end
|
19
|
+
|
20
|
+
def priority
|
21
|
+
@priority || PRIORITY_READER
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Extensions
|
4
|
+
class Evaluator < Extension
|
5
|
+
def initialize(iterator, target, converter: nil, device: nil, eval_hook: nil, eval_func: nil)
|
6
|
+
@priority = Extension::PRIORITY_WRITER
|
7
|
+
|
8
|
+
if iterator.kind_of?(Dataset::Iterator)
|
9
|
+
iterator = { main: iterator }
|
10
|
+
end
|
11
|
+
@iterators = iterator
|
12
|
+
|
13
|
+
if target.kind_of?(Link)
|
14
|
+
target = { main: target }
|
15
|
+
end
|
16
|
+
@targets = target
|
17
|
+
|
18
|
+
@converter = converter || Dataset::Convert.method(:concat_examples)
|
19
|
+
@device = device
|
20
|
+
@eval_hook = eval_hook
|
21
|
+
@eval_func = eval_func
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|