red-chainer 0.2.1 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (52) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -2
  3. data/examples/cifar/models/vgg.rb +84 -0
  4. data/examples/cifar/train_cifar.rb +70 -0
  5. data/examples/iris.rb +103 -0
  6. data/lib/chainer.rb +17 -0
  7. data/lib/chainer/configuration.rb +2 -1
  8. data/lib/chainer/cuda.rb +18 -0
  9. data/lib/chainer/dataset/convert.rb +30 -9
  10. data/lib/chainer/datasets/cifar.rb +56 -0
  11. data/lib/chainer/datasets/mnist.rb +3 -3
  12. data/lib/chainer/datasets/tuple_dataset.rb +3 -1
  13. data/lib/chainer/function.rb +1 -0
  14. data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
  15. data/lib/chainer/functions/activation/log_softmax.rb +4 -4
  16. data/lib/chainer/functions/activation/relu.rb +3 -4
  17. data/lib/chainer/functions/activation/sigmoid.rb +4 -4
  18. data/lib/chainer/functions/activation/tanh.rb +5 -5
  19. data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
  20. data/lib/chainer/functions/connection/linear.rb +1 -1
  21. data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
  22. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
  23. data/lib/chainer/functions/math/identity.rb +26 -0
  24. data/lib/chainer/functions/noise/dropout.rb +45 -0
  25. data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
  26. data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
  27. data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
  28. data/lib/chainer/gradient_check.rb +240 -0
  29. data/lib/chainer/initializer.rb +2 -0
  30. data/lib/chainer/initializers/constant.rb +1 -1
  31. data/lib/chainer/initializers/init.rb +5 -1
  32. data/lib/chainer/initializers/normal.rb +1 -1
  33. data/lib/chainer/iterators/serial_iterator.rb +1 -1
  34. data/lib/chainer/link.rb +11 -0
  35. data/lib/chainer/links/connection/convolution_2d.rb +98 -0
  36. data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
  37. data/lib/chainer/optimizer.rb +40 -1
  38. data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
  39. data/lib/chainer/parameter.rb +1 -1
  40. data/lib/chainer/serializers/marshal.rb +7 -3
  41. data/lib/chainer/testing/array.rb +32 -0
  42. data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
  43. data/lib/chainer/training/extensions/snapshot.rb +1 -1
  44. data/lib/chainer/training/standard_updater.rb +4 -0
  45. data/lib/chainer/training/trainer.rb +1 -1
  46. data/lib/chainer/utils/array.rb +13 -2
  47. data/lib/chainer/utils/conv.rb +59 -0
  48. data/lib/chainer/utils/math.rb +72 -0
  49. data/lib/chainer/utils/variable.rb +7 -3
  50. data/lib/chainer/version.rb +1 -1
  51. data/red-chainer.gemspec +1 -0
  52. metadata +37 -3
@@ -0,0 +1,78 @@
1
+ module Chainer
2
+ module Training
3
+ module Extensions
4
+ # Trainer extension to exponentially shift an optimizer attribute.
5
+ #
6
+ # This extension exponentially increases or decreases the specified attribute of the optimizer.
7
+ # The typical use case is an exponential decay of the learning rate.
8
+ # This extension is also called before the training loop starts by default.
9
+ class ExponentialShift < Extension
10
+ attr_reader :last_value
11
+
12
+ # @param [string] attr Name of the attribute to shift
13
+ # @param [float] rate Rate of the exponential shift.
14
+ # @param [float] init Initial value of the attribute.
15
+ # @param [float] target Target value of the attribute.
16
+ # @param [Chainer::Optimizer] optimizer Target optimizer to adjust the attribute.
17
+ def initialize(attr, rate, init: nil, target: nil, optimizer: nil)
18
+ @attr = attr
19
+ raise 'ExponentialShift does not support negative rate' if rate < 0
20
+ @rate = rate
21
+ @init = init
22
+ @target = target
23
+ @optimizer = optimizer
24
+ @t = 0
25
+ @last_value = nil
26
+ end
27
+
28
+ def init(trainer)
29
+ optimizer = get_optimizer(trainer)
30
+ @init = optimizer.send(@attr) if @init.nil?
31
+ if @last_value.nil?
32
+ update_value(optimizer, @init)
33
+ else
34
+ update_value(optimizer, @last_value)
35
+ end
36
+ end
37
+
38
+ def call(trainer)
39
+ @t += 1
40
+
41
+ optimizer = get_optimizer(trainer)
42
+ value = @init * (@rate ** @t)
43
+ if @target
44
+ if @rate > 1
45
+ if value / @target > 1
46
+ value = @target
47
+ end
48
+ else
49
+ if value / @target < 1
50
+ value = @target
51
+ end
52
+ end
53
+ end
54
+ update_value(optimizer, value)
55
+ end
56
+
57
+ def serialize(serializer)
58
+ @t = serializer.('t', @t)
59
+ @last_value = serializer.('last_value', @last_value)
60
+ if @last_value.is_a?(Numo::NArray)
61
+ @last_value = @last_value[0]
62
+ end
63
+ end
64
+
65
+ private
66
+
67
+ def get_optimizer(trainer)
68
+ @optimizer || trainer.updater.get_optimizer(:main)
69
+ end
70
+
71
+ def update_value(optimizer, value)
72
+ optimizer.send("#{@attr}=", value)
73
+ @last_value = value
74
+ end
75
+ end
76
+ end
77
+ end
78
+ end
@@ -23,7 +23,7 @@ module Chainer
23
23
  filename = filename_proc.call(trainer)
