ruby-dnn 0.16.1 → 0.16.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2e04eef16303e2e223bfdb1ac17e2a35fc20d0eafff022f71e412d8b41dc49c9
4
- data.tar.gz: d0cd83a209069eba6ea2c48fc6c3480e5edf5b99932b31cb9d7a73adb0eb573d
3
+ metadata.gz: b76b0eb5bf75a22e48726f93fa4faff413c196b1f1587ce145aba7042c84a532
4
+ data.tar.gz: f32d09cb89391583f51a557e5f72024f618d943af3becc25f8f19905fa395b3b
5
5
  SHA512:
6
- metadata.gz: 2b7126723c81495c603a1b98fc64fdd111d4853db78e4f0c840f75e925d60d97a3dd6ba22a9a9e2a8f643562a1284bfe502e38e0e9e4870c7699f16ed151df40
7
- data.tar.gz: c1c8900e92ec77a03f46853a9ff14d4b09c8442a597f6ae87d1ee049764a3316e1b7ff92cdb4c13c4b2787d2759ff5c8f6c28c11668503d1d0062a2f671a04b7
6
+ metadata.gz: f70784c49f71420df424c2077b430c0cb837b31b5493f62831cf49b296d834045796f2bae7749fa3b4bfb3ca365a97da856ccbebcc666c1489020eb55ec408fc
7
+ data.tar.gz: 52fd5cf850341b3d2bdd44da35f69254e508b0dd77e052a7217fe250a8d41a00d237b5ce0cb379e652bf2e6b7d0b99e95b3530ef5abe2d43a2b8f8c8f7dbca7e
@@ -0,0 +1,118 @@
1
+ require "dnn"
2
+ require "dnn/datasets/mnist"
3
+ require "dnn/image"
4
+ require "numo/linalg/autoloader"
5
+
6
+ include DNN::Models
7
+ include DNN::Layers
8
+ include DNN::Optimizers
9
+ include DNN::Losses
10
+
11
+ x_train, y_train = DNN::MNIST.load_train
12
+ x_test, y_test = DNN::MNIST.load_test
13
+
14
+ x_train = Numo::SFloat.cast(x_train).reshape(x_train.shape[0], 784)
15
+ x_test = Numo::SFloat.cast(x_test).reshape(x_test.shape[0], 784)
16
+
17
+ x_train /= 255
18
+ x_test /= 255
19
+
20
+ y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
21
+ y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
22
+
23
+ $z_dim = 2
24
+ $z_mean = nil
25
+ $z_sigma = nil
26
+
27
+ class Sampling < MergeLayer
28
+ def forward(z_mean, z_sigma)
29
+ epsilon = DNN::Tensor.new(Numo::SFloat.new($z_dim).rand_norm(0, 1))
30
+ Tanh.(z_mean + z_sigma * epsilon)
31
+ end
32
+ end
33
+
34
+ class Encoder < Model
35
+ def initialize
36
+ super
37
+ @l1 = Dense.new(196)
38
+ @l2 = Dense.new(49)
39
+ @l3_1 = Dense.new($z_dim)
40
+ @l3_2 = Dense.new($z_dim)
41
+ @bn1 = BatchNormalization.new
42
+ @bn2 = BatchNormalization.new
43
+ end
44
+
45
+ def forward(x)
46
+ x = InputLayer.new(784).(x)
47
+ x = @l1.(x)
48
+ x = @bn1.(x)
49
+ x = ReLU.(x)
50
+ x = @l2.(x)
51
+ x = @bn2.(x)
52
+ x = ReLU.(x)
53
+ z_mean = @l3_1.(x)
54
+ z_sigma = @l3_2.(x)
55
+ [z_mean, z_sigma]
56
+ end
57
+ end
58
+
59
+ class Decoder < Model
60
+ def initialize
61
+ super
62
+ @l3 = Dense.new(196)
63
+ @l4 = Dense.new(784)
64
+ @bn1 = BatchNormalization.new
65
+ end
66
+
67
+ def forward(z)
68
+ x = @l3.(z)
69
+ x = @bn1.(x)
70
+ x = ReLU.(x)
71
+ x = @l4.(x)
72
+ x
73
+ end
74
+ end
75
+
76
+ class VAE < Model
77
+ attr_accessor :enc
78
+ attr_accessor :dec
79
+
80
+ def initialize(enc = nil, dec = nil)
81
+ super()
82
+ @enc = enc || Encoder.new
83
+ @dec = dec || Decoder.new
84
+ end
85
+
86
+ def forward(x)
87
+ z_mean, z_sigma = @enc.(x)
88
+ $z_mean, $z_sigma = z_mean, z_sigma
89
+ z = Sampling.(z_mean, z_sigma)
90
+ x = @dec.(z)
91
+ x
92
+ end
93
+ end
94
+
95
+ class VAELoss < Loss
96
+ def forward(y, t)
97
+ kl = -0.5 * Mean.(Sum.(1 + Log.($z_sigma**2) - $z_mean**2 - $z_sigma**2, axis: 1), axis: 0)
98
+ SigmoidCrossEntropy.(y, t) + kl
99
+ end
100
+ end
101
+
102
+ model = VAE.new
103
+ dec = model.dec
104
+ model.setup(Adam.new, VAELoss.new)
105
+
106
+ model.train(x_train, x_train, 10, batch_size: 100)
107
+
108
+ images = []
109
+ 10.times do |i|
110
+ 10.times do |j|
111
+ z1 = (i / 4.5) - 1
112
+ z2 = (j / 4.5) - 1
113
+ z = Numo::SFloat[z1, z2]
114
+ out = DNN::Utils.sigmoid(dec.predict1(z))
115
+ img = Numo::UInt8.cast(out * 255).reshape(28, 28, 1)
116
+ DNN::Image.write("img/img_#{i}_#{j}.png", img)
117
+ end
118
+ end
@@ -2,6 +2,8 @@ module DNN
2
2
  module Layers
