nn 1.5 → 1.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/document.txt +6 -4
- data/lib/nn.rb +31 -23
- metadata +2 -3
- data/lib/nn/version.rb +0 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8afd54182615146fe3023709fdcecdbdf8ef1857e048611efcde6abab91f00fc
|
4
|
+
data.tar.gz: 4d1f66eadf5c5c33ba416d30a7898af4f999c5a541d6018d914f10b3d0840285
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: de6ec3e4db6512fb71d20ac10db29eae197a57e5893befce8ab3493d2a4f22fc1d6b0a4c7779ff2031114d3541068c5095da32b35daad35348e50c0209794b87
|
7
|
+
data.tar.gz: 6f0aa482117bd51dcc255e923450646dbf5ae2df3506b1da466a162285824b3924602b79f19aaf255b6bf741f3fc706ab0c5daccf8aef1c7353e3087a0efb9cf
|
data/document.txt
CHANGED
@@ -35,7 +35,7 @@ initialize(num_nodes,
|
|
35
35
|
learning_rate: 0.01,
|
36
36
|
batch_size: 1,
|
37
37
|
activation: [:relu, :identity],
|
38
|
-
momentum: 0,
|
38
|
+
momentum: 0.9,
|
39
39
|
weight_decay: 0,
|
40
40
|
use_dropout: false,
|
41
41
|
dropout_ratio: 0.5,
|
@@ -52,7 +52,8 @@ initialize(num_nodes,
|
|
52
52
|
Float dropout_ratio ドロップアウトさせるノードの比率
|
53
53
|
bool use_batch_norm バッチノーマライゼーションを使用するか否か
|
54
54
|
|
55
|
-
train(x_train, y_train, x_test, y_test,
|
55
|
+
train(x_train, y_train, x_test, y_test, epochs,
|
56
|
+
learning_rate_decay: 0,
|
56
57
|
save_dir: nil,
|
57
58
|
save_interval: 1,
|
58
59
|
test: nil,
|
@@ -62,7 +63,8 @@ train(x_train, y_train, x_test, y_test, epoch,
|
|
62
63
|
学習を行います。
|
63
64
|
Array<Array<Numeric>> | SFloat x_train トレーニング用入力データ。
|
64
65
|
Array<Array<Numeric>> | SFloat y_train トレーニング用正解データ。
|
65
|
-
Integer
|
66
|
+
Integer epochs 学習回数。入力データすべてを見たタイミングを1エポックとします。
|
67
|
+
Float learning_rate_decay 1エポックごとに学習率を減衰される値。
|
66
68
|
String save_dir 学習中にセーブを行う場合、セーブするディレクトリを指定します。nilの場合、セーブを行いません。
|
67
69
|
Integer save_interval 学習中にセーブするタイミングをエポック単位で指定します。
|
68
70
|
Array<Array<Array<Numeric>> | SFloat> test テストで使用するデータ。[x_test, y_test]の形式で指定してください。
|
@@ -209,4 +211,4 @@ end
|
|
209
211
|
2018/3/14 バージョン1.3公開
|
210
212
|
2018/3/18 バージョン1.4公開
|
211
213
|
2018/3/22 バージョン1.5公開
|
212
|
-
2018/
|
214
|
+
2018/4/15 バージョン1.6公開
|
data/lib/nn.rb
CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
|
|
2
2
|
require "json"
|
3
3
|
|
4
4
|
class NN
|
5
|
-
VERSION = "1.
|
5
|
+
VERSION = "1.6"
|
6
6
|
|
7
7
|
include Numo
|
8
8
|
|
@@ -22,7 +22,7 @@ class NN
|
|
22
22
|
learning_rate: 0.01,
|
23
23
|
batch_size: 1,
|
24
24
|
activation: %i(relu identity),
|
25
|
-
momentum: 0,
|
25
|
+
momentum: 0.9,
|
26
26
|
weight_decay: 0,
|
27
27
|
use_dropout: false,
|
28
28
|
dropout_ratio: 0.5,
|
@@ -64,29 +64,37 @@ class NN
|
|
64
64
|
nn
|
65
65
|
end
|
66
66
|
|
67
|
-
def train(x_train, y_train,
|
68
|
-
|
67
|
+
def train(x_train, y_train, epochs,
|
68
|
+
learning_rate_decay: 0,
|
69
|
+
save_dir: nil,
|
70
|
+
save_interval: 1,
|
71
|
+
test: nil,
|
72
|
+
border: nil,
|
73
|
+
tolerance: 0.5,
|
74
|
+
&block)
|
69
75
|
num_train_data = x_train.is_a?(SFloat) ? x_train.shape[0] : x_train.length
|
70
|
-
(
|
71
|
-
loss =
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
now_epoch = (count + 1) / (num_train_data / @batch_size)
|
78
|
-
if save_dir && now_epoch % save_interval == 0
|
79
|
-
save("#{save_dir}/epoch#{now_epoch}.json")
|
80
|
-
end
|
81
|
-
msg = "epoch #{now_epoch}/#{epoch} loss: #{loss}"
|
82
|
-
if test
|
83
|
-
acc = accurate(*test, tolerance, &block)
|
84
|
-
puts "#{msg} accurate: #{acc}"
|
85
|
-
break if border && acc >= border
|
86
|
-
else
|
87
|
-
puts msg
|
76
|
+
(1..epochs).each do |epoch|
|
77
|
+
loss = nil
|
78
|
+
(num_train_data.to_f / @batch_size).ceil.times do
|
79
|
+
loss = learn(x_train, y_train, &block)
|
80
|
+
if loss.nan?
|
81
|
+
puts "loss is nan"
|
82
|
+
break
|
88
83
|
end
|
89
84
|
end
|
85
|
+
if save_dir && epoch % save_interval == 0
|
86
|
+
save("#{save_dir}/epoch#{epoch}.json")
|
87
|
+
end
|
88
|
+
msg = "epoch #{epoch}/#{epochs} loss: #{loss}"
|
89
|
+
if test
|
90
|
+
acc = accurate(*test, tolerance, &block)
|
91
|
+
puts "#{msg} accurate: #{acc}"
|
92
|
+
break if border && acc >= border
|
93
|
+
else
|
94
|
+
puts msg
|
95
|
+
end
|
96
|
+
@learning_rate -= learning_rate_decay
|
97
|
+
@learning_rate = 1e-7 if @learning_rate < 1e-7
|
90
98
|
end
|
91
99
|
end
|
92
100
|
|
@@ -99,7 +107,7 @@ class NN
|
|
99
107
|
def accurate(x_test, y_test, tolerance = 0.5, &block)
|
100
108
|
correct = 0
|
101
109
|
num_test_data = x_test.is_a?(SFloat) ? x_test.shape[0] : x_test.length
|
102
|
-
(num_test_data / @batch_size).times do |i|
|
110
|
+
(num_test_data.to_f / @batch_size).ceil.times do |i|
|
103
111
|
x = SFloat.zeros(@batch_size, @num_nodes.first)
|
104
112
|
y = SFloat.zeros(@batch_size, @num_nodes.last)
|
105
113
|
@batch_size.times do |j|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: nn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: '1.
|
4
|
+
version: '1.6'
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-04-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -69,7 +69,6 @@ files:
|
|
69
69
|
- document.txt
|
70
70
|
- lib/nn.rb
|
71
71
|
- lib/nn/mnist.rb
|
72
|
-
- lib/nn/version.rb
|
73
72
|
- nn.gemspec
|
74
73
|
homepage: https://github.com/unagiootoro/nn.git
|
75
74
|
licenses:
|
data/lib/nn/version.rb
DELETED