24
24
  prefix = "tmp#{filename}"
25
25
  temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
26
- save_class.save_file(temp_file, trainer)
26
+ save_class.save_file(temp_file.path, trainer)
27
27
  FileUtils.move(temp_file.path, File.join(trainer.out, filename))
28
28
  end
29
29
  end
@@ -20,6 +20,10 @@ module Chainer
20
20
  @iteration = 0
21
21
  end
22
22
 
23
+ def get_optimizer(name)
24
+ @optimizers[name]
25
+ end
26
+
23
27
  def get_all_optimizers
24
28
  @optimizers.to_h
25
29
  end
@@ -54,7 +54,7 @@ module Chainer
54
54
  elsif extension.default_name
55
55
  extension.default_name
56
56
  else
57
- raise ArgumentError 'name is not given for the extension'
57
+ raise ArgumentError, 'name is not given for the extension'
58
58
  end
59
59
  end
60
60
 
@@ -2,8 +2,19 @@ module Chainer
2
2
  module Utils
3
3
  module Array
4
4
  def self.force_array(x, dtype=nil)
5
- # TODO: conversion by dtype
6
- Numo::NArray.[](*x)
5
+ if x.is_a? Integer or x.is_a? Float
6
+ if dtype.nil?
7
+ Numo::NArray.cast(x)
8
+ else
9
+ dtype.cast(x.dup)
10
+ end
11
+ else
12
+ if dtype.nil?
13
+ x
14
+ else
15
+ dtype.cast(x)
16
+ end
17
+ end
7
18
  end
8
19
  end
9
20
  end
