nn 1.5 → 1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (5) hide show
  1. checksums.yaml +4 -4
  2. data/document.txt +6 -4
  3. data/lib/nn.rb +31 -23
  4. metadata +2 -3
  5. data/lib/nn/version.rb +0 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ae8f441634517a886dbdcee8b7186c5d02710c84dc52b381bb29d0d219f958f3
4
- data.tar.gz: 3af1f9f95a8727aec20100b1d82ef598dda3857915641a3b73ba16cf372841de
3
+ metadata.gz: 8afd54182615146fe3023709fdcecdbdf8ef1857e048611efcde6abab91f00fc
4
+ data.tar.gz: 4d1f66eadf5c5c33ba416d30a7898af4f999c5a541d6018d914f10b3d0840285
5
5
  SHA512:
6
- metadata.gz: 2774e79cddbc52530d9f00cd34bb981f7beaa7f6d1c402b6205a41dbe7949fb40dedd2e6c97e9f1a38abc64b5e3bf1bffc40843c943bd0f88dcba2e9bd52f202
7
- data.tar.gz: c45b37a70bfddb2a31f1682d4268e943098e228b16102d61b74d2d26dd526b2fd491c4ec4d7a4758cfe838a7511061d4e87b263652acb46a5e29e14d298676bf
6
+ metadata.gz: de6ec3e4db6512fb71d20ac10db29eae197a57e5893befce8ab3493d2a4f22fc1d6b0a4c7779ff2031114d3541068c5095da32b35daad35348e50c0209794b87
7
+ data.tar.gz: 6f0aa482117bd51dcc255e923450646dbf5ae2df3506b1da466a162285824b3924602b79f19aaf255b6bf741f3fc706ab0c5daccf8aef1c7353e3087a0efb9cf
@@ -35,7 +35,7 @@ initialize(num_nodes,
35
35
  learning_rate: 0.01,
36
36
  batch_size: 1,
37
37
  activation: [:relu, :identity],
38
- momentum: 0,
38
+ momentum: 0.9,
39
39
  weight_decay: 0,
40
40
  use_dropout: false,
41
41
  dropout_ratio: 0.5,
@@ -52,7 +52,8 @@ initialize(num_nodes,
52
52
  Float dropout_ratio ドロップアウトさせるノードの比率
53
53
  bool use_batch_norm バッチノーマライゼーションを使用するか否か
54
54
 
55
- train(x_train, y_train, x_test, y_test, epoch,
55
+ train(x_train, y_train, x_test, y_test, epochs,
56
+ learning_rate_decay: 0,
56
57
  save_dir: nil,
57
58
  save_interval: 1,
58
59
  test: nil,
@@ -62,7 +63,8 @@ train(x_train, y_train, x_test, y_test, epoch,
62
63
  学習を行います。
63
64
  Array<Array<Numeric>> | SFloat x_train トレーニング用入力データ。
64
65
  Array<Array<Numeric>> | SFloat y_train トレーニング用正解データ。
65
- Integer epoch 学習回数。入力データすべてを見たタイミングを1エポックとします。
66
+ Integer epochs 学習回数。入力データすべてを見たタイミングを1エポックとします。
67
+ Float learning_rate_decay 1エポックごとに学習率を減衰される値。
66
68
  String save_dir 学習中にセーブを行う場合、セーブするディレクトリを指定します。nilの場合、セーブを行いません。
67
69
  Integer save_interval 学習中にセーブするタイミングをエポック単位で指定します。
68
70
  Array<Array<Array<Numeric>> | SFloat> test テストで使用するデータ。[x_test, y_test]の形式で指定してください。
@@ -209,4 +211,4 @@ end
209
211
  2018/3/14 バージョン1.3公開
210
212
  2018/3/18 バージョン1.4公開
211
213
  2018/3/22 バージョン1.5公開
212
- 2018/3/27 RubyGemに公開
214
+ 2018/4/15 バージョン1.6公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "1.5"
5
+ VERSION = "1.6"
6
6
 
7
7
  include Numo
8
8
 
@@ -22,7 +22,7 @@ class NN
22
22
  learning_rate: 0.01,
23
23
  batch_size: 1,
24
24
  activation: %i(relu identity),
25
- momentum: 0,
25
+ momentum: 0.9,
26
26
  weight_decay: 0,
27
27
  use_dropout: false,
28
28
  dropout_ratio: 0.5,
@@ -64,29 +64,37 @@ class NN
64
64
  nn
65
65
  end
66
66
 
67
- def train(x_train, y_train, epoch,
68
- save_dir: nil, save_interval: 1, test: nil, border: nil, tolerance: 0.5, &block)
67
+ def train(x_train, y_train, epochs,
68
+ learning_rate_decay: 0,
69
+ save_dir: nil,
70
+ save_interval: 1,
71
+ test: nil,
72
+ border: nil,
73
+ tolerance: 0.5,
74
+ &block)
69
75
  num_train_data = x_train.is_a?(SFloat) ? x_train.shape[0] : x_train.length
70
- (epoch * num_train_data / @batch_size).times do |count|
71
- loss = learn(x_train, y_train, &block)
72
- if loss.nan?
73
- puts "loss is nan"
74
- break
75
- end
76
- if (count + 1) % (num_train_data / @batch_size) == 0
77
- now_epoch = (count + 1) / (num_train_data / @batch_size)
78
- if save_dir && now_epoch % save_interval == 0
79
- save("#{save_dir}/epoch#{now_epoch}.json")
80
- end
81
- msg = "epoch #{now_epoch}/#{epoch} loss: #{loss}"
82
- if test
83
- acc = accurate(*test, tolerance, &block)
84
- puts "#{msg} accurate: #{acc}"
85
- break if border && acc >= border
86
- else
87
- puts msg
76
+ (1..epochs).each do |epoch|
77
+ loss = nil
78
+ (num_train_data.to_f / @batch_size).ceil.times do
79
+ loss = learn(x_train, y_train, &block)
80
+ if loss.nan?
81
+ puts "loss is nan"
82
+ break
88
83
  end
89
84
  end
85
+ if save_dir && epoch % save_interval == 0
86
+ save("#{save_dir}/epoch#{epoch}.json")
87
+ end
88
+ msg = "epoch #{epoch}/#{epochs} loss: #{loss}"
89
+ if test
90
+ acc = accurate(*test, tolerance, &block)
91
+ puts "#{msg} accurate: #{acc}"
92
+ break if border && acc >= border
93
+ else
94
+ puts msg
95
+ end
96
+ @learning_rate -= learning_rate_decay
97
+ @learning_rate = 1e-7 if @learning_rate < 1e-7
90
98
  end
91
99
  end
92
100
 
@@ -99,7 +107,7 @@ class NN
99
107
  def accurate(x_test, y_test, tolerance = 0.5, &block)
100
108
  correct = 0
101
109
  num_test_data = x_test.is_a?(SFloat) ? x_test.shape[0] : x_test.length
102
- (num_test_data / @batch_size).times do |i|
110
+ (num_test_data.to_f / @batch_size).ceil.times do |i|
103
111
  x = SFloat.zeros(@batch_size, @num_nodes.first)
104
112
  y = SFloat.zeros(@batch_size, @num_nodes.last)
105
113
  @batch_size.times do |j|
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: '1.5'
4
+ version: '1.6'
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-03-27 00:00:00.000000000 Z
11
+ date: 2018-04-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -69,7 +69,6 @@ files:
69
69
  - document.txt
70
70
  - lib/nn.rb
71
71
  - lib/nn/mnist.rb
72
- - lib/nn/version.rb
73
72
  - nn.gemspec
74
73
  homepage: https://github.com/unagiootoro/nn.git
75
74
  licenses:
@@ -1,2 +0,0 @@
1
- require "nn"
2
-