ruby-dnn 0.8.3 → 0.8.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9c34f355a4fdde6dce0cdc8c84f0eb3c670d61aa939b6e33b3c1b21b96ff84e2
4
- data.tar.gz: 27362633d8824d7b5a1e765a50c5f248c3bc680ee9261a0cb2e103fde9eca42c
3
+ metadata.gz: 61defa3dd2c7b95ac7d3c3484789abfc7ec792be0fd4dea728fee628479fcf76
4
+ data.tar.gz: 12e3a398119eb5a64eec7578b6fb3263dca4a5a1028300ab97b47194766bf65b
5
5
  SHA512:
6
- metadata.gz: 3fac1e61fffaca5126c4caa488bdadc02d613bd88a368316ce7924b162394d9cc1d9e0b4f8ec6b56078cea5cb23d02e89bed95ac0b0bba4c374290034a3550c2
7
- data.tar.gz: 4df5ad737f827fc5ed063be6cc8d22bee34ef95e35c9c5522dc69b7bd453658b0af04b0f6adb1bb7c94d4f52f16d44040ee1097f83e533fa01717f1736a1a8f6
6
+ metadata.gz: cf23b2ea6cb6d00fa2a6b822c309e0db6677f890be18c43f910aae41ea3fee0342734434f349c98e597620b2729d3a165aa8c135f4ecb94d53343bbe49709f9a
7
+ data.tar.gz: 44c73b4b066a2026a2fa5d6ca81c52fdefb4c411ac9a241290c254a5071a9aa35d3ad63c71bbe09583980523f24ed8d78988fb37656b098d6896eab267585f16
@@ -114,11 +114,6 @@ module DNN
114
114
  def backward(dout)
115
115
  dout = dout.reshape(dout.shape[0..2].reduce(:*), dout.shape[3])
116
116
  @weight.grad = @col.transpose.dot(dout)
117
- if @l1_lambda > 0
118
- @weight.grad += dlasso
119
- elsif @l2_lambda > 0
120
- @weight.grad += dridge
121
- end
122
117
  @bias.grad = dout.sum(0)
123
118
  dcol = dout.dot(@weight.data.transpose)
124
119
  dx = col2im(dcol, @x_shape, *@out_size, *@filter_size, @strides)
@@ -144,13 +144,17 @@ module DNN
144
144
  end
145
145
 
146
146
  def dlasso
147
- dlasso = Xumo::SFloat.ones(*@weight.data.shape)
148
- dlasso[@weight.data < 0] = -1
149
- @l1_lambda * dlasso
147
+ if @l1_lambda > 0
148
+ dlasso = Xumo::SFloat.ones(*@weight.data.shape)
149
+ dlasso[@weight.data < 0] = -1
150
+ @weight.grad += @l1_lambda * dlasso
151
+ end
150
152
  end
151
153
 
152
154
  def dridge
153
- @l2_lambda * @weight.data
155
+ if @l2_lambda > 0
156
+ @weight.grad += @l2_lambda * @weight.data
157
+ end
154
158
  end
155
159
 
156
160
  def to_hash(merge_hash)
@@ -197,11 +201,6 @@ module DNN
197
201
 
198
202
  def backward(dout)
199
203
  @weight.grad = @x.transpose.dot(dout)
200
- if @l1_lambda > 0
201
- @weight.grad += dlasso
202
- elsif @l2_lambda > 0
203
- @weight.grad += dridge
204
- end
205
204
  @bias.grad = dout.sum(0)
206
205
  dout.dot(@weight.data.transpose)
207
206
  end
@@ -270,6 +269,18 @@ module DNN
270
269
 
271
270
 
272
271
  class OutputLayer < Layer
272
+ # Classes that inherit from this class must implement this method.
273
+ def loss(x)
274
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
275
+ end
276
+
277
+ def dloss
278
+ @model.layers.select { |layer| layer.is_a?(Connection) }.each do |layer|
279
+ layer.dlasso
280
+ layer.dridge
281
+ end
282
+ end
283
+
273
284
  private
274
285
 
275
286
  def lasso
@@ -2,6 +2,7 @@ require "json"
2
2
  require "base64"
3
3
 
4
4
  module DNN
