ruby-dnn 0.8.3 → 0.8.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9c34f355a4fdde6dce0cdc8c84f0eb3c670d61aa939b6e33b3c1b21b96ff84e2
4
- data.tar.gz: 27362633d8824d7b5a1e765a50c5f248c3bc680ee9261a0cb2e103fde9eca42c
3
+ metadata.gz: 61defa3dd2c7b95ac7d3c3484789abfc7ec792be0fd4dea728fee628479fcf76
4
+ data.tar.gz: 12e3a398119eb5a64eec7578b6fb3263dca4a5a1028300ab97b47194766bf65b
5
5
  SHA512:
6
- metadata.gz: 3fac1e61fffaca5126c4caa488bdadc02d613bd88a368316ce7924b162394d9cc1d9e0b4f8ec6b56078cea5cb23d02e89bed95ac0b0bba4c374290034a3550c2
7
- data.tar.gz: 4df5ad737f827fc5ed063be6cc8d22bee34ef95e35c9c5522dc69b7bd453658b0af04b0f6adb1bb7c94d4f52f16d44040ee1097f83e533fa01717f1736a1a8f6
6
+ metadata.gz: cf23b2ea6cb6d00fa2a6b822c309e0db6677f890be18c43f910aae41ea3fee0342734434f349c98e597620b2729d3a165aa8c135f4ecb94d53343bbe49709f9a
7
+ data.tar.gz: 44c73b4b066a2026a2fa5d6ca81c52fdefb4c411ac9a241290c254a5071a9aa35d3ad63c71bbe09583980523f24ed8d78988fb37656b098d6896eab267585f16
@@ -114,11 +114,6 @@ module DNN
114
114
  def backward(dout)
115
115
  dout = dout.reshape(dout.shape[0..2].reduce(:*), dout.shape[3])
116
116
  @weight.grad = @col.transpose.dot(dout)
117
- if @l1_lambda > 0
118
- @weight.grad += dlasso
119
- elsif @l2_lambda > 0
120
- @weight.grad += dridge
121
- end
122
117
  @bias.grad = dout.sum(0)
123
118
  dcol = dout.dot(@weight.data.transpose)
124
119
  dx = col2im(dcol, @x_shape, *@out_size, *@filter_size, @strides)
@@ -144,13 +144,17 @@ module DNN
144
144
  end
145
145
 
146
146
  def dlasso
147
- dlasso = Xumo::SFloat.ones(*@weight.data.shape)
148
- dlasso[@weight.data < 0] = -1
149
- @l1_lambda * dlasso
147
+ if @l1_lambda > 0
148
+ dlasso = Xumo::SFloat.ones(*@weight.data.shape)
149
+ dlasso[@weight.data < 0] = -1
150
+ @weight.grad += @l1_lambda * dlasso
151
+ end
150
152
  end
151
153
 
152
154
  def dridge
153
- @l2_lambda * @weight.data
155
+ if @l2_lambda > 0
156
+ @weight.grad += @l2_lambda * @weight.data
157
+ end
154
158
  end
155
159
 
156
160
  def to_hash(merge_hash)
@@ -197,11 +201,6 @@ module DNN
197
201
 
198
202
  def backward(dout)
199
203
  @weight.grad = @x.transpose.dot(dout)
200
- if @l1_lambda > 0
201
- @weight.grad += dlasso
202
- elsif @l2_lambda > 0
203
- @weight.grad += dridge
204
- end
205
204
  @bias.grad = dout.sum(0)
206
205
  dout.dot(@weight.data.transpose)
207
206
  end
@@ -270,6 +269,18 @@ module DNN
270
269
 
271
270
 
272
271
  class OutputLayer < Layer
272
+ # Classes that inherit from this class must implement this method.
273
+ def loss(x)
274
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'forward'")
275
+ end
276
+
277
+ def dloss
278
+ @model.layers.select { |layer| layer.is_a?(Connection) }.each do |layer|
279
+ layer.dlasso
280
+ layer.dridge
281
+ end
282
+ end
283
+
273
284
  private
274
285
 
275
286
  def lasso
@@ -2,6 +2,7 @@ require "json"
2
2
  require "base64"
3
3
 
4
4
  module DNN
