nn 1.6 → 1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/document.txt +3 -2
- data/lib/nn.rb +4 -4
- data/lib/nn/cifar10.rb +53 -0
- data/lib/nn/mnist.rb +0 -0
- metadata +3 -3
- data/.gitignore +0 -8
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ac474651871e2134d4e2372cf8254d21f3e2e6e53173d9b0a2897e72a5c20979
|
4
|
+
data.tar.gz: fe0813922ccb9f7a1351d9bc4a7377ed99c7ff4e4140b537074a7540afe8ed69
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9b84f7be4c0aa1b0bf00d24b3e5e7df1f47cc4c73174a4b54c9d8b1e691db414c95ed174213b3aa7275d26752684cb9cd927d6b2e879951e1f278deb15727010
|
7
|
+
data.tar.gz: 7b50ce21afdb88cd306d1e38fa1c7718740cdc313e2ed669558f1d251c439cc939d996fe85cf76b4b16d9b902ebacecf343b81065fc4b0ccc84578d64400d948
|
data/document.txt
CHANGED
@@ -35,7 +35,7 @@ initialize(num_nodes,
|
|
35
35
|
learning_rate: 0.01,
|
36
36
|
batch_size: 1,
|
37
37
|
activation: [:relu, :identity],
|
38
|
-
momentum: 0
|
38
|
+
momentum: 0,
|
39
39
|
weight_decay: 0,
|
40
40
|
use_dropout: false,
|
41
41
|
dropout_ratio: 0.5,
|
@@ -211,4 +211,5 @@ end
|
|
211
211
|
2018/3/14 バージョン1.3公開
|
212
212
|
2018/3/18 バージョン1.4公開
|
213
213
|
2018/3/22 バージョン1.5公開
|
214
|
-
2018/4/15 バージョン1.6公開
|
214
|
+
2018/4/15 バージョン1.6公開
|
215
|
+
2018/5/4 バージョン1.8公開
|
data/lib/nn.rb
CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
|
|
2
2
|
require "json"
|
3
3
|
|
4
4
|
class NN
|
5
|
-
VERSION = "1.
|
5
|
+
VERSION = "1.8"
|
6
6
|
|
7
7
|
include Numo
|
8
8
|
|
@@ -22,7 +22,7 @@ class NN
|
|
22
22
|
learning_rate: 0.01,
|
23
23
|
batch_size: 1,
|
24
24
|
activation: %i(relu identity),
|
25
|
-
momentum: 0
|
25
|
+
momentum: 0,
|
26
26
|
weight_decay: 0,
|
27
27
|
use_dropout: false,
|
28
28
|
dropout_ratio: 0.5,
|
@@ -79,7 +79,7 @@ class NN
|
|
79
79
|
loss = learn(x_train, y_train, &block)
|
80
80
|
if loss.nan?
|
81
81
|
puts "loss is nan"
|
82
|
-
|
82
|
+
return
|
83
83
|
end
|
84
84
|
end
|
85
85
|
if save_dir && epoch % save_interval == 0
|
@@ -112,6 +112,7 @@ class NN
|
|
112
112
|
y = SFloat.zeros(@batch_size, @num_nodes.last)
|
113
113
|
@batch_size.times do |j|
|
114
114
|
k = i * @batch_size + j
|
115
|
+
break if k >= num_test_data
|
115
116
|
if x_test.is_a?(SFloat)
|
116
117
|
x[j, true] = x_test[k, true]
|
117
118
|
y[j, true] = y_test[k, true]
|
@@ -216,7 +217,6 @@ class NN
|
|
216
217
|
@beta_amounts = Array.new(@num_nodes.length - 2, 0)
|
217
218
|
end
|
218
219
|
|
219
|
-
|
220
220
|
def init_layers
|
221
221
|
@layers = []
|
222
222
|
@num_nodes[0...-2].each_index do |i|
|
data/lib/nn/cifar10.rb
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
module CIFAR10
|
2
|
+
def self.load_train(index)
|
3
|
+
if File.exist?("CIFAR-10-train#{index}.marshal")
|
4
|
+
marshal = File.binread("CIFAR-10-train#{index}.marshal")
|
5
|
+
return Marshal.load(marshal)
|
6
|
+
end
|
7
|
+
bin = File.binread("#{dir}/data_batch_#{index}.bin")
|
8
|
+
datasets = bin.unpack("C*")
|
9
|
+
x_train = []
|
10
|
+
y_train = []
|
11
|
+
loop do
|
12
|
+
label = datasets.shift
|
13
|
+
break unless label
|
14
|
+
x_train << datasets.slice!(0, 3072)
|
15
|
+
y_train << label
|
16
|
+
end
|
17
|
+
train = [x_train, y_train]
|
18
|
+
File.binwrite("CIFAR-10-train#{index}.marshal", Marshal.dump(train))
|
19
|
+
train
|
20
|
+
end
|
21
|
+
|
22
|
+
def self.load_test
|
23
|
+
if File.exist?("CIFAR-10-test.marshal")
|
24
|
+
marshal = File.binread("CIFAR-10-test.marshal")
|
25
|
+
return Marshal.load(marshal)
|
26
|
+
end
|
27
|
+
bin = File.binread("#{dir}/test_batch.bin")
|
28
|
+
datasets = bin.unpack("C*")
|
29
|
+
x_test = []
|
30
|
+
y_test = []
|
31
|
+
loop do
|
32
|
+
label = datasets.shift
|
33
|
+
break unless label
|
34
|
+
x_test << datasets.slice!(0, 3072)
|
35
|
+
y_test << label
|
36
|
+
end
|
37
|
+
test = [x_test, y_test]
|
38
|
+
File.binwrite("CIFAR-10-test.marshal", Marshal.dump(test))
|
39
|
+
test
|
40
|
+
end
|
41
|
+
|
42
|
+
def self.categorical(y_data)
|
43
|
+
y_data = y_data.map do |label|
|
44
|
+
classes = Array.new(10, 0)
|
45
|
+
classes[label] = 1
|
46
|
+
classes
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
def self.dir
|
51
|
+
"cifar-10-batches-bin"
|
52
|
+
end
|
53
|
+
end
|
data/lib/nn/mnist.rb
CHANGED
File without changes
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: nn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: '1.
|
4
|
+
version: '1.8'
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-05-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -59,7 +59,6 @@ executables: []
|
|
59
59
|
extensions: []
|
60
60
|
extra_rdoc_files: []
|
61
61
|
files:
|
62
|
-
- ".gitignore"
|
63
62
|
- Gemfile
|
64
63
|
- LICENSE.txt
|
65
64
|
- README.md
|
@@ -68,6 +67,7 @@ files:
|
|
68
67
|
- bin/setup
|
69
68
|
- document.txt
|
70
69
|
- lib/nn.rb
|
70
|
+
- lib/nn/cifar10.rb
|
71
71
|
- lib/nn/mnist.rb
|
72
72
|
- nn.gemspec
|
73
73
|
homepage: https://github.com/unagiootoro/nn.git
|