nn 2.2.0 → 2.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a22a2db0abf0ff95fda30fcfdc92ec12847bd033b544cdb87ce007f69cb9f760
4
- data.tar.gz: 91528c11fa35c99e0a3b0d84b863266a0107e810d8493d0489f6c7c372483fbe
3
+ metadata.gz: dae759f927b41bad2f29c7d4739cf198cf0080e7054d2e8276d518db94bb33f3
4
+ data.tar.gz: 99594955a3e1d8801a2415413387cc9ca86b068a25e3f079ce15d19226318952
5
5
  SHA512:
6
- metadata.gz: d6cf94cf86c9c0e160f369466a146b18b5114bafff755ed6ff4cb88a98898727fb67a94ce04a117b3d90761028be57777e7fb7effdca01a023586ba6a84cd1f1
7
- data.tar.gz: 27cf471cece4749dea22e2a470d5e342fd33de689ff38191cfc0445a6df3655b0da5f91552766fbd2db10e0ba3bd7a17997a863a2811db2b937bddaea5dc5a60
6
+ metadata.gz: 2612677d78320574f714002006b8be85f5abe5d0b0eb107dbb83d8be67eebc99f077c62771325f4d06e372ba54e4a60f9f1c7edbb5bea3214a610790502e9b0e
7
+ data.tar.gz: 36268d75d8b43bf9d3fc7442d4d3b15ca93f45bb585e0fd73f6967586ba576ebbadbeed957ef59c731ce76daf40159441c6684c26238e3e6e53a1ca659e7fc7f
data/Gemfile CHANGED
File without changes
data/LICENSE.txt CHANGED
File without changes
data/README.md CHANGED
File without changes
data/Rakefile CHANGED
File without changes
data/document.txt CHANGED
@@ -151,5 +151,6 @@ http://d.hatena.ne.jp/n_shuyo/20090913/mnist
151
151
  2018/5/4 バージョン1.8公開
152
152
  2018/5/16 バージョン2.0公開
153
153
  2018/6/10 バージョン2.0.1公開
154
- 2018/6/10 バージョン2.1.0公開
155
- 2018/6/10 バージョン2.2.0公開
154
+ 2018/6/17 バージョン2.1.0公開
155
+ 2018/6/17 バージョン2.2.0公開
156
+ 2018/6/24 バージョン2.3.0公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "2.2"
5
+ VERSION = "2.3"
6
6
 
7
7
  include Numo
8
8
 
@@ -124,17 +124,17 @@ class NN
124
124
  end
125
125
 
126
126
  def learn(x_train, y_train, &block)
127
- x = SFloat.zeros(@batch_size, @num_nodes.first)
128
- y = SFloat.zeros(@batch_size, @num_nodes.last)
129
- @batch_size.times do |i|
130
- if x_train.is_a?(SFloat)
131
- r = rand(x_train.shape[0])
132
- x[i, true] = x_train[r, true]
133
- y[i, true] = y_train[r, true]
134
- else
135
- r = rand(x_train.length)
136
- x[i, true] = SFloat.cast(x_train[r])
137
- y[i, true] = SFloat.cast(y_train[r])
127
+ if x_train.is_a?(SFloat)
128
+ indexes = (0...x_train.shape[0]).to_a.sample(@batch_size)
129
+ x = x_train[indexes, true]
130
+ y = y_train[indexes, true]
131
+ else
132
+ indexes = (0...x_train.length).to_a.sample(@batch_size)
133
+ x = SFloat.zeros(@batch_size, @num_nodes.first)
134
+ y = SFloat.zeros(@batch_size, @num_nodes.last)
135
+ @batch_size.times do |i|
136
+ x[i, true] = SFloat.cast(x_train[indexes[i]])
137
+ y[i, true] = SFloat.cast(y_train[indexes[i]])
138
138
  end
139
139
  end
140
140
  x, y = block.call(x, y) if block
@@ -415,14 +415,12 @@ class NN::BatchNorm
415
415
  end
416
416
 
417
417
  def forward(x)
418
- @x = x
419
418
  @mean = x.mean(0)
420
419
  @xc = x - @mean
421
420
  @var = (@xc ** 2).mean(0)
422
421
  @std = NMath.sqrt(@var + 1e-7)
423
422
  @xn = @xc / @std
424
- out = @nn.gammas[@index] * @xn + @nn.betas[@index]
425
- out.reshape(*@x.shape)
423
+ @nn.gammas[@index] * @xn + @nn.betas[@index]
426
424
  end
427
425
 
428
426
  def backward(dout)
@@ -434,7 +432,6 @@ class NN::BatchNorm
434
432
  dvar = 0.5 * dstd / @std
435
433
  dxc += (2.0 / @nn.batch_size) * @xc * dvar
436
434
  dmean = dxc.sum(0)
437
- dx = dxc - dmean / @nn.batch_size
438
- dx.reshape(*@x.shape)
435
+ dxc - dmean / @nn.batch_size
439
436
  end
440
437
  end
data/nn.gemspec CHANGED
File without changes
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: 2.2.0
4
+ version: 2.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-06-17 00:00:00.000000000 Z
11
+ date: 2018-06-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -93,7 +93,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
93
93
  version: '0'
94
94
  requirements: []
95
95
  rubyforge_project:
96
- rubygems_version: 2.7.3
96
+ rubygems_version: 2.7.6
97
97
  signing_key:
98
98
  specification_version: 4
99
99
  summary: Ruby用ニューラルネットワークライブラリ