nn 2.0.1 → 2.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/document.txt +10 -0
- data/lib/nn.rb +20 -12
- data/nn.gemspec +1 -1
- data/sample/cifar10_program.rb +1 -1
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8e26103663135846646308af8fa48851d0f56d34122bfdb6172e1a8bd4675bdf
|
4
|
+
data.tar.gz: cffd4270bae567540673a1ec327ea416b14c3794172b7db3c589c0053bf15466
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: abaa7793eb7e44cd44d9f49f63d7cd252f01e11e4c41d7e709cf6c8124b923d300ff6c2e2b07f410f4dbeaaa8f53a82f876e22582681ec4178364110128726ab
|
7
|
+
data.tar.gz: c8f830238afd0a767f5cc34a7be600c2fc481c591f8f8f433c13d0ad9647011e49f3b0ff4b87d5784a0c8ce2c01552445dc1540e252321cd856eea23c4a6f84d
|
data/document.txt
CHANGED
@@ -13,6 +13,11 @@ class NN
|
|
13
13
|
|
14
14
|
<クラスメソッド>
|
15
15
|
load(file_name) : NN
|
16
|
+
Marshal形式で保存された学習結果を読み込みます。
|
17
|
+
String file_name 読み込むMarshalファイル名
|
18
|
+
戻り値 NNのインスタンス
|
19
|
+
|
20
|
+
load_json(file_name) : NN
|
16
21
|
JSON形式で保存された学習結果を読み込みます。
|
17
22
|
String file_name 読み込むJSONファイル名
|
18
23
|
戻り値 NNのインスタンス
|
@@ -109,6 +114,10 @@ run(x) : SFloat
|
|
109
114
|
戻り値 出力ノードの値
|
110
115
|
|
111
116
|
save(file_name) : void
|
117
|
+
学習結果をMarshal形式で保存します。
|
118
|
+
String file_name 書き込むMarshalファイル名
|
119
|
+
|
120
|
+
save_json(file_name) : void
|
112
121
|
学習結果をJSON形式で保存します。
|
113
122
|
String file_name 書き込むJSONファイル名
|
114
123
|
|
@@ -142,3 +151,4 @@ http://d.hatena.ne.jp/n_shuyo/20090913/mnist
|
|
142
151
|
2018/5/4 バージョン1.8公開
|
143
152
|
2018/5/16 バージョン2.0公開
|
144
153
|
2018/6/10 バージョン2.0.1公開
|
154
|
+
2018/6/10 バージョン2.1.0公開
|
data/lib/nn.rb
CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
|
|
2
2
|
require "json"
|
3
3
|
|
4
4
|
class NN
|
5
|
-
VERSION = "2.
|
5
|
+
VERSION = "2.1"
|
6
6
|
|
7
7
|
include Numo
|
8
8
|
|
@@ -44,6 +44,10 @@ class NN
|
|
44
44
|
end
|
45
45
|
|
46
46
|
def self.load(file_name)
|
47
|
+
Marshal.load(File.binread(file_name))
|
48
|
+
end
|
49
|
+
|
50
|
+
def self.load_json(file_name)
|
47
51
|
json = JSON.parse(File.read(file_name))
|
48
52
|
nn = self.new(json["num_nodes"],
|
49
53
|
learning_rate: json["learning_rate"],
|
@@ -150,6 +154,10 @@ class NN
|
|
150
154
|
end
|
151
155
|
|
152
156
|
def save(file_name)
|
157
|
+
File.binwrite(file_name, Marshal.dump(self))
|
158
|
+
end
|
159
|
+
|
160
|
+
def save_json(file_name)
|
153
161
|
json = {
|
154
162
|
"version" => VERSION,
|
155
163
|
"num_nodes" => @num_nodes,
|
@@ -240,8 +248,8 @@ class NN
|
|
240
248
|
|
241
249
|
def update_weight_and_bias
|
242
250
|
@layers.select{|layer| layer.is_a?(Affine)}.each.with_index do |layer, i|
|
243
|
-
weight_amount = layer.d_weight
|
244
|
-
bias_amount = layer.d_bias
|
251
|
+
weight_amount = layer.d_weight * @learning_rate
|
252
|
+
bias_amount = layer.d_bias * @learning_rate
|
245
253
|
if @momentum > 0
|
246
254
|
weight_amount += @momentum * @weight_amounts[i]
|
247
255
|
@weight_amounts[i] = weight_amount
|
@@ -255,8 +263,8 @@ class NN
|
|
255
263
|
|
256
264
|
def update_gamma_and_beta
|
257
265
|
@layers.select{|layer| layer.is_a?(BatchNorm)}.each.with_index do |layer, i|
|
258
|
-
gamma_amount = layer.d_gamma
|
259
|
-
beta_amount = layer.d_beta
|
266
|
+
gamma_amount = layer.d_gamma * @learning_rate
|
267
|
+
beta_amount = layer.d_beta * @learning_rate
|
260
268
|
if @momentum > 0
|
261
269
|
gamma_amount += @momentum * @gamma_amounts[i]
|
262
270
|
@gamma_amounts[i] = gamma_amount
|
@@ -290,20 +298,22 @@ class NN::Affine
|
|
290
298
|
|
291
299
|
def backward(dout)
|
292
300
|
x = @x.reshape(*@x.shape, 1)
|
293
|
-
@d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1]))
|
301
|
+
@d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1])).mean(0)
|
294
302
|
if @nn.weight_decay > 0
|
295
303
|
dridge = @nn.weight_decay * @nn.weights[@index]
|
296
304
|
@d_weight += dridge
|
297
305
|
end
|
298
|
-
@d_bias = dout
|
306
|
+
@d_bias = dout.mean
|
299
307
|
dout.dot(@nn.weights[@index].transpose)
|
300
308
|
end
|
301
309
|
end
|
302
310
|
|
303
311
|
|
304
312
|
class NN::Sigmoid
|
313
|
+
include Numo
|
314
|
+
|
305
315
|
def forward(x)
|
306
|
-
@out = 1.0 / (1 +
|
316
|
+
@out = 1.0 / (1 + NMath.exp(-x))
|
307
317
|
end
|
308
318
|
|
309
319
|
def backward(dout)
|
@@ -328,8 +338,6 @@ end
|
|
328
338
|
|
329
339
|
|
330
340
|
class NN::Identity
|
331
|
-
include Numo
|
332
|
-
|
333
341
|
def initialize(nn)
|
334
342
|
@nn = nn
|
335
343
|
end
|
@@ -419,8 +427,8 @@ class NN::BatchNorm
|
|
419
427
|
end
|
420
428
|
|
421
429
|
def backward(dout)
|
422
|
-
@d_beta = dout.sum(0)
|
423
|
-
@d_gamma = (@xn * dout).sum(0)
|
430
|
+
@d_beta = dout.sum(0).mean
|
431
|
+
@d_gamma = (@xn * dout).sum(0).mean
|
424
432
|
dxn = @nn.gammas[@index] * dout
|
425
433
|
dxc = dxn / @std
|
426
434
|
dstd = -((dxn * @xc) / (@std ** 2)).sum(0)
|
data/nn.gemspec
CHANGED
data/sample/cifar10_program.rb
CHANGED