red-chainer 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -1,23 +1,40 @@
1
1
  module Chainer
2
2
  class VariableNode
3
- attr_reader :dtype, :shape
4
- attr_accessor :data, :name, :grad, :rank, :creator, :requires_grad, :variable
3
+ attr_reader :dtype, :shape, :data
4
+ attr_accessor :name, :requires_grad, :variable, :creator_node, :rank, :old_style_grad_generator
5
5
 
6
- def initialize(variable: , name:, grad: nil)
6
+ def initialize(variable: , name:)
7
7
  @variable = WeakRef.new(variable)
8
- @creator = nil
8
+ @creator_node = nil
9
9
  @data = nil
10
10
  @rank = 0
11
11
  @name = name
12
12
  @requires_grad = variable.requires_grad
13
13
 
14
+ @old_style_grad_generator = nil
15
+
14
16
  set_data_type(variable.data)
17
+ end
15
18
 
16
- @grad = grad
19
+ def creator
20
+ node = @creator_node
21
+ if node.nil?
22
+ return nil
23
+ end
24
+
25
+ if node.is_a?(Chainer::FunctionAdapter)
26
+ return node.function
27
+ end
28
+ node
17
29
  end
18
30
 
19
31
  def creator=(func)
20
- @creator = func
32
+ self.creator_node = func
33
+ end
34
+
35
+ def creator_node=(func)
36
+ func = func.node if func.is_a?(Chainer::Function)
37
+ @creator_node = func
21
38
  unless func.nil?
22
39
  @rank = func.rank + 1
23
40
  end
@@ -28,9 +45,16 @@ module Chainer
28
45
  set_data_type(data)
29
46
  end
30
47
 
31
- def grad=(g)
32
- Utils::Variable.check_grad_type(nil, self, g)
33
- @grad = g
48
+ # Gradient array of the corresponding variable.
49
+ def grad
50
+ var = get_variable
51
+ var.nil? ? nil : var.grad
52
+ end
53
+
54
+ # Gradient variable of the corresponding variable.<Paste>
55
+ def grad_var
56
+ var = get_variable
57
+ var.nil? ? nil : var.grad_var
34
58
  end
35
59
 
36
60
  def label
@@ -41,8 +65,32 @@ module Chainer
41
65
  end
42
66
  end
43
67
 
68
+ # Returns the corresponding :class:`Variable` object.
69
+ #
70
+ # @return [Chainer::Variable] The variable object that refers this node.
71
+ def get_variable
72
+ var = @variable
73
+ # workaround: check weakref_alive?, because weakref sometimes delegates references by GC
74
+ return var.__getobj__ if !var.nil? && var.weakref_alive?
75
+
76
+ var = Chainer::Variable.new(@data, name: @name, requires_grad: @requires_grad)
77
+ var.node = self
78
+ var
79
+ end
80
+
81
+ def set_creator(creator)
82
+ self.creator = creator
83
+ end
84
+
85
+ # Sets a `FunctionNode` object that created this node.
86
+ #
87
+ # @param [Chainer::FunctionNode] creator_node Function node that has this variable as an output.
88
+ def set_creator_node(creator_node)
89
+ self.creator_node = creator_node
90
+ end
91
+
44
92
  def unchain
45
- @creator = nil
93
+ self.creator_node = nil
46
94
  end
47
95
 
48
96
  def retain_data
@@ -67,5 +115,11 @@ module Chainer
67
115
  Utils::Variable.check_grad_type(func, var, g)
68
116
  @grad = g
69
117
  end
118
+
119
+ def check_old_style_gradient
120
+ if @old_style_grad_generator
121
+ raise RuntimeError, "cannot twice-differentiate an old style Function #{@old_style_grad_generator}"
122
+ end
123
+ end
70
124
  end
71
125
  end
@@ -1,4 +1,4 @@
1
1
  module Chainer
2
- VERSION = "0.3.2"
2
+ VERSION = "0.4.0"
3
3
  end
4
4
 
data/red-chainer.gemspec CHANGED
@@ -20,9 +20,10 @@ Gem::Specification.new do |spec|
20
20
  spec.require_paths = ["lib"]
21
21
 
22
22
  spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
23
- spec.add_runtime_dependency "red-datasets", ">= 0.0.5"
23
+ spec.add_runtime_dependency "red-datasets", ">= 0.0.6"
24
24
 
25
- spec.add_development_dependency "bundler", "~> 1.15"
25
+ spec.add_development_dependency "bundler"
26
26
  spec.add_development_dependency "rake", "~> 10.0"
27
- spec.add_development_dependency "test-unit"
27
+ spec.add_development_dependency "test-unit", ">= 3.2.9"
28
+ spec.add_development_dependency "yard", ">= 0.9.10"
28
29
  end
