red-chainer 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,33 @@
1
+ module Chainer
2
+ module Training
3
+ class Updater
4
+ def connect_trainer(trainer)
5
+ end
6
+
7
+ def finalize
8
+ end
9
+
10
+ def get_optimizer(name)
11
+ raise NotImplementedError
12
+ end
13
+
14
+ def get_all_optimizers
15
+ raise NotImplementedError
16
+ end
17
+
18
+ def update
19
+ raise NotImplementedError
20
+ end
21
+
22
+ def serialize(serializer)
23
+ raise NotImplementedError
24
+ end
25
+
26
+ # this method uses in ERB
27
+ # example: ERB.new("<%= self %>").result(updater.bind)
28
+ def bind
29
+ binding
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,13 @@
1
+ module Chainer
2
+ module Training
3
+ module Util
4
+ def self.get_trigger(trigger)
5
+ if trigger.nil?
6
+ false
7
+ else
8
+ Triggers::IntervalTrigger.new(*trigger)
9
+ end
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,10 @@
1
+ module Chainer
2
+ module Utils
3
+ module Array
4
+ def self.force_array(x, dtype=nil)
5
+ # TODO: conversion by dtype
6
+ Numo::NArray.[](*x)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,14 @@
1
+ module Chainer
2
+ module Utils
3
+ module Initializer
4
+ def self.get_fans(shape)
5
+ raise 'shape must be of length >= 2: shape={}' if shape.size < 2
6
+ slice_arr = shape.slice(2, shape.size)
7
+ receptive_field_size = slice_arr.empty? ? 1 : Numo::Int32[slice_arr].prod
8
+ fan_in = shape[1] * receptive_field_size
9
+ fan_out = shape[0] * receptive_field_size
10
+ [fan_in, fan_out]
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,20 @@
1
+ module Chainer
2
+ module Utils
3
+ module Variable
4
+ def self.check_grad_type(func, x, gx)
5
+ if x.data.nil? || gx.nil?
6
+ return
7
+ end
8
+
9
+ unless gx.instance_of?(x.data.class)
10
+ raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}"
11
+ end
12
+
13
+ unless gx.shape == x.data.shape
14
+ raise TypeError, "Shape of data and grad mismatch\n#{x.class} != #{gx.class}"
15
+ end
16
+ end
17
+ end
18
+ end
19
+ end
20
+
@@ -0,0 +1,204 @@
1
+ module Chainer
2
+ class Variable
3
+ attr_accessor :data, :grad, :requires_grad, :node
4
+
5
+ def initialize(data=nil, name: nil, grad: nil, requires_grad: true)
6
+ unless data.nil? || data.is_a?(Numo::NArray)
7
+ raise TypeError, "Numo::NArray are expected."
8
+ end
9
+
10
+ @data = [data]
11
+ @grad = grad
12
+ @requires_grad = requires_grad
13
+ @node = VariableNode.new(variable: self, name: name, grad: grad)
14
+ end
15
+
16
+ def data
17
+ return @data[0]
18
+ end
19
+
20
+ def data=(d)
21
+ @data[0] = d
22
+ @node.set_data_type(d)
23
+ end
24
+
25
+ def name
26
+ return @node.name
27
+ end
28
+
29
+ def name=(n)
30
+ @node.name = n
31
+ end
32
+
33
+ def label
34
+ @node.label
35
+ end
36
+
37
+ def creator
38
+ @node.creator
39
+ end
40
+
41
+ def creator=(func)
42
+ @node.creator = func
43
+ end
44
+
45
+ def grad
46
+ @node.grad
47
+ end
48
+
49
+ def grad=(g)
50
+ @node.set_grad_with_check(g, nil, self)
51
+ end
52
+
53
+ def shape
54
+ self.data.shape
55
+ end
56
+
57
+ def ndim
58
+ self.data.ndim
59
+ end
60
+
61
+ def size
62
+ self.data.size
63
+ end
64
+
65
+ def dtype
66
+ self.data.class
67
+ end
68
+
69
+ def rank
70
+ @node.rank
71
+ end
72
+
73
+ def cleargrad
74
+ @node.grad = nil
75
+ end
76
+
77
+ def backward(retain_grad: false)
78
+ return if self.creator.nil?
79
+
80
+ if self.data.size == 1 && self.grad.nil?
81
+ self.grad = self.data.new_ones
82
+ end
83
+
84
+ funcs = [self.creator]
85
+
86
+ while func = funcs.pop
87
+ outputs = func.outputs.map(&:__getobj__)
88
+ in_data = func.inputs.map(&:data)
89
+ out_grad = outputs.map { |y| y.nil? ? nil : y.grad }
90
+
91
+ func.output_data = outputs.map { |y| y.nil? ? nil : y.data }
92
+ gxs = func.backward(in_data, out_grad)
93
+
94
+ raise unless gxs.size == in_data.size
95
+
96
+ unless func.retain_after_backward
97
+ func.output_data = nil
98
+ end
99
+
100
+ unless retain_grad
101
+ outputs.each do |y|
102
+ unless y.nil? || y == @node
103
+ y.grad = nil
104
+ end
105
+ end
106
+ end
107
+
108
+ seen_vars = []
109
+ need_copy = []
110
+
111
+ func.inputs.zip(gxs).each do |x, gx|
112
+ next if gx.nil?
113
+ next unless x.requires_grad
114
+
115
+ Utils::Variable.check_grad_type(func, x, gx)
116
+
117
+ id_x = x.object_id
118
+ if x.creator.nil? # leaf
119
+ if x.grad.nil?
120
+ x.grad = gx
121
+ need_copy << id_x
122
+ else
123
+ if need_copy.include?(id_x)
124
+ x.grad = Utils::Array.force_array(x.grad + gx)
125
+ need_copy.delete(id_x)
126
+ else
127
+ x.grad += gx
128
+ end
129
+ end
130
+ else # not leaf
131
+ funcs << x.creator
132
+ if seen_vars.include?(id_x)
133
+ if need_copy.include?(id_x)
134
+ x.grad = Utils::Array.force_array(gx + x.grad)
135
+ need_copy.delete(id_x)
136
+ else
137
+ x.grad += gx
138
+ end
139
+ else
140
+ x.grad = gx
141
+ seen_vars << id_x
142
+ need_copy << id_x
143
+ end
144
+ end
145
+ end
146
+ end
147
+ end
148
+
149
+ def -@
150
+ Functions::Math::Neg.new.(self)
151
+ end
152
+
153
+ def +(other)
154
+ if other.instance_of?(Chainer::Variable)
155
+ Functions::Math::Add.new.(*[self, other])
156
+ else
157
+ Functions::Math::AddConstant.new(other).(self)
158
+ end
159
+ end
160
+
161
+ def -(other)
162
+ if other.instance_of?(Chainer::Variable)
163
+ Functions::Math::Sub.new.(*[self, other])
164
+ else
165
+ Functions::Math::AddConstant.new(-other).(self)
166
+ end
167
+ end
168
+
169
+ def *(other)
170
+ if other.instance_of?(Chainer::Variable)
171
+ Functions::Math::Mul.new.(*[self, other])
172
+ else
173
+ Functions::Math::MulConstant.new(other).(self)
174
+ end
175
+ end
176
+
177
+ def /(other)
178
+ if other.instance_of?(Chainer::Variable)
179
+ Functions::Math::Div.new.(*[self, other])
180
+ else
181
+ Functions::Math::MulConstant.new(1 / other).(self)
182
+ end
183
+ end
184
+
185
+ def **(other)
186
+ if other.instance_of?(Chainer::Variable)
187
+ Functions::Math::PowVarVar.new.(*[self, other])
188
+ else
189
+ Functions::Math::PowVarConst.new(other).(self)
190
+ end
191
+ end
192
+
193
+ def retain_data
194
+ @node.data = @data[0]
195
+ end
196
+
197
+ # when left side is Numeric value and right side is Chainer::Value, call this method.
198
+ def coerce(other)
199
+ other = self.data.class[*other] if other.kind_of?(Numeric)
200
+ [Chainer::Variable.new(other, requires_grad: false), self]
201
+ end
202
+ end
203
+ end
204
+
@@ -0,0 +1,71 @@
1
+ module Chainer
2
+ class VariableNode
3
+ attr_reader :dtype, :shape
4
+ attr_accessor :data, :name, :grad, :rank, :creator, :requires_grad, :variable
5
+
6
+ def initialize(variable: , name:, grad: nil)
7
+ @variable = WeakRef.new(variable)
8
+ @creator = nil
9
+ @data = nil
10
+ @rank = 0
11
+ @name = name
12
+ @requires_grad = variable.requires_grad
13
+
14
+ set_data_type(variable.data)
15
+
16
+ @grad = grad
17
+ end
18
+
19
+ def creator=(func)
20
+ @creator = func
21
+ unless func.nil?
22
+ @rank = func.rank + 1
23
+ end
24
+ end
25
+
26
+ def data=(data)
27
+ @data = data
28
+ set_data_type(data)
29
+ end
30
+
31
+ def grad=(g)
32
+ Utils::Variable.check_grad_type(nil, self, g)
33
+ @grad = g
34
+ end
35
+
36
+ def label
37
+ if @shape.nil? || @shape.empty?
38
+ @dtype.to_s
39
+ else
40
+ "(#{@shape.join(', ')}), #{@dtype.to_s}"
41
+ end
42
+ end
43
+
44
+ def unchain
45
+ @creator = nil
46
+ end
47
+
48
+ def retain_data
49
+ if @variable.nil?
50
+ raise "cannot retain variable data: the variable has been already released"
51
+ else
52
+ @variable.data
53
+ end
54
+ end
55
+
56
+ def set_data_type(data)
57
+ if data.nil?
58
+ @dtype = nil
59
+ @shape = nil
60
+ else
61
+ @dtype = data.class
62
+ @shape = data.shape
63
+ end
64
+ end
65
+
66
+ def set_grad_with_check(g, func, var)
67
+ Utils::Variable.check_grad_type(func, var, g)
68
+ @grad = g
69
+ end
70
+ end
71
+ end
@@ -0,0 +1,4 @@
1
+ module Chainer
2
+ VERSION = "0.1.0"
3
+ end
4
+
@@ -0,0 +1,27 @@
1
+ # coding: utf-8
2
+ lib = File.expand_path("../lib", __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require "chainer/version"
5
+
6
+ Gem::Specification.new do |spec|
7
+ spec.name = "red-chainer"
8
+ spec.version = Chainer::VERSION
9
+ spec.authors = ["Yusaku Hatanaka"]
10
+ spec.email = ["hatappi@hatappi.me"]
11
+
12
+ spec.summary, spec.description = "A flexible framework for neural network for Ruby"
13
+ spec.homepage = "https://github.com/red-data-tools/red-chainer"
14
+ spec.license = "MIT"
15
+ spec.files = `git ls-files -z`.split("\x0").reject do |f|
16
+ f.match(%r{^(test|spec|features)/})
17
+ end
18
+ spec.bindir = "exe"
19
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
20
+ spec.require_paths = ["lib"]
21
+
22
+ spec.add_runtime_dependency "numo-narray", ">= 0.9.0.8"
23
+
24
+ spec.add_development_dependency "bundler", "~> 1.15"
25
+ spec.add_development_dependency "rake", "~> 10.0"
26
+ spec.add_development_dependency "test-unit"
27
+ end
metadata ADDED
@@ -0,0 +1,156 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: red-chainer
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Yusaku Hatanaka
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2017-11-18 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-narray
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: 0.9.0.8
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: 0.9.0.8
27
+ - !ruby/object:Gem::Dependency
28
+ name: bundler
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '1.15'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '1.15'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rake
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '10.0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '10.0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: test-unit
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - ">="
60
+ - !ruby/object:Gem::Version
61
+ version: '0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '0'
69
+ description: ''
70
+ email:
71
+ - hatappi@hatappi.me
72
+ executables: []
73
+ extensions: []
74
+ extra_rdoc_files: []
75
+ files:
76
+ - ".gitignore"
77
+ - ".rspec"
78
+ - ".travis.yml"
79
+ - CODE_OF_CONDUCT.md
80
+ - Gemfile
81
+ - LICENSE.txt
82
+ - README.md
83
+ - Rakefile
84
+ - bin/console
85
+ - bin/setup
86
+ - examples/mnist.rb
87
+ - lib/chainer.rb
88
+ - lib/chainer/configuration.rb
89
+ - lib/chainer/dataset/convert.rb
90
+ - lib/chainer/dataset/download.rb
91
+ - lib/chainer/dataset/iterator.rb
92
+ - lib/chainer/datasets/mnist.rb
93
+ - lib/chainer/datasets/tuple_dataset.rb
94
+ - lib/chainer/function.rb
95
+ - lib/chainer/functions/activation/log_softmax.rb
96
+ - lib/chainer/functions/activation/relu.rb
97
+ - lib/chainer/functions/connection/linear.rb
98
+ - lib/chainer/functions/evaluation/accuracy.rb
99
+ - lib/chainer/functions/loss/softmax_cross_entropy.rb
100
+ - lib/chainer/functions/math/basic_math.rb
101
+ - lib/chainer/gradient_method.rb
102
+ - lib/chainer/hyperparameter.rb
103
+ - lib/chainer/initializer.rb
104
+ - lib/chainer/initializers/constant.rb
105
+ - lib/chainer/initializers/init.rb
106
+ - lib/chainer/initializers/normal.rb
107
+ - lib/chainer/iterators/serial_iterator.rb
108
+ - lib/chainer/link.rb
109
+ - lib/chainer/links/connection/linear.rb
110
+ - lib/chainer/links/model/classifier.rb
111
+ - lib/chainer/optimizer.rb
112
+ - lib/chainer/optimizers/adam.rb
113
+ - lib/chainer/parameter.rb
114
+ - lib/chainer/reporter.rb
115
+ - lib/chainer/training/extension.rb
116
+ - lib/chainer/training/extensions/evaluator.rb
117
+ - lib/chainer/training/extensions/log_report.rb
118
+ - lib/chainer/training/extensions/print_report.rb
119
+ - lib/chainer/training/extensions/progress_bar.rb
120
+ - lib/chainer/training/standard_updater.rb
121
+ - lib/chainer/training/trainer.rb
122
+ - lib/chainer/training/triggers/interval.rb
123
+ - lib/chainer/training/updater.rb
124
+ - lib/chainer/training/util.rb
125
+ - lib/chainer/utils/array.rb
126
+ - lib/chainer/utils/initializer.rb
127
+ - lib/chainer/utils/variable.rb
128
+ - lib/chainer/variable.rb
129
+ - lib/chainer/variable_node.rb
130
+ - lib/chainer/version.rb
131
+ - red-chainer.gemspec
132
+ homepage: https://github.com/red-data-tools/red-chainer
133
+ licenses:
134
+ - MIT
135
+ metadata: {}
136
+ post_install_message:
137
+ rdoc_options: []
138
+ require_paths:
139
+ - lib
140
+ required_ruby_version: !ruby/object:Gem::Requirement
141
+ requirements:
142
+ - - ">="
143
+ - !ruby/object:Gem::Version
144
+ version: '0'
145
+ required_rubygems_version: !ruby/object:Gem::Requirement
146
+ requirements:
147
+ - - ">="
148
+ - !ruby/object:Gem::Version
149
+ version: '0'
150
+ requirements: []
151
+ rubyforge_project:
152
+ rubygems_version: 2.6.13
153
+ signing_key:
154
+ specification_version: 4
155
+ summary: A flexible framework for neural network for Ruby
156
+ test_files: []