nn 2.0.1 → 2.1.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5fd7f6dd2b015169de254f1fc68d8566a5ffa7e14bfc8bd9581239c96a6f6965
4
- data.tar.gz: 232ee0dbbd7a16fce35e56e58ea40b79dc76b066a2c93dff576d48a34958ab66
3
+ metadata.gz: 8e26103663135846646308af8fa48851d0f56d34122bfdb6172e1a8bd4675bdf
4
+ data.tar.gz: cffd4270bae567540673a1ec327ea416b14c3794172b7db3c589c0053bf15466
5
5
  SHA512:
6
- metadata.gz: 6323e2716a4ef835665b8955ce699ee5843649dfbaa469257d4b164c2702f39a508f2034510c001785e671dd781e36b4afc863f16d1305865114925876aa807a
7
- data.tar.gz: da603d96fc83ab378cd34233624487b39fbd137037af5e4ca2d306a5c63cafc2193585d94405fb3f624ab7f9749e8587e86fc842a7cb5845e37001c70476f122
6
+ metadata.gz: abaa7793eb7e44cd44d9f49f63d7cd252f01e11e4c41d7e709cf6c8124b923d300ff6c2e2b07f410f4dbeaaa8f53a82f876e22582681ec4178364110128726ab
7
+ data.tar.gz: c8f830238afd0a767f5cc34a7be600c2fc481c591f8f8f433c13d0ad9647011e49f3b0ff4b87d5784a0c8ce2c01552445dc1540e252321cd856eea23c4a6f84d
data/document.txt CHANGED
@@ -13,6 +13,11 @@ class NN
13
13
 
14
14
  <クラスメソッド>
15
15
  load(file_name) : NN
16
+ Marshal形式で保存された学習結果を読み込みます。
17
+ String file_name 読み込むMarshalファイル名
18
+ 戻り値 NNのインスタンス
19
+
20
+ load_json(file_name) : NN
16
21
  JSON形式で保存された学習結果を読み込みます。
17
22
  String file_name 読み込むJSONファイル名
18
23
  戻り値 NNのインスタンス
@@ -109,6 +114,10 @@ run(x) : SFloat
109
114
  戻り値 出力ノードの値
110
115
 
111
116
  save(file_name) : void
117
+ 学習結果をMarshal形式で保存します。
118
+ String file_name 書き込むMarshalファイル名
119
+
120
+ save_json(file_name) : void
112
121
  学習結果をJSON形式で保存します。
113
122
  String file_name 書き込むJSONファイル名
114
123
 
@@ -142,3 +151,4 @@ http://d.hatena.ne.jp/n_shuyo/20090913/mnist
142
151
  2018/5/4 バージョン1.8公開
143
152
  2018/5/16 バージョン2.0公開
144
153
  2018/6/10 バージョン2.0.1公開
154
+ 2018/6/10 バージョン2.1.0公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "2.0"
5
+ VERSION = "2.1"
6
6
 
7
7
  include Numo
8
8
 
@@ -44,6 +44,10 @@ class NN
44
44
  end
45
45
 
46
46
  def self.load(file_name)
47
+ Marshal.load(File.binread(file_name))
48
+ end
49
+
50
+ def self.load_json(file_name)
47
51
  json = JSON.parse(File.read(file_name))
48
52
  nn = self.new(json["num_nodes"],
49
53
  learning_rate: json["learning_rate"],
@@ -150,6 +154,10 @@ class NN
150
154
  end
151
155
 
152
156
  def save(file_name)
157
+ File.binwrite(file_name, Marshal.dump(self))
158
+ end
159
+
160
+ def save_json(file_name)
153
161
  json = {
154
162
  "version" => VERSION,
155
163
  "num_nodes" => @num_nodes,
@@ -240,8 +248,8 @@ class NN
240
248
 
241
249
  def update_weight_and_bias
242
250
  @layers.select{|layer| layer.is_a?(Affine)}.each.with_index do |layer, i|
243
- weight_amount = layer.d_weight.mean(0) * @learning_rate
244
- bias_amount = layer.d_bias.mean * @learning_rate
251
+ weight_amount = layer.d_weight * @learning_rate
252
+ bias_amount = layer.d_bias * @learning_rate
245
253
  if @momentum > 0
246
254
  weight_amount += @momentum * @weight_amounts[i]
247
255
  @weight_amounts[i] = weight_amount
@@ -255,8 +263,8 @@ class NN
255
263
 
256
264
  def update_gamma_and_beta
257
265
  @layers.select{|layer| layer.is_a?(BatchNorm)}.each.with_index do |layer, i|
258
- gamma_amount = layer.d_gamma.mean * @learning_rate
259
- beta_amount = layer.d_beta.mean * @learning_rate
266
+ gamma_amount = layer.d_gamma * @learning_rate
267
+ beta_amount = layer.d_beta * @learning_rate
260
268
  if @momentum > 0
261
269
  gamma_amount += @momentum * @gamma_amounts[i]
262
270
  @gamma_amounts[i] = gamma_amount
@@ -290,20 +298,22 @@ class NN::Affine
290
298
 
291
299
  def backward(dout)
292
300
  x = @x.reshape(*@x.shape, 1)
293
- @d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1]))
301
+ @d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1])).mean(0)
294
302
  if @nn.weight_decay > 0
295
303
  dridge = @nn.weight_decay * @nn.weights[@index]
296
304
  @d_weight += dridge
297
305
  end
298
- @d_bias = dout
306
+ @d_bias = dout.mean
299
307
  dout.dot(@nn.weights[@index].transpose)
300
308
  end
301
309
  end
302
310
 
303
311
 
304
312
  class NN::Sigmoid
313
+ include Numo
314
+
305
315
  def forward(x)
306
- @out = 1.0 / (1 + Numo::NMath.exp(-x))
316
+ @out = 1.0 / (1 + NMath.exp(-x))
307
317
  end
308
318
 
309
319
  def backward(dout)
@@ -328,8 +338,6 @@ end
328
338
 
329
339
 
330
340
  class NN::Identity
331
- include Numo
332
-
333
341
  def initialize(nn)
334
342
  @nn = nn
335
343
  end
@@ -419,8 +427,8 @@ class NN::BatchNorm
419
427
  end
420
428
 
421
429
  def backward(dout)
422
- @d_beta = dout.sum(0)
423
- @d_gamma = (@xn * dout).sum(0)
430
+ @d_beta = dout.sum(0).mean
431
+ @d_gamma = (@xn * dout).sum(0).mean
424
432
  dxn = @nn.gammas[@index] * dout
425
433
  dxc = dxn / @std
426
434
  dstd = -((dxn * @xc) / (@std ** 2)).sum(0)
data/nn.gemspec CHANGED
@@ -5,7 +5,7 @@ require "nn"
5
5
 
6
6
  Gem::Specification.new do |spec|
7
7
  spec.name = "nn"
8
- spec.version = NN::VERSION + ".1"
8
+ spec.version = NN::VERSION + ".0"
9
9
  spec.authors = ["unagiootoro"]
10
10
  spec.email = ["ootoro838861@outlook.jp"]
11
11
 
@@ -32,7 +32,7 @@ func = -> x, y do
32
32
  [x, y]
33
33
  end
34
34
 
35
- nn.train(x_train, y_train, 50, func) do |epoch|
35
+ nn.train(x_train, y_train, 20, func) do |epoch|
36
36
  nn.test(x_test, y_test, &func)
37
37
  nn.learning_rate *= 0.99
38
38
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: 2.0.1
4
+ version: 2.1.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro