nn 1.6 → 1.8

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8afd54182615146fe3023709fdcecdbdf8ef1857e048611efcde6abab91f00fc
4
- data.tar.gz: 4d1f66eadf5c5c33ba416d30a7898af4f999c5a541d6018d914f10b3d0840285
3
+ metadata.gz: ac474651871e2134d4e2372cf8254d21f3e2e6e53173d9b0a2897e72a5c20979
4
+ data.tar.gz: fe0813922ccb9f7a1351d9bc4a7377ed99c7ff4e4140b537074a7540afe8ed69
5
5
  SHA512:
6
- metadata.gz: de6ec3e4db6512fb71d20ac10db29eae197a57e5893befce8ab3493d2a4f22fc1d6b0a4c7779ff2031114d3541068c5095da32b35daad35348e50c0209794b87
7
- data.tar.gz: 6f0aa482117bd51dcc255e923450646dbf5ae2df3506b1da466a162285824b3924602b79f19aaf255b6bf741f3fc706ab0c5daccf8aef1c7353e3087a0efb9cf
6
+ metadata.gz: 9b84f7be4c0aa1b0bf00d24b3e5e7df1f47cc4c73174a4b54c9d8b1e691db414c95ed174213b3aa7275d26752684cb9cd927d6b2e879951e1f278deb15727010
7
+ data.tar.gz: 7b50ce21afdb88cd306d1e38fa1c7718740cdc313e2ed669558f1d251c439cc939d996fe85cf76b4b16d9b902ebacecf343b81065fc4b0ccc84578d64400d948
@@ -35,7 +35,7 @@ initialize(num_nodes,
35
35
  learning_rate: 0.01,
36
36
  batch_size: 1,
37
37
  activation: [:relu, :identity],
38
- momentum: 0.9,
38
+ momentum: 0,
39
39
  weight_decay: 0,
40
40
  use_dropout: false,
41
41
  dropout_ratio: 0.5,
@@ -211,4 +211,5 @@ end
211
211
  2018/3/14 バージョン1.3公開
212
212
  2018/3/18 バージョン1.4公開
213
213
  2018/3/22 バージョン1.5公開
214
- 2018/4/15 バージョン1.6公開
214
+ 2018/4/15 バージョン1.6公開
215
+ 2018/5/4 バージョン1.8公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "1.6"
5
+ VERSION = "1.8"
6
6
 
7
7
  include Numo
8
8
 
@@ -22,7 +22,7 @@ class NN
22
22
  learning_rate: 0.01,
23
23
  batch_size: 1,
24
24
  activation: %i(relu identity),
25
- momentum: 0.9,
25
+ momentum: 0,
26
26
  weight_decay: 0,
27
27
  use_dropout: false,
28
28
  dropout_ratio: 0.5,
@@ -79,7 +79,7 @@ class NN
79
79
  loss = learn(x_train, y_train, &block)
80
80
  if loss.nan?
81
81
  puts "loss is nan"
82
- break
82
+ return
83
83
  end
84
84
  end
85
85
  if save_dir && epoch % save_interval == 0
@@ -112,6 +112,7 @@ class NN
112
112
  y = SFloat.zeros(@batch_size, @num_nodes.last)
113
113
  @batch_size.times do |j|
114
114
  k = i * @batch_size + j
115
+ break if k >= num_test_data
115
116
  if x_test.is_a?(SFloat)
116
117
  x[j, true] = x_test[k, true]
117
118
  y[j, true] = y_test[k, true]
@@ -216,7 +217,6 @@ class NN
216
217
  @beta_amounts = Array.new(@num_nodes.length - 2, 0)
217
218
  end
218
219
 
219
-
220
220
  def init_layers
221
221
  @layers = []
222
222
  @num_nodes[0...-2].each_index do |i|
@@ -0,0 +1,53 @@
1
+ module CIFAR10
2
+ def self.load_train(index)
3
+ if File.exist?("CIFAR-10-train#{index}.marshal")
4
+ marshal = File.binread("CIFAR-10-train#{index}.marshal")
5
+ return Marshal.load(marshal)
6
+ end
7
+ bin = File.binread("#{dir}/data_batch_#{index}.bin")
8
+ datasets = bin.unpack("C*")
9
+ x_train = []
10
+ y_train = []
11
+ loop do
12
+ label = datasets.shift
13
+ break unless label
14
+ x_train << datasets.slice!(0, 3072)
15
+ y_train << label
16
+ end
17
+ train = [x_train, y_train]
18
+ File.binwrite("CIFAR-10-train#{index}.marshal", Marshal.dump(train))
19
+ train
20
+ end
21
+
22
+ def self.load_test
23
+ if File.exist?("CIFAR-10-test.marshal")
24
+ marshal = File.binread("CIFAR-10-test.marshal")
25
+ return Marshal.load(marshal)
26
+ end
27
+ bin = File.binread("#{dir}/test_batch.bin")
28
+ datasets = bin.unpack("C*")
29
+ x_test = []
30
+ y_test = []
31
+ loop do
32
+ label = datasets.shift
33
+ break unless label
34
+ x_test << datasets.slice!(0, 3072)
35
+ y_test << label
36
+ end
37
+ test = [x_test, y_test]
38
+ File.binwrite("CIFAR-10-test.marshal", Marshal.dump(test))
39
+ test
40
+ end
41
+
42
+ def self.categorical(y_data)
43
+ y_data = y_data.map do |label|
44
+ classes = Array.new(10, 0)
45
+ classes[label] = 1
46
+ classes
47
+ end
48
+ end
49
+
50
+ def self.dir
51
+ "cifar-10-batches-bin"
52
+ end
53
+ end
File without changes
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: '1.6'
4
+ version: '1.8'
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-04-15 00:00:00.000000000 Z
11
+ date: 2018-05-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -59,7 +59,6 @@ executables: []
59
59
  extensions: []
60
60
  extra_rdoc_files: []
61
61
  files:
62
- - ".gitignore"
63
62
  - Gemfile
64
63
  - LICENSE.txt
65
64
  - README.md
@@ -68,6 +67,7 @@ files:
68
67
  - bin/setup
69
68
  - document.txt
70
69
  - lib/nn.rb
70
+ - lib/nn/cifar10.rb
71
71
  - lib/nn/mnist.rb
72
72
  - nn.gemspec
73
73
  homepage: https://github.com/unagiootoro/nn.git
data/.gitignore DELETED
@@ -1,8 +0,0 @@
1
- /.bundle/
2
- /.yardoc
3
- /_yardoc/
4
- /coverage/
5
- /doc/
6
- /pkg/
7
- /spec/reports/
8
- /tmp/