ruby-dnn 0.8.3 → 0.8.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/dnn/core/cnn_layers.rb +0 -5
- data/lib/dnn/core/layers.rb +20 -9
- data/lib/dnn/core/model.rb +7 -0
- data/lib/dnn/core/rnn_layers.rb +8 -35
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 61defa3dd2c7b95ac7d3c3484789abfc7ec792be0fd4dea728fee628479fcf76
|
4
|
+
data.tar.gz: 12e3a398119eb5a64eec7578b6fb3263dca4a5a1028300ab97b47194766bf65b
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cf23b2ea6cb6d00fa2a6b822c309e0db6677f890be18c43f910aae41ea3fee0342734434f349c98e597620b2729d3a165aa8c135f4ecb94d53343bbe49709f9a
|
7
|
+
data.tar.gz: 44c73b4b066a2026a2fa5d6ca81c52fdefb4c411ac9a241290c254a5071a9aa35d3ad63c71bbe09583980523f24ed8d78988fb37656b098d6896eab267585f16
|
data/lib/dnn/core/cnn_layers.rb
CHANGED
@@ -114,11 +114,6 @@ module DNN
|
|
114
114
|
def backward(dout)
|
115
115
|
dout = dout.reshape(dout.shape[0..2].reduce(:*), dout.shape[3])
|
116
116
|
@weight.grad = @col.transpose.dot(dout)
|
117
|
-
if @l1_lambda > 0
|
118
|
-
@weight.grad += dlasso
|
119
|
-
elsif @l2_lambda > 0
|
120
|
-
@weight.grad += dridge
|
121
|
-
end
|
122
117
|
@bias.grad = dout.sum(0)
|
123
118
|
dcol = dout.dot(@weight.data.transpose)
|
124
119
|
dx = col2im(dcol, @x_shape, *@out_size, *@filter_size, @strides)
|
data/lib/dnn/core/layers.rb
CHANGED
@@ -144,13 +144,17 @@ module DNN
|
|
144
144
|
end
|
145
145
|
|
146
146
|
def dlasso
|
147
|
-
|
148
|
-
|
149
|
-
|
147
|
+
if @l1_lambda > 0
|
148
|
+
dlasso = Xumo::SFloat.ones(*@weight.data.shape)
|
149
|
+
dlasso[@weight.data < 0] = -1
|
150
|
+
@weight.grad += @l1_lambda * dlasso
|
151
|
+
end
|
150
152
|
end
|
151
153
|
|
152
154
|
def dridge
|
153
|
-
@l2_lambda
|
155
|
+
if @l2_lambda > 0
|
156
|
+
@weight.grad += @l2_lambda * @weight.data
|
157
|
+
end
|
154
158
|
end
|
155
159
|
|
156
160
|
def to_hash(merge_hash)
|
@@ -197,11 +201,6 @@ module DNN
|
|
197
201
|
|
198
202
|
def backward(dout)
|
199
203
|
@weight.grad = @x.transpose.dot(dout)
|
200
|
-
if @l1_lambda > 0
|
201
|
-
@weight.grad += dlasso
|
202
|
-
elsif @l2_lambda > 0
|
203
|
-
@weight.grad += dridge
|
204
|
-
end
|
205
204
|
@bias.grad = dout.sum(0)
|
206
205
|
dout.dot(@weight.data.transpose)
|
207
206
|
end
|
@@ -270,6 +269,18 @@ module DNN
|
|
270
269
|
|
271
270
|
|
272
271
|
class OutputLayer < Layer
|
272
|
+
# Classes that inherit from this class must implement this method.
|
273
|
+
def loss(x)
|
274
|
+
raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
|
275
|
+
end
|
276
|
+
|
277
|
+
def dloss
|
278
|
+
@model.layers.select { |layer| layer.is_a?(Connection) }.each do |layer|
|
279
|
+
layer.dlasso
|
280
|
+
layer.dridge
|
281
|
+
end
|
282
|
+
end
|
283
|
+
|
273
284
|
private
|
274
285
|
|
275
286
|
def lasso
|
data/lib/dnn/core/model.rb
CHANGED
@@ -2,6 +2,7 @@ require "json"
|
|
2
2
|
require "base64"
|
3
3
|
|
4
4
|
module DNN
|
5
|
+
|
5
6
|
# This class deals with the model of the network.
|
6
7
|
class Model
|
7
8
|
attr_accessor :layers # All layers possessed by the model
|
@@ -162,6 +163,7 @@ module DNN
|
|
162
163
|
forward(x, true)
|
163
164
|
loss_value = loss(y)
|
164
165
|
backward(y)
|
166
|
+
dloss
|
165
167
|
update
|
166
168
|
loss_value
|
167
169
|
end
|
@@ -230,6 +232,10 @@ module DNN
|
|
230
232
|
def loss(y)
|
231
233
|
@layers[-1].loss(y)
|
232
234
|
end
|
235
|
+
|
236
|
+
def dloss
|
237
|
+
@layers[-1].dloss
|
238
|
+
end
|
233
239
|
|
234
240
|
def backward(y)
|
235
241
|
dout = y
|
@@ -303,4 +309,5 @@ module DNN
|
|
303
309
|
end
|
304
310
|
end
|
305
311
|
end
|
312
|
+
|
306
313
|
end
|
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -94,24 +94,18 @@ module DNN
|
|
94
94
|
end
|
95
95
|
end
|
96
96
|
|
97
|
-
def dlasso
|
98
|
-
dlasso = Xumo::SFloat.ones(*@weight.data.shape)
|
99
|
-
dlasso[@weight.data < 0] = -1
|
100
|
-
@l1_lambda * dlasso
|
101
|
-
end
|
102
|
-
|
103
|
-
def dridge
|
104
|
-
@l2_lambda * @weight.data
|
105
|
-
end
|
106
|
-
|
107
97
|
def dlasso2
|
108
|
-
|
109
|
-
|
110
|
-
|
98
|
+
if @l1_lambda > 0
|
99
|
+
dlasso = Xumo::SFloat.ones(*@weight2.data.shape)
|
100
|
+
dlasso[@weight2.data < 0] = -1
|
101
|
+
@weight2.grad += @l1_lambda * dlasso
|
102
|
+
end
|
111
103
|
end
|
112
104
|
|
113
105
|
def dridge2
|
114
|
-
@l2_lambda
|
106
|
+
if @l2_lambda > 0
|
107
|
+
@weight2.grad += l2_lambda * @weight2.data
|
108
|
+
end
|
115
109
|
end
|
116
110
|
|
117
111
|
private
|
@@ -139,13 +133,6 @@ module DNN
|
|
139
133
|
dh2 = @activation.backward(dh2)
|
140
134
|
@rnn.weight.grad += @x.transpose.dot(dh2)
|
141
135
|
@rnn.weight2.grad += @h.transpose.dot(dh2)
|
142
|
-
if @rnn.l1_lambda > 0
|
143
|
-
@rnn.weight.grad += dlasso
|
144
|
-
@rnn.weight2.grad += dlasso2
|
145
|
-
elsif @rnn.l2_lambda > 0
|
146
|
-
@rnn.weight.grad += dridge
|
147
|
-
@rnn.weight2.grad += dridge2
|
148
|
-
end
|
149
136
|
@rnn.bias.grad += dh2.sum(0)
|
150
137
|
dx = dh2.dot(@rnn.weight.data.transpose)
|
151
138
|
dh = dh2.dot(@rnn.weight2.data.transpose)
|
@@ -250,13 +237,6 @@ module DNN
|
|
250
237
|
|
251
238
|
@rnn.weight.grad += @x.transpose.dot(da)
|
252
239
|
@rnn.weight2.grad += @h.transpose.dot(da)
|
253
|
-
if @rnn.l1_lambda > 0
|
254
|
-
@rnn.weight.grad += dlasso
|
255
|
-
@rnn.weight2.grad += dlasso2
|
256
|
-
elsif @rnn.l2_lambda > 0
|
257
|
-
@rnn.weight.grad += dridge
|
258
|
-
@rnn.weight2.grad += dridge2
|
259
|
-
end
|
260
240
|
@rnn.bias.grad += da.sum(0)
|
261
241
|
dx = da.dot(@rnn.weight.data.transpose)
|
262
242
|
dh = da.dot(@rnn.weight2.data.transpose)
|
@@ -401,13 +381,6 @@ module DNN
|
|
401
381
|
|
402
382
|
@rnn.weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
|
403
383
|
@rnn.weight2.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
|
404
|
-
if @rnn.l1_lambda > 0
|
405
|
-
@rnn.weight.grad += dlasso
|
406
|
-
@rnn.weight2.grad += dlasso2
|
407
|
-
elsif @rnn.l2_lambda > 0
|
408
|
-
@rnn.weight.grad += dridge
|
409
|
-
@rnn.weight2.grad += dridge2
|
410
|
-
end
|
411
384
|
@rnn.bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h])
|
412
385
|
[dx, dh]
|
413
386
|
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.8.
|
4
|
+
version: 0.8.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-02-
|
11
|
+
date: 2019-02-21 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|