nn 1.5 → 1.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (5) hide show
  1. checksums.yaml +4 -4
  2. data/document.txt +6 -4
  3. data/lib/nn.rb +31 -23
  4. metadata +2 -3
  5. data/lib/nn/version.rb +0 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ae8f441634517a886dbdcee8b7186c5d02710c84dc52b381bb29d0d219f958f3
4
- data.tar.gz: 3af1f9f95a8727aec20100b1d82ef598dda3857915641a3b73ba16cf372841de
3
+ metadata.gz: 8afd54182615146fe3023709fdcecdbdf8ef1857e048611efcde6abab91f00fc
4
+ data.tar.gz: 4d1f66eadf5c5c33ba416d30a7898af4f999c5a541d6018d914f10b3d0840285
5
5
  SHA512:
6
- metadata.gz: 2774e79cddbc52530d9f00cd34bb981f7beaa7f6d1c402b6205a41dbe7949fb40dedd2e6c97e9f1a38abc64b5e3bf1bffc40843c943bd0f88dcba2e9bd52f202
7
- data.tar.gz: c45b37a70bfddb2a31f1682d4268e943098e228b16102d61b74d2d26dd526b2fd491c4ec4d7a4758cfe838a7511061d4e87b263652acb46a5e29e14d298676bf
6
+ metadata.gz: de6ec3e4db6512fb71d20ac10db29eae197a57e5893befce8ab3493d2a4f22fc1d6b0a4c7779ff2031114d3541068c5095da32b35daad35348e50c0209794b87
7
+ data.tar.gz: 6f0aa482117bd51dcc255e923450646dbf5ae2df3506b1da466a162285824b3924602b79f19aaf255b6bf741f3fc706ab0c5daccf8aef1c7353e3087a0efb9cf
@@ -35,7 +35,7 @@ initialize(num_nodes,
35
35
  learning_rate: 0.01,
36
36
  batch_size: 1,
37
37
  activation: [:relu, :identity],
38
- momentum: 0,
38
+ momentum: 0.9,
39
39
  weight_decay: 0,
40
40
  use_dropout: false,
41
41
  dropout_ratio: 0.5,
@@ -52,7 +52,8 @@ initialize(num_nodes,
52
52
  Float dropout_ratio ドロップアウトさせるノードの比率
53
53
  bool use_batch_norm バッチノーマライゼーションを使用するか否か
54
54
 
55
- train(x_train, y_train, x_test, y_test, epoch,
55
+ train(x_train, y_train, x_test, y_test, epochs,
56
+ learning_rate_decay: 0,
56
57
  save_dir: nil,
57
58
  save_interval: 1,
58
59
  test: nil,
@@ -62,7 +63,8 @@ train(x_train, y_train, x_test, y_test, epoch,
62
63
  学習を行います。
63
64
  Array<Array<Numeric>> | SFloat x_train トレーニング用入力データ。
64
65
  Array<Array<Numeric>> | SFloat y_train トレーニング用正解データ。
65
- Integer epoch 学習回数。入力データすべてを見たタイミングを1エポックとします。
66
+ Integer epochs 学習回数。入力データすべてを見たタイミングを1エポックとします。
67
+ Float learning_rate_decay 1エポックごとに学習率を減衰される値。
66
68
  String save_dir 学習中にセーブを行う場合、セーブするディレクトリを指定します。nilの場合、セーブを行いません。
67
69
  Integer save_interval 学習中にセーブするタイミングをエポック単位で指定します。
68
70
  Array<Array<Array<Numeric>> | SFloat> test テストで使用するデータ。[x_test, y_test]の形式で指定してください。
@@ -209,4 +211,4 @@ end
209
211
  2018/3/14 バージョン1.3公開
210
212
  2018/3/18 バージョン1.4公開
211
213
  2018/3/22 バージョン1.5公開
212
- 2018/3/27 RubyGemに公開
214
+ 2018/4/15 バージョン1.6公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "1.5"
5
+ VERSION = "1.6"
6
6
 
7
7
  include Numo
8
8
 
@@ -22,7 +22,7 @@ class NN
22
22
  learning_rate: 0.01,
23
23
  batch_size: 1,
24
24
  activation: %i(relu identity),
25
- momentum: 0,
25
+ momentum: 0.9,
26
26
  weight_decay: 0,
27
27
  use_dropout: false,
28
28
  dropout_ratio: 0.5,
@@ -64,29 +64,37 @@ class NN
64
64
  nn
65
65
  end
66
66
 
67
- def train(x_train, y_train, epoch,
68
- save_dir: nil, save_interval: 1, test: nil, border: nil, tolerance: 0.5, &block)
67
+ def train(x_train, y_train, epochs,
68
+ learning_rate_decay: 0,
69
+ save_dir: nil,
70
+ save_interval: 1,
71
+ test: nil,
72
+ border: nil,
73
+ tolerance: 0.5,
74
+ &block)
69
75
  num_train_data = x_train.is_a?(SFloat) ? x_train.shape[0] : x_train.length
70
- (epoch * num_train_data / @batch_size).times do |count|
71
- loss = learn(x_train, y_train, &block)
72
- if loss.nan?
73
- puts "loss is nan"
74
- break
75
- end
76
- if (count + 1) % (num_train_data / @batch_size) == 0
77
- now_epoch = (count + 1) / (num_train_data / @batch_size)
78
- if save_dir && now_epoch % save_interval == 0
79
- save("#{save_dir}/epoch#{now_epoch}.json")
80
- end
81
- msg = "epoch #{now_epoch}/#{epoch} loss: #{loss}"
82
- if test
83
- acc = accurate(*test, tolerance, &block)
84
- puts "#{msg} accurate: #{acc}"
85
- break if border && acc >= border
86
- else
87
- puts msg
76
+ (1..epochs).each do |epoch|
77
+ loss = nil
78
+ (num_train_data.to_f / @batch_size).ceil.times do
79
+ loss = learn(x_train, y_train, &block)
80
+ if loss.nan?
81
+ puts "loss is nan"
82
+ break
88
83
  end
89
84
  end
85
+ if save_dir && epoch % save_interval == 0
86
+ save("#{save_dir}/epoch#{epoch}.json")
87
+ end
88
+ msg = "epoch #{epoch}/#{epochs} loss: #{loss}"
89
+ if test
90
+ acc = accurate(*test, tolerance, &block)
91
+ puts "#{msg} accurate: #{acc}"
92
+ break if border && acc >= border
93
+ else
94
+ puts msg
95
+ end
96
+ @learning_rate -= learning_rate_decay
97
+ @learning_rate = 1e-7 if @learning_rate < 1e-7
90
98
  end
91
99
  end
92
100
 
@@ -99,7 +107,7 @@ class NN
99
107
  def accurate(x_test, y_test, tolerance = 0.5, &block)
100
108
  correct = 0
101
109
  num_test_data = x_test.is_a?(SFloat) ? x_test.shape[0] : x_test.length
102
- (num_test_data / @batch_size).times do |i|
110
+ (num_test_data.to_f / @batch_size).ceil.times do |i|
103
111
  x = SFloat.zeros(@batch_size, @num_nodes.first)
104
112
  y = SFloat.zeros(@batch_size, @num_nodes.last)
105
113
  @batch_size.times do |j|
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: '1.5'
4
+ version: '1.6'
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-03-27 00:00:00.000000000 Z
11
+ date: 2018-04-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -69,7 +69,6 @@ files:
69
69
  - document.txt
70
70
  - lib/nn.rb
71
71
  - lib/nn/mnist.rb
72
- - lib/nn/version.rb
73
72
  - nn.gemspec
74
73
  homepage: https://github.com/unagiootoro/nn.git
75
74
  licenses:
@@ -1,2 +0,0 @@
1
- require "nn"
2
-