red-chainer 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -1,23 +1,40 @@
1
1
  module Chainer
2
2
  class VariableNode
3
- attr_reader :dtype, :shape
4
- attr_accessor :data, :name, :grad, :rank, :creator, :requires_grad, :variable
3
+ attr_reader :dtype, :shape, :data
4
+ attr_accessor :name, :requires_grad, :variable, :creator_node, :rank, :old_style_grad_generator
5
5
 
6
- def initialize(variable: , name:, grad: nil)
6
+ def initialize(variable: , name:)
7
7
  @variable = WeakRef.new(variable)
8
- @creator = nil
8
+ @creator_node = nil
9
9
  @data = nil
10
10
  @rank = 0
11
11
  @name = name
12
12
  @requires_grad = variable.requires_grad
13
13
 
14
+ @old_style_grad_generator = nil
15
+
14
16
  set_data_type(variable.data)
17
+ end
15
18
 
16
- @grad = grad
19
+ def creator
20
+ node = @creator_node
21
+ if node.nil?
22
+ return nil
23
+ end
24
+
25
+ if node.is_a?(Chainer::FunctionAdapter)
26
+ return node.function
27
+ end
28
+ node
17
29
  end
18
30
 
19
31
  def creator=(func)
20
- @creator = func
32
+ self.creator_node = func
33
+ end
34
+
35
+ def creator_node=(func)
36
+ func = func.node if func.is_a?(Chainer::Function)
37
+ @creator_node = func
21
38
  unless func.nil?
22
39
  @rank = func.rank + 1
23
40
  end
@@ -28,9 +45,16 @@ module Chainer
28
45
  set_data_type(data)
29
46
  end
30
47
 
31
- def grad=(g)
32
- Utils::Variable.check_grad_type(nil, self, g)
33
- @grad = g
48
+ # Gradient array of the corresponding variable.
49
+ def grad
50
+ var = get_variable
51
+ var.nil? ? nil : var.grad
52
+ end
53
+
54
+ # Gradient variable of the corresponding variable.<Paste>
55
+ def grad_var
56
+ var = get_variable
57
+ var.nil? ? nil : var.grad_var
34
58
  end
35
59
 
36
60
  def label
@@ -41,8 +65,32 @@ module Chainer
41
65
  end
42
66
  end
43
67
 
68
+ # Returns the corresponding :class:`Variable` object.
69
+ #
70
+ # @return [Chainer::Variable] The variable object that refers this node.
71
+ def get_variable
72
+ var = @variable
73
+ # workaround: check weakref_alive?, because weakref sometimes delegates references by GC
74
+ return var.__getobj__ if !var.nil? && var.weakref_alive?
75
+
76
+ var = Chainer::Variable.new(@data, name: @name, requires_grad: @requires_grad)
77
+ var.node = self
78
+ var
79
+ end
80
+
81
+ def set_creator(creator)
82
+ self.creator = creator
83
+ end
84
+
85
+ # Sets a `FunctionNode` object that created this node.
86
+ #
87
+ # @param [Chainer::FunctionNode] creator_node Function node that has this variable as an output.
88
+ def set_creator_node(creator_node)
89
+ self.creator_node = creator_node
90
+ end
91
+
44
92
  def unchain
45
- @creator = nil
93
+ self.creator_node = nil
46
94
  end
47
95
 
48
96
  def retain_data
@@ -67,5 +115,11 @@ module Chainer
67
115
  Utils::Variable.check_grad_type(func, var, g)
68
116
  @grad = g
69
117
  end
118
+
119
+ def check_old_style_gradient
120
+ if @old_style_grad_generator
121
+ raise RuntimeError, "cannot twice-differentiate an old style Function #{@old_style_grad_generator}"
122
+ end
123
+ end
70
124
  end
71
125
  end
@@ -1,4 +1,4 @@
1
1
  module Chainer
2
- VERSION = "0.3.2"
2
+ VERSION = "0.4.0"
3
3
  end
4
4
 
data/red-chainer.gemspec CHANGED
@@ -20,9 +20,10 @@ Gem::Specification.new do |spec|
20
20
  spec.require_paths = ["lib"]
21
21
 
22
22
  spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
23
- spec.add_runtime_dependency "red-datasets", ">= 0.0.5"
23
+ spec.add_runtime_dependency "red-datasets", ">= 0.0.6"
24
24
 
25
- spec.add_development_dependency "bundler", "~> 1.15"
25
+ spec.add_development_dependency "bundler"
26
26
  spec.add_development_dependency "rake", "~> 10.0"
27
- spec.add_development_dependency "test-unit"
27
+ spec.add_development_dependency "test-unit", ">= 3.2.9"
28
+ spec.add_development_dependency "yard", ">= 0.9.10"
28
29
  end