5
+
5
6
  # This class deals with the model of the network.
6
7
  class Model
7
8
  attr_accessor :layers # All layers possessed by the model
@@ -162,6 +163,7 @@ module DNN
162
163
  forward(x, true)
163
164
  loss_value = loss(y)
164
165
  backward(y)
166
+ dloss
165
167
  update
166
168
  loss_value
167
169
  end
@@ -230,6 +232,10 @@ module DNN
230
232
  def loss(y)
231
233
  @layers[-1].loss(y)
232
234
  end
235
+
236
+ def dloss
237
+ @layers[-1].dloss
238
+ end
233
239
 
234
240
  def backward(y)
235
241
  dout = y
@@ -303,4 +309,5 @@ module DNN
303
309
  end
304
310
  end
305
311
  end
312
+
306
313
  end
@@ -94,24 +94,18 @@ module DNN
94
94
  end
95
95
  end
96
96
 
97
- def dlasso
98
- dlasso = Xumo::SFloat.ones(*@weight.data.shape)
99
- dlasso[@weight.data < 0] = -1
100
- @l1_lambda * dlasso
101
- end
102
-
103
- def dridge
104
- @l2_lambda * @weight.data
105
- end
106
-
107
97
  def dlasso2
108
- dlasso = Xumo::SFloat.ones(*@weight2.data.shape)
109
- dlasso[@weight2.data < 0] = -1
110
- @l1_lambda * dlasso
98
+ if @l1_lambda > 0
99
+ dlasso = Xumo::SFloat.ones(*@weight2.data.shape)
100
+ dlasso[@weight2.data < 0] = -1
101
+ @weight2.grad += @l1_lambda * dlasso
102
+ end
111
103
  end
112
104
 
113
105
  def dridge2
114
- @l2_lambda * @weight2.data
106
+ if @l2_lambda > 0
107
+ @weight2.grad += l2_lambda * @weight2.data
108
+ end
115
109
  end
116
110
 
117
111
  private
@@ -139,13 +133,6 @@ module DNN
139
133
  dh2 = @activation.backward(dh2)
140
134
  @rnn.weight.grad += @x.transpose.dot(dh2)
141
135
  @rnn.weight2.grad += @h.transpose.dot(dh2)
142
- if @rnn.l1_lambda > 0
143
- @rnn.weight.grad += dlasso
144
- @rnn.weight2.grad += dlasso2
145
- elsif @rnn.l2_lambda > 0
146
- @rnn.weight.grad += dridge
147
- @rnn.weight2.grad += dridge2
148
- end
149
136
  @rnn.bias.grad += dh2.sum(0)
150
137
  dx = dh2.dot(@rnn.weight.data.transpose)
151
138
  dh = dh2.dot(@rnn.weight2.data.transpose)
@@ -250,13 +237,6 @@ module DNN
250
237
 
251
238
  @rnn.weight.grad += @x.transpose.dot(da)
252
239
  @rnn.weight2.grad += @h.transpose.dot(da)
253
- if @rnn.l1_lambda > 0
254
- @rnn.weight.grad += dlasso
255
- @rnn.weight2.grad += dlasso2
256
- elsif @rnn.l2_lambda > 0
257
- @rnn.weight.grad += dridge
258
- @rnn.weight2.grad += dridge2
259
- end
260
240
  @rnn.bias.grad += da.sum(0)
261
241
  dx = da.dot(@rnn.weight.data.transpose)
262
242
  dh = da.dot(@rnn.weight2.data.transpose)
@@ -401,13 +381,6 @@ module DNN
401
381
 
402
382
  @rnn.weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
403
383
  @rnn.weight2.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
404
- if @rnn.l1_lambda > 0
405
- @rnn.weight.grad += dlasso
406
- @rnn.weight2.grad += dlasso2
407
- elsif @rnn.l2_lambda > 0
408
- @rnn.weight.grad += dridge
409
- @rnn.weight2.grad += dridge2
410
- end
411
384
  @rnn.bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h])
412
385
  [dx, dh]
413
386
  end
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.8.3"
2
+ VERSION = "0.8.4"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.3
4
+ version: 0.8.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-02-18 00:00:00.000000000 Z
11
+ date: 2019-02-21 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray