red-chainer 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +12 -0
- data/.rspec +2 -0
- data/.travis.yml +5 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +23 -0
- data/README.md +60 -0
- data/Rakefile +8 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/examples/mnist.rb +42 -0
- data/lib/chainer.rb +59 -0
- data/lib/chainer/configuration.rb +10 -0
- data/lib/chainer/dataset/convert.rb +62 -0
- data/lib/chainer/dataset/download.rb +56 -0
- data/lib/chainer/dataset/iterator.rb +15 -0
- data/lib/chainer/datasets/mnist.rb +89 -0
- data/lib/chainer/datasets/tuple_dataset.rb +33 -0
- data/lib/chainer/function.rb +80 -0
- data/lib/chainer/functions/activation/log_softmax.rb +37 -0
- data/lib/chainer/functions/activation/relu.rb +23 -0
- data/lib/chainer/functions/connection/linear.rb +48 -0
- data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
- data/lib/chainer/functions/math/basic_math.rb +119 -0
- data/lib/chainer/gradient_method.rb +63 -0
- data/lib/chainer/hyperparameter.rb +23 -0
- data/lib/chainer/initializer.rb +12 -0
- data/lib/chainer/initializers/constant.rb +18 -0
- data/lib/chainer/initializers/init.rb +24 -0
- data/lib/chainer/initializers/normal.rb +28 -0
- data/lib/chainer/iterators/serial_iterator.rb +74 -0
- data/lib/chainer/link.rb +118 -0
- data/lib/chainer/links/connection/linear.rb +43 -0
- data/lib/chainer/links/model/classifier.rb +39 -0
- data/lib/chainer/optimizer.rb +69 -0
- data/lib/chainer/optimizers/adam.rb +62 -0
- data/lib/chainer/parameter.rb +53 -0
- data/lib/chainer/reporter.rb +130 -0
- data/lib/chainer/training/extension.rb +25 -0
- data/lib/chainer/training/extensions/evaluator.rb +26 -0
- data/lib/chainer/training/extensions/log_report.rb +72 -0
- data/lib/chainer/training/extensions/print_report.rb +62 -0
- data/lib/chainer/training/extensions/progress_bar.rb +89 -0
- data/lib/chainer/training/standard_updater.rb +63 -0
- data/lib/chainer/training/trainer.rb +136 -0
- data/lib/chainer/training/triggers/interval.rb +27 -0
- data/lib/chainer/training/updater.rb +33 -0
- data/lib/chainer/training/util.rb +13 -0
- data/lib/chainer/utils/array.rb +10 -0
- data/lib/chainer/utils/initializer.rb +14 -0
- data/lib/chainer/utils/variable.rb +20 -0
- data/lib/chainer/variable.rb +204 -0
- data/lib/chainer/variable_node.rb +71 -0
- data/lib/chainer/version.rb +4 -0
- data/red-chainer.gemspec +27 -0
- metadata +156 -0
@@ -0,0 +1,119 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Math
|
4
|
+
|
5
|
+
class Neg < ::Chainer::Function
|
6
|
+
def forward(x)
|
7
|
+
retain_inputs([])
|
8
|
+
[Utils::Array.force_array(-x[0])]
|
9
|
+
end
|
10
|
+
|
11
|
+
def backward(x, gy)
|
12
|
+
[Utils::Array.force_array(-gy[0])]
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
class Add < ::Chainer::Function
|
17
|
+
def forward(x)
|
18
|
+
retain_inputs([])
|
19
|
+
[Utils::Array.force_array(x[0] + x[1])]
|
20
|
+
end
|
21
|
+
|
22
|
+
def backward(x, gy)
|
23
|
+
[gy[0], gy[0]]
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
class AddConstant < ::Chainer::Function
|
28
|
+
def initialize(value)
|
29
|
+
@value = value
|
30
|
+
end
|
31
|
+
|
32
|
+
def forward(x)
|
33
|
+
retain_inputs([])
|
34
|
+
[Utils::Array.force_array(x[0] + @value)]
|
35
|
+
end
|
36
|
+
|
37
|
+
def backward(x, gy)
|
38
|
+
[gy[0]]
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
class Sub < ::Chainer::Function
|
43
|
+
def forward(x)
|
44
|
+
retain_inputs([])
|
45
|
+
[Utils::Array.force_array(x[0] - x[1])]
|
46
|
+
end
|
47
|
+
|
48
|
+
def backward(x, gy)
|
49
|
+
[gy[0], Utils::Array.force_array(-gy[0])]
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
class Mul < ::Chainer::Function
|
54
|
+
def forward(x)
|
55
|
+
[Utils::Array.force_array(x[0] * x[1])]
|
56
|
+
end
|
57
|
+
|
58
|
+
def backward(x, gy)
|
59
|
+
[Utils::Array.force_array(gy[0] * x[1]), Utils::Array.force_array(gy[0] * x[0])]
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
class MulConstant < ::Chainer::Function
|
64
|
+
def initialize(value)
|
65
|
+
@value = value
|
66
|
+
end
|
67
|
+
|
68
|
+
def forward(x)
|
69
|
+
[Utils::Array.force_array(@value * x[0])]
|
70
|
+
end
|
71
|
+
|
72
|
+
def backward(x, gy)
|
73
|
+
[Utils::Array.force_array(@value * gy[0])]
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
class Div < ::Chainer::Function
|
78
|
+
def forward(x)
|
79
|
+
[Utils::Array.force_array(x[0] / x[1])]
|
80
|
+
end
|
81
|
+
|
82
|
+
def backward(x, gy)
|
83
|
+
gx0 = Utils::Array.force_array(gy[0] / x[1])
|
84
|
+
[gx0, Utils::Array.force_array(-1 * gx0 * x[0] / x[1])]
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
class PowVarVar < ::Chainer::Function
|
89
|
+
def forward(x)
|
90
|
+
@y = Utils::Array.force_array(x[0] ** x[1])
|
91
|
+
[@y]
|
92
|
+
end
|
93
|
+
|
94
|
+
def backward(x, gy)
|
95
|
+
one = x[1].class.ones[0]
|
96
|
+
gx0 = Utils::Array.force_array(x[1] * (x[0] ** (x[1] - one)) * gy[0])
|
97
|
+
gx1 = Utils::Array.force_array(Numo::NMath.log(x[0]) * @y * gy[0])
|
98
|
+
[gx0, gx1]
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
class PowVarConst < ::Chainer::Function
|
103
|
+
def initialize(value)
|
104
|
+
@value = value
|
105
|
+
end
|
106
|
+
|
107
|
+
def forward(x)
|
108
|
+
[Utils::Array.force_array(x[0] ** @value)]
|
109
|
+
end
|
110
|
+
|
111
|
+
def backward(x, gy)
|
112
|
+
val_1 = @value - 1
|
113
|
+
gx = @value * (x[0] ** val_1) * gy[0]
|
114
|
+
[Utils::Array.force_array(gx)]
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
@@ -0,0 +1,63 @@
|
|
1
|
+
module Chainer
|
2
|
+
class GradientMethod < Chainer::Optimizer
|
3
|
+
def initialize
|
4
|
+
super()
|
5
|
+
@hyperparam = Hyperparameter.new
|
6
|
+
end
|
7
|
+
|
8
|
+
def setup(link)
|
9
|
+
super(link)
|
10
|
+
link.params do |param|
|
11
|
+
param.update_rule = create_update_rule
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
def reallocate_cleared_grads
|
16
|
+
@target.namedparams(include_uninit: false) do |(name, param)|
|
17
|
+
if param.grad.nil?
|
18
|
+
param.grad = Numo::NArray.[](*param.data).new_zeros
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def call_hooks
|
24
|
+
@hooks.each do |hook|
|
25
|
+
_call_hook(hook)
|
26
|
+
reallocate_cleared_grads
|
27
|
+
end
|
28
|
+
end
|
29
|
+
|
30
|
+
def update(lossfun=nil, *args, **kwds)
|
31
|
+
if lossfun
|
32
|
+
use_cleargrads = self.methods.include?(:use_cleargrads) ? self.use_cleargrads : true
|
33
|
+
if args.size > 0 && kwds.keys.size > 0
|
34
|
+
loss = lossfun.(*args, **kwds)
|
35
|
+
elsif args.size > 0
|
36
|
+
loss = lossfun.(*args)
|
37
|
+
elsif kwds.keys.size > 0
|
38
|
+
loss = lossfun.(**kwds)
|
39
|
+
end
|
40
|
+
|
41
|
+
if use_cleargrads
|
42
|
+
@target.cleargrads()
|
43
|
+
else
|
44
|
+
@target.zerograds()
|
45
|
+
end
|
46
|
+
loss.backward()
|
47
|
+
end
|
48
|
+
|
49
|
+
reallocate_cleared_grads
|
50
|
+
|
51
|
+
call_hooks
|
52
|
+
|
53
|
+
@t += 1
|
54
|
+
@target.params do |param|
|
55
|
+
param.update
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def create_update_rule
|
60
|
+
raise NotImplementedError
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Chainer
|
2
|
+
class Hyperparameter
|
3
|
+
attr_reader :parent
|
4
|
+
|
5
|
+
def initialize(parent: nil)
|
6
|
+
@parent = parent
|
7
|
+
end
|
8
|
+
|
9
|
+
def method_missing(name)
|
10
|
+
@parent.instance_variable_get("@#{name}")
|
11
|
+
end
|
12
|
+
|
13
|
+
def get_dict
|
14
|
+
d = @parent.nil? ? {} : @parent.get_dict
|
15
|
+
self.instance_variables.each do |m|
|
16
|
+
unless m == :@parent
|
17
|
+
d[m.to_s.delete('@')] = self.instance_variable_get(m)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
d
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Initializers
|
3
|
+
class Constant < ::Chainer::Initializer
|
4
|
+
def initialize(fill_value, dtype: nil)
|
5
|
+
@fill_value = fill_value
|
6
|
+
super(dtype: dtype)
|
7
|
+
end
|
8
|
+
|
9
|
+
def call(array)
|
10
|
+
if @dtype
|
11
|
+
raise ArgumentError unless array.dtype == @dtype
|
12
|
+
end
|
13
|
+
array.store(@fill_value)
|
14
|
+
array
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,24 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Initializers
|
3
|
+
def self.generate_array(initializer, shape)
|
4
|
+
array = Numo::DFloat.new(shape).rand
|
5
|
+
initializer.(array)
|
6
|
+
end
|
7
|
+
|
8
|
+
def self.get_initializer(initializer)
|
9
|
+
return HeNormal.new(scale: 1 / Numo::NMath.sqrt(2)) if initializer.nil?
|
10
|
+
return Constant.new(initializer) if initializer.kind_of?(Numeric)
|
11
|
+
return Constant.new(initializer) if initializer.kind_of?(Numo::NArray)
|
12
|
+
|
13
|
+
unless initializer.method_defined?(:call)
|
14
|
+
raise TypeError, "invalid type of initializer: #{initializer.class}"
|
15
|
+
end
|
16
|
+
|
17
|
+
return initializer
|
18
|
+
end
|
19
|
+
|
20
|
+
def self.nan(dtype: nil)
|
21
|
+
Constant.new(Float::NAN, dtype: dtype)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Initializers
|
3
|
+
class Normal < ::Chainer::Initializer
|
4
|
+
def initialize(scale: 0.05, dtype: nil)
|
5
|
+
@scale = scale
|
6
|
+
super(dtype: dtype)
|
7
|
+
end
|
8
|
+
|
9
|
+
def call(array)
|
10
|
+
args = { loc: 0.0, scale: @scale, size: array.shape}
|
11
|
+
Numo::DFloat.new(array.shape).rand_norm(0.0, @scale)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
|
15
|
+
class HeNormal < ::Chainer::Initializer
|
16
|
+
def initialize(scale: 1.0, dtype: nil)
|
17
|
+
@scale = scale
|
18
|
+
super(dtype: dtype)
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(array)
|
22
|
+
fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape)
|
23
|
+
s = @scale * Numo::NMath.sqrt(2.0 / fan_in)
|
24
|
+
Normal.new(scale: s).(array)
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Iterators
|
3
|
+
class SerialIterator < Chainer::Dataset::Iterator
|
4
|
+
attr_reader :epoch
|
5
|
+
|
6
|
+
def initialize(dataset, batch_size, repeat: true, shuffle: true)
|
7
|
+
@dataset = dataset
|
8
|
+
@batch_size = batch_size
|
9
|
+
@repeat = repeat
|
10
|
+
@shuffle = shuffle
|
11
|
+
|
12
|
+
reset
|
13
|
+
end
|
14
|
+
|
15
|
+
def next
|
16
|
+
raise StopIteration if !@repeat && @epoch > 0
|
17
|
+
|
18
|
+
@previous_epoch_detail = @epoch_detail
|
19
|
+
|
20
|
+
i = @current_position
|
21
|
+
i_end = i + @batch_size
|
22
|
+
n = @dataset.size
|
23
|
+
|
24
|
+
batch = @order[i...i_end].to_a.map { |index| @dataset[index] }
|
25
|
+
|
26
|
+
if i_end >= n
|
27
|
+
if @repeat
|
28
|
+
rest = i_end - n
|
29
|
+
unless @order.nil?
|
30
|
+
@order = @order.class[*@order.to_a.shuffle]
|
31
|
+
end
|
32
|
+
if rest > 0
|
33
|
+
if @order.nil?
|
34
|
+
batch = batch.append(@dataset[0...rest])
|
35
|
+
else
|
36
|
+
batch = @dataset[0...rest].map { |index| @dataset[index] }
|
37
|
+
end
|
38
|
+
end
|
39
|
+
@current_position = rest
|
40
|
+
else
|
41
|
+
@current_position = 0
|
42
|
+
end
|
43
|
+
|
44
|
+
@epoch += 1
|
45
|
+
@is_new_epoch = true
|
46
|
+
else
|
47
|
+
@is_new_epoch = false
|
48
|
+
@current_position = i_end
|
49
|
+
end
|
50
|
+
|
51
|
+
batch
|
52
|
+
end
|
53
|
+
|
54
|
+
def epoch_detail
|
55
|
+
@epoch + @current_position.to_f / @dataset.size
|
56
|
+
end
|
57
|
+
|
58
|
+
def reset
|
59
|
+
if @shuffle
|
60
|
+
order = @dataset.size.times.map(&:to_i).shuffle
|
61
|
+
@order = Numo::Int64[*order]
|
62
|
+
else
|
63
|
+
order = @dataset.size.times.map(&:to_i)
|
64
|
+
@order = Numo::Int64[*order]
|
65
|
+
end
|
66
|
+
|
67
|
+
@current_position = 0
|
68
|
+
@epoch = 0
|
69
|
+
@is_new_epoch = false
|
70
|
+
@previous_epoch_detail = -1.0
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
data/lib/chainer/link.rb
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
module Chainer
|
2
|
+
class Link
|
3
|
+
def initialize
|
4
|
+
@params = []
|
5
|
+
@persistent = []
|
6
|
+
@within_init_scope = false
|
7
|
+
@name = nil
|
8
|
+
end
|
9
|
+
|
10
|
+
def within_init_scope
|
11
|
+
@within_init_scope || false
|
12
|
+
end
|
13
|
+
|
14
|
+
def init_scope
|
15
|
+
old_flag = self.within_init_scope
|
16
|
+
@within_init_scope = true
|
17
|
+
|
18
|
+
begin
|
19
|
+
yield
|
20
|
+
set_attr
|
21
|
+
ensure
|
22
|
+
@within_init_scope = old_flag
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def set_attr
|
27
|
+
self.instance_variables.each do |name|
|
28
|
+
value = self.instance_variable_get(name)
|
29
|
+
if value.instance_of?(Chainer::Parameter)
|
30
|
+
@params << name
|
31
|
+
@persistent.delete(name)
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
def cleargrads
|
37
|
+
params do |param|
|
38
|
+
param.cleargrad
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
def params(include_uninit: true)
|
43
|
+
@params.map do |name|
|
44
|
+
data = self.instance_variable_get(name).data
|
45
|
+
if include_uninit || data
|
46
|
+
yield self.instance_variable_get(name)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
|
51
|
+
def namedparams(include_uninit: true)
|
52
|
+
@params.each do |name|
|
53
|
+
if include_uninit || self.instance_variable_get(name).data
|
54
|
+
yield ['/' + name.to_s, self.instance_variable_get(name)]
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def namedlinks(skipself: false)
|
60
|
+
yield('/', self) unless skipself
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
class Chain < Link
|
65
|
+
def initialize
|
66
|
+
super
|
67
|
+
@children = []
|
68
|
+
end
|
69
|
+
|
70
|
+
def set_attr
|
71
|
+
self.instance_variables.each do |name|
|
72
|
+
value = self.instance_variable_get(name)
|
73
|
+
if value.kind_of?(Chainer::Link)
|
74
|
+
@children << name
|
75
|
+
end
|
76
|
+
end
|
77
|
+
super
|
78
|
+
end
|
79
|
+
|
80
|
+
def params(include_uninit: true)
|
81
|
+
super(include_uninit: include_uninit) do |param|
|
82
|
+
yield param
|
83
|
+
end
|
84
|
+
|
85
|
+
@children.each do |name|
|
86
|
+
self.instance_variable_get(name).params(include_uninit: include_uninit) do |param|
|
87
|
+
yield param
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
|
92
|
+
def namedparams(include_uninit: true)
|
93
|
+
super(include_uninit: include_uninit) do |param|
|
94
|
+
yield param
|
95
|
+
end
|
96
|
+
|
97
|
+
@children.each do |name|
|
98
|
+
prefix = "/#{name}"
|
99
|
+
self.instance_variable_get(name).namedparams(include_uninit: include_uninit) do |(path, param)|
|
100
|
+
yield [prefix + path, param]
|
101
|
+
end
|
102
|
+
end
|
103
|
+
end
|
104
|
+
|
105
|
+
def namedlinks(skipself: false)
|
106
|
+
yield('/' , self) unless skipself
|
107
|
+
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
|
108
|
+
@children.each do |name|
|
109
|
+
child = d[name.to_sym]
|
110
|
+
prefix = '/' + name.to_s
|
111
|
+
yield(prefix, child)
|
112
|
+
d[name].namedlinks(skipself: true) do |path, link|
|
113
|
+
yield(prefix + path, link)
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|