@@ -0,0 +1,40 @@
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <%= erb(:headers) %>
5
+
6
+ <!-- Additional settings for MathJax are from here. -->
7
+ <script type="text/x-mathjax-config">
8
+ MathJax.Hub.Config({
9
+ tex2jax:{
10
+ inlineMath: [ ['$','$'], ["\\(","\\)"] ],
11
+ displayMath: [ ['$$','$$'], ["\\[","\\]"] ]
12
+ }
13
+ });
14
+ </script>
15
+ <script type="text/javascript"
16
+ src="http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML">
17
+ </script>
18
+ <meta http-equiv="X-UA-Compatible" CONTENT="IE=EmulateIE7" />
19
+ <!-- Additional settings for MathJax are over here. -->
20
+
21
+ </head>
22
+ <body>
23
+ <div class="nav_wrap">
24
+ <iframe id="nav" src="<%= @nav_url %>?1"></iframe>
25
+ <div id="resizer"></div>
26
+ </div>
27
+
28
+ <div id="main" tabindex="-1">
29
+ <div id="header">
30
+ <%= erb(:breadcrumb) %>
31
+ <%= erb(:search) %>
32
+ <div class="clear"></div>
33
+ </div>
34
+
35
+ <div id="content"><%= yieldall %></div>
36
+
37
+ <%= erb(:footer) %>
38
+ </div>
39
+ </body>
40
+ </html>
@@ -0,0 +1,33 @@
1
+ <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
2
+ "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
3
+ <html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en">
4
+ <head>
5
+ <meta http-equiv="Content-Type" content="text/html; charset=<%= charset %>" />
6
+ <title><%= defined?(@title) ? @title : '' %></title>
7
+ <%= erb(:headers) %>
8
+
9
+ <!-- Additional settings for MathJax are from here. -->
10
+ <script type="text/x-mathjax-config">
11
+ MathJax.Hub.Config({
12
+ tex2jax:{
13
+ inlineMath: [ ['$','$'], ["\\(","\\)"] ],
14
+ displayMath: [ ['$$','$$'], ["\\[","\\]"] ]
15
+ }
16
+ });
17
+ </script>
18
+ <script type="text/javascript"
19
+ src="http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML">
20
+ </script>
21
+ <meta http-equiv="X-UA-Compatible" CONTENT="IE=EmulateIE7" />
22
+ <!-- Additional settings for MathJax are over here. -->
23
+
24
+ </head>
25
+ <body>
26
+ <div id="content">
27
+ <h1><%= defined?(@title) ? @title : '' %></h1>
28
+ <%= yieldall %>
29
+ </div>
30
+
31
+ <%= erb(:footer) %>
32
+ </body>
33
+ </html>
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-chainer
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.2
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Yusaku Hatanaka
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-06-27 00:00:00.000000000 Z
11
+ date: 2019-03-28 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,28 +30,28 @@ dependencies:
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: 0.0.5
33
+ version: 0.0.6
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: 0.0.5
40
+ version: 0.0.6
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: bundler
43
43
  requirement: !ruby/object:Gem::Requirement
44
44
  requirements:
45
- - - "~>"
45
+ - - ">="
46
46
  - !ruby/object:Gem::Version
47
- version: '1.15'
47
+ version: '0'
48
48
  type: :development
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
- - - "~>"
52
+ - - ">="
53
53
  - !ruby/object:Gem::Version
54
- version: '1.15'
54
+ version: '0'
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: rake
57
57
  requirement: !ruby/object:Gem::Requirement
@@ -72,14 +72,28 @@ dependencies:
72
72
  requirements:
73
73
  - - ">="
74
74
  - !ruby/object:Gem::Version
75
- version: '0'
75
+ version: 3.2.9
76
76
  type: :development
77
77
  prerelease: false
78
78
  version_requirements: !ruby/object:Gem::Requirement
79
79
  requirements:
80
80
  - - ">="
81
81
  - !ruby/object:Gem::Version
82
- version: '0'
82
+ version: 3.2.9
83
+ - !ruby/object:Gem::Dependency
84
+ name: yard
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: 0.9.10
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: 0.9.10
83
97
  description: ''
84
98
  email:
85
99
  - hatappi@hatappi.me
@@ -90,6 +104,7 @@ files:
90
104
  - ".gitignore"
91
105
  - ".rspec"
92
106
  - ".travis.yml"
107
+ - ".yardopts"
93
108
  - Gemfile
94
109
  - LICENSE.txt
95
110
  - README.md
@@ -102,27 +117,42 @@ files:
102
117
  - examples/iris/iris.rb
103
118
  - examples/mnist/mnist.rb
104
119
  - lib/chainer.rb
120
+ - lib/chainer/backend.rb
105
121
  - lib/chainer/configuration.rb
106
122
  - lib/chainer/cuda.rb
107
123
  - lib/chainer/dataset/convert.rb
108
- - lib/chainer/dataset/download.rb
109
124
  - lib/chainer/dataset/iterator.rb
110
125
  - lib/chainer/datasets/cifar.rb
111
126
  - lib/chainer/datasets/mnist.rb
112
127
  - lib/chainer/datasets/tuple_dataset.rb
