red-chainer 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +23 -0
- data/README.md +60 -0
- data/Rakefile +8 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/examples/mnist.rb +42 -0
- data/lib/chainer.rb +59 -0
- data/lib/chainer/configuration.rb +10 -0
- data/lib/chainer/dataset/convert.rb +62 -0
- data/lib/chainer/dataset/download.rb +56 -0
- data/lib/chainer/dataset/iterator.rb +15 -0
- data/lib/chainer/datasets/mnist.rb +89 -0
- data/lib/chainer/datasets/tuple_dataset.rb +33 -0
- data/lib/chainer/function.rb +80 -0
- data/lib/chainer/functions/activation/log_softmax.rb +37 -0
- data/lib/chainer/functions/activation/relu.rb +23 -0
- data/lib/chainer/functions/connection/linear.rb +48 -0
- data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
- data/lib/chainer/functions/math/basic_math.rb +119 -0
- data/lib/chainer/gradient_method.rb +63 -0
- data/lib/chainer/hyperparameter.rb +23 -0
- data/lib/chainer/initializer.rb +12 -0
- data/lib/chainer/initializers/constant.rb +18 -0
- data/lib/chainer/initializers/init.rb +24 -0
- data/lib/chainer/initializers/normal.rb +28 -0
- data/lib/chainer/iterators/serial_iterator.rb +74 -0
- data/lib/chainer/link.rb +118 -0
- data/lib/chainer/links/connection/linear.rb +43 -0
- data/lib/chainer/links/model/classifier.rb +39 -0
- data/lib/chainer/optimizer.rb +69 -0
- data/lib/chainer/optimizers/adam.rb +62 -0
- data/lib/chainer/parameter.rb +53 -0
- data/lib/chainer/reporter.rb +130 -0
- data/lib/chainer/training/extension.rb +25 -0
- data/lib/chainer/training/extensions/evaluator.rb +26 -0
- data/lib/chainer/training/extensions/log_report.rb +72 -0
- data/lib/chainer/training/extensions/print_report.rb +62 -0
- data/lib/chainer/training/extensions/progress_bar.rb +89 -0
- data/lib/chainer/training/standard_updater.rb +63 -0
- data/lib/chainer/training/trainer.rb +136 -0
- data/lib/chainer/training/triggers/interval.rb +27 -0
- data/lib/chainer/training/updater.rb +33 -0
- data/lib/chainer/training/util.rb +13 -0
- data/lib/chainer/utils/array.rb +10 -0
- data/lib/chainer/utils/initializer.rb +14 -0
- data/lib/chainer/utils/variable.rb +20 -0
- data/lib/chainer/variable.rb +204 -0
- data/lib/chainer/variable_node.rb +71 -0
- data/lib/chainer/version.rb +4 -0
- data/red-chainer.gemspec +27 -0
- metadata +156 -0
@@ -0,0 +1,43 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Links
|
3
|
+
module Connection
|
4
|
+
class Linear < ::Chainer::Link
|
5
|
+
attr_reader :w, :b
|
6
|
+
|
7
|
+
def initialize(in_size, out_size: nil, nobias: false, initial_w: nil, initial_bias: nil)
|
8
|
+
super()
|
9
|
+
in_size, out_size = nil, in_size if out_size.nil?
|
10
|
+
@out_size = out_size
|
11
|
+
|
12
|
+
init_scope do
|
13
|
+
w_initializer = Chainer::Initializers.get_initializer(initial_w)
|
14
|
+
@w = Chainer::Parameter.new(initializer: w_initializer)
|
15
|
+
|
16
|
+
initialize_params(in_size) unless in_size.nil?
|
17
|
+
|
18
|
+
if nobias
|
19
|
+
@b = nil
|
20
|
+
else
|
21
|
+
initial_bias = 0 if initial_bias.nil?
|
22
|
+
bias_initializer = Chainer::Initializers.get_initializer(initial_bias)
|
23
|
+
@b = Chainer::Parameter.new(initializer: bias_initializer, shape: out_size)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
def call(x)
|
29
|
+
if @w.data.nil?
|
30
|
+
initialize_params(x.size.div(x.shape[0]))
|
31
|
+
end
|
32
|
+
Chainer::Functions::Connection::LinearFunction.linear(x, @w, @b)
|
33
|
+
end
|
34
|
+
|
35
|
+
private
|
36
|
+
|
37
|
+
def initialize_params(in_size)
|
38
|
+
@w.init([@out_size, in_size])
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Links
|
3
|
+
module Model
|
4
|
+
class Classifier < Chain
|
5
|
+
attr_accessor :compute_accuracy
|
6
|
+
|
7
|
+
def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy))
|
8
|
+
super()
|
9
|
+
@lossfun = lossfun
|
10
|
+
@accfun = accfun
|
11
|
+
@y = nil
|
12
|
+
@loss = nil
|
13
|
+
@accuracy = nil
|
14
|
+
@compute_accuracy = true
|
15
|
+
|
16
|
+
init_scope do
|
17
|
+
@predictor = predictor
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(*args)
|
22
|
+
t = args.pop
|
23
|
+
x = args
|
24
|
+
@y = nil
|
25
|
+
@accuracy = nil
|
26
|
+
@y = @predictor.(*x)
|
27
|
+
|
28
|
+
@loss = @lossfun.call(@y, t)
|
29
|
+
Chainer::Reporter.save_report({loss: @loss}, self)
|
30
|
+
if @compute_accuracy
|
31
|
+
@accuracy = @accfun.call(@y, t)
|
32
|
+
Chainer::Reporter.save_report({accuracy: @accuracy}, self)
|
33
|
+
end
|
34
|
+
@loss
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
@@ -0,0 +1,69 @@
|
|
1
|
+
module Chainer
|
2
|
+
class Optimizer
|
3
|
+
attr_accessor :target
|
4
|
+
|
5
|
+
def setup(link)
|
6
|
+
@target = link
|
7
|
+
@t = 0
|
8
|
+
@epoch = 0
|
9
|
+
|
10
|
+
@hooks = {}
|
11
|
+
end
|
12
|
+
|
13
|
+
def _call_hook(hook)
|
14
|
+
if hook.methods.include?(:call_for_each_param)
|
15
|
+
@target.params.each do |param|
|
16
|
+
hook.(param.update_rule, param)
|
17
|
+
end
|
18
|
+
else
|
19
|
+
hook(self)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
class UpdateRule
|
25
|
+
attr_reader :state
|
26
|
+
|
27
|
+
def initialize(parent_hyperparam:)
|
28
|
+
@hooks = {}
|
29
|
+
@state = nil
|
30
|
+
@enabled = true
|
31
|
+
@hyperparam = Chainer::Hyperparameter.new(parent: parent_hyperparam)
|
32
|
+
@t = 0
|
33
|
+
end
|
34
|
+
|
35
|
+
def update(param)
|
36
|
+
return unless @enabled
|
37
|
+
|
38
|
+
@t += 1
|
39
|
+
prepare(param)
|
40
|
+
@hooks.values.each do |hook|
|
41
|
+
hook.call(param)
|
42
|
+
end
|
43
|
+
update_core(param)
|
44
|
+
end
|
45
|
+
|
46
|
+
def update_core(param)
|
47
|
+
# TODO: support GPU
|
48
|
+
update_core_cpu(param)
|
49
|
+
end
|
50
|
+
|
51
|
+
def update_core_cpu
|
52
|
+
raise NotImplementedError
|
53
|
+
end
|
54
|
+
|
55
|
+
def init_state(param)
|
56
|
+
raise NotImplementedError
|
57
|
+
end
|
58
|
+
|
59
|
+
private
|
60
|
+
|
61
|
+
def prepare(param)
|
62
|
+
if @state.nil?
|
63
|
+
@state = {}
|
64
|
+
init_state(param)
|
65
|
+
end
|
66
|
+
@state.select! { |_, v| v.kind_of?(Numo::NArray) }
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
@@ -0,0 +1,62 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Optimizers
|
3
|
+
class AdamRule < UpdateRule
|
4
|
+
def initialize(parent_hyperparam: nil, alpha: nil, beta1: nil, beta2: nil, eps: nil)
|
5
|
+
hyperparam = Hyperparameter.new
|
6
|
+
hyperparam.instance_variable_set('@alpha', 0.001)
|
7
|
+
hyperparam.instance_variable_set('@beta1', 0.9)
|
8
|
+
hyperparam.instance_variable_set('@beta2', 0.999)
|
9
|
+
hyperparam.instance_variable_set('@eps', 1e-8)
|
10
|
+
|
11
|
+
super(parent_hyperparam: parent_hyperparam || hyperparam)
|
12
|
+
|
13
|
+
@hyperparam.instance_variable_set('@alpha', alpha) if alpha
|
14
|
+
@hyperparam.instance_variable_set('@beta1', beta1) if beta1
|
15
|
+
@hyperparam.instance_variable_set('@beta2', beta2) if beta2
|
16
|
+
@hyperparam.instance_variable_set('@eps', eps) if eps
|
17
|
+
end
|
18
|
+
|
19
|
+
def init_state(param)
|
20
|
+
@state[:m] = param.data.new_zeros
|
21
|
+
@state[:v] = param.data.new_zeros
|
22
|
+
end
|
23
|
+
|
24
|
+
def update_core_cpu(param)
|
25
|
+
grad = param.grad
|
26
|
+
return if grad.nil?
|
27
|
+
|
28
|
+
hp = @hyperparam
|
29
|
+
|
30
|
+
@state[:m] += (1 - hp.beta1) * (grad - @state[:m])
|
31
|
+
@state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
|
32
|
+
param.data -= lr * @state[:m] / (Numo::NMath.sqrt(@state[:v]) + hp.eps)
|
33
|
+
end
|
34
|
+
|
35
|
+
def lr
|
36
|
+
fix1 = 1.0 - @hyperparam.beta1 ** @t
|
37
|
+
fix2 = 1.0 - @hyperparam.beta2 ** @t
|
38
|
+
@hyperparam.alpha * Math.sqrt(fix2) / fix1
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
class Adam < GradientMethod
|
43
|
+
def initialize(alpha: nil, beta1: nil, beta2: nil, eps: nil)
|
44
|
+
super()
|
45
|
+
@hyperparam.instance_variable_set('@alpha', alpha || 0.001)
|
46
|
+
@hyperparam.instance_variable_set('@beta1', beta1 || 0.9)
|
47
|
+
@hyperparam.instance_variable_set('@beta2', beta2 || 0.999)
|
48
|
+
@hyperparam.instance_variable_set('@eps', eps || 1e-8)
|
49
|
+
end
|
50
|
+
|
51
|
+
def create_update_rule
|
52
|
+
AdamRule.new(parent_hyperparam: @hyperparam)
|
53
|
+
end
|
54
|
+
|
55
|
+
def lr
|
56
|
+
fix1 = 1.0 - (@hyperparam.beta1 ** @t)
|
57
|
+
fix2 = 1.0 - (@hyperparam.beta2 ** @t)
|
58
|
+
@hyperparam.alpha * Math.sqrt(fix2) / fix1
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
@@ -0,0 +1,53 @@
|
|
1
|
+
module Chainer
|
2
|
+
class Parameter < Variable
|
3
|
+
attr_accessor :initializer, :grad_initializer, :update_rule
|
4
|
+
|
5
|
+
def initialize(initializer: nil, shape: nil, name: nil)
|
6
|
+
if initializer.nil?
|
7
|
+
initializer = Chainer::Initializers.nan()
|
8
|
+
elsif initializer.kind_of?(Numeric)
|
9
|
+
initializer = Initializers::Constant.new(initializer)
|
10
|
+
end
|
11
|
+
|
12
|
+
if shape.nil?
|
13
|
+
if @initializer.kind_of?(Numo::NArray)
|
14
|
+
super(initializer, name: name)
|
15
|
+
else
|
16
|
+
super(name: name)
|
17
|
+
@initializer = initializer
|
18
|
+
dtype = initializer.respond_to?(:dtype) ? initializer.dtype : 'DFloat'
|
19
|
+
@grad_initializer = Chainer::Initializers.nan()
|
20
|
+
end
|
21
|
+
else
|
22
|
+
if initializer.kind_of?(Numo::NArray)
|
23
|
+
initializer = Initializers::Constant.new(initializer)
|
24
|
+
end
|
25
|
+
data = Chainer::Initializers.generate_array(initializer, shape)
|
26
|
+
grad = Numo::NArray[*[1, 2]].new_fill(-922337203)
|
27
|
+
super(data, name: name, grad: grad)
|
28
|
+
end
|
29
|
+
|
30
|
+
@update_rule = nil
|
31
|
+
end
|
32
|
+
|
33
|
+
def cleargrad
|
34
|
+
super
|
35
|
+
@grad_initializer = nil if self.data.nil?
|
36
|
+
end
|
37
|
+
|
38
|
+
def init(shape)
|
39
|
+
data = Chainer::Initializers.generate_array(@initializer, shape)
|
40
|
+
ginit = @grad_initializer
|
41
|
+
grad = ginit.nil? ? nil : Chainer::Initializers.generate_array(ginit, shape)
|
42
|
+
|
43
|
+
@data[0] = data
|
44
|
+
@node.grad = grad
|
45
|
+
end
|
46
|
+
|
47
|
+
def update
|
48
|
+
if @update_rule
|
49
|
+
@update_rule.update(self)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
@@ -0,0 +1,130 @@
|
|
1
|
+
module Chainer
|
2
|
+
module ReportService
|
3
|
+
@@reporters = []
|
4
|
+
end
|
5
|
+
|
6
|
+
class Reporter
|
7
|
+
include ReportService
|
8
|
+
|
9
|
+
def initialize
|
10
|
+
@observer_names = {}
|
11
|
+
@observation = {}
|
12
|
+
end
|
13
|
+
|
14
|
+
def self.save_report(values, observer=nil)
|
15
|
+
reporter = @@reporters[-1]
|
16
|
+
reporter.report(values, observer)
|
17
|
+
end
|
18
|
+
|
19
|
+
def report(values, observer=nil)
|
20
|
+
# TODO: keep_graph_on_report option
|
21
|
+
if observer
|
22
|
+
observer_id = observer.object_id
|
23
|
+
unless @observer_names.keys.include?(observer_id)
|
24
|
+
raise "Given observer is not registered to the reporter."
|
25
|
+
end
|
26
|
+
observer_name = @observer_names[observer_id]
|
27
|
+
values.each do |key, value|
|
28
|
+
name = "#{observer_name}/#{key}"
|
29
|
+
@observation[name] = value
|
30
|
+
end
|
31
|
+
else
|
32
|
+
@observation.update(values)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
def add_observer(name, observer)
|
37
|
+
@observer_names[observer.object_id] = name
|
38
|
+
end
|
39
|
+
|
40
|
+
def scope(observation)
|
41
|
+
@@reporters << self
|
42
|
+
old = @observation
|
43
|
+
@observation = observation
|
44
|
+
yield
|
45
|
+
@observation = old
|
46
|
+
@@reporters.pop
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
class Summary
|
51
|
+
def initialize
|
52
|
+
@x = 0
|
53
|
+
@x2 = 0
|
54
|
+
@n = 0
|
55
|
+
end
|
56
|
+
|
57
|
+
# Adds a scalar value.
|
58
|
+
# Args:
|
59
|
+
# value: Scalar value to accumulate.
|
60
|
+
def add(value)
|
61
|
+
@x += value
|
62
|
+
@x2 += value * value
|
63
|
+
@n += 1
|
64
|
+
end
|
65
|
+
|
66
|
+
# Computes the mean.
|
67
|
+
def compute_mean
|
68
|
+
@x.to_f / @n
|
69
|
+
end
|
70
|
+
|
71
|
+
# Computes and returns the mean and standard deviation values.
|
72
|
+
# Returns:
|
73
|
+
# array: Mean and standard deviation values.
|
74
|
+
def make_statistics
|
75
|
+
mean = @x / @n
|
76
|
+
var = @x2 / @n - mean * mean
|
77
|
+
std = Math.sqrt(var)
|
78
|
+
[mean, std]
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
# Online summarization of a sequence of dictionaries.
|
83
|
+
# ``DictSummary`` computes the statistics of a given set of scalars online.
|
84
|
+
# It only computes the statistics for scalar values and variables of scalar values in the dictionaries.
|
85
|
+
class DictSummary
|
86
|
+
def initialize
|
87
|
+
@summaries = Hash.new { |h,k| h[k] = Summary.new }
|
88
|
+
end
|
89
|
+
|
90
|
+
# Adds a dictionary of scalars.
|
91
|
+
# Args:
|
92
|
+
# d (dict): Dictionary of scalars to accumulate. Only elements of
|
93
|
+
# scalars, zero-dimensional arrays, and variables of
|
94
|
+
# zero-dimensional arrays are accumulated.
|
95
|
+
def add(d)
|
96
|
+
d.each do |k, v|
|
97
|
+
v = v.data if v.kind_of?(Chainer::Variable)
|
98
|
+
if v.class.method_defined?(:to_i) || (v.class.method_defined?(:ndim) && v.ndim == 0)
|
99
|
+
@summaries[k].add(v)
|
100
|
+
end
|
101
|
+
end
|
102
|
+
end
|
103
|
+
|
104
|
+
# Creates a dictionary of mean values.
|
105
|
+
# It returns a single dictionary that holds a mean value for each entry added to the summary.
|
106
|
+
#
|
107
|
+
# Returns:
|
108
|
+
# dict: Dictionary of mean values.
|
109
|
+
def compute_mean
|
110
|
+
@summaries.each_with_object({}) { |(name, summary), h| h[name] = summary.compute_mean }
|
111
|
+
end
|
112
|
+
|
113
|
+
# Creates a dictionary of statistics.
|
114
|
+
# It returns a single dictionary that holds mean and standard deviation
|
115
|
+
# values for every entry added to the summary. For an entry of name
|
116
|
+
# ``'key'``, these values are added to the dictionary by names ``'key'`` and ``'key.std'``, respectively.
|
117
|
+
#
|
118
|
+
# Returns:
|
119
|
+
# dict: Dictionary of statistics of all entries.
|
120
|
+
def make_statistics
|
121
|
+
stats = {}
|
122
|
+
@summaries.each do |name, summary|
|
123
|
+
mean, std = summary.make_statistics
|
124
|
+
stats[name] = mean
|
125
|
+
stats[name + '.std'] = std
|
126
|
+
end
|
127
|
+
stats
|
128
|
+
end
|
129
|
+
end
|
130
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
class Extension
|
4
|
+
PRIORITY_WRITER = 300
|
5
|
+
PRIORITY_EDITOR = 200
|
6
|
+
PRIORITY_READER = 100
|
7
|
+
|
8
|
+
attr_accessor :name, :priority
|
9
|
+
|
10
|
+
def initialize
|
11
|
+
end
|
12
|
+
|
13
|
+
def call(trainer)
|
14
|
+
end
|
15
|
+
|
16
|
+
def default_name
|
17
|
+
self.class.to_s
|
18
|
+
end
|
19
|
+
|
20
|
+
def priority
|
21
|
+
@priority || PRIORITY_READER
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,26 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Training
|
3
|
+
module Extensions
|
4
|
+
class Evaluator < Extension
|
5
|
+
def initialize(iterator, target, converter: nil, device: nil, eval_hook: nil, eval_func: nil)
|
6
|
+
@priority = Extension::PRIORITY_WRITER
|
7
|
+
|
8
|
+
if iterator.kind_of?(Dataset::Iterator)
|
9
|
+
iterator = { main: iterator }
|
10
|
+
end
|
11
|
+
@iterators = iterator
|
12
|
+
|
13
|
+
if target.kind_of?(Link)
|
14
|
+
target = { main: target }
|
15
|
+
end
|
16
|
+
@targets = target
|
17
|
+
|
18
|
+
@converter = converter || Dataset::Convert.method(:concat_examples)
|
19
|
+
@device = device
|
20
|
+
@eval_hook = eval_hook
|
21
|
+
@eval_func = eval_func
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|