3
3
 
4
4
  class Add < MergeLayer
5
+ include MergeLayerNode
6
+
5
7
  def forward_node(x1, x2)
6
8
  x1 + x2
7
9
  end
@@ -12,6 +14,8 @@ module DNN
12
14
  end
13
15
 
14
16
  class Sub < MergeLayer
17
+ include MergeLayerNode
18
+
15
19
  def forward_node(x1, x2)
16
20
  x1 - x2
17
21
  end
@@ -22,6 +26,8 @@ module DNN
22
26
  end
23
27
 
24
28
  class Mul < MergeLayer
29
+ include MergeLayerNode
30
+
25
31
  def forward_node(x1, x2)
26
32
  @x1, @x2 = x1, x2
27
33
  x1 * x2
@@ -33,6 +39,8 @@ module DNN
33
39
  end
34
40
 
35
41
  class Div < MergeLayer
42
+ include MergeLayerNode
43
+
36
44
  def forward_node(x1, x2)
37
45
  @x1, @x2 = x1, x2
38
46
  x1 / x2
@@ -46,6 +54,8 @@ module DNN
46
54
  end
47
55
 
48
56
  class Dot < MergeLayer
57
+ include MergeLayerNode
58
+
49
59
  def forward_node(x1, x2)
50
60
  @x1, @x2 = x1, x2
51
61
  x1.dot(x2)
@@ -60,12 +70,11 @@ module DNN
60
70
  include LayerNode
61
71
 
62
72
  def forward_node(x)
63
- @x = x
64
- Xumo::NMath.exp(x)
73
+ @y = Xumo::NMath.exp(x)
65
74
  end
66
75
 
67
76
  def backward_node(dy)
68
- dy * Xumo::NMath.exp(@x)
77
+ dy * @y
69
78
  end
70
79
  end
71
80
 
@@ -122,20 +131,16 @@ module DNN
122
131
  end
123
132
 
124
133
  def forward_node(x)
125
- if @axis
126
- @dim = x.shape[@axis]
127
- x.sum(axis: @axis, keepdims: true)
128
- else
129
- x.sum
130
- end
134
+ @x_shape = x.shape
135
+ @dim = x.shape[@axis]
136
+ x.sum(axis: @axis, keepdims: true)
131
137
  end
132
138
 
133
139
  def backward_node(dy)
134
- dx = dy.clone
135
- if @axis
136
- (@dim - 1).times do
137
- dx = dx.concatenate(dy, axis: @axis)
138
- end
140
+ return dy if @x_shape == dy.shape
141
+ dx = dy
142
+ (@dim - 1).times do
143
+ dx = dx.concatenate(dy, axis: @axis)
139
144
  end
140
145
  dx
141
146
  end
@@ -150,16 +155,16 @@ module DNN
150
155
  end
151
156
 
152
157
  def forward_node(x)
153
- @dim = @axis ? x.shape[@axis] : x.size
158
+ @x_shape = x.shape
159
+ @dim = x.shape[@axis]
154
160
  x.mean(axis: @axis, keepdims: true)
155
161
  end
156
162
 
157
163
  def backward_node(dy)
164
+ return dy / @dim if @x_shape == dy.shape
158
165
  dx = dy
159
- if @axis
160
- (@dim - 1).times do
161
- dx = dx.concatenate(dy, axis: @axis)
162
- end
166
+ (@dim - 1).times do
167
+ dx = dx.concatenate(dy, axis: @axis)
163
168
  end
164
169
  dx / @dim
165
170
  end
@@ -26,8 +26,6 @@ module DNN
26
26
  end
27
27
 
28
28
  class MergeLayer < Layers::Layer
29
- include MergeLayerNode
30
-
31
29
  def self.call(x1, x2, *args)
32
30
  new(*args).call(x1, x2)
33
31
  end
@@ -45,6 +43,8 @@ module DNN
45
43
  end
46
44
 
47
45
  class Concatenate < MergeLayer
46
+ include MergeLayerNode
47
+
48
48
  attr_reader :axis
49
49
 
50
50
  def initialize(axis: 1)
@@ -75,8 +75,8 @@ module DNN
75
75
  0.5 * ((y - t)**2).mean(0).sum
76
76
  end
77
77
 
78
- def backward_node(dy)
79
- @y - @t
78
+ def backward_node(d)
79
+ (@y - @t) / @y.shape[0]
80
80
  end
81
81
  end
82
82
 
@@ -90,10 +90,10 @@ module DNN
90
90
  end
91
91
 
92
92
  def backward_node(d)
93
- dy = @y - @t
93
+ dy = (@y - @t)
94
94
  dy[dy >= 0] = 1
95
95
  dy[dy < 0] = -1
96
- dy
96
+ dy / @y.shape[0]
97
97
  end
98
98
  end
99
99
 
@@ -109,7 +109,7 @@ module DNN
109
109
  def backward_node(d)
110
110
  a = Xumo::SFloat.ones(*@a.shape)
111
111
  a[@a <= 0] = 0
112
- a * -@t
112
+ (a * -@t) / a.shape[0]
113
113
  end
114
114
  end
115
115
 
@@ -124,12 +124,12 @@ module DNN
124
124
  end
125
125
 
126
126
  def backward_node(d)
127
- dy = @y - @t
127
+ dy = (@y - @t)
128
128
  if @loss_value > 1
129
129
  dy[dy >= 0] = 1
130
130
  dy[dy < 0] = -1
131
131
  end
132
- dy
132
+ dy / @y.shape[0]
133
133
  end
134
134
 
135
135
  private
@@ -168,7 +168,7 @@ module DNN
168
168
  end
169
169
 
170
170
  def backward_node(d)
171
- @x - @t
171
+ (@x - @t) / @x.shape[0]
172
172
  end
173
173
 
174
174
  def to_hash
@@ -205,7 +205,7 @@ module DNN
205
205
  end
206
206
 
207
207
  def backward_node(d)
208
- @x - @t
208
+ (@x - @t) / @x.shape[0]
209
209
  end
210
210
 
211
211
  def to_hash
@@ -262,7 +262,7 @@ module DNN
262
262
  DNN.learning_phase = true
263
263
  out = call(Tensor.convert(x))
264
264
  loss = @loss_func.loss(out, Tensor.convert(y), layers)
265
- loss.link.backward(Xumo::SFloat.zeros(y[0...1, false].shape))
265
+ loss.link.backward(Xumo::SFloat.ones(y[0...1, false].shape[0], 1))
266
266
  @optimizer.update(get_all_trainable_params)
267
267
  @last_log[:train_loss] = loss.data
268
268
  call_callbacks(:after_train_on_batch)
@@ -353,7 +353,12 @@ module DNN
353
353
  # @param [Numo::SFloat] x Input data. However, x is single data.
354
354
  def predict1(x, use_loss_activation: true)
355
355
  check_xy_type(x)
356
- predict(x.reshape(1, *x.shape), use_loss_activation: use_loss_activation)[0, false]
356
+ input = if x.is_a?(Array)
357
+ x.map { |v| v.reshape(1, *v.shape) }
358
+ else
359
+ x.reshape(1, *x.shape)
360
+ end
361
+ predict(input, use_loss_activation: use_loss_activation)[0, false]
357
362
  end
358
363
 
359
364
  # Add callback function.
@@ -458,7 +463,7 @@ module DNN
458
463
  if !x.is_a?(Xumo::SFloat) && !x.is_a?(Array)
459
464
  raise TypeError, "x:#{x.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
460
465
  end
461
- if y && !y.is_a?(Xumo::SFloat) && !x.is_a?(Array)
466
+ if y && !y.is_a?(Xumo::SFloat) && !y.is_a?(Array)
462
467
  raise TypeError, "y:#{y.class.name} is not an instance of #{Xumo::SFloat.name} class or Array class."
463
468
  end
464
469
  end
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.16.1"
2
+ VERSION = "0.16.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.16.1
4
+ version: 0.16.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-01-06 00:00:00.000000000 Z
11
+ date: 2020-01-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -114,6 +114,7 @@ files:
114
114
  - examples/pix2pix/dcgan.rb
115
115
  - examples/pix2pix/imgen.rb
116
116
  - examples/pix2pix/train.rb
117
+ - examples/vae.rb
117
118
  - examples/xor_example.rb
118
119
  - ext/rb_stb_image/extconf.rb
119
120
  - ext/rb_stb_image/rb_stb_image.c