nn 1.6 → 1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8afd54182615146fe3023709fdcecdbdf8ef1857e048611efcde6abab91f00fc
4
- data.tar.gz: 4d1f66eadf5c5c33ba416d30a7898af4f999c5a541d6018d914f10b3d0840285
3
+ metadata.gz: ac474651871e2134d4e2372cf8254d21f3e2e6e53173d9b0a2897e72a5c20979
4
+ data.tar.gz: fe0813922ccb9f7a1351d9bc4a7377ed99c7ff4e4140b537074a7540afe8ed69
5
5
  SHA512:
6
- metadata.gz: de6ec3e4db6512fb71d20ac10db29eae197a57e5893befce8ab3493d2a4f22fc1d6b0a4c7779ff2031114d3541068c5095da32b35daad35348e50c0209794b87
7
- data.tar.gz: 6f0aa482117bd51dcc255e923450646dbf5ae2df3506b1da466a162285824b3924602b79f19aaf255b6bf741f3fc706ab0c5daccf8aef1c7353e3087a0efb9cf
6
+ metadata.gz: 9b84f7be4c0aa1b0bf00d24b3e5e7df1f47cc4c73174a4b54c9d8b1e691db414c95ed174213b3aa7275d26752684cb9cd927d6b2e879951e1f278deb15727010
7
+ data.tar.gz: 7b50ce21afdb88cd306d1e38fa1c7718740cdc313e2ed669558f1d251c439cc939d996fe85cf76b4b16d9b902ebacecf343b81065fc4b0ccc84578d64400d948
@@ -35,7 +35,7 @@ initialize(num_nodes,
35
35
  learning_rate: 0.01,
36
36
  batch_size: 1,
37
37
  activation: [:relu, :identity],
38
- momentum: 0.9,
38
+ momentum: 0,
39
39
  weight_decay: 0,
40
40
  use_dropout: false,
41
41
  dropout_ratio: 0.5,
@@ -211,4 +211,5 @@ end
211
211
  2018/3/14 バージョン1.3公開
212
212
  2018/3/18 バージョン1.4公開
213
213
  2018/3/22 バージョン1.5公開
214
- 2018/4/15 バージョン1.6公開
214
+ 2018/4/15 バージョン1.6公開
215
+ 2018/5/4 バージョン1.8公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "1.6"
5
+ VERSION = "1.8"
6
6
 
7
7
  include Numo
8
8
 
@@ -22,7 +22,7 @@ class NN
22
22
  learning_rate: 0.01,
23
23
  batch_size: 1,
24
24
  activation: %i(relu identity),
25
- momentum: 0.9,
25
+ momentum: 0,
26
26
  weight_decay: 0,
27
27
  use_dropout: false,
28
28
  dropout_ratio: 0.5,
@@ -79,7 +79,7 @@ class NN
79
79
  loss = learn(x_train, y_train, &block)
80
80
  if loss.nan?
81
81
  puts "loss is nan"
82
- break
82
+ return
83
83
  end
84
84
  end
85
85
  if save_dir && epoch % save_interval == 0
@@ -112,6 +112,7 @@ class NN
112
112
  y = SFloat.zeros(@batch_size, @num_nodes.last)
113
113
  @batch_size.times do |j|
114
114
  k = i * @batch_size + j
115
+ break if k >= num_test_data
115
116
  if x_test.is_a?(SFloat)
116
117
  x[j, true] = x_test[k, true]
117
118
  y[j, true] = y_test[k, true]
@@ -216,7 +217,6 @@ class NN
216
217
  @beta_amounts = Array.new(@num_nodes.length - 2, 0)
217
218
  end
218
219
 
219
-
220
220
  def init_layers
221
221
  @layers = []
222
222
  @num_nodes[0...-2].each_index do |i|
@@ -0,0 +1,53 @@
1
+ module CIFAR10
2
+ def self.load_train(index)
3
+ if File.exist?("CIFAR-10-train#{index}.marshal")
4
+ marshal = File.binread("CIFAR-10-train#{index}.marshal")
5
+ return Marshal.load(marshal)
6
+ end
7
+ bin = File.binread("#{dir}/data_batch_#{index}.bin")
8
+ datasets = bin.unpack("C*")
9
+ x_train = []
10
+ y_train = []
11
+ loop do
12
+ label = datasets.shift
13
+ break unless label
14
+ x_train << datasets.slice!(0, 3072)
15
+ y_train << label
16
+ end
17
+ train = [x_train, y_train]
18
+ File.binwrite("CIFAR-10-train#{index}.marshal", Marshal.dump(train))
19
+ train
20
+ end
21
+
22
+ def self.load_test
23
+ if File.exist?("CIFAR-10-test.marshal")
24
+ marshal = File.binread("CIFAR-10-test.marshal")
25
+ return Marshal.load(marshal)
26
+ end
27
+ bin = File.binread("#{dir}/test_batch.bin")
28
+ datasets = bin.unpack("C*")
29
+ x_test = []
30
+ y_test = []
31
+ loop do
32
+ label = datasets.shift
33
+ break unless label
34
+ x_test << datasets.slice!(0, 3072)
35
+ y_test << label
36
+ end
37
+ test = [x_test, y_test]
38
+ File.binwrite("CIFAR-10-test.marshal", Marshal.dump(test))
39
+ test
40
+ end
41
+
42
+ def self.categorical(y_data)
43
+ y_data = y_data.map do |label|
44
+ classes = Array.new(10, 0)
45
+ classes[label] = 1
46
+ classes
47
+ end
48
+ end
49
+
50
+ def self.dir
51
+ "cifar-10-batches-bin"
52
+ end
53
+ end
File without changes
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: '1.6'
4
+ version: '1.8'
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-04-15 00:00:00.000000000 Z
11
+ date: 2018-05-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -59,7 +59,6 @@ executables: []
59
59
  extensions: []
60
60
  extra_rdoc_files: []
61
61
  files:
62
- - ".gitignore"
63
62
  - Gemfile
64
63
  - LICENSE.txt
65
64
  - README.md
@@ -68,6 +67,7 @@ files:
68
67
  - bin/setup
69
68
  - document.txt
70
69
  - lib/nn.rb
70
+ - lib/nn/cifar10.rb
71
71
  - lib/nn/mnist.rb
72
72
  - nn.gemspec
73
73
  homepage: https://github.com/unagiootoro/nn.git
data/.gitignore DELETED
@@ -1,8 +0,0 @@
1
- /.bundle/
2
- /.yardoc
3
- /_yardoc/
4
- /coverage/
5
- /doc/
6
- /pkg/
7
- /spec/reports/
8
- /tmp/