nn 1.6 → 1.8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/document.txt +3 -2
- data/lib/nn.rb +4 -4
- data/lib/nn/cifar10.rb +53 -0
- data/lib/nn/mnist.rb +0 -0
- metadata +3 -3
- data/.gitignore +0 -8
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ac474651871e2134d4e2372cf8254d21f3e2e6e53173d9b0a2897e72a5c20979
|
4
|
+
data.tar.gz: fe0813922ccb9f7a1351d9bc4a7377ed99c7ff4e4140b537074a7540afe8ed69
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9b84f7be4c0aa1b0bf00d24b3e5e7df1f47cc4c73174a4b54c9d8b1e691db414c95ed174213b3aa7275d26752684cb9cd927d6b2e879951e1f278deb15727010
|
7
|
+
data.tar.gz: 7b50ce21afdb88cd306d1e38fa1c7718740cdc313e2ed669558f1d251c439cc939d996fe85cf76b4b16d9b902ebacecf343b81065fc4b0ccc84578d64400d948
|
data/document.txt
CHANGED
@@ -35,7 +35,7 @@ initialize(num_nodes,
|
|
35
35
|
learning_rate: 0.01,
|
36
36
|
batch_size: 1,
|
37
37
|
activation: [:relu, :identity],
|
38
|
-
momentum: 0
|
38
|
+
momentum: 0,
|
39
39
|
weight_decay: 0,
|
40
40
|
use_dropout: false,
|
41
41
|
dropout_ratio: 0.5,
|
@@ -211,4 +211,5 @@ end
|
|
211
211
|
2018/3/14 バージョン1.3公開
|
212
212
|
2018/3/18 バージョン1.4公開
|
213
213
|
2018/3/22 バージョン1.5公開
|
214
|
-
2018/4/15 バージョン1.6公開
|
214
|
+
2018/4/15 バージョン1.6公開
|
215
|
+
2018/5/4 バージョン1.8公開
|
data/lib/nn.rb
CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
|
|
2
2
|
require "json"
|
3
3
|
|
4
4
|
class NN
|
5
|
-
VERSION = "1.
|
5
|
+
VERSION = "1.8"
|
6
6
|
|
7
7
|
include Numo
|
8
8
|
|
@@ -22,7 +22,7 @@ class NN
|
|
22
22
|
learning_rate: 0.01,
|
23
23
|
batch_size: 1,
|
24
24
|
activation: %i(relu identity),
|
25
|
-
momentum: 0
|
25
|
+
momentum: 0,
|
26
26
|
weight_decay: 0,
|
27
27
|
use_dropout: false,
|
28
28
|
dropout_ratio: 0.5,
|
@@ -79,7 +79,7 @@ class NN
|
|
79
79
|
loss = learn(x_train, y_train, &block)
|
80
80
|
if loss.nan?
|
81
81
|
puts "loss is nan"
|
82
|
-
|
82
|
+
return
|
83
83
|
end
|
84
84
|
end
|
85
85
|
if save_dir && epoch % save_interval == 0
|
@@ -112,6 +112,7 @@ class NN
|
|
112
112
|
y = SFloat.zeros(@batch_size, @num_nodes.last)
|
113
113
|
@batch_size.times do |j|
|
114
114
|
k = i * @batch_size + j
|
115
|
+
break if k >= num_test_data
|
115
116
|
if x_test.is_a?(SFloat)
|
116
117
|
x[j, true] = x_test[k, true]
|
117
118
|
y[j, true] = y_test[k, true]
|
@@ -216,7 +217,6 @@ class NN
|
|
216
217
|
@beta_amounts = Array.new(@num_nodes.length - 2, 0)
|
217
218
|
end
|
218
219
|
|
219
|
-
|
220
220
|
def init_layers
|
221
221
|
@layers = []
|
222
222
|
@num_nodes[0...-2].each_index do |i|
|
data/lib/nn/cifar10.rb
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
module CIFAR10
|
2
|
+
def self.load_train(index)
|
3
|
+
if File.exist?("CIFAR-10-train#{index}.marshal")
|
4
|
+
marshal = File.binread("CIFAR-10-train#{index}.marshal")
|
5
|
+
return Marshal.load(marshal)
|
6
|
+
end
|
7
|
+
bin = File.binread("#{dir}/data_batch_#{index}.bin")
|
8
|
+
datasets = bin.unpack("C*")
|
9
|
+
x_train = []
|
10
|
+
y_train = []
|
11
|
+
loop do
|
12
|
+
label = datasets.shift
|
13
|
+
break unless label
|
14
|
+
x_train << datasets.slice!(0, 3072)
|
15
|
+
y_train << label
|
16
|
+
end
|
17
|
+
train = [x_train, y_train]
|
18
|
+
File.binwrite("CIFAR-10-train#{index}.marshal", Marshal.dump(train))
|
19
|
+
train
|
20
|
+
end
|
21
|
+
|
22
|
+
def self.load_test
|
23
|
+
if File.exist?("CIFAR-10-test.marshal")
|
24
|
+
marshal = File.binread("CIFAR-10-test.marshal")
|
25
|
+
return Marshal.load(marshal)
|
26
|
+
end
|
27
|
+
bin = File.binread("#{dir}/test_batch.bin")
|
28
|
+
datasets = bin.unpack("C*")
|
29
|
+
x_test = []
|
30
|
+
y_test = []
|
31
|
+
loop do
|
32
|
+
label = datasets.shift
|
33
|
+
break unless label
|
34
|
+
x_test << datasets.slice!(0, 3072)
|
35
|
+
y_test << label
|
36
|
+
end
|
37
|
+
test = [x_test, y_test]
|
38
|
+
File.binwrite("CIFAR-10-test.marshal", Marshal.dump(test))
|
39
|
+
test
|
40
|
+
end
|
41
|
+
|
42
|
+
def self.categorical(y_data)
|
43
|
+
y_data = y_data.map do |label|
|
44
|
+
classes = Array.new(10, 0)
|
45
|
+
classes[label] = 1
|
46
|
+
classes
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
def self.dir
|
51
|
+
"cifar-10-batches-bin"
|
52
|
+
end
|
53
|
+
end
|
data/lib/nn/mnist.rb
CHANGED
File without changes
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: nn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: '1.
|
4
|
+
version: '1.8'
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-05-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -59,7 +59,6 @@ executables: []
|
|
59
59
|
extensions: []
|
60
60
|
extra_rdoc_files: []
|
61
61
|
files:
|
62
|
-
- ".gitignore"
|
63
62
|
- Gemfile
|
64
63
|
- LICENSE.txt
|
65
64
|
- README.md
|
@@ -68,6 +67,7 @@ files:
|
|
68
67
|
- bin/setup
|
69
68
|
- document.txt
|
70
69
|
- lib/nn.rb
|
70
|
+
- lib/nn/cifar10.rb
|
71
71
|
- lib/nn/mnist.rb
|
72
72
|
- nn.gemspec
|
73
73
|
homepage: https://github.com/unagiootoro/nn.git
|