red-chainer 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,119 @@
1
+ module Chainer
2
+ module Functions
3
+ module Math
4
+
5
+ class Neg < ::Chainer::Function
6
+ def forward(x)
7
+ retain_inputs([])
8
+ [Utils::Array.force_array(-x[0])]
9
+ end
10
+
11
+ def backward(x, gy)
12
+ [Utils::Array.force_array(-gy[0])]
13
+ end
14
+ end
15
+
16
+ class Add < ::Chainer::Function
17
+ def forward(x)
18
+ retain_inputs([])
19
+ [Utils::Array.force_array(x[0] + x[1])]
20
+ end
21
+
22
+ def backward(x, gy)
23
+ [gy[0], gy[0]]
24
+ end
25
+ end
26
+
27
+ class AddConstant < ::Chainer::Function
28
+ def initialize(value)
29
+ @value = value
30
+ end
31
+
32
+ def forward(x)
33
+ retain_inputs([])
34
+ [Utils::Array.force_array(x[0] + @value)]
35
+ end
36
+
37
+ def backward(x, gy)
38
+ [gy[0]]
39
+ end
40
+ end
41
+
42
+ class Sub < ::Chainer::Function
43
+ def forward(x)
44
+ retain_inputs([])
45
+ [Utils::Array.force_array(x[0] - x[1])]
46
+ end
47
+
48
+ def backward(x, gy)
49
+ [gy[0], Utils::Array.force_array(-gy[0])]
50
+ end
51
+ end
52
+
53
+ class Mul < ::Chainer::Function
54
+ def forward(x)
55
+ [Utils::Array.force_array(x[0] * x[1])]
56
+ end
57
+
58
+ def backward(x, gy)
59
+ [Utils::Array.force_array(gy[0] * x[1]), Utils::Array.force_array(gy[0] * x[0])]
60
+ end
61
+ end
62
+
63
+ class MulConstant < ::Chainer::Function
64
+ def initialize(value)
65
+ @value = value
66
+ end
67
+
68
+ def forward(x)
69
+ [Utils::Array.force_array(@value * x[0])]
70
+ end
71
+
72
+ def backward(x, gy)
73
+ [Utils::Array.force_array(@value * gy[0])]
74
+ end
75
+ end
76
+
77
+ class Div < ::Chainer::Function
78
+ def forward(x)
79
+ [Utils::Array.force_array(x[0] / x[1])]
80
+ end
81
+
82
+ def backward(x, gy)
83
+ gx0 = Utils::Array.force_array(gy[0] / x[1])
84
+ [gx0, Utils::Array.force_array(-1 * gx0 * x[0] / x[1])]
85
+ end
86
+ end
87
+
88
+ class PowVarVar < ::Chainer::Function
89
+ def forward(x)
90
+ @y = Utils::Array.force_array(x[0] ** x[1])
91
+ [@y]
92
+ end
93
+
94
+ def backward(x, gy)
95
+ one = x[1].class.ones[0]
96
+ gx0 = Utils::Array.force_array(x[1] * (x[0] ** (x[1] - one)) * gy[0])
97
+ gx1 = Utils::Array.force_array(Numo::NMath.log(x[0]) * @y * gy[0])
98
+ [gx0, gx1]
99
+ end
100
+ end
101
+
102
+ class PowVarConst < ::Chainer::Function
103
+ def initialize(value)
104
+ @value = value
105
+ end
106
+
107
+ def forward(x)
108
+ [Utils::Array.force_array(x[0] ** @value)]
109
+ end
110
+
111
+ def backward(x, gy)
112
+ val_1 = @value - 1
113
+ gx = @value * (x[0] ** val_1) * gy[0]
114
+ [Utils::Array.force_array(gx)]
115
+ end
116
+ end
117
+ end
118
+ end
119
+ end
@@ -0,0 +1,63 @@
1
+ module Chainer
2
+ class GradientMethod < Chainer::Optimizer
3
+ def initialize
4
+ super()
5
+ @hyperparam = Hyperparameter.new
6
+ end
7
+
8
+ def setup(link)
9
+ super(link)
10
+ link.params do |param|
11
+ param.update_rule = create_update_rule
12
+ end
13
+ end
14
+
15
+ def reallocate_cleared_grads
16
+ @target.namedparams(include_uninit: false) do |(name, param)|
17
+ if param.grad.nil?
18
+ param.grad = Numo::NArray.[](*param.data).new_zeros
19
+ end
20
+ end
21
+ end
22
+
23
+ def call_hooks
24
+ @hooks.each do |hook|
25
+ _call_hook(hook)
26
+ reallocate_cleared_grads
27
+ end
28
+ end
29
+
30
+ def update(lossfun=nil, *args, **kwds)
31
+ if lossfun
32
+ use_cleargrads = self.methods.include?(:use_cleargrads) ? self.use_cleargrads : true
33
+ if args.size > 0 && kwds.keys.size > 0
34
+ loss = lossfun.(*args, **kwds)
35
+ elsif args.size > 0
36
+ loss = lossfun.(*args)
37
+ elsif kwds.keys.size > 0
38
+ loss = lossfun.(**kwds)
39
+ end
40
+
41
+ if use_cleargrads
42
+ @target.cleargrads()
43
+ else
44
+ @target.zerograds()
45
+ end
46
+ loss.backward()
47
+ end
48
+
49
+ reallocate_cleared_grads
50
+
51
+ call_hooks
52
+
53
+ @t += 1
54
+ @target.params do |param|
55
+ param.update
56
+ end
57
+ end
58
+
59
+ def create_update_rule
60
+ raise NotImplementedError
61
+ end
62
+ end
63
+ end
@@ -0,0 +1,23 @@
1
+ module Chainer
2
+ class Hyperparameter
3
+ attr_reader :parent
4
+
5
+ def initialize(parent: nil)
6
+ @parent = parent
7
+ end
8
+
9
+ def method_missing(name)
10
+ @parent.instance_variable_get("@#{name}")
11
+ end
12
+
13
+ def get_dict
14
+ d = @parent.nil? ? {} : @parent.get_dict
15
+ self.instance_variables.each do |m|
16
+ unless m == :@parent
17
+ d[m.to_s.delete('@')] = self.instance_variable_get(m)
18
+ end
19
+ end
20
+ d
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,12 @@
1
+ module Chainer
2
+ class Initializer
3
+ def initialize(dtype: nil)
4
+ @dtype = dtype
5
+ end
6
+
7
+ def call(array)
8
+ raise NotImplementedError
9
+ end
10
+ end
11
+ end
12
+
@@ -0,0 +1,18 @@
1
+ module Chainer
2
+ module Initializers
3
+ class Constant < ::Chainer::Initializer
4
+ def initialize(fill_value, dtype: nil)
5
+ @fill_value = fill_value
6
+ super(dtype: dtype)
7
+ end
8
+
9
+ def call(array)
10
+ if @dtype
11
+ raise ArgumentError unless array.dtype == @dtype
12
+ end
13
+ array.store(@fill_value)
14
+ array
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,24 @@
1
+ module Chainer
2
+ module Initializers
3
+ def self.generate_array(initializer, shape)
4
+ array = Numo::DFloat.new(shape).rand
5
+ initializer.(array)
6
+ end
7
+
8
+ def self.get_initializer(initializer)
9
+ return HeNormal.new(scale: 1 / Numo::NMath.sqrt(2)) if initializer.nil?
10
+ return Constant.new(initializer) if initializer.kind_of?(Numeric)
11
+ return Constant.new(initializer) if initializer.kind_of?(Numo::NArray)
12
+
13
+ unless initializer.method_defined?(:call)
14
+ raise TypeError, "invalid type of initializer: #{initializer.class}"
15
+ end
16
+
17
+ return initializer
18
+ end
19
+
20
+ def self.nan(dtype: nil)
21
+ Constant.new(Float::NAN, dtype: dtype)
22
+ end
23
+ end
24
+ end
@@ -0,0 +1,28 @@
1
+ module Chainer
2
+ module Initializers
3
+ class Normal < ::Chainer::Initializer
4
+ def initialize(scale: 0.05, dtype: nil)
5
+ @scale = scale
6
+ super(dtype: dtype)
7
+ end
8
+
9
+ def call(array)
10
+ args = { loc: 0.0, scale: @scale, size: array.shape}
11
+ Numo::DFloat.new(array.shape).rand_norm(0.0, @scale)
12
+ end
13
+ end
14
+
15
+ class HeNormal < ::Chainer::Initializer
16
+ def initialize(scale: 1.0, dtype: nil)
17
+ @scale = scale
18
+ super(dtype: dtype)
19
+ end
20
+
21
+ def call(array)
22
+ fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape)
23
+ s = @scale * Numo::NMath.sqrt(2.0 / fan_in)
24
+ Normal.new(scale: s).(array)
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,74 @@
1
+ module Chainer
2
+ module Iterators
3
+ class SerialIterator < Chainer::Dataset::Iterator
4
+ attr_reader :epoch
5
+
6
+ def initialize(dataset, batch_size, repeat: true, shuffle: true)
7
+ @dataset = dataset
8
+ @batch_size = batch_size
9
+ @repeat = repeat
10
+ @shuffle = shuffle
11
+
12
+ reset
13
+ end
14
+
15
+ def next
16
+ raise StopIteration if !@repeat && @epoch > 0
17
+
18
+ @previous_epoch_detail = @epoch_detail
19
+
20
+ i = @current_position
21
+ i_end = i + @batch_size
22
+ n = @dataset.size
23
+
24
+ batch = @order[i...i_end].to_a.map { |index| @dataset[index] }
25
+
26
+ if i_end >= n
27
+ if @repeat
28
+ rest = i_end - n
29
+ unless @order.nil?
30
+ @order = @order.class[*@order.to_a.shuffle]
31
+ end
32
+ if rest > 0
33
+ if @order.nil?
34
+ batch = batch.append(@dataset[0...rest])
35
+ else
36
+ batch = @dataset[0...rest].map { |index| @dataset[index] }
37
+ end
38
+ end
39
+ @current_position = rest
40
+ else
41
+ @current_position = 0
42
+ end
43
+
44
+ @epoch += 1
45
+ @is_new_epoch = true
46
+ else
47
+ @is_new_epoch = false
48
+ @current_position = i_end
49
+ end
50
+
51
+ batch
52
+ end
53
+
54
+ def epoch_detail
55
+ @epoch + @current_position.to_f / @dataset.size
56
+ end
57
+
58
+ def reset
59
+ if @shuffle
60
+ order = @dataset.size.times.map(&:to_i).shuffle
61
+ @order = Numo::Int64[*order]
62
+ else
63
+ order = @dataset.size.times.map(&:to_i)
64
+ @order = Numo::Int64[*order]
65
+ end
66
+
67
+ @current_position = 0
68
+ @epoch = 0
69
+ @is_new_epoch = false
70
+ @previous_epoch_detail = -1.0
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,118 @@
1
+ module Chainer
2
+ class Link
3
+ def initialize
4
+ @params = []
5
+ @persistent = []
6
+ @within_init_scope = false
7
+ @name = nil
8
+ end
9
+
10
+ def within_init_scope
11
+ @within_init_scope || false
12
+ end
13
+
14
+ def init_scope
15
+ old_flag = self.within_init_scope
16
+ @within_init_scope = true
17
+
18
+ begin
19
+ yield
20
+ set_attr
21
+ ensure
22
+ @within_init_scope = old_flag
23
+ end
24
+ end
25
+
26
+ def set_attr
27
+ self.instance_variables.each do |name|
28
+ value = self.instance_variable_get(name)
29
+ if value.instance_of?(Chainer::Parameter)
30
+ @params << name
31
+ @persistent.delete(name)
32
+ end
33
+ end
34
+ end
35
+
36
+ def cleargrads
37
+ params do |param|
38
+ param.cleargrad
39
+ end
40
+ end
41
+
42
+ def params(include_uninit: true)
43
+ @params.map do |name|
44
+ data = self.instance_variable_get(name).data
45
+ if include_uninit || data
46
+ yield self.instance_variable_get(name)
47
+ end
48
+ end
49
+ end
50
+
51
+ def namedparams(include_uninit: true)
52
+ @params.each do |name|
53
+ if include_uninit || self.instance_variable_get(name).data
54
+ yield ['/' + name.to_s, self.instance_variable_get(name)]
55
+ end
56
+ end
57
+ end
58
+
59
+ def namedlinks(skipself: false)
60
+ yield('/', self) unless skipself
61
+ end
62
+ end
63
+
64
+ class Chain < Link
65
+ def initialize
66
+ super
67
+ @children = []
68
+ end
69
+
70
+ def set_attr
71
+ self.instance_variables.each do |name|
72
+ value = self.instance_variable_get(name)
73
+ if value.kind_of?(Chainer::Link)
74
+ @children << name
75
+ end
76
+ end
77
+ super
78
+ end
79
+
80
+ def params(include_uninit: true)
81
+ super(include_uninit: include_uninit) do |param|
82
+ yield param
83
+ end
84
+
85
+ @children.each do |name|
86
+ self.instance_variable_get(name).params(include_uninit: include_uninit) do |param|
87
+ yield param
88
+ end
89
+ end
90
+ end
91
+
92
+ def namedparams(include_uninit: true)
93
+ super(include_uninit: include_uninit) do |param|
94
+ yield param
95
+ end
96
+
97
+ @children.each do |name|
98
+ prefix = "/#{name}"
99
+ self.instance_variable_get(name).namedparams(include_uninit: include_uninit) do |(path, param)|
100
+ yield [prefix + path, param]
101
+ end
102
+ end
103
+ end
104
+
105
+ def namedlinks(skipself: false)
106
+ yield('/' , self) unless skipself
107
+ d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
108
+ @children.each do |name|
109
+ child = d[name.to_sym]
110
+ prefix = '/' + name.to_s
111
+ yield(prefix, child)
112
+ d[name].namedlinks(skipself: true) do |path, link|
113
+ yield(prefix + path, link)
114
+ end
115
+ end
116
+ end
117
+ end
118
+ end