@@ -0,0 +1,59 @@
1
+ module Chainer
2
+ module Utils
3
+ module Conv
4
+ def self.get_conv_outsize(size, k, s, p, cover_all: false, d: 1)
5
+ dk = k + (k - 1) * (d - 1)
6
+ if cover_all
7
+ (size + p * 2 - dk + s - 1).div(s) + 1
8
+ else
9
+ (size + p * 2 - dk).div(s) + 1
10
+ end
11
+ end
12
+
13
+ def self.im2col_cpu(img, kh, kw, sy, sx, ph, pw, pval: 0, cover_all: false, dy: 1, dx: 1)
14
+ n, c, h, w = img.shape
15
+
16
+ out_h = self.get_conv_outsize(h, kh, sy, ph, cover_all: cover_all, d: dy)
17
+ raise 'Height in the output should be positive.' if out_h <= 0
18
+ out_w = self.get_conv_outsize(w, kw, sx, pw, cover_all: cover_all, d: dx)
19
+ raise 'Width in the output should be positive.' if out_w <= 0
20
+
21
+ # padding
22
+ # TODO: ref: numpy.pad
23
+ pad_bottom = ph + sy - 1
24
+ pad_right = pw + sx - 1
25
+ pad_img = img.class.new(n, c, (h + ph + pad_bottom), (w + pw + pad_right)).fill(pval)
26
+ pad_img[nil, nil, ph...(ph+h), pw...(pw+w)] = img
27
+
28
+ col = pad_img.class.new(n, c, kh, kw, out_h, out_w).rand(1)
29
+
30
+ kh.times do |j|
31
+ jdy = j * dy
32
+ j_lim = [jdy + sy * out_h, pad_img.shape[2]].min
33
+ kw.times do |i|
34
+ idx = i * dx
35
+ i_lim = [idx + sx * out_w, pad_img.shape[3]].min
36
+ col[nil, nil, j, i, nil, nil] = pad_img[nil, nil, (jdy...j_lim).step(sy), (idx...i_lim).step(sx)]
37
+ end
38
+ end
39
+
40
+ col
41
+ end
42
+
43
+ def self.col2im_cpu(col, sy, sx, ph, pw, h, w, dy: 1, dx: 1)
44
+ n, c, kh, kw, out_h, out_w = col.shape
45
+ img = col.class.zeros(n, c, h + 2 * ph + sy - 1, w + 2 * pw + sx - 1)
46
+ kh.times do |j|
47
+ jdy = j * dy
48
+ j_lim = [jdy + sy * out_h, img.shape[2]].min
49
+ kw.times do |i|
50
+ idx = i * dx
51
+ i_lim = [idx + sx * out_w, img.shape[3]].min
52
+ img[nil, nil, (jdy...j_lim).step(sy), (idx...i_lim).step(sx)] += col[nil, nil, j, i, nil, nil]
53
+ end
54
+ end
55
+ return img[nil, nil, (ph...(h + ph)), (pw...(w + pw))]
56
+ end
57
+ end
58
+ end
59
+ end
@@ -0,0 +1,72 @@
1
+ module Chainer
2
+ module Utils
3
+ module Math
4
+ def self.tensordot(a, b, axes)
5
+ if axes.is_a?(Integer)
6
+ axes_a = (-axes...0).to_a
7
+ axes_b = (0...axes).to_a
8
+ else axes.is_a?(Array)
9
+ axes_a, axes_b = axes
10
+ end
11
+
12
+ axes_a = Array(axes_a)
13
+ axes_b = Array(axes_b)
14
+ na = axes_a.size
15
+ nb = axes_b.size
16
+
17
+ as = a.shape
18
+ nda = a.ndim
19
+ bs = b.shape
20
+ ndb = b.ndim
21
+ equal = true
22
+ if na != nb
23
+ equal = false
24
+ else
25
+ na.times do |k|
26
+ if as[axes_a[k]] != bs[axes_b[k]]
27
+ equal = false
28
+ break
29
+ end
30
+
31
+ if axes_a[k] < 0
32
+ axes_a[k] += nda
33
+ end
34
+
35
+ if axes_b[k] < 0
36
+ axes_b[k] += ndb
37
+ end
38
+ end
39
+ end
40
+
41
+ raise "shape-mismatch for sum" unless equal
42
+
43
+ notin = (0...nda).reject { |i| axes_a.include?(i) }
44
+ newaxes_a = notin + axes_a
45
+ n2 = 1
46
+ axes_a.each do |axis|
47
+ n2 *= as[axis]
48
+ end
49
+ tmp = a.shape.reduce(:*) / n2
50
+ newshape_a = [tmp, n2]
51
+ olda = notin.map { |axis| as[axis] }
52
+
53
+ notin = (0...ndb).reject { |i| axes_b.include?(i) }
54
+ newaxes_b = axes_b + notin
55
+ n2 = 1
56
+ axes_b.each do |axis|
57
+ n2 *= bs[axis]
58
+ end
59
+ tmp = b.shape.reduce(:*) / n2
60
+ newshape_b = [n2, tmp]
61
+ oldb = notin.map { |axis| bs[axis] }
62
+
63
+ at = a.transpose(*newaxes_a).reshape(*newshape_a)
64
+ bt = b.transpose(*newaxes_b).reshape(*newshape_b)
65
+ res = at.dot(bt)
66
+
67
+ return res.reshape(*(olda + oldb))
68
+ end
69
+ end
70
+ end
71
+ end
72
+
@@ -6,12 +6,16 @@ module Chainer
6
6
  return
7
7
  end
8
8
 
9
- unless gx.instance_of?(x.data.class)
10
- raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}"
9
+ unless gx.is_a?(x.data.class.superclass)
10
+ raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}"
11
+ end
12
+
13
+ unless gx.class == x.data.class
14
+ raise TypeError, "Dtype(Class) of data and grad mismatch\n#{x.data.class} != #{gx.class}"
11
15
  end
12
16
 
13
17
  unless gx.shape == x.data.shape
14
- raise TypeError, "Shape of data and grad mismatch\n#{x.class} != #{gx.class}"
18
+ raise TypeError, "Shape of data and grad mismatch\n#{x.data.shape} != #{gx.shape}"
15
19
  end
16
20
  end
17
21
  end
@@ -1,4 +1,4 @@
1
1
  module Chainer
2
- VERSION = "0.2.1"
2
+ VERSION = "0.3.0"
3
3
  end
4
4
 
data/red-chainer.gemspec CHANGED
@@ -20,6 +20,7 @@ Gem::Specification.new do |spec|
20
20
  spec.require_paths = ["lib"]
21
21
 
22
22
  spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
23
+ spec.add_runtime_dependency "red-datasets"
23
24
 
24
25
  spec.add_development_dependency "bundler", "~> 1.15"
