ruby-dnn 0.15.2 → 0.15.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/dnn/core/callbacks.rb +8 -3
- data/lib/dnn/core/losses.rb +8 -13
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b213164ca3e4a7d781a673c4c49ec8dfbd6713468687a98c3739c0cf5629de73
|
4
|
+
data.tar.gz: 92a8133dc0ab085d387199f92c2a602bb3782b810c20a931f7ee8de6e247f2a4
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ea6652994f71d97142fe7357947701eff5f0c26708698a670a148e995fb10816b75b62b0bd60004d9ea99b6c837e5b599e790982ad12d010394e4ec7d6962135
|
7
|
+
data.tar.gz: 9d474a7bdd120f0e39fcb05e11f0f8370100dd21e79440fe0408bbccf676890726623f516ada78ae6b8829bcf0799ceeea78218f64d74a70ca7f5dd9286bdae3
|
data/lib/dnn/core/callbacks.rb
CHANGED
@@ -27,7 +27,8 @@ module DNN
|
|
27
27
|
|
28
28
|
# This callback wrap the lambda function.
|
29
29
|
class LambdaCallback < Callback
|
30
|
-
def initialize(event, lambda)
|
30
|
+
def initialize(event, lambda = nil, &block)
|
31
|
+
lambda = block unless lambda
|
31
32
|
instance_eval do
|
32
33
|
define_singleton_method(event) { lambda.call }
|
33
34
|
end
|
@@ -37,15 +38,19 @@ module DNN
|
|
37
38
|
# A callback that save the model at the after of the epoch.
|
38
39
|
# @param [String] base_file_name Base file name for saving.
|
39
40
|
# @param [Boolean] include_model When set a true, save data included model structure.
|
41
|
+
# @param [Integer] interval Save interval.
|
40
42
|
class CheckPoint < Callback
|
41
|
-
def initialize(base_file_name, include_model: true)
|
43
|
+
def initialize(base_file_name, include_model: true, interval: 1)
|
42
44
|
@base_file_name = base_file_name
|
43
45
|
@include_model = include_model
|
46
|
+
@interval = interval
|
44
47
|
end
|
45
48
|
|
46
49
|
def after_epoch
|
47
50
|
saver = Savers::MarshalSaver.new(@model, include_model: @include_model)
|
48
|
-
|
51
|
+
if @model.last_log[:epoch] % @interval == 0
|
52
|
+
saver.save(@base_file_name + "_epoch#{model.last_log[:epoch]}.marshal")
|
53
|
+
end
|
49
54
|
end
|
50
55
|
end
|
51
56
|
|
data/lib/dnn/core/losses.rb
CHANGED
@@ -17,7 +17,7 @@ module DNN
|
|
17
17
|
end
|
18
18
|
loss_value = forward(y, t)
|
19
19
|
loss_value += regularizers_forward(layers) if layers
|
20
|
-
loss_value
|
20
|
+
loss_value
|
21
21
|
end
|
22
22
|
|
23
23
|
def forward(y, t)
|
@@ -65,8 +65,7 @@ module DNN
|
|
65
65
|
|
66
66
|
class MeanSquaredError < Loss
|
67
67
|
def forward(y, t)
|
68
|
-
|
69
|
-
0.5 * ((y - t)**2).sum / batch_size
|
68
|
+
0.5 * ((y - t)**2).mean(0).sum
|
70
69
|
end
|
71
70
|
|
72
71
|
def backward(y, t)
|
@@ -76,8 +75,7 @@ module DNN
|
|
76
75
|
|
77
76
|
class MeanAbsoluteError < Loss
|
78
77
|
def forward(y, t)
|
79
|
-
|
80
|
-
(y - t).abs.sum / batch_size
|
78
|
+
(y - t).abs.mean(0).sum
|
81
79
|
end
|
82
80
|
|
83
81
|
def backward(y, t)
|
@@ -91,7 +89,7 @@ module DNN
|
|
91
89
|
class Hinge < Loss
|
92
90
|
def forward(y, t)
|
93
91
|
@a = 1 - y * t
|
94
|
-
Xumo::SFloat.maximum(0, @a)
|
92
|
+
Xumo::SFloat.maximum(0, @a).mean(0).sum
|
95
93
|
end
|
96
94
|
|
97
95
|
def backward(y, t)
|
@@ -119,13 +117,11 @@ module DNN
|
|
119
117
|
private
|
120
118
|
|
121
119
|
def loss_l1(y, t)
|
122
|
-
|
123
|
-
(y - t).abs.sum / batch_size
|
120
|
+
(y - t).abs.mean(0).sum
|
124
121
|
end
|
125
122
|
|
126
123
|
def loss_l2(y, t)
|
127
|
-
|
128
|
-
0.5 * ((y - t)**2).sum / batch_size
|
124
|
+
0.5 * ((y - t)**2).mean(0).sum
|
129
125
|
end
|
130
126
|
end
|
131
127
|
|
@@ -147,8 +143,7 @@ module DNN
|
|
147
143
|
|
148
144
|
def forward(y, t)
|
149
145
|
@x = SoftmaxCrossEntropy.softmax(y)
|
150
|
-
|
151
|
-
-(t * Xumo::NMath.log(@x + @eps)).sum / batch_size
|
146
|
+
-(t * Xumo::NMath.log(@x + @eps)).mean(0).sum
|
152
147
|
end
|
153
148
|
|
154
149
|
def backward(y, t)
|
@@ -182,7 +177,7 @@ module DNN
|
|
182
177
|
|
183
178
|
def forward(y, t)
|
184
179
|
@x = SigmoidCrossEntropy.sigmoid(y)
|
185
|
-
-(t * Xumo::NMath.log(@x + @eps) + (1 - t) * Xumo::NMath.log(1 - @x + @eps))
|
180
|
+
-(t * Xumo::NMath.log(@x + @eps) + (1 - t) * Xumo::NMath.log(1 - @x + @eps)).mean(0).sum
|
186
181
|
end
|
187
182
|
|
188
183
|
def backward(y, t)
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.15.
|
4
|
+
version: 0.15.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-12-
|
11
|
+
date: 2019-12-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|