nn 2.0.1 → 2.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5fd7f6dd2b015169de254f1fc68d8566a5ffa7e14bfc8bd9581239c96a6f6965
4
- data.tar.gz: 232ee0dbbd7a16fce35e56e58ea40b79dc76b066a2c93dff576d48a34958ab66
3
+ metadata.gz: 8e26103663135846646308af8fa48851d0f56d34122bfdb6172e1a8bd4675bdf
4
+ data.tar.gz: cffd4270bae567540673a1ec327ea416b14c3794172b7db3c589c0053bf15466
5
5
  SHA512:
6
- metadata.gz: 6323e2716a4ef835665b8955ce699ee5843649dfbaa469257d4b164c2702f39a508f2034510c001785e671dd781e36b4afc863f16d1305865114925876aa807a
7
- data.tar.gz: da603d96fc83ab378cd34233624487b39fbd137037af5e4ca2d306a5c63cafc2193585d94405fb3f624ab7f9749e8587e86fc842a7cb5845e37001c70476f122
6
+ metadata.gz: abaa7793eb7e44cd44d9f49f63d7cd252f01e11e4c41d7e709cf6c8124b923d300ff6c2e2b07f410f4dbeaaa8f53a82f876e22582681ec4178364110128726ab
7
+ data.tar.gz: c8f830238afd0a767f5cc34a7be600c2fc481c591f8f8f433c13d0ad9647011e49f3b0ff4b87d5784a0c8ce2c01552445dc1540e252321cd856eea23c4a6f84d
data/document.txt CHANGED
@@ -13,6 +13,11 @@ class NN
13
13
 
14
14
  <クラスメソッド>
15
15
  load(file_name) : NN
16
+ Marshal形式で保存された学習結果を読み込みます。
17
+ String file_name 読み込むMarshalファイル名
18
+ 戻り値 NNのインスタンス
19
+
20
+ load_json(file_name) : NN
16
21
  JSON形式で保存された学習結果を読み込みます。
17
22
  String file_name 読み込むJSONファイル名
18
23
  戻り値 NNのインスタンス
@@ -109,6 +114,10 @@ run(x) : SFloat
109
114
  戻り値 出力ノードの値
110
115
 
111
116
  save(file_name) : void
117
+ 学習結果をMarshal形式で保存します。
118
+ String file_name 書き込むMarshalファイル名
119
+
120
+ save_json(file_name) : void
112
121
  学習結果をJSON形式で保存します。
113
122
  String file_name 書き込むJSONファイル名
114
123
 
@@ -142,3 +151,4 @@ http://d.hatena.ne.jp/n_shuyo/20090913/mnist
142
151
  2018/5/4 バージョン1.8公開
143
152
  2018/5/16 バージョン2.0公開
144
153
  2018/6/10 バージョン2.0.1公開
154
+ 2018/6/10 バージョン2.1.0公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "2.0"
5
+ VERSION = "2.1"
6
6
 
7
7
  include Numo
8
8
 
@@ -44,6 +44,10 @@ class NN
44
44
  end
45
45
 
46
46
  def self.load(file_name)
47
+ Marshal.load(File.binread(file_name))
48
+ end
49
+
50
+ def self.load_json(file_name)
47
51
  json = JSON.parse(File.read(file_name))
48
52
  nn = self.new(json["num_nodes"],
49
53
  learning_rate: json["learning_rate"],
@@ -150,6 +154,10 @@ class NN
150
154
  end
151
155
 
152
156
  def save(file_name)
157
+ File.binwrite(file_name, Marshal.dump(self))
158
+ end
159
+
160
+ def save_json(file_name)
153
161
  json = {
154
162
  "version" => VERSION,
155
163
  "num_nodes" => @num_nodes,
@@ -240,8 +248,8 @@ class NN
240
248
 
241
249
  def update_weight_and_bias
242
250
  @layers.select{|layer| layer.is_a?(Affine)}.each.with_index do |layer, i|
243
- weight_amount = layer.d_weight.mean(0) * @learning_rate
244
- bias_amount = layer.d_bias.mean * @learning_rate
251
+ weight_amount = layer.d_weight * @learning_rate
252
+ bias_amount = layer.d_bias * @learning_rate
245
253
  if @momentum > 0
246
254
  weight_amount += @momentum * @weight_amounts[i]
247
255
  @weight_amounts[i] = weight_amount
@@ -255,8 +263,8 @@ class NN
255
263
 
256
264
  def update_gamma_and_beta
257
265
  @layers.select{|layer| layer.is_a?(BatchNorm)}.each.with_index do |layer, i|
258
- gamma_amount = layer.d_gamma.mean * @learning_rate
259
- beta_amount = layer.d_beta.mean * @learning_rate
266
+ gamma_amount = layer.d_gamma * @learning_rate
267
+ beta_amount = layer.d_beta * @learning_rate
260
268
  if @momentum > 0
261
269
  gamma_amount += @momentum * @gamma_amounts[i]
262
270
  @gamma_amounts[i] = gamma_amount
@@ -290,20 +298,22 @@ class NN::Affine
290
298
 
291
299
  def backward(dout)
292
300
  x = @x.reshape(*@x.shape, 1)
293
- @d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1]))
301
+ @d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1])).mean(0)
294
302
  if @nn.weight_decay > 0
295
303
  dridge = @nn.weight_decay * @nn.weights[@index]
296
304
  @d_weight += dridge
297
305
  end
298
- @d_bias = dout
306
+ @d_bias = dout.mean
299
307
  dout.dot(@nn.weights[@index].transpose)
300
308
  end
301
309
  end
302
310
 
303
311
 
304
312
  class NN::Sigmoid
313
+ include Numo
314
+
305
315
  def forward(x)
306
- @out = 1.0 / (1 + Numo::NMath.exp(-x))
316
+ @out = 1.0 / (1 + NMath.exp(-x))
307
317
  end
308
318
 
309
319
  def backward(dout)
@@ -328,8 +338,6 @@ end
328
338
 
329
339
 
330
340
  class NN::Identity
331
- include Numo
332
-
333
341
  def initialize(nn)
334
342
  @nn = nn
335
343
  end
@@ -419,8 +427,8 @@ class NN::BatchNorm
419
427
  end
420
428
 
421
429
  def backward(dout)
422
- @d_beta = dout.sum(0)
423
- @d_gamma = (@xn * dout).sum(0)
430
+ @d_beta = dout.sum(0).mean
431
+ @d_gamma = (@xn * dout).sum(0).mean
424
432
  dxn = @nn.gammas[@index] * dout
425
433
  dxc = dxn / @std
426
434
  dstd = -((dxn * @xc) / (@std ** 2)).sum(0)
data/nn.gemspec CHANGED
@@ -5,7 +5,7 @@ require "nn"
5
5
 
6
6
  Gem::Specification.new do |spec|
7
7
  spec.name = "nn"
8
- spec.version = NN::VERSION + ".1"
8
+ spec.version = NN::VERSION + ".0"
9
9
  spec.authors = ["unagiootoro"]
10
10
  spec.email = ["ootoro838861@outlook.jp"]
11
11
 
@@ -32,7 +32,7 @@ func = -> x, y do
32
32
  [x, y]
33
33
  end
34
34
 
35
- nn.train(x_train, y_train, 50, func) do |epoch|
35
+ nn.train(x_train, y_train, 20, func) do |epoch|
36
36
  nn.test(x_test, y_test, &func)
37
37
  nn.learning_rate *= 0.99
38
38
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: 2.0.1
4
+ version: 2.1.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro