red-chainer 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,33 @@
1
+ module Chainer
2
+ module Training
3
+ class Updater
4
+ def connect_trainer(trainer)
5
+ end
6
+
7
+ def finalize
8
+ end
9
+
10
+ def get_optimizer(name)
11
+ raise NotImplementedError
12
+ end
13
+
14
+ def get_all_optimizers
15
+ raise NotImplementedError
16
+ end
17
+
18
+ def update
19
+ raise NotImplementedError
20
+ end
21
+
22
+ def serialize(serializer)
23
+ raise NotImplementedError
24
+ end
25
+
26
+ # this method uses in ERB
27
+ # example: ERB.new("<%= self %>").result(updater.bind)
28
+ def bind
29
+ binding
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,13 @@
1
+ module Chainer
2
+ module Training
3
+ module Util
4
+ def self.get_trigger(trigger)
5
+ if trigger.nil?
6
+ false
7
+ else
8
+ Triggers::IntervalTrigger.new(*trigger)
9
+ end
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,10 @@
1
+ module Chainer
2
+ module Utils
3
+ module Array
4
+ def self.force_array(x, dtype=nil)
5
+ # TODO: conversion by dtype
6
+ Numo::NArray.[](*x)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,14 @@
1
+ module Chainer
2
+ module Utils
3
+ module Initializer
4
+ def self.get_fans(shape)
5
+ raise 'shape must be of length >= 2: shape={}' if shape.size < 2
6
+ slice_arr = shape.slice(2, shape.size)
7
+ receptive_field_size = slice_arr.empty? ? 1 : Numo::Int32[slice_arr].prod
8
+ fan_in = shape[1] * receptive_field_size
9
+ fan_out = shape[0] * receptive_field_size
10
+ [fan_in, fan_out]
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,20 @@
1
+ module Chainer
2
+ module Utils
3
+ module Variable
4
+ def self.check_grad_type(func, x, gx)
5
+ if x.data.nil? || gx.nil?
6
+ return
7
+ end
8
+
9
+ unless gx.instance_of?(x.data.class)
10
+ raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}"
11
+ end
12
+
13
+ unless gx.shape == x.data.shape
14
+ raise TypeError, "Shape of data and grad mismatch\n#{x.class} != #{gx.class}"
15
+ end
16
+ end
17
+ end
18
+ end
19
+ end
20
+
@@ -0,0 +1,204 @@
1
+ module Chainer
2
+ class Variable
3
+ attr_accessor :data, :grad, :requires_grad, :node
4
+
5
+ def initialize(data=nil, name: nil, grad: nil, requires_grad: true)
6
+ unless data.nil? || data.is_a?(Numo::NArray)
7
+ raise TypeError, "Numo::NArray are expected."
8
+ end
9
+
10
+ @data = [data]
11
+ @grad = grad
12
+ @requires_grad = requires_grad
13
+ @node = VariableNode.new(variable: self, name: name, grad: grad)
14
+ end
15
+
16
+ def data
17
+ return @data[0]
18
+ end
19
+
20
+ def data=(d)
21
+ @data[0] = d
22
+ @node.set_data_type(d)
23
+ end
24
+
25
+ def name
26
+ return @node.name
27
+ end
28
+
29
+ def name=(n)
30
+ @node.name = n
31
+ end
32
+
33
+ def label
34
+ @node.label
35
+ end
36
+
37
+ def creator
38
+ @node.creator
39
+ end
40
+
41
+ def creator=(func)
42
+ @node.creator = func
43
+ end
44
+
45
+ def grad
46
+ @node.grad
47
+ end
48
+
49
+ def grad=(g)
50
+ @node.set_grad_with_check(g, nil, self)
51
+ end
52
+
53
+ def shape
54
+ self.data.shape
55
+ end
56
+
57
+ def ndim
58
+ self.data.ndim
59
+ end
60
+
61
+ def size
62
+ self.data.size
63
+ end
64
+
65
+ def dtype
66
+ self.data.class
67
+ end
68
+
69
+ def rank
70
+ @node.rank
71
+ end
72
+
73
+ def cleargrad
74
+ @node.grad = nil
75
+ end
76
+
77
+ def backward(retain_grad: false)
78
+ return if self.creator.nil?
79
+
80
+ if self.data.size == 1 && self.grad.nil?
81
+ self.grad = self.data.new_ones
82
+ end
83
+
84
+ funcs = [self.creator]
85
+
86
+ while func = funcs.pop
87
+ outputs = func.outputs.map(&:__getobj__)
88
+ in_data = func.inputs.map(&:data)
89
+ out_grad = outputs.map { |y| y.nil? ? nil : y.grad }
90
+
91
+ func.output_data = outputs.map { |y| y.nil? ? nil : y.data }
92
+ gxs = func.backward(in_data, out_grad)
93
+
94
+ raise unless gxs.size == in_data.size
95
+
96
+ unless func.retain_after_backward
97
+ func.output_data = nil
98
+ end
99
+
100
+ unless retain_grad
101
+ outputs.each do |y|
102
+ unless y.nil? || y == @node
103
+ y.grad = nil
104
+ end
105
+ end
106
+ end
107
+
108
+ seen_vars = []
109
+ need_copy = []
110
+
111
+ func.inputs.zip(gxs).each do |x, gx|
112
+ next if gx.nil?
113
+ next unless x.requires_grad
114
+
115
+ Utils::Variable.check_grad_type(func, x, gx)
116
+
117
+ id_x = x.object_id
118
+ if x.creator.nil? # leaf
119
+ if x.grad.nil?
120
+ x.grad = gx
121
+ need_copy << id_x
122
+ else
123
+ if need_copy.include?(id_x)
124
+ x.grad = Utils::Array.force_array(x.grad + gx)
125
+ need_copy.delete(id_x)
126
+ else
127
+ x.grad += gx
128
+ end
129
+ end
130
+ else # not leaf
131
+ funcs << x.creator
132
+ if seen_vars.include?(id_x)
133
+ if need_copy.include?(id_x)
134
+ x.grad = Utils::Array.force_array(gx + x.grad)
135
+ need_copy.delete(id_x)
136
+ else
137
+ x.grad += gx
138
+ end
139
+ else
140
+ x.grad = gx
141
+ seen_vars << id_x
142
+ need_copy << id_x
143
+ end
144
+ end
145
+ end
146
+ end
147
+ end
148
+
149
+ def -@
150
+ Functions::Math::Neg.new.(self)
151
+ end
152
+
153
+ def +(other)
154
+ if other.instance_of?(Chainer::Variable)
155
+ Functions::Math::Add.new.(*[self, other])
156
+ else
157
+ Functions::Math::AddConstant.new(other).(self)
158
+ end
159
+ end
160
+
161
+ def -(other)
162
+ if other.instance_of?(Chainer::Variable)
163
+ Functions::Math::Sub.new.(*[self, other])
164
+ else
165
+ Functions::Math::AddConstant.new(-other).(self)
166
+ end
167
+ end
168
+
169
+ def *(other)
170
+ if other.instance_of?(Chainer::Variable)
171
+ Functions::Math::Mul.new.(*[self, other])
172
+ else
173
+ Functions::Math::MulConstant.new(other).(self)
174
+ end
175
+ end
176
+
177
+ def /(other)
178
+ if other.instance_of?(Chainer::Variable)
179
+ Functions::Math::Div.new.(*[self, other])
180
+ else
181
+ Functions::Math::MulConstant.new(1 / other).(self)
182
+ end
183
+ end
184
+
185
+ def **(other)
186
+ if other.instance_of?(Chainer::Variable)
187
+ Functions::Math::PowVarVar.new.(*[self, other])
188
+ else
189
+ Functions::Math::PowVarConst.new(other).(self)
190
+ end
191
+ end
192
+
193
+ def retain_data
194
+ @node.data = @data[0]
195
+ end
196
+
197
+ # when left side is Numeric value and right side is Chainer::Value, call this method.
198
+ def coerce(other)
199
+ other = self.data.class[*other] if other.kind_of?(Numeric)
200
+ [Chainer::Variable.new(other, requires_grad: false), self]
201
+ end
202
+ end
203
+ end
204
+
@@ -0,0 +1,71 @@
1
+ module Chainer
2
+ class VariableNode
3
+ attr_reader :dtype, :shape
4
+ attr_accessor :data, :name, :grad, :rank, :creator, :requires_grad, :variable
5
+
6
+ def initialize(variable: , name:, grad: nil)
7
+ @variable = WeakRef.new(variable)
8
+ @creator = nil
9
+ @data = nil
10
+ @rank = 0
11
+ @name = name
12
+ @requires_grad = variable.requires_grad
13
+
14
+ set_data_type(variable.data)
15
+
16
+ @grad = grad
17
+ end
18
+
19
+ def creator=(func)
20
+ @creator = func
21
+ unless func.nil?
22
+ @rank = func.rank + 1
23
+ end
24
+ end
25
+
26
+ def data=(data)
27
+ @data = data
28
+ set_data_type(data)
29
+ end
30
+
31
+ def grad=(g)
32
+ Utils::Variable.check_grad_type(nil, self, g)
33
+ @grad = g
34
+ end
35
+
36
+ def label
37
+ if @shape.nil? || @shape.empty?
38
+ @dtype.to_s
39
+ else
40
+ "(#{@shape.join(', ')}), #{@dtype.to_s}"
41
+ end
42
+ end
43
+
44
+ def unchain
45
+ @creator = nil
46
+ end
47
+
48
+ def retain_data
49
+ if @variable.nil?
50
+ raise "cannot retain variable data: the variable has been already released"
51
+ else
52
+ @variable.data
53
+ end
54
+ end
55
+
56
+ def set_data_type(data)
57
+ if data.nil?
58
+ @dtype = nil
59
+ @shape = nil
60
+ else
61
+ @dtype = data.class
62
+ @shape = data.shape
63
+ end
64
+ end
65
+
66
+ def set_grad_with_check(g, func, var)
67
+ Utils::Variable.check_grad_type(func, var, g)
68
+ @grad = g
69
+ end
70
+ end
71
+ end
@@ -0,0 +1,4 @@
1
+ module Chainer
2
+ VERSION = "0.1.0"
3
+ end
4
+
@@ -0,0 +1,27 @@
1
+ # coding: utf-8
2
+ lib = File.expand_path("../lib", __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require "chainer/version"
5
+
6
+ Gem::Specification.new do |spec|
7
+ spec.name = "red-chainer"
8
+ spec.version = Chainer::VERSION
9
+ spec.authors = ["Yusaku Hatanaka"]
10
+ spec.email = ["hatappi@hatappi.me"]
11
+
12
+ spec.summary, spec.description = "A flexible framework for neural network for Ruby"
13
+ spec.homepage = "https://github.com/red-data-tools/red-chainer"
14
+ spec.license = "MIT"
15
+ spec.files = `git ls-files -z`.split("\x0").reject do |f|
16
+ f.match(%r{^(test|spec|features)/})
17
+ end
18
+ spec.bindir = "exe"
19
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
20
+ spec.require_paths = ["lib"]
21
+
22
+ spec.add_runtime_dependency "numo-narray", ">= 0.9.0.8"
23
+
24
+ spec.add_development_dependency "bundler", "~> 1.15"
25
+ spec.add_development_dependency "rake", "~> 10.0"
26
+ spec.add_development_dependency "test-unit"
27
+ end
metadata ADDED
@@ -0,0 +1,156 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: red-chainer
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Yusaku Hatanaka
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2017-11-18 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-narray
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: 0.9.0.8
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: 0.9.0.8
27
+ - !ruby/object:Gem::Dependency
28
+ name: bundler
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '1.15'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '1.15'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rake
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '10.0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '10.0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: test-unit
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - ">="
60
+ - !ruby/object:Gem::Version
61
+ version: '0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '0'
69
+ description: ''
70
+ email:
71
+ - hatappi@hatappi.me
72
+ executables: []
73
+ extensions: []
74
+ extra_rdoc_files: []
75
+ files:
76
+ - ".gitignore"
77
+ - ".rspec"
78
+ - ".travis.yml"
79
+ - CODE_OF_CONDUCT.md
80
+ - Gemfile
81
+ - LICENSE.txt
82
+ - README.md
83
+ - Rakefile
84
+ - bin/console
85
+ - bin/setup
86
+ - examples/mnist.rb
87
+ - lib/chainer.rb
88
+ - lib/chainer/configuration.rb
89
+ - lib/chainer/dataset/convert.rb
90
+ - lib/chainer/dataset/download.rb
91
+ - lib/chainer/dataset/iterator.rb
92
+ - lib/chainer/datasets/mnist.rb
93
+ - lib/chainer/datasets/tuple_dataset.rb
94
+ - lib/chainer/function.rb
95
+ - lib/chainer/functions/activation/log_softmax.rb
96
+ - lib/chainer/functions/activation/relu.rb
97
+ - lib/chainer/functions/connection/linear.rb
98
+ - lib/chainer/functions/evaluation/accuracy.rb
99
+ - lib/chainer/functions/loss/softmax_cross_entropy.rb
100
+ - lib/chainer/functions/math/basic_math.rb
101
+ - lib/chainer/gradient_method.rb
102
+ - lib/chainer/hyperparameter.rb
103
+ - lib/chainer/initializer.rb
104
+ - lib/chainer/initializers/constant.rb
105
+ - lib/chainer/initializers/init.rb
106
+ - lib/chainer/initializers/normal.rb
107
+ - lib/chainer/iterators/serial_iterator.rb
108
+ - lib/chainer/link.rb
109
+ - lib/chainer/links/connection/linear.rb
110
+ - lib/chainer/links/model/classifier.rb
111
+ - lib/chainer/optimizer.rb
112
+ - lib/chainer/optimizers/adam.rb
113
+ - lib/chainer/parameter.rb
114
+ - lib/chainer/reporter.rb
115
+ - lib/chainer/training/extension.rb
116
+ - lib/chainer/training/extensions/evaluator.rb
117
+ - lib/chainer/training/extensions/log_report.rb
118
+ - lib/chainer/training/extensions/print_report.rb
119
+ - lib/chainer/training/extensions/progress_bar.rb
120
+ - lib/chainer/training/standard_updater.rb
121
+ - lib/chainer/training/trainer.rb
122
+ - lib/chainer/training/triggers/interval.rb
123
+ - lib/chainer/training/updater.rb
124
+ - lib/chainer/training/util.rb
125
+ - lib/chainer/utils/array.rb
126
+ - lib/chainer/utils/initializer.rb
127
+ - lib/chainer/utils/variable.rb
128
+ - lib/chainer/variable.rb
129
+ - lib/chainer/variable_node.rb
130
+ - lib/chainer/version.rb
131
+ - red-chainer.gemspec
132
+ homepage: https://github.com/red-data-tools/red-chainer
133
+ licenses:
134
+ - MIT
135
+ metadata: {}
136
+ post_install_message:
137
+ rdoc_options: []
138
+ require_paths:
139
+ - lib
140
+ required_ruby_version: !ruby/object:Gem::Requirement
141
+ requirements:
142
+ - - ">="
143
+ - !ruby/object:Gem::Version
144
+ version: '0'
145
+ required_rubygems_version: !ruby/object:Gem::Requirement
146
+ requirements:
147
+ - - ">="
148
+ - !ruby/object:Gem::Version
149
+ version: '0'
150
+ requirements: []
151
+ rubyforge_project:
152
+ rubygems_version: 2.6.13
153
+ signing_key:
154
+ specification_version: 4
155
+ summary: A flexible framework for neural network for Ruby
156
+ test_files: []