nn 2.0.1 → 2.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/document.txt +10 -0
- data/lib/nn.rb +20 -12
- data/nn.gemspec +1 -1
- data/sample/cifar10_program.rb +1 -1
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8e26103663135846646308af8fa48851d0f56d34122bfdb6172e1a8bd4675bdf
|
4
|
+
data.tar.gz: cffd4270bae567540673a1ec327ea416b14c3794172b7db3c589c0053bf15466
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: abaa7793eb7e44cd44d9f49f63d7cd252f01e11e4c41d7e709cf6c8124b923d300ff6c2e2b07f410f4dbeaaa8f53a82f876e22582681ec4178364110128726ab
|
7
|
+
data.tar.gz: c8f830238afd0a767f5cc34a7be600c2fc481c591f8f8f433c13d0ad9647011e49f3b0ff4b87d5784a0c8ce2c01552445dc1540e252321cd856eea23c4a6f84d
|
data/document.txt
CHANGED
@@ -13,6 +13,11 @@ class NN
|
|
13
13
|
|
14
14
|
<クラスメソッド>
|
15
15
|
load(file_name) : NN
|
16
|
+
Marshal形式で保存された学習結果を読み込みます。
|
17
|
+
String file_name 読み込むMarshalファイル名
|
18
|
+
戻り値 NNのインスタンス
|
19
|
+
|
20
|
+
load_json(file_name) : NN
|
16
21
|
JSON形式で保存された学習結果を読み込みます。
|
17
22
|
String file_name 読み込むJSONファイル名
|
18
23
|
戻り値 NNのインスタンス
|
@@ -109,6 +114,10 @@ run(x) : SFloat
|
|
109
114
|
戻り値 出力ノードの値
|
110
115
|
|
111
116
|
save(file_name) : void
|
117
|
+
学習結果をMarshal形式で保存します。
|
118
|
+
String file_name 書き込むMarshalファイル名
|
119
|
+
|
120
|
+
save_json(file_name) : void
|
112
121
|
学習結果をJSON形式で保存します。
|
113
122
|
String file_name 書き込むJSONファイル名
|
114
123
|
|
@@ -142,3 +151,4 @@ http://d.hatena.ne.jp/n_shuyo/20090913/mnist
|
|
142
151
|
2018/5/4 バージョン1.8公開
|
143
152
|
2018/5/16 バージョン2.0公開
|
144
153
|
2018/6/10 バージョン2.0.1公開
|
154
|
+
2018/6/10 バージョン2.1.0公開
|
data/lib/nn.rb
CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
|
|
2
2
|
require "json"
|
3
3
|
|
4
4
|
class NN
|
5
|
-
VERSION = "2.
|
5
|
+
VERSION = "2.1"
|
6
6
|
|
7
7
|
include Numo
|
8
8
|
|
@@ -44,6 +44,10 @@ class NN
|
|
44
44
|
end
|
45
45
|
|
46
46
|
def self.load(file_name)
|
47
|
+
Marshal.load(File.binread(file_name))
|
48
|
+
end
|
49
|
+
|
50
|
+
def self.load_json(file_name)
|
47
51
|
json = JSON.parse(File.read(file_name))
|
48
52
|
nn = self.new(json["num_nodes"],
|
49
53
|
learning_rate: json["learning_rate"],
|
@@ -150,6 +154,10 @@ class NN
|
|
150
154
|
end
|
151
155
|
|
152
156
|
def save(file_name)
|
157
|
+
File.binwrite(file_name, Marshal.dump(self))
|
158
|
+
end
|
159
|
+
|
160
|
+
def save_json(file_name)
|
153
161
|
json = {
|
154
162
|
"version" => VERSION,
|
155
163
|
"num_nodes" => @num_nodes,
|
@@ -240,8 +248,8 @@ class NN
|
|
240
248
|
|
241
249
|
def update_weight_and_bias
|
242
250
|
@layers.select{|layer| layer.is_a?(Affine)}.each.with_index do |layer, i|
|
243
|
-
weight_amount = layer.d_weight
|
244
|
-
bias_amount = layer.d_bias
|
251
|
+
weight_amount = layer.d_weight * @learning_rate
|
252
|
+
bias_amount = layer.d_bias * @learning_rate
|
245
253
|
if @momentum > 0
|
246
254
|
weight_amount += @momentum * @weight_amounts[i]
|
247
255
|
@weight_amounts[i] = weight_amount
|
@@ -255,8 +263,8 @@ class NN
|
|
255
263
|
|
256
264
|
def update_gamma_and_beta
|
257
265
|
@layers.select{|layer| layer.is_a?(BatchNorm)}.each.with_index do |layer, i|
|
258
|
-
gamma_amount = layer.d_gamma
|
259
|
-
beta_amount = layer.d_beta
|
266
|
+
gamma_amount = layer.d_gamma * @learning_rate
|
267
|
+
beta_amount = layer.d_beta * @learning_rate
|
260
268
|
if @momentum > 0
|
261
269
|
gamma_amount += @momentum * @gamma_amounts[i]
|
262
270
|
@gamma_amounts[i] = gamma_amount
|
@@ -290,20 +298,22 @@ class NN::Affine
|
|
290
298
|
|
291
299
|
def backward(dout)
|
292
300
|
x = @x.reshape(*@x.shape, 1)
|
293
|
-
@d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1]))
|
301
|
+
@d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1])).mean(0)
|
294
302
|
if @nn.weight_decay > 0
|
295
303
|
dridge = @nn.weight_decay * @nn.weights[@index]
|
296
304
|
@d_weight += dridge
|
297
305
|
end
|
298
|
-
@d_bias = dout
|
306
|
+
@d_bias = dout.mean
|
299
307
|
dout.dot(@nn.weights[@index].transpose)
|
300
308
|
end
|
301
309
|
end
|
302
310
|
|
303
311
|
|
304
312
|
class NN::Sigmoid
|
313
|
+
include Numo
|
314
|
+
|
305
315
|
def forward(x)
|
306
|
-
@out = 1.0 / (1 +
|
316
|
+
@out = 1.0 / (1 + NMath.exp(-x))
|
307
317
|
end
|
308
318
|
|
309
319
|
def backward(dout)
|
@@ -328,8 +338,6 @@ end
|
|
328
338
|
|
329
339
|
|
330
340
|
class NN::Identity
|
331
|
-
include Numo
|
332
|
-
|
333
341
|
def initialize(nn)
|
334
342
|
@nn = nn
|
335
343
|
end
|
@@ -419,8 +427,8 @@ class NN::BatchNorm
|
|
419
427
|
end
|
420
428
|
|
421
429
|
def backward(dout)
|
422
|
-
@d_beta = dout.sum(0)
|
423
|
-
@d_gamma = (@xn * dout).sum(0)
|
430
|
+
@d_beta = dout.sum(0).mean
|
431
|
+
@d_gamma = (@xn * dout).sum(0).mean
|
424
432
|
dxn = @nn.gammas[@index] * dout
|
425
433
|
dxc = dxn / @std
|
426
434
|
dstd = -((dxn * @xc) / (@std ** 2)).sum(0)
|
data/nn.gemspec
CHANGED
data/sample/cifar10_program.rb
CHANGED