@@ -0,0 +1,40 @@
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <%= erb(:headers) %>
5
+
6
+ <!-- Additional settings for MathJax are from here. -->
7
+ <script type="text/x-mathjax-config">
8
+ MathJax.Hub.Config({
9
+ tex2jax:{
10
+ inlineMath: [ ['$','$'], ["\\(","\\)"] ],
11
+ displayMath: [ ['$$','$$'], ["\\[","\\]"] ]
12
+ }
13
+ });
14
+ </script>
15
+ <script type="text/javascript"
16
+ src="http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML">
17
+ </script>
18
+ <meta http-equiv="X-UA-Compatible" CONTENT="IE=EmulateIE7" />
19
+ <!-- Additional settings for MathJax are over here. -->
20
+
21
+ </head>
22
+ <body>
23
+ <div class="nav_wrap">
24
+ <iframe id="nav" src="<%= @nav_url %>?1"></iframe>
25
+ <div id="resizer"></div>
26
+ </div>
27
+
28
+ <div id="main" tabindex="-1">
29
+ <div id="header">
30
+ <%= erb(:breadcrumb) %>
31
+ <%= erb(:search) %>
32
+ <div class="clear"></div>
33
+ </div>
34
+
35
+ <div id="content"><%= yieldall %></div>
36
+
37
+ <%= erb(:footer) %>
38
+ </div>
39
+ </body>
40
+ </html>
@@ -0,0 +1,33 @@
1
+ <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
2
+ "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
3
+ <html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en">
4
+ <head>
5
+ <meta http-equiv="Content-Type" content="text/html; charset=<%= charset %>" />
6
+ <title><%= defined?(@title) ? @title : '' %></title>
7
+ <%= erb(:headers) %>
8
+
9
+ <!-- Additional settings for MathJax are from here. -->
10
+ <script type="text/x-mathjax-config">
11
+ MathJax.Hub.Config({
12
+ tex2jax:{
13
+ inlineMath: [ ['$','$'], ["\\(","\\)"] ],
14
+ displayMath: [ ['$$','$$'], ["\\[","\\]"] ]
15
+ }
16
+ });
17
+ </script>
18
+ <script type="text/javascript"
19
+ src="http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML">
20
+ </script>
21
+ <meta http-equiv="X-UA-Compatible" CONTENT="IE=EmulateIE7" />
22
+ <!-- Additional settings for MathJax are over here. -->
23
+
24
+ </head>
25
+ <body>
26
+ <div id="content">
27
+ <h1><%= defined?(@title) ? @title : '' %></h1>
28
+ <%= yieldall %>
29
+ </div>
30
+
31
+ <%= erb(:footer) %>
32
+ </body>
33
+ </html>
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-chainer
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.2
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Yusaku Hatanaka
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-06-27 00:00:00.000000000 Z
11
+ date: 2019-03-28 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,28 +30,28 @@ dependencies:
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: 0.0.5
33
+ version: 0.0.6
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: 0.0.5
40
+ version: 0.0.6
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: bundler
43
43
  requirement: !ruby/object:Gem::Requirement
44
44
  requirements:
45
- - - "~>"
45
+ - - ">="
46
46
  - !ruby/object:Gem::Version
47
- version: '1.15'
47
+ version: '0'
48
48
  type: :development
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
- - - "~>"
52
+ - - ">="
53
53
  - !ruby/object:Gem::Version
54
- version: '1.15'
54
+ version: '0'
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: rake
57
57
  requirement: !ruby/object:Gem::Requirement
@@ -72,14 +72,28 @@ dependencies:
72
72
  requirements:
73
73
  - - ">="
74
74
  - !ruby/object:Gem::Version
75
- version: '0'
75
+ version: 3.2.9
76
76
  type: :development
77
77
  prerelease: false
78
78
  version_requirements: !ruby/object:Gem::Requirement
79
79
  requirements:
80
80
  - - ">="
81
81
  - !ruby/object:Gem::Version
82
- version: '0'
82
+ version: 3.2.9
83
+ - !ruby/object:Gem::Dependency
84
+ name: yard
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: 0.9.10
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: 0.9.10
83
97
  description: ''
84
98
  email:
85
99
  - hatappi@hatappi.me
@@ -90,6 +104,7 @@ files:
90
104
  - ".gitignore"
91
105
  - ".rspec"
92
106
  - ".travis.yml"
107
+ - ".yardopts"
93
108
  - Gemfile
94
109
  - LICENSE.txt
95
110
  - README.md
@@ -102,27 +117,42 @@ files:
102
117
  - examples/iris/iris.rb
103
118
  - examples/mnist/mnist.rb
104
119
  - lib/chainer.rb
120
+ - lib/chainer/backend.rb
105
121
  - lib/chainer/configuration.rb
106
122
  - lib/chainer/cuda.rb
107
123
  - lib/chainer/dataset/convert.rb
108
- - lib/chainer/dataset/download.rb
109
124
  - lib/chainer/dataset/iterator.rb
110
125
  - lib/chainer/datasets/cifar.rb
111
126
  - lib/chainer/datasets/mnist.rb
112
127
  - lib/chainer/datasets/tuple_dataset.rb