25
26
  spec.add_development_dependency "rake", "~> 10.0"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-chainer
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Yusaku Hatanaka
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-03-01 00:00:00.000000000 Z
11
+ date: 2018-05-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -24,6 +24,20 @@ dependencies:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
26
  version: 0.9.1.1
27
+ - !ruby/object:Gem::Dependency
28
+ name: red-datasets
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
27
41
  - !ruby/object:Gem::Dependency
28
42
  name: bundler
29
43
  requirement: !ruby/object:Gem::Requirement
@@ -82,12 +96,17 @@ files:
82
96
  - Rakefile
83
97
  - bin/console
84
98
  - bin/setup
99
+ - examples/cifar/models/vgg.rb
100
+ - examples/cifar/train_cifar.rb
101
+ - examples/iris.rb
85
102
  - examples/mnist.rb
86
103
  - lib/chainer.rb
87
104
  - lib/chainer/configuration.rb
105
+ - lib/chainer/cuda.rb
88
106
  - lib/chainer/dataset/convert.rb
89
107
  - lib/chainer/dataset/download.rb
90
108
  - lib/chainer/dataset/iterator.rb
109
+ - lib/chainer/datasets/cifar.rb
91
110
  - lib/chainer/datasets/mnist.rb
92
111
  - lib/chainer/datasets/tuple_dataset.rb
93
112
  - lib/chainer/function.rb
@@ -96,10 +115,18 @@ files:
96
115
  - lib/chainer/functions/activation/relu.rb
97
116
  - lib/chainer/functions/activation/sigmoid.rb
98
117
  - lib/chainer/functions/activation/tanh.rb
118
+ - lib/chainer/functions/connection/convolution_2d.rb
99
119
  - lib/chainer/functions/connection/linear.rb
100
120
  - lib/chainer/functions/evaluation/accuracy.rb
121
+ - lib/chainer/functions/loss/mean_squared_error.rb
101
122
  - lib/chainer/functions/loss/softmax_cross_entropy.rb
102
123
  - lib/chainer/functions/math/basic_math.rb
124
+ - lib/chainer/functions/math/identity.rb
125
+ - lib/chainer/functions/noise/dropout.rb
126
+ - lib/chainer/functions/normalization/batch_normalization.rb
127
+ - lib/chainer/functions/pooling/max_pooling_2d.rb
128
+ - lib/chainer/functions/pooling/pooling_2d.rb
129
+ - lib/chainer/gradient_check.rb
103
130
  - lib/chainer/gradient_method.rb
104
131
  - lib/chainer/hyperparameter.rb
105
132
  - lib/chainer/initializer.rb
@@ -108,16 +135,21 @@ files:
108
135
  - lib/chainer/initializers/normal.rb
109
136
  - lib/chainer/iterators/serial_iterator.rb
110
137
  - lib/chainer/link.rb
138
+ - lib/chainer/links/connection/convolution_2d.rb
111
139
  - lib/chainer/links/connection/linear.rb
112
140
  - lib/chainer/links/model/classifier.rb
141
+ - lib/chainer/links/normalization/batch_normalization.rb
113
142
  - lib/chainer/optimizer.rb
114
143
  - lib/chainer/optimizers/adam.rb
144
+ - lib/chainer/optimizers/momentum_sgd.rb
115
145
  - lib/chainer/parameter.rb
116
146
  - lib/chainer/reporter.rb
117
147
  - lib/chainer/serializer.rb
118
148
  - lib/chainer/serializers/marshal.rb
149
+ - lib/chainer/testing/array.rb
119
150
  - lib/chainer/training/extension.rb
120
151
  - lib/chainer/training/extensions/evaluator.rb
152
+ - lib/chainer/training/extensions/exponential_shift.rb
121
153
  - lib/chainer/training/extensions/log_report.rb
122
154
  - lib/chainer/training/extensions/print_report.rb
123
155
  - lib/chainer/training/extensions/progress_bar.rb
@@ -128,7 +160,9 @@ files:
128
160
  - lib/chainer/training/updater.rb
129
161
  - lib/chainer/training/util.rb
130
162
  - lib/chainer/utils/array.rb
163
+ - lib/chainer/utils/conv.rb
131
164
  - lib/chainer/utils/initializer.rb
165
+ - lib/chainer/utils/math.rb
132
166
  - lib/chainer/utils/variable.rb
133
167
  - lib/chainer/variable.rb
134
168
  - lib/chainer/variable_node.rb
@@ -154,7 +188,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
154
188
  version: '0'
155
189
  requirements: []
156
190
  rubyforge_project:
157
- rubygems_version: 2.7.3
191
+ rubygems_version: 2.7.6
158
192
  signing_key:
159
193
  specification_version: 4
160
194
  summary: A flexible framework for neural network for Ruby