red-chainer 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +23 -0
- data/README.md +60 -0
- data/Rakefile +8 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/examples/mnist.rb +42 -0
- data/lib/chainer.rb +59 -0
- data/lib/chainer/configuration.rb +10 -0
- data/lib/chainer/dataset/convert.rb +62 -0
- data/lib/chainer/dataset/download.rb +56 -0
- data/lib/chainer/dataset/iterator.rb +15 -0
- data/lib/chainer/datasets/mnist.rb +89 -0
- data/lib/chainer/datasets/tuple_dataset.rb +33 -0
- data/lib/chainer/function.rb +80 -0
- data/lib/chainer/functions/activation/log_softmax.rb +37 -0
- data/lib/chainer/functions/activation/relu.rb +23 -0
- data/lib/chainer/functions/connection/linear.rb +48 -0
- data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
- data/lib/chainer/functions/math/basic_math.rb +119 -0
- data/lib/chainer/gradient_method.rb +63 -0
- data/lib/chainer/hyperparameter.rb +23 -0
- data/lib/chainer/initializer.rb +12 -0
- data/lib/chainer/initializers/constant.rb +18 -0
- data/lib/chainer/initializers/init.rb +24 -0
- data/lib/chainer/initializers/normal.rb +28 -0
- data/lib/chainer/iterators/serial_iterator.rb +74 -0
- data/lib/chainer/link.rb +118 -0
- data/lib/chainer/links/connection/linear.rb +43 -0
- data/lib/chainer/links/model/classifier.rb +39 -0
- data/lib/chainer/optimizer.rb +69 -0
- data/lib/chainer/optimizers/adam.rb +62 -0
- data/lib/chainer/parameter.rb +53 -0
- data/lib/chainer/reporter.rb +130 -0
- data/lib/chainer/training/extension.rb +25 -0
- data/lib/chainer/training/extensions/evaluator.rb +26 -0
- data/lib/chainer/training/extensions/log_report.rb +72 -0
- data/lib/chainer/training/extensions/print_report.rb +62 -0
- data/lib/chainer/training/extensions/progress_bar.rb +89 -0
- data/lib/chainer/training/standard_updater.rb +63 -0
- data/lib/chainer/training/trainer.rb +136 -0
- data/lib/chainer/training/triggers/interval.rb +27 -0
- data/lib/chainer/training/updater.rb +33 -0
- data/lib/chainer/training/util.rb +13 -0
- data/lib/chainer/utils/array.rb +10 -0
- data/lib/chainer/utils/initializer.rb +14 -0
- data/lib/chainer/utils/variable.rb +20 -0
- data/lib/chainer/variable.rb +204 -0
- data/lib/chainer/variable_node.rb +71 -0
- data/lib/chainer/version.rb +4 -0
- data/red-chainer.gemspec +27 -0
- metadata +156 -0
@@ -0,0 +1,119 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Math
|
4
|
+
|
5
|
+
class Neg < ::Chainer::Function
|
6
|
+
def forward(x)
|
7
|
+
retain_inputs([])
|
8
|
+
[Utils::Array.force_array(-x[0])]
|
9
|
+
end
|
10
|
+
|
11
|
+
def backward(x, gy)
|
12
|
+
[Utils::Array.force_array(-gy[0])]
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
class Add < ::Chainer::Function
|
17
|
+
def forward(x)
|
18
|
+
retain_inputs([])
|
19
|
+
[Utils::Array.force_array(x[0] + x[1])]
|
20
|
+
end
|
21
|
+
|
22
|
+
def backward(x, gy)
|
23
|
+
[gy[0], gy[0]]
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
class AddConstant < ::Chainer::Function
|
28
|
+
def initialize(value)
|
29
|
+
@value = value
|
30
|
+
end
|
31
|
+
|
32
|
+
def forward(x)
|
33
|
+
retain_inputs([])
|
34
|
+
[Utils::Array.force_array(x[0] + @value)]
|
35
|
+
end
|
36
|
+
|
37
|
+
def backward(x, gy)
|
38
|
+
[gy[0]]
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
class Sub < ::Chainer::Function
|
43
|
+
def forward(x)
|
44
|
+
retain_inputs([])
|
45
|
+
[Utils::Array.force_array(x[0] - x[1])]
|
46
|
+
end
|
47
|
+
|
48
|
+
def backward(x, gy)
|
49
|
+
[gy[0], Utils::Array.force_array(-gy[0])]
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
class Mul < ::Chainer::Function
|
54
|
+
def forward(x)
|
55
|
+
[Utils::Array.force_array(x[0] * x[1])]
|
56
|
+
end
|
57
|
+
|
58
|
+
def backward(x, gy)
|
59
|
+
[Utils::Array.force_array(gy[0] * x[1]), Utils::Array.force_array(gy[0] * x[0])]
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
class MulConstant < ::Chainer::Function
|
64
|
+
def initialize(value)
|
65
|
+
@value = value
|
66
|
+
end
|
67
|
+
|
68
|
+
def forward(x)
|
69
|
+
[Utils::Array.force_array(@value * x[0])]
|
70
|
+
end
|
71
|
+
|
72
|
+
def backward(x, gy)
|
73
|
+
[Utils::Array.force_array(@value * gy[0])]
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
class Div < ::Chainer::Function
|
78
|
+
def forward(x)
|
79
|
+
[Utils::Array.force_array(x[0] / x[1])]
|
80
|
+
end
|
81
|
+
|
82
|
+
def backward(x, gy)
|
83
|
+
gx0 = Utils::Array.force_array(gy[0] / x[1])
|
84
|
+
[gx0, Utils::Array.force_array(-1 * gx0 * x[0] / x[1])]
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
class PowVarVar < ::Chainer::Function
|
89
|
+
def forward(x)
|
90
|
+
@y = Utils::Array.force_array(x[0] ** x[1])
|
91
|
+
[@y]
|
92
|
+
end
|
93
|
+
|
94
|
+
def backward(x, gy)
|
95
|
+
one = x[1].class.ones[0]
|
96
|
+
gx0 = Utils::Array.force_array(x[1] * (x[0] ** (x[1] - one)) * gy[0])
|
97
|
+
gx1 = Utils::Array.force_array(Numo::NMath.log(x[0]) * @y * gy[0])
|
98
|
+
[gx0, gx1]
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
class PowVarConst < ::Chainer::Function
|
103
|
+
def initialize(value)
|
104
|
+
@value = value
|
105
|
+
end
|
106
|
+
|
107
|
+
def forward(x)
|
108
|
+
[Utils::Array.force_array(x[0] ** @value)]
|
109
|
+
end
|
110
|
+
|
111
|
+
def backward(x, gy)
|
112
|
+
val_1 = @value - 1
|
113
|
+
gx = @value * (x[0] ** val_1) * gy[0]
|
114
|
+
[Utils::Array.force_array(gx)]
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
@@ -0,0 +1,63 @@
|
|
1
|
+
module Chainer
|
2
|
+
class GradientMethod < Chainer::Optimizer
|
3
|
+
def initialize
|
4
|
+
super()
|
5
|
+
@hyperparam = Hyperparameter.new
|
6
|
+
end
|
7
|
+
|
8
|
+
def setup(link)
|
9
|
+
super(link)
|
10
|
+
link.params do |param|
|
11
|
+
param.update_rule = create_update_rule
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
def reallocate_cleared_grads
|
16
|
+
@target.namedparams(include_uninit: false) do |(name, param)|
|
17
|
+
if param.grad.nil?
|
18
|
+
param.grad = Numo::NArray.[](*param.data).new_zeros
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def call_hooks
|
24
|
+
@hooks.each do |hook|
|
25
|
+
_call_hook(hook)
|
26
|
+
reallocate_cleared_grads
|
27
|
+
end
|
28
|
+
end
|
29
|
+
|
30
|
+
def update(lossfun=nil, *args, **kwds)
|
31
|
+
if lossfun
|
32
|
+
use_cleargrads = self.methods.include?(:use_cleargrads) ? self.use_cleargrads : true
|
33
|
+
if args.size > 0 && kwds.keys.size > 0
|
34
|
+
loss = lossfun.(*args, **kwds)
|
35
|
+
elsif args.size > 0
|
36
|
+
loss = lossfun.(*args)
|
37
|
+
elsif kwds.keys.size > 0
|
38
|
+
loss = lossfun.(**kwds)
|
39
|
+
end
|
40
|
+
|
41
|
+
if use_cleargrads
|
42
|
+
@target.cleargrads()
|
43
|
+
else
|
44
|
+
@target.zerograds()
|
45
|
+
end
|
46
|
+
loss.backward()
|
47
|
+
end
|
48
|
+
|
49
|
+
reallocate_cleared_grads
|
50
|
+
|
51
|
+
call_hooks
|
52
|
+
|
53
|
+
@t += 1
|
54
|
+
@target.params do |param|
|
55
|
+
param.update
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def create_update_rule
|
60
|
+
raise NotImplementedError
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Chainer
|
2
|
+
class Hyperparameter
|
3
|
+
attr_reader :parent
|
4
|
+
|
5
|
+
def initialize(parent: nil)
|
6
|
+
@parent = parent
|
7
|
+
end
|
8
|
+
|
9
|
+
def method_missing(name)
|
10
|
+
@parent.instance_variable_get("@#{name}")
|
11
|
+
end
|
12
|
+
|
13
|
+
def get_dict
|
14
|
+
d = @parent.nil? ? {} : @parent.get_dict
|
15
|
+
self.instance_variables.each do |m|
|
16
|
+
unless m == :@parent
|
17
|
+
d[m.to_s.delete('@')] = self.instance_variable_get(m)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
d
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Initializers
|
3
|
+
class Constant < ::Chainer::Initializer
|
4
|
+
def initialize(fill_value, dtype: nil)
|
5
|
+
@fill_value = fill_value
|
6
|
+
super(dtype: dtype)
|
7
|
+
end
|
8
|
+
|
9
|
+
def call(array)
|
10
|
+
if @dtype
|
11
|
+
raise ArgumentError unless array.dtype == @dtype
|
12
|
+
end
|
13
|
+
array.store(@fill_value)
|
14
|
+
array
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,24 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Initializers
|
3
|
+
def self.generate_array(initializer, shape)
|
4
|
+
array = Numo::DFloat.new(shape).rand
|
5
|
+
initializer.(array)
|
6
|
+
end
|
7
|
+
|
8
|
+
def self.get_initializer(initializer)
|
9
|
+
return HeNormal.new(scale: 1 / Numo::NMath.sqrt(2)) if initializer.nil?
|
10
|
+
return Constant.new(initializer) if initializer.kind_of?(Numeric)
|
11
|
+
return Constant.new(initializer) if initializer.kind_of?(Numo::NArray)
|
12
|
+
|
13
|
+
unless initializer.method_defined?(:call)
|
14
|
+
raise TypeError, "invalid type of initializer: #{initializer.class}"
|
15
|
+
end
|
16
|
+
|
17
|
+
return initializer
|
18
|
+
end
|
19
|
+
|
20
|
+
def self.nan(dtype: nil)
|
21
|
+
Constant.new(Float::NAN, dtype: dtype)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Initializers
|
3
|
+
class Normal < ::Chainer::Initializer
|
4
|
+
def initialize(scale: 0.05, dtype: nil)
|
5
|
+
@scale = scale
|
6
|
+
super(dtype: dtype)
|
7
|
+
end
|
8
|
+
|
9
|
+
def call(array)
|
10
|
+
args = { loc: 0.0, scale: @scale, size: array.shape}
|
11
|
+
Numo::DFloat.new(array.shape).rand_norm(0.0, @scale)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
class HeNormal < ::Chainer::Initializer
|
16
|
+
def initialize(scale: 1.0, dtype: nil)
|
17
|
+
@scale = scale
|
18
|
+
super(dtype: dtype)
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(array)
|
22
|
+
fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape)
|
23
|
+
s = @scale * Numo::NMath.sqrt(2.0 / fan_in)
|
24
|
+
Normal.new(scale: s).(array)
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Iterators
|
3
|
+
class SerialIterator < Chainer::Dataset::Iterator
|
4
|
+
attr_reader :epoch
|
5
|
+
|
6
|
+
def initialize(dataset, batch_size, repeat: true, shuffle: true)
|
7
|
+
@dataset = dataset
|
8
|
+
@batch_size = batch_size
|
9
|
+
@repeat = repeat
|
10
|
+
@shuffle = shuffle
|
11
|
+
|
12
|
+
reset
|
13
|
+
end
|
14
|
+
|
15
|
+
def next
|
16
|
+
raise StopIteration if !@repeat && @epoch > 0
|
17
|
+
|
18
|
+
@previous_epoch_detail = @epoch_detail
|
19
|
+
|
20
|
+
i = @current_position
|
21
|
+
i_end = i + @batch_size
|
22
|
+
n = @dataset.size
|
23
|
+
|
24
|
+
batch = @order[i...i_end].to_a.map { |index| @dataset[index] }
|
25
|
+
|
26
|
+
if i_end >= n
|
27
|
+
if @repeat
|
28
|
+
rest = i_end - n
|
29
|
+
unless @order.nil?
|
30
|
+
@order = @order.class[*@order.to_a.shuffle]
|
31
|
+
end
|
32
|
+
if rest > 0
|
33
|
+
if @order.nil?
|
34
|
+
batch = batch.append(@dataset[0...rest])
|
35
|
+
else
|
36
|
+
batch = @dataset[0...rest].map { |index| @dataset[index] }
|
37
|
+
end
|
38
|
+
end
|
39
|
+
@current_position = rest
|
40
|
+
else
|
41
|
+
@current_position = 0
|
42
|
+
end
|
43
|
+
|
44
|
+
@epoch += 1
|
45
|
+
@is_new_epoch = true
|
46
|
+
else
|
47
|
+
@is_new_epoch = false
|
48
|
+
@current_position = i_end
|
49
|
+
end
|
50
|
+
|
51
|
+
batch
|
52
|
+
end
|
53
|
+
|
54
|
+
def epoch_detail
|
55
|
+
@epoch + @current_position.to_f / @dataset.size
|
56
|
+
end
|
57
|
+
|
58
|
+
def reset
|
59
|
+
if @shuffle
|
60
|
+
order = @dataset.size.times.map(&:to_i).shuffle
|
61
|
+
@order = Numo::Int64[*order]
|
62
|
+
else
|
63
|
+
order = @dataset.size.times.map(&:to_i)
|
64
|
+
@order = Numo::Int64[*order]
|
65
|
+
end
|
66
|
+
|
67
|
+
@current_position = 0
|
68
|
+
@epoch = 0
|
69
|
+
@is_new_epoch = false
|
70
|
+
@previous_epoch_detail = -1.0
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
data/lib/chainer/link.rb
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
module Chainer
|
2
|
+
class Link
|
3
|
+
def initialize
|
4
|
+
@params = []
|
5
|
+
@persistent = []
|
6
|
+
@within_init_scope = false
|
7
|
+
@name = nil
|
8
|
+
end
|
9
|
+
|
10
|
+
def within_init_scope
|
11
|
+
@within_init_scope || false
|
12
|
+
end
|
13
|
+
|
14
|
+
def init_scope
|
15
|
+
old_flag = self.within_init_scope
|
16
|
+
@within_init_scope = true
|
17
|
+
|
18
|
+
begin
|
19
|
+
yield
|
20
|
+
set_attr
|
21
|
+
ensure
|
22
|
+
@within_init_scope = old_flag
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def set_attr
|
27
|
+
self.instance_variables.each do |name|
|
28
|
+
value = self.instance_variable_get(name)
|
29
|
+
if value.instance_of?(Chainer::Parameter)
|
30
|
+
@params << name
|
31
|
+
@persistent.delete(name)
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
def cleargrads
|
37
|
+
params do |param|
|
38
|
+
param.cleargrad
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
def params(include_uninit: true)
|
43
|
+
@params.map do |name|
|
44
|
+
data = self.instance_variable_get(name).data
|
45
|
+
if include_uninit || data
|
46
|
+
yield self.instance_variable_get(name)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
|
51
|
+
def namedparams(include_uninit: true)
|
52
|
+
@params.each do |name|
|
53
|
+
if include_uninit || self.instance_variable_get(name).data
|
54
|
+
yield ['/' + name.to_s, self.instance_variable_get(name)]
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def namedlinks(skipself: false)
|
60
|
+
yield('/', self) unless skipself
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
class Chain < Link
|
65
|
+
def initialize
|
66
|
+
super
|
67
|
+
@children = []
|
68
|
+
end
|
69
|
+
|
70
|
+
def set_attr
|
71
|
+
self.instance_variables.each do |name|
|
72
|
+
value = self.instance_variable_get(name)
|
73
|
+
if value.kind_of?(Chainer::Link)
|
74
|
+
@children << name
|
75
|
+
end
|
76
|
+
end
|
77
|
+
super
|
78
|
+
end
|
79
|
+
|
80
|
+
def params(include_uninit: true)
|
81
|
+
super(include_uninit: include_uninit) do |param|
|
82
|
+
yield param
|
83
|
+
end
|
84
|
+
|
85
|
+
@children.each do |name|
|
86
|
+
self.instance_variable_get(name).params(include_uninit: include_uninit) do |param|
|
87
|
+
yield param
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
|
92
|
+
def namedparams(include_uninit: true)
|
93
|
+
super(include_uninit: include_uninit) do |param|
|
94
|
+
yield param
|
95
|
+
end
|
96
|
+
|
97
|
+
@children.each do |name|
|
98
|
+
prefix = "/#{name}"
|
99
|
+
self.instance_variable_get(name).namedparams(include_uninit: include_uninit) do |(path, param)|
|
100
|
+
yield [prefix + path, param]
|
101
|
+
end
|
102
|
+
end
|
103
|
+
end
|
104
|
+
|
105
|
+
def namedlinks(skipself: false)
|
106
|
+
yield('/' , self) unless skipself
|
107
|
+
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
|
108
|
+
@children.each do |name|
|
109
|
+
child = d[name.to_sym]
|
110
|
+
prefix = '/' + name.to_s
|
111
|
+
yield(prefix, child)
|
112
|
+
d[name].namedlinks(skipself: true) do |path, link|
|
113
|
+
yield(prefix + path, link)
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|