128
+ - lib/chainer/device.rb
113
129
  - lib/chainer/function.rb
130
+ - lib/chainer/function_node.rb
114
131
  - lib/chainer/functions/activation/leaky_relu.rb
115
132
  - lib/chainer/functions/activation/log_softmax.rb
116
133
  - lib/chainer/functions/activation/relu.rb
134
+ - lib/chainer/functions/activation/relu_grad2.rb
117
135
  - lib/chainer/functions/activation/sigmoid.rb
136
+ - lib/chainer/functions/activation/sigmoid_grad.rb
118
137
  - lib/chainer/functions/activation/tanh.rb
138
+ - lib/chainer/functions/array/broadcast_to.rb
139
+ - lib/chainer/functions/array/cast.rb
140
+ - lib/chainer/functions/array/reshape.rb
141
+ - lib/chainer/functions/array/rollaxis.rb
142
+ - lib/chainer/functions/array/select_item.rb
143
+ - lib/chainer/functions/array/squeeze.rb
144
+ - lib/chainer/functions/array/transpose.rb
119
145
  - lib/chainer/functions/connection/convolution_2d.rb
146
+ - lib/chainer/functions/connection/convolution_2d_grad_w.rb
147
+ - lib/chainer/functions/connection/deconvolution_2d.rb
120
148
  - lib/chainer/functions/connection/linear.rb
121
149
  - lib/chainer/functions/evaluation/accuracy.rb
122
150
  - lib/chainer/functions/loss/mean_squared_error.rb
123
151
  - lib/chainer/functions/loss/softmax_cross_entropy.rb
124
152
  - lib/chainer/functions/math/basic_math.rb
153
+ - lib/chainer/functions/math/exp.rb
125
154
  - lib/chainer/functions/math/identity.rb
155
+ - lib/chainer/functions/math/sum.rb
126
156
  - lib/chainer/functions/noise/dropout.rb
127
157
  - lib/chainer/functions/normalization/batch_normalization.rb
128
158
  - lib/chainer/functions/pooling/average_pooling_2d.rb
@@ -135,6 +165,7 @@ files:
135
165
  - lib/chainer/initializers/constant.rb
136
166
  - lib/chainer/initializers/init.rb
137
167
  - lib/chainer/initializers/normal.rb
168
+ - lib/chainer/initializers/uniform.rb
138
169
  - lib/chainer/iterators/serial_iterator.rb
139
170
  - lib/chainer/link.rb
140
171
  - lib/chainer/links/connection/convolution_2d.rb
@@ -170,6 +201,8 @@ files:
170
201
  - lib/chainer/variable_node.rb
171
202
  - lib/chainer/version.rb
172
203
  - red-chainer.gemspec
204
+ - templates/default/layout/html/layout.erb
205
+ - templates/default/onefile/html/layout.erb
173
206
  homepage: https://github.com/red-data-tools/red-chainer
174
207
  licenses:
175
208
  - MIT
@@ -1,56 +0,0 @@
1
- require "open-uri"
2
- require "pstore"
3
-
4
- module Chainer
5
- module Dataset
6
- module Download
7
- DATASET_ROOT = ENV.fetch("RED_CHAINER_DATASET_ROOT", File.expand_path(".red-chainer/dataset", "~"))
8
-
9
- def self.cached_download(url)
10
- cache_root = File.expand_path('_dl_cache', DATASET_ROOT)
11
- FileUtils.mkdir_p(cache_root)
12
- lock_path = File.expand_path('_dl_lock', cache_root)
13
- urlhash = Digest::MD5.hexdigest(url)
14
- cache_path = File.expand_path(urlhash, cache_root)
15
-
16
- return cache_path if File.exist?(cache_path)
17
-
18
- temp_root = Dir.mktmpdir(nil, cache_root)
19
- temp_path = File.expand_path('dl', temp_root)
20
- open(url) do |f|
21
- puts "Downloading from #{url}"
22
- open(temp_path, "w+b") do |out|
23
- out.write(f.read)
24
- end
25
- FileUtils.mv(temp_path, cache_path)
26
- FileUtils.rm_r(temp_root)
27
- end
28
- cache_path
29
- end
30
-
31
- def self.get_dataset_directory(dataset_name, create_directory: true)
32
- path = File.expand_path(dataset_name, DATASET_ROOT)
33
- FileUtils.mkdir_p(path) if create_directory
34
- path
35
- end
36
-
37
- def self.cache_or_load_file(path, &creator)
38
- raise 'Please set dataset creator on block' if creator.nil?
39
-
40
- return PStore.new(path).transaction { |t| t['data'] } if File.exist?(path)
41
-
42
- data = creator.call
43
- PStore.new(path).transaction do |t|
44
- t['data'] = data
45
- end
46
- data
47
- rescue TypeError => e
48
- puts e.message
49
- FileUtils.rm_f(path)
50
- cache_or_load_file(path) do
51
- creator.call
52
- end
53
- end
54
- end
55
- end
56
- end