5
+
5
6
  # This class deals with the model of the network.
6
7
  class Model
7
8
  attr_accessor :layers # All layers possessed by the model
@@ -162,6 +163,7 @@ module DNN
162
163
  forward(x, true)
163
164
  loss_value = loss(y)
164
165
  backward(y)
166
+ dloss
165
167
  update
166
168
  loss_value
167
169
  end
@@ -230,6 +232,10 @@ module DNN
230
232
  def loss(y)
231
233
  @layers[-1].loss(y)
232
234
  end
235
+
236
+ def dloss
237
+ @layers[-1].dloss
238
+ end
233
239
 
234
240
  def backward(y)
235
241
  dout = y
@@ -303,4 +309,5 @@ module DNN
303
309
  end
304
310
  end
305
311
  end
312
+
306
313
  end
@@ -94,24 +94,18 @@ module DNN
94
94
  end
95
95
  end
96
96
 
97
- def dlasso
98
- dlasso = Xumo::SFloat.ones(*@weight.data.shape)
99
- dlasso[@weight.data < 0] = -1
100
- @l1_lambda * dlasso
101
- end
102
-
103
- def dridge
104
- @l2_lambda * @weight.data
105
- end
106
-
107
97
  def dlasso2
108
- dlasso = Xumo::SFloat.ones(*@weight2.data.shape)
109
- dlasso[@weight2.data < 0] = -1
110
- @l1_lambda * dlasso
98
+ if @l1_lambda > 0
99
+ dlasso = Xumo::SFloat.ones(*@weight2.data.shape)
100
+ dlasso[@weight2.data < 0] = -1
101
+ @weight2.grad += @l1_lambda * dlasso
102
+ end
111
103
  end
112
104
 
113
105
  def dridge2
114
- @l2_lambda * @weight2.data
106
+ if @l2_lambda > 0
107
+ @weight2.grad += l2_lambda * @weight2.data
108
+ end
115
109
  end
116
110
 
117
111
  private
@@ -139,13 +133,6 @@ module DNN
139
133
  dh2 = @activation.backward(dh2)
140
134
  @rnn.weight.grad += @x.transpose.dot(dh2)
141
135
  @rnn.weight2.grad += @h.transpose.dot(dh2)
142
- if @rnn.l1_lambda > 0
143
- @rnn.weight.grad += dlasso
144
- @rnn.weight2.grad += dlasso2
145
- elsif @rnn.l2_lambda > 0
146
- @rnn.weight.grad += dridge
147
- @rnn.weight2.grad += dridge2
148
- end
149
136
  @rnn.bias.grad += dh2.sum(0)
150
137
  dx = dh2.dot(@rnn.weight.data.transpose)
151
138
  dh = dh2.dot(@rnn.weight2.data.transpose)
@@ -250,13 +237,6 @@ module DNN
250
237
 
251
238
  @rnn.weight.grad += @x.transpose.dot(da)
252
239
  @rnn.weight2.grad += @h.transpose.dot(da)
253
- if @rnn.l1_lambda > 0
254
- @rnn.weight.grad += dlasso
255
- @rnn.weight2.grad += dlasso2
256
- elsif @rnn.l2_lambda > 0
257
- @rnn.weight.grad += dridge
258
- @rnn.weight2.grad += dridge2
259
- end
260
240
  @rnn.bias.grad += da.sum(0)
261
241
  dx = da.dot(@rnn.weight.data.transpose)
262
242
  dh = da.dot(@rnn.weight2.data.transpose)
@@ -401,13 +381,6 @@ module DNN
401
381
 
402
382
  @rnn.weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
403
383
  @rnn.weight2.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
404
- if @rnn.l1_lambda > 0
405
- @rnn.weight.grad += dlasso
406
- @rnn.weight2.grad += dlasso2
407
- elsif @rnn.l2_lambda > 0
408
- @rnn.weight.grad += dridge
409
- @rnn.weight2.grad += dridge2
410
- end
411
384
  @rnn.bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h])
412
385
  [dx, dh]
413
386
  end
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.8.3"
2
+ VERSION = "0.8.4"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.3
4
+ version: 0.8.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-02-18 00:00:00.000000000 Z
11
+ date: 2019-02-21 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray