ruby-dnn 0.9.0 → 0.9.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/API-Reference.ja.md +6 -6
- data/lib/dnn.rb +3 -1
- data/lib/dnn/core/activations.rb +0 -8
- data/lib/dnn/core/layers.rb +2 -2
- data/lib/dnn/core/losses.rb +3 -6
- data/lib/dnn/core/model.rb +8 -9
- data/lib/dnn/core/optimizers.rb +4 -4
- data/lib/dnn/core/utils.rb +0 -2
- data/lib/dnn/lib/image.rb +1 -1
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: '02929b358bbd4ff8c3107c54be0ce37ae86c20a3d345c0090f4be5bcd2ad8b32'
|
4
|
+
data.tar.gz: 7641e5072f9bcdd4eb1bd93d173f78fe0fa11769ecca614ce44155dc7e310b96
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: df47d323eda15b0f11dcf2153083bf05a8bd6c158227c3c93c0c9e0ab1f4679769fa8ad031a8bcbb886e5610880b2c95e26ad425a49ccdfe1b79b0ee280628de
|
7
|
+
data.tar.gz: 86766ea8873229cb665d3e93e70142c246ad8b3e197543093d06ce674d66319a2f9a94a78cf089a0677df4a4966fdbfcfb0ffc95337e6465eabd2ab4d9ba7a1c
|
data/API-Reference.ja.md
CHANGED
@@ -939,25 +939,25 @@ Numo::SFloat y
|
|
939
939
|
### return
|
940
940
|
損失関数の値。
|
941
941
|
|
942
|
-
## abstruct def
|
943
|
-
|
942
|
+
## abstruct def backward(y)
|
943
|
+
損失関数の逆伝搬を行います。全ての損失関数のクラスは、このメソッドを実装する必要があります。
|
944
944
|
### arguments
|
945
945
|
Numo::SFloat y
|
946
946
|
出力データ。
|
947
947
|
### return
|
948
948
|
損失関数の値。
|
949
949
|
|
950
|
-
# class MeanSquaredError <
|
950
|
+
# class MeanSquaredError < Loss
|
951
951
|
二乗誤差の損失関数です。
|
952
952
|
|
953
953
|
|
954
|
-
# class
|
954
|
+
# class MeanAbsoluteError < Loss
|
955
955
|
平均絶対誤差の損失関数です。
|
956
956
|
|
957
957
|
|
958
|
-
# class SoftmaxCrossEntropy <
|
958
|
+
# class SoftmaxCrossEntropy < Loss
|
959
959
|
ソフトマックス関数とクロスエントロピー誤差を合わせた損失関数です。
|
960
960
|
|
961
961
|
|
962
|
-
# class SigmoidCrossEntropy <
|
962
|
+
# class SigmoidCrossEntropy < Loss
|
963
963
|
シグモイド関数とクロスエントロピー誤差を合わせた損失関数です。
|
data/lib/dnn.rb
CHANGED
data/lib/dnn/core/activations.rb
CHANGED
@@ -13,8 +13,6 @@ module DNN
|
|
13
13
|
|
14
14
|
|
15
15
|
class Tanh < Layers::Layer
|
16
|
-
NMath = Xumo::NMath
|
17
|
-
|
18
16
|
def forward(x)
|
19
17
|
@out = NMath.tanh(x)
|
20
18
|
end
|
@@ -38,8 +36,6 @@ module DNN
|
|
38
36
|
|
39
37
|
|
40
38
|
class Softplus < Layers::Layer
|
41
|
-
NMath = Xumo::NMath
|
42
|
-
|
43
39
|
def forward(x)
|
44
40
|
@x = x
|
45
41
|
NMath.log(1 + NMath.exp(x))
|
@@ -52,8 +48,6 @@ module DNN
|
|
52
48
|
|
53
49
|
|
54
50
|
class Swish < Layers::Layer
|
55
|
-
NMath = Xumo::NMath
|
56
|
-
|
57
51
|
def forward(x)
|
58
52
|
@x = x
|
59
53
|
@out = x * (1 / (1 + NMath.exp(-x)))
|
@@ -111,8 +105,6 @@ module DNN
|
|
111
105
|
|
112
106
|
|
113
107
|
class ELU < Layers::Layer
|
114
|
-
NMath = Xumo::NMath
|
115
|
-
|
116
108
|
attr_reader :alpha
|
117
109
|
|
118
110
|
def self.load_hash(hash)
|
data/lib/dnn/core/layers.rb
CHANGED
@@ -318,14 +318,14 @@ module DNN
|
|
318
318
|
mean = x.mean(0)
|
319
319
|
@xc = x - mean
|
320
320
|
var = (@xc**2).mean(0)
|
321
|
-
@std =
|
321
|
+
@std = NMath.sqrt(var + 1e-7)
|
322
322
|
xn = @xc / @std
|
323
323
|
@xn = xn
|
324
324
|
@running_mean.data = @momentum * @running_mean.data + (1 - @momentum) * mean
|
325
325
|
@running_var.data = @momentum * @running_var.data + (1 - @momentum) * var
|
326
326
|
else
|
327
327
|
xc = x - @running_mean.data
|
328
|
-
xn = xc /
|
328
|
+
xn = xc / NMath.sqrt(@running_var.data + 1e-7)
|
329
329
|
end
|
330
330
|
@gamma.data * xn + @beta.data
|
331
331
|
end
|
data/lib/dnn/core/losses.rb
CHANGED
@@ -57,11 +57,12 @@ module DNN
|
|
57
57
|
|
58
58
|
|
59
59
|
class HuberLoss < Loss
|
60
|
-
def forward(out, y)
|
60
|
+
def forward(out, y, layers)
|
61
61
|
@out = out
|
62
62
|
loss = loss_l1(y)
|
63
63
|
loss = loss > 1 ? loss : loss_l2(y)
|
64
|
-
|
64
|
+
#@loss = loss + regularize(layers)
|
65
|
+
@loss = loss
|
65
66
|
end
|
66
67
|
|
67
68
|
def backward(y)
|
@@ -88,8 +89,6 @@ module DNN
|
|
88
89
|
|
89
90
|
|
90
91
|
class SoftmaxCrossEntropy < Loss
|
91
|
-
NMath = Xumo::NMath
|
92
|
-
|
93
92
|
def forward(x, y)
|
94
93
|
@out = Utils.softmax(x)
|
95
94
|
batch_size = y.shape[0]
|
@@ -103,8 +102,6 @@ module DNN
|
|
103
102
|
|
104
103
|
|
105
104
|
class SigmoidCrossEntropy < Loss
|
106
|
-
NMath = Xumo::NMath
|
107
|
-
|
108
105
|
def forward(x, y)
|
109
106
|
@out = Utils.sigmoid(x)
|
110
107
|
batch_size = y.shape[0]
|
data/lib/dnn/core/model.rb
CHANGED
@@ -43,11 +43,7 @@ module DNN
|
|
43
43
|
hash_params.each do |key, (shape, base64_param)|
|
44
44
|
bin = Base64.decode64(base64_param)
|
45
45
|
data = Xumo::SFloat.from_binary(bin).reshape(*shape)
|
46
|
-
|
47
|
-
layer.params[key].data = data
|
48
|
-
else
|
49
|
-
layer.params[key] = data
|
50
|
-
end
|
46
|
+
layer.params[key].data = data
|
51
47
|
end
|
52
48
|
has_param_layers_index += 1
|
53
49
|
end
|
@@ -154,9 +150,7 @@ module DNN
|
|
154
150
|
verbose: true,
|
155
151
|
batch_proc: nil,
|
156
152
|
&epoch_proc)
|
157
|
-
unless compiled?
|
158
|
-
raise DNN_Error.new("The model is not compiled.")
|
159
|
-
end
|
153
|
+
raise DNN_Error.new("The model is not compiled.") unless compiled?
|
160
154
|
check_xy_type(x, y)
|
161
155
|
dataset = Dataset.new(x, y)
|
162
156
|
num_train_datas = x.shape[0]
|
@@ -194,11 +188,16 @@ module DNN
|
|
194
188
|
end
|
195
189
|
|
196
190
|
def train_on_batch(x, y, &batch_proc)
|
191
|
+
raise DNN_Error.new("The model is not compiled.") unless compiled?
|
197
192
|
check_xy_type(x, y)
|
198
193
|
input_data_shape_check(x, y)
|
199
194
|
x, y = batch_proc.call(x, y) if batch_proc
|
200
195
|
out = forward(x, true)
|
201
|
-
loss_value =
|
196
|
+
loss_value = if @loss.is_a?(HuberLoss)
|
197
|
+
@loss.forward(out, y, get_all_layers)
|
198
|
+
else
|
199
|
+
@loss.forward(out, y) + @loss.regularize(get_all_layers)
|
200
|
+
end
|
202
201
|
dout = @loss.backward(y)
|
203
202
|
backward(dout, true)
|
204
203
|
@loss.d_regularize(get_all_layers)
|
data/lib/dnn/core/optimizers.rb
CHANGED
@@ -88,7 +88,7 @@ module DNN
|
|
88
88
|
params.select { |key, param| param.grad }.each_value do |param|
|
89
89
|
@g[param] ||= 0
|
90
90
|
@g[param] += param.grad**2
|
91
|
-
param.data -= (@learning_rate /
|
91
|
+
param.data -= (@learning_rate / NMath.sqrt(@g[param] + 1e-7)) * param.grad
|
92
92
|
end
|
93
93
|
end
|
94
94
|
end
|
@@ -111,7 +111,7 @@ module DNN
|
|
111
111
|
params.select { |key, param| param.grad }.each_value do |param|
|
112
112
|
@g[param] ||= 0
|
113
113
|
@g[param] = @alpha * @g[param] + (1 - @alpha) * param.grad**2
|
114
|
-
param.data -= (@learning_rate /
|
114
|
+
param.data -= (@learning_rate / NMath.sqrt(@g[param] + 1e-7)) * param.grad
|
115
115
|
end
|
116
116
|
end
|
117
117
|
|
@@ -140,7 +140,7 @@ module DNN
|
|
140
140
|
@h[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
141
141
|
@s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
|
142
142
|
@h[param] = @rho * @h[param] + (1 - @rho) * param.grad**2
|
143
|
-
v = (
|
143
|
+
v = (NMath.sqrt(@s[param] + 1e-6) / NMath.sqrt(@h[param] + 1e-6)) * param.grad
|
144
144
|
@s[param] = @rho * @s[param] + (1 - @rho) * v**2
|
145
145
|
param.data -= v
|
146
146
|
end
|
@@ -177,7 +177,7 @@ module DNN
|
|
177
177
|
@v[param] ||= 0
|
178
178
|
@m[param] += (1 - @beta1) * (param.grad - @m[param])
|
179
179
|
@v[param] += (1 - @beta2) * (param.grad**2 - @v[param])
|
180
|
-
param.data -= lr * @m[param] /
|
180
|
+
param.data -= lr * @m[param] / NMath.sqrt(@v[param] + 1e-7)
|
181
181
|
end
|
182
182
|
end
|
183
183
|
|
data/lib/dnn/core/utils.rb
CHANGED
data/lib/dnn/lib/image.rb
CHANGED
@@ -14,7 +14,7 @@ module DNN
|
|
14
14
|
if img.shape.length == 2
|
15
15
|
img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
|
16
16
|
elsif img.shape[2] == 1
|
17
|
-
img = img.
|
17
|
+
img = img.reshape(img.shape[0], img.shape[1])
|
18
18
|
img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
|
19
19
|
end
|
20
20
|
h, w, ch = img.shape
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.9.
|
4
|
+
version: 0.9.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-05-
|
11
|
+
date: 2019-05-04 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|