nn 1.5 → 1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/document.txt +6 -4
- data/lib/nn.rb +31 -23
- metadata +2 -3
- data/lib/nn/version.rb +0 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8afd54182615146fe3023709fdcecdbdf8ef1857e048611efcde6abab91f00fc
|
4
|
+
data.tar.gz: 4d1f66eadf5c5c33ba416d30a7898af4f999c5a541d6018d914f10b3d0840285
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: de6ec3e4db6512fb71d20ac10db29eae197a57e5893befce8ab3493d2a4f22fc1d6b0a4c7779ff2031114d3541068c5095da32b35daad35348e50c0209794b87
|
7
|
+
data.tar.gz: 6f0aa482117bd51dcc255e923450646dbf5ae2df3506b1da466a162285824b3924602b79f19aaf255b6bf741f3fc706ab0c5daccf8aef1c7353e3087a0efb9cf
|
data/document.txt
CHANGED
@@ -35,7 +35,7 @@ initialize(num_nodes,
|
|
35
35
|
learning_rate: 0.01,
|
36
36
|
batch_size: 1,
|
37
37
|
activation: [:relu, :identity],
|
38
|
-
momentum: 0,
|
38
|
+
momentum: 0.9,
|
39
39
|
weight_decay: 0,
|
40
40
|
use_dropout: false,
|
41
41
|
dropout_ratio: 0.5,
|
@@ -52,7 +52,8 @@ initialize(num_nodes,
|
|
52
52
|
Float dropout_ratio ドロップアウトさせるノードの比率
|
53
53
|
bool use_batch_norm バッチノーマライゼーションを使用するか否か
|
54
54
|
|
55
|
-
train(x_train, y_train, x_test, y_test,
|
55
|
+
train(x_train, y_train, x_test, y_test, epochs,
|
56
|
+
learning_rate_decay: 0,
|
56
57
|
save_dir: nil,
|
57
58
|
save_interval: 1,
|
58
59
|
test: nil,
|
@@ -62,7 +63,8 @@ train(x_train, y_train, x_test, y_test, epoch,
|
|
62
63
|
学習を行います。
|
63
64
|
Array<Array<Numeric>> | SFloat x_train トレーニング用入力データ。
|
64
65
|
Array<Array<Numeric>> | SFloat y_train トレーニング用正解データ。
|
65
|
-
Integer
|
66
|
+
Integer epochs 学習回数。入力データすべてを見たタイミングを1エポックとします。
|
67
|
+
Float learning_rate_decay 1エポックごとに学習率を減衰される値。
|
66
68
|
String save_dir 学習中にセーブを行う場合、セーブするディレクトリを指定します。nilの場合、セーブを行いません。
|
67
69
|
Integer save_interval 学習中にセーブするタイミングをエポック単位で指定します。
|
68
70
|
Array<Array<Array<Numeric>> | SFloat> test テストで使用するデータ。[x_test, y_test]の形式で指定してください。
|
@@ -209,4 +211,4 @@ end
|
|
209
211
|
2018/3/14 バージョン1.3公開
|
210
212
|
2018/3/18 バージョン1.4公開
|
211
213
|
2018/3/22 バージョン1.5公開
|
212
|
-
2018/
|
214
|
+
2018/4/15 バージョン1.6公開
|
data/lib/nn.rb
CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
|
|
2
2
|
require "json"
|
3
3
|
|
4
4
|
class NN
|
5
|
-
VERSION = "1.
|
5
|
+
VERSION = "1.6"
|
6
6
|
|
7
7
|
include Numo
|
8
8
|
|
@@ -22,7 +22,7 @@ class NN
|
|
22
22
|
learning_rate: 0.01,
|
23
23
|
batch_size: 1,
|
24
24
|
activation: %i(relu identity),
|
25
|
-
momentum: 0,
|
25
|
+
momentum: 0.9,
|
26
26
|
weight_decay: 0,
|
27
27
|
use_dropout: false,
|
28
28
|
dropout_ratio: 0.5,
|
@@ -64,29 +64,37 @@ class NN
|
|
64
64
|
nn
|
65
65
|
end
|
66
66
|
|
67
|
-
def train(x_train, y_train,
|
68
|
-
|
67
|
+
def train(x_train, y_train, epochs,
|
68
|
+
learning_rate_decay: 0,
|
69
|
+
save_dir: nil,
|
70
|
+
save_interval: 1,
|
71
|
+
test: nil,
|
72
|
+
border: nil,
|
73
|
+
tolerance: 0.5,
|
74
|
+
&block)
|
69
75
|
num_train_data = x_train.is_a?(SFloat) ? x_train.shape[0] : x_train.length
|
70
|
-
(
|
71
|
-
loss =
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
now_epoch = (count + 1) / (num_train_data / @batch_size)
|
78
|
-
if save_dir && now_epoch % save_interval == 0
|
79
|
-
save("#{save_dir}/epoch#{now_epoch}.json")
|
80
|
-
end
|
81
|
-
msg = "epoch #{now_epoch}/#{epoch} loss: #{loss}"
|
82
|
-
if test
|
83
|
-
acc = accurate(*test, tolerance, &block)
|
84
|
-
puts "#{msg} accurate: #{acc}"
|
85
|
-
break if border && acc >= border
|
86
|
-
else
|
87
|
-
puts msg
|
76
|
+
(1..epochs).each do |epoch|
|
77
|
+
loss = nil
|
78
|
+
(num_train_data.to_f / @batch_size).ceil.times do
|
79
|
+
loss = learn(x_train, y_train, &block)
|
80
|
+
if loss.nan?
|
81
|
+
puts "loss is nan"
|
82
|
+
break
|
88
83
|
end
|
89
84
|
end
|
85
|
+
if save_dir && epoch % save_interval == 0
|
86
|
+
save("#{save_dir}/epoch#{epoch}.json")
|
87
|
+
end
|
88
|
+
msg = "epoch #{epoch}/#{epochs} loss: #{loss}"
|
89
|
+
if test
|
90
|
+
acc = accurate(*test, tolerance, &block)
|
91
|
+
puts "#{msg} accurate: #{acc}"
|
92
|
+
break if border && acc >= border
|
93
|
+
else
|
94
|
+
puts msg
|
95
|
+
end
|
96
|
+
@learning_rate -= learning_rate_decay
|
97
|
+
@learning_rate = 1e-7 if @learning_rate < 1e-7
|
90
98
|
end
|
91
99
|
end
|
92
100
|
|
@@ -99,7 +107,7 @@ class NN
|
|
99
107
|
def accurate(x_test, y_test, tolerance = 0.5, &block)
|
100
108
|
correct = 0
|
101
109
|
num_test_data = x_test.is_a?(SFloat) ? x_test.shape[0] : x_test.length
|
102
|
-
(num_test_data / @batch_size).times do |i|
|
110
|
+
(num_test_data.to_f / @batch_size).ceil.times do |i|
|
103
111
|
x = SFloat.zeros(@batch_size, @num_nodes.first)
|
104
112
|
y = SFloat.zeros(@batch_size, @num_nodes.last)
|
105
113
|
@batch_size.times do |j|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: nn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: '1.
|
4
|
+
version: '1.6'
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-04-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -69,7 +69,6 @@ files:
|
|
69
69
|
- document.txt
|
70
70
|
- lib/nn.rb
|
71
71
|
- lib/nn/mnist.rb
|
72
|
-
- lib/nn/version.rb
|
73
72
|
- nn.gemspec
|
74
73
|
homepage: https://github.com/unagiootoro/nn.git
|
75
74
|
licenses:
|
data/lib/nn/version.rb
DELETED