red-chainer 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,119 @@
1
+ module Chainer
2
+ module Functions
3
+ module Math
4
+
5
+ class Neg < ::Chainer::Function
6
+ def forward(x)
7
+ retain_inputs([])
8
+ [Utils::Array.force_array(-x[0])]
9
+ end
10
+
11
+ def backward(x, gy)
12
+ [Utils::Array.force_array(-gy[0])]
13
+ end
14
+ end
15
+
16
+ class Add < ::Chainer::Function
17
+ def forward(x)
18
+ retain_inputs([])
19
+ [Utils::Array.force_array(x[0] + x[1])]
20
+ end
21
+
22
+ def backward(x, gy)
23
+ [gy[0], gy[0]]
24
+ end
25
+ end
26
+
27
+ class AddConstant < ::Chainer::Function
28
+ def initialize(value)
29
+ @value = value
30
+ end
31
+
32
+ def forward(x)
33
+ retain_inputs([])
34
+ [Utils::Array.force_array(x[0] + @value)]
35
+ end
36
+
37
+ def backward(x, gy)
38
+ [gy[0]]
39
+ end
40
+ end
41
+
42
+ class Sub < ::Chainer::Function
43
+ def forward(x)
44
+ retain_inputs([])
45
+ [Utils::Array.force_array(x[0] - x[1])]
46
+ end
47
+
48
+ def backward(x, gy)
49
+ [gy[0], Utils::Array.force_array(-gy[0])]
50
+ end
51
+ end
52
+
53
+ class Mul < ::Chainer::Function
54
+ def forward(x)
55
+ [Utils::Array.force_array(x[0] * x[1])]
56
+ end
57
+
58
+ def backward(x, gy)
59
+ [Utils::Array.force_array(gy[0] * x[1]), Utils::Array.force_array(gy[0] * x[0])]
60
+ end
61
+ end
62
+
63
+ class MulConstant < ::Chainer::Function
64
+ def initialize(value)
65
+ @value = value
66
+ end
67
+
68
+ def forward(x)
69
+ [Utils::Array.force_array(@value * x[0])]
70
+ end
71
+
72
+ def backward(x, gy)
73
+ [Utils::Array.force_array(@value * gy[0])]
74
+ end
75
+ end
76
+
77
+ class Div < ::Chainer::Function
78
+ def forward(x)
79
+ [Utils::Array.force_array(x[0] / x[1])]
80
+ end
81
+
82
+ def backward(x, gy)
83
+ gx0 = Utils::Array.force_array(gy[0] / x[1])
84
+ [gx0, Utils::Array.force_array(-1 * gx0 * x[0] / x[1])]
85
+ end
86
+ end
87
+
88
+ class PowVarVar < ::Chainer::Function
89
+ def forward(x)
90
+ @y = Utils::Array.force_array(x[0] ** x[1])
91
+ [@y]
92
+ end
93
+
94
+ def backward(x, gy)
95
+ one = x[1].class.ones[0]
96
+ gx0 = Utils::Array.force_array(x[1] * (x[0] ** (x[1] - one)) * gy[0])
97
+ gx1 = Utils::Array.force_array(Numo::NMath.log(x[0]) * @y * gy[0])
98
+ [gx0, gx1]
99
+ end
100
+ end
101
+
102
+ class PowVarConst < ::Chainer::Function
103
+ def initialize(value)
104
+ @value = value
105
+ end
106
+
107
+ def forward(x)
108
+ [Utils::Array.force_array(x[0] ** @value)]
109
+ end
110
+
111
+ def backward(x, gy)
112
+ val_1 = @value - 1
113
+ gx = @value * (x[0] ** val_1) * gy[0]
114
+ [Utils::Array.force_array(gx)]
115
+ end
116
+ end
117
+ end
118
+ end
119
+ end
@@ -0,0 +1,63 @@
1
+ module Chainer
2
+ class GradientMethod < Chainer::Optimizer
3
+ def initialize
4
+ super()
5
+ @hyperparam = Hyperparameter.new
6
+ end
7
+
8
+ def setup(link)
9
+ super(link)
10
+ link.params do |param|
11
+ param.update_rule = create_update_rule
12
+ end
13
+ end
14
+
15
+ def reallocate_cleared_grads
16
+ @target.namedparams(include_uninit: false) do |(name, param)|
17
+ if param.grad.nil?
18
+ param.grad = Numo::NArray.[](*param.data).new_zeros
19
+ end
20
+ end
21
+ end
22
+
23
+ def call_hooks
24
+ @hooks.each do |hook|
25
+ _call_hook(hook)
26
+ reallocate_cleared_grads
27
+ end
28
+ end
29
+
30
+ def update(lossfun=nil, *args, **kwds)
31
+ if lossfun
32
+ use_cleargrads = self.methods.include?(:use_cleargrads) ? self.use_cleargrads : true
33
+ if args.size > 0 && kwds.keys.size > 0
34
+ loss = lossfun.(*args, **kwds)
35
+ elsif args.size > 0
36
+ loss = lossfun.(*args)
37
+ elsif kwds.keys.size > 0
38
+ loss = lossfun.(**kwds)
39
+ end
40
+
41
+ if use_cleargrads
42
+ @target.cleargrads()
43
+ else
44
+ @target.zerograds()
45
+ end
46
+ loss.backward()
47
+ end
48
+
49
+ reallocate_cleared_grads
50
+
51
+ call_hooks
52
+
53
+ @t += 1
54
+ @target.params do |param|
55
+ param.update
56
+ end
57
+ end
58
+
59
+ def create_update_rule
60
+ raise NotImplementedError
61
+ end
62
+ end
63
+ end
@@ -0,0 +1,23 @@
1
+ module Chainer
2
+ class Hyperparameter
3
+ attr_reader :parent
4
+
5
+ def initialize(parent: nil)
6
+ @parent = parent
7
+ end
8
+
9
+ def method_missing(name)
10
+ @parent.instance_variable_get("@#{name}")
11
+ end
12
+
13
+ def get_dict
14
+ d = @parent.nil? ? {} : @parent.get_dict
15
+ self.instance_variables.each do |m|
16
+ unless m == :@parent
17
+ d[m.to_s.delete('@')] = self.instance_variable_get(m)
18
+ end
19
+ end
20
+ d
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,12 @@
1
+ module Chainer
2
+ class Initializer
3
+ def initialize(dtype: nil)
4
+ @dtype = dtype
5
+ end
6
+
7
+ def call(array)
8
+ raise NotImplementedError
9
+ end
10
+ end
11
+ end
12
+
@@ -0,0 +1,18 @@
1
+ module Chainer
2
+ module Initializers
3
+ class Constant < ::Chainer::Initializer
4
+ def initialize(fill_value, dtype: nil)
5
+ @fill_value = fill_value
6
+ super(dtype: dtype)
7
+ end
8
+
9
+ def call(array)
10
+ if @dtype
11
+ raise ArgumentError unless array.dtype == @dtype
12
+ end
13
+ array.store(@fill_value)
14
+ array
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,24 @@
1
+ module Chainer
2
+ module Initializers
3
+ def self.generate_array(initializer, shape)
4
+ array = Numo::DFloat.new(shape).rand
5
+ initializer.(array)
6
+ end
7
+
8
+ def self.get_initializer(initializer)
9
+ return HeNormal.new(scale: 1 / Numo::NMath.sqrt(2)) if initializer.nil?
10
+ return Constant.new(initializer) if initializer.kind_of?(Numeric)
11
+ return Constant.new(initializer) if initializer.kind_of?(Numo::NArray)
12
+
13
+ unless initializer.method_defined?(:call)
14
+ raise TypeError, "invalid type of initializer: #{initializer.class}"
15
+ end
16
+
17
+ return initializer
18
+ end
19
+
20
+ def self.nan(dtype: nil)
21
+ Constant.new(Float::NAN, dtype: dtype)
22
+ end
23
+ end
24
+ end
@@ -0,0 +1,28 @@
1
+ module Chainer
2
+ module Initializers
3
+ class Normal < ::Chainer::Initializer
4
+ def initialize(scale: 0.05, dtype: nil)
5
+ @scale = scale
6
+ super(dtype: dtype)
7
+ end
8
+
9
+ def call(array)
10
+ args = { loc: 0.0, scale: @scale, size: array.shape}
11
+ Numo::DFloat.new(array.shape).rand_norm(0.0, @scale)
12
+ end
13
+ end
14
+
15
+ class HeNormal < ::Chainer::Initializer
16
+ def initialize(scale: 1.0, dtype: nil)
17
+ @scale = scale
18
+ super(dtype: dtype)
19
+ end
20
+
21
+ def call(array)
22
+ fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape)
23
+ s = @scale * Numo::NMath.sqrt(2.0 / fan_in)
24
+ Normal.new(scale: s).(array)
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,74 @@
1
+ module Chainer
2
+ module Iterators
3
+ class SerialIterator < Chainer::Dataset::Iterator
4
+ attr_reader :epoch
5
+
6
+ def initialize(dataset, batch_size, repeat: true, shuffle: true)
7
+ @dataset = dataset
8
+ @batch_size = batch_size
9
+ @repeat = repeat
10
+ @shuffle = shuffle
11
+
12
+ reset
13
+ end
14
+
15
+ def next
16
+ raise StopIteration if !@repeat && @epoch > 0
17
+
18
+ @previous_epoch_detail = @epoch_detail
19
+
20
+ i = @current_position
21
+ i_end = i + @batch_size
22
+ n = @dataset.size
23
+
24
+ batch = @order[i...i_end].to_a.map { |index| @dataset[index] }
25
+
26
+ if i_end >= n
27
+ if @repeat
28
+ rest = i_end - n
29
+ unless @order.nil?
30
+ @order = @order.class[*@order.to_a.shuffle]
31
+ end
32
+ if rest > 0
33
+ if @order.nil?
34
+ batch = batch.append(@dataset[0...rest])
35
+ else
36
+ batch = @dataset[0...rest].map { |index| @dataset[index] }
37
+ end
38
+ end
39
+ @current_position = rest
40
+ else
41
+ @current_position = 0
42
+ end
43
+
44
+ @epoch += 1
45
+ @is_new_epoch = true
46
+ else
47
+ @is_new_epoch = false
48
+ @current_position = i_end
49
+ end
50
+
51
+ batch
52
+ end
53
+
54
+ def epoch_detail
55
+ @epoch + @current_position.to_f / @dataset.size
56
+ end
57
+
58
+ def reset
59
+ if @shuffle
60
+ order = @dataset.size.times.map(&:to_i).shuffle
61
+ @order = Numo::Int64[*order]
62
+ else
63
+ order = @dataset.size.times.map(&:to_i)
64
+ @order = Numo::Int64[*order]
65
+ end
66
+
67
+ @current_position = 0
68
+ @epoch = 0
69
+ @is_new_epoch = false
70
+ @previous_epoch_detail = -1.0
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,118 @@
1
+ module Chainer
2
+ class Link
3
+ def initialize
4
+ @params = []
5
+ @persistent = []
6
+ @within_init_scope = false
7
+ @name = nil
8
+ end
9
+
10
+ def within_init_scope
11
+ @within_init_scope || false
12
+ end
13
+
14
+ def init_scope
15
+ old_flag = self.within_init_scope
16
+ @within_init_scope = true
17
+
18
+ begin
19
+ yield
20
+ set_attr
21
+ ensure
22
+ @within_init_scope = old_flag
23
+ end
24
+ end
25
+
26
+ def set_attr
27
+ self.instance_variables.each do |name|
28
+ value = self.instance_variable_get(name)
29
+ if value.instance_of?(Chainer::Parameter)
30
+ @params << name
31
+ @persistent.delete(name)
32
+ end
33
+ end
34
+ end
35
+
36
+ def cleargrads
37
+ params do |param|
38
+ param.cleargrad
39
+ end
40
+ end
41
+
42
+ def params(include_uninit: true)
43
+ @params.map do |name|
44
+ data = self.instance_variable_get(name).data
45
+ if include_uninit || data
46
+ yield self.instance_variable_get(name)
47
+ end
48
+ end
49
+ end
50
+
51
+ def namedparams(include_uninit: true)
52
+ @params.each do |name|
53
+ if include_uninit || self.instance_variable_get(name).data
54
+ yield ['/' + name.to_s, self.instance_variable_get(name)]
55
+ end
56
+ end
57
+ end
58
+
59
+ def namedlinks(skipself: false)
60
+ yield('/', self) unless skipself
61
+ end
62
+ end
63
+
64
+ class Chain < Link
65
+ def initialize
66
+ super
67
+ @children = []
68
+ end
69
+
70
+ def set_attr
71
+ self.instance_variables.each do |name|
72
+ value = self.instance_variable_get(name)
73
+ if value.kind_of?(Chainer::Link)
74
+ @children << name
75
+ end
76
+ end
77
+ super
78
+ end
79
+
80
+ def params(include_uninit: true)
81
+ super(include_uninit: include_uninit) do |param|
82
+ yield param
83
+ end
84
+
85
+ @children.each do |name|
86
+ self.instance_variable_get(name).params(include_uninit: include_uninit) do |param|
87
+ yield param
88
+ end
89
+ end
90
+ end
91
+
92
+ def namedparams(include_uninit: true)
93
+ super(include_uninit: include_uninit) do |param|
94
+ yield param
95
+ end
96
+
97
+ @children.each do |name|
98
+ prefix = "/#{name}"
99
+ self.instance_variable_get(name).namedparams(include_uninit: include_uninit) do |(path, param)|
100
+ yield [prefix + path, param]
101
+ end
102
+ end
103
+ end
104
+
105
+ def namedlinks(skipself: false)
106
+ yield('/' , self) unless skipself
107
+ d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
108
+ @children.each do |name|
109
+ child = d[name.to_sym]
110
+ prefix = '/' + name.to_s
111
+ yield(prefix, child)
112
+ d[name].namedlinks(skipself: true) do |path, link|
113
+ yield(prefix + path, link)
114
+ end
115
+ end
116
+ end
117
+ end
118
+ end