ruby-dnn 0.15.2 → 0.15.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/dnn/core/callbacks.rb +8 -3
- data/lib/dnn/core/losses.rb +8 -13
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b213164ca3e4a7d781a673c4c49ec8dfbd6713468687a98c3739c0cf5629de73
|
4
|
+
data.tar.gz: 92a8133dc0ab085d387199f92c2a602bb3782b810c20a931f7ee8de6e247f2a4
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ea6652994f71d97142fe7357947701eff5f0c26708698a670a148e995fb10816b75b62b0bd60004d9ea99b6c837e5b599e790982ad12d010394e4ec7d6962135
|
7
|
+
data.tar.gz: 9d474a7bdd120f0e39fcb05e11f0f8370100dd21e79440fe0408bbccf676890726623f516ada78ae6b8829bcf0799ceeea78218f64d74a70ca7f5dd9286bdae3
|
data/lib/dnn/core/callbacks.rb
CHANGED
@@ -27,7 +27,8 @@ module DNN
|
|
27
27
|
|
28
28
|
# This callback wrap the lambda function.
|
29
29
|
class LambdaCallback < Callback
|
30
|
-
def initialize(event, lambda)
|
30
|
+
def initialize(event, lambda = nil, &block)
|
31
|
+
lambda = block unless lambda
|
31
32
|
instance_eval do
|
32
33
|
define_singleton_method(event) { lambda.call }
|
33
34
|
end
|
@@ -37,15 +38,19 @@ module DNN
|
|
37
38
|
# A callback that save the model at the after of the epoch.
|
38
39
|
# @param [String] base_file_name Base file name for saving.
|
39
40
|
# @param [Boolean] include_model When set a true, save data included model structure.
|
41
|
+
# @param [Integer] interval Save interval.
|
40
42
|
class CheckPoint < Callback
|
41
|
-
def initialize(base_file_name, include_model: true)
|
43
|
+
def initialize(base_file_name, include_model: true, interval: 1)
|
42
44
|
@base_file_name = base_file_name
|
43
45
|
@include_model = include_model
|
46
|
+
@interval = interval
|
44
47
|
end
|
45
48
|
|
46
49
|
def after_epoch
|
47
50
|
saver = Savers::MarshalSaver.new(@model, include_model: @include_model)
|
48
|
-
|
51
|
+
if @model.last_log[:epoch] % @interval == 0
|
52
|
+
saver.save(@base_file_name + "_epoch#{model.last_log[:epoch]}.marshal")
|
53
|
+
end
|
49
54
|
end
|
50
55
|
end
|
51
56
|
|
data/lib/dnn/core/losses.rb
CHANGED
@@ -17,7 +17,7 @@ module DNN
|
|
17
17
|
end
|
18
18
|
loss_value = forward(y, t)
|
19
19
|
loss_value += regularizers_forward(layers) if layers
|
20
|
-
loss_value
|
20
|
+
loss_value
|
21
21
|
end
|
22
22
|
|
23
23
|
def forward(y, t)
|
@@ -65,8 +65,7 @@ module DNN
|
|
65
65
|
|
66
66
|
class MeanSquaredError < Loss
|
67
67
|
def forward(y, t)
|
68
|
-
|
69
|
-
0.5 * ((y - t)**2).sum / batch_size
|
68
|
+
0.5 * ((y - t)**2).mean(0).sum
|
70
69
|
end
|
71
70
|
|
72
71
|
def backward(y, t)
|
@@ -76,8 +75,7 @@ module DNN
|
|
76
75
|
|
77
76
|
class MeanAbsoluteError < Loss
|
78
77
|
def forward(y, t)
|
79
|
-
|
80
|
-
(y - t).abs.sum / batch_size
|
78
|
+
(y - t).abs.mean(0).sum
|
81
79
|
end
|
82
80
|
|
83
81
|
def backward(y, t)
|
@@ -91,7 +89,7 @@ module DNN
|
|
91
89
|
class Hinge < Loss
|
92
90
|
def forward(y, t)
|
93
91
|
@a = 1 - y * t
|
94
|
-
Xumo::SFloat.maximum(0, @a)
|
92
|
+
Xumo::SFloat.maximum(0, @a).mean(0).sum
|
95
93
|
end
|
96
94
|
|
97
95
|
def backward(y, t)
|
@@ -119,13 +117,11 @@ module DNN
|
|
119
117
|
private
|
120
118
|
|
121
119
|
def loss_l1(y, t)
|
122
|
-
|
123
|
-
(y - t).abs.sum / batch_size
|
120
|
+
(y - t).abs.mean(0).sum
|
124
121
|
end
|
125
122
|
|
126
123
|
def loss_l2(y, t)
|
127
|
-
|
128
|
-
0.5 * ((y - t)**2).sum / batch_size
|
124
|
+
0.5 * ((y - t)**2).mean(0).sum
|
129
125
|
end
|
130
126
|
end
|
131
127
|
|
@@ -147,8 +143,7 @@ module DNN
|
|
147
143
|
|
148
144
|
def forward(y, t)
|
149
145
|
@x = SoftmaxCrossEntropy.softmax(y)
|
150
|
-
|
151
|
-
-(t * Xumo::NMath.log(@x + @eps)).sum / batch_size
|
146
|
+
-(t * Xumo::NMath.log(@x + @eps)).mean(0).sum
|
152
147
|
end
|
153
148
|
|
154
149
|
def backward(y, t)
|
@@ -182,7 +177,7 @@ module DNN
|
|
182
177
|
|
183
178
|
def forward(y, t)
|
184
179
|
@x = SigmoidCrossEntropy.sigmoid(y)
|
185
|
-
-(t * Xumo::NMath.log(@x + @eps) + (1 - t) * Xumo::NMath.log(1 - @x + @eps))
|
180
|
+
-(t * Xumo::NMath.log(@x + @eps) + (1 - t) * Xumo::NMath.log(1 - @x + @eps)).mean(0).sum
|
186
181
|
end
|
187
182
|
|
188
183
|
def backward(y, t)
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.15.
|
4
|
+
version: 0.15.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-12-
|
11
|
+
date: 2019-12-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|