128
+ - lib/chainer/device.rb
113
129
  - lib/chainer/function.rb
130
+ - lib/chainer/function_node.rb
114
131
  - lib/chainer/functions/activation/leaky_relu.rb
115
132
  - lib/chainer/functions/activation/log_softmax.rb
116
133
  - lib/chainer/functions/activation/relu.rb
134
+ - lib/chainer/functions/activation/relu_grad2.rb
117
135
  - lib/chainer/functions/activation/sigmoid.rb
136
+ - lib/chainer/functions/activation/sigmoid_grad.rb
118
137
  - lib/chainer/functions/activation/tanh.rb
138
+ - lib/chainer/functions/array/broadcast_to.rb
139
+ - lib/chainer/functions/array/cast.rb
140
+ - lib/chainer/functions/array/reshape.rb
141
+ - lib/chainer/functions/array/rollaxis.rb
142
+ - lib/chainer/functions/array/select_item.rb
143
+ - lib/chainer/functions/array/squeeze.rb
144
+ - lib/chainer/functions/array/transpose.rb
119
145
  - lib/chainer/functions/connection/convolution_2d.rb
146
+ - lib/chainer/functions/connection/convolution_2d_grad_w.rb
147
+ - lib/chainer/functions/connection/deconvolution_2d.rb
120
148
  - lib/chainer/functions/connection/linear.rb
121
149
  - lib/chainer/functions/evaluation/accuracy.rb
122
150
  - lib/chainer/functions/loss/mean_squared_error.rb
123
151
  - lib/chainer/functions/loss/softmax_cross_entropy.rb
124
152
  - lib/chainer/functions/math/basic_math.rb
153
+ - lib/chainer/functions/math/exp.rb
125
154
  - lib/chainer/functions/math/identity.rb
155
+ - lib/chainer/functions/math/sum.rb
126
156
  - lib/chainer/functions/noise/dropout.rb
127
157
  - lib/chainer/functions/normalization/batch_normalization.rb
128
158
  - lib/chainer/functions/pooling/average_pooling_2d.rb
@@ -135,6 +165,7 @@ files:
135
165
  - lib/chainer/initializers/constant.rb
136
166
  - lib/chainer/initializers/init.rb
137
167
  - lib/chainer/initializers/normal.rb
168
+ - lib/chainer/initializers/uniform.rb
138
169
  - lib/chainer/iterators/serial_iterator.rb
139
170
  - lib/chainer/link.rb
140
171
  - lib/chainer/links/connection/convolution_2d.rb
@@ -170,6 +201,8 @@ files:
170
201
  - lib/chainer/variable_node.rb
171
202
  - lib/chainer/version.rb
172
203
  - red-chainer.gemspec
204
+ - templates/default/layout/html/layout.erb
205
+ - templates/default/onefile/html/layout.erb
173
206
  homepage: https://github.com/red-data-tools/red-chainer
174
207
  licenses:
175
208
  - MIT
@@ -1,56 +0,0 @@
1
- require "open-uri"
2
- require "pstore"
3
-
4
- module Chainer
5
- module Dataset
6
- module Download
7
- DATASET_ROOT = ENV.fetch("RED_CHAINER_DATASET_ROOT", File.expand_path(".red-chainer/dataset", "~"))
8
-
9
- def self.cached_download(url)
10
- cache_root = File.expand_path('_dl_cache', DATASET_ROOT)
11
- FileUtils.mkdir_p(cache_root)
12
- lock_path = File.expand_path('_dl_lock', cache_root)
13
- urlhash = Digest::MD5.hexdigest(url)
14
- cache_path = File.expand_path(urlhash, cache_root)
15
-
16
- return cache_path if File.exist?(cache_path)
17
-
18
- temp_root = Dir.mktmpdir(nil, cache_root)
19
- temp_path = File.expand_path('dl', temp_root)
20
- open(url) do |f|
21
- puts "Downloading from #{url}"
22
- open(temp_path, "w+b") do |out|
23
- out.write(f.read)
24
- end
25
- FileUtils.mv(temp_path, cache_path)
26
- FileUtils.rm_r(temp_root)
27
- end
28
- cache_path
29
- end
30
-
31
- def self.get_dataset_directory(dataset_name, create_directory: true)
32
- path = File.expand_path(dataset_name, DATASET_ROOT)
33
- FileUtils.mkdir_p(path) if create_directory
34
- path
35
- end
36
-
37
- def self.cache_or_load_file(path, &creator)
38
- raise 'Please set dataset creator on block' if creator.nil?
39
-
40
- return PStore.new(path).transaction { |t| t['data'] } if File.exist?(path)
41
-
42
- data = creator.call
43
- PStore.new(path).transaction do |t|
44
- t['data'] = data
45
- end
46
- data
47
- rescue TypeError => e
48
- puts e.message
49
- FileUtils.rm_f(path)
50
- cache_or_load_file(path) do
51
- creator.call
52
- end
53
- end
54
- end
55
- end
56
- end