nn 2.2.0 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a22a2db0abf0ff95fda30fcfdc92ec12847bd033b544cdb87ce007f69cb9f760
4
- data.tar.gz: 91528c11fa35c99e0a3b0d84b863266a0107e810d8493d0489f6c7c372483fbe
3
+ metadata.gz: dae759f927b41bad2f29c7d4739cf198cf0080e7054d2e8276d518db94bb33f3
4
+ data.tar.gz: 99594955a3e1d8801a2415413387cc9ca86b068a25e3f079ce15d19226318952
5
5
  SHA512:
6
- metadata.gz: d6cf94cf86c9c0e160f369466a146b18b5114bafff755ed6ff4cb88a98898727fb67a94ce04a117b3d90761028be57777e7fb7effdca01a023586ba6a84cd1f1
7
- data.tar.gz: 27cf471cece4749dea22e2a470d5e342fd33de689ff38191cfc0445a6df3655b0da5f91552766fbd2db10e0ba3bd7a17997a863a2811db2b937bddaea5dc5a60
6
+ metadata.gz: 2612677d78320574f714002006b8be85f5abe5d0b0eb107dbb83d8be67eebc99f077c62771325f4d06e372ba54e4a60f9f1c7edbb5bea3214a610790502e9b0e
7
+ data.tar.gz: 36268d75d8b43bf9d3fc7442d4d3b15ca93f45bb585e0fd73f6967586ba576ebbadbeed957ef59c731ce76daf40159441c6684c26238e3e6e53a1ca659e7fc7f
data/Gemfile CHANGED
File without changes
data/LICENSE.txt CHANGED
File without changes
data/README.md CHANGED
File without changes
data/Rakefile CHANGED
File without changes
data/document.txt CHANGED
@@ -151,5 +151,6 @@ http://d.hatena.ne.jp/n_shuyo/20090913/mnist
151
151
  2018/5/4 バージョン1.8公開
152
152
  2018/5/16 バージョン2.0公開
153
153
  2018/6/10 バージョン2.0.1公開
154
- 2018/6/10 バージョン2.1.0公開
155
- 2018/6/10 バージョン2.2.0公開
154
+ 2018/6/17 バージョン2.1.0公開
155
+ 2018/6/17 バージョン2.2.0公開
156
+ 2018/6/24 バージョン2.3.0公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "2.2"
5
+ VERSION = "2.3"
6
6
 
7
7
  include Numo
8
8
 
@@ -124,17 +124,17 @@ class NN
124
124
  end
125
125
 
126
126
  def learn(x_train, y_train, &block)
127
- x = SFloat.zeros(@batch_size, @num_nodes.first)
128
- y = SFloat.zeros(@batch_size, @num_nodes.last)
129
- @batch_size.times do |i|
130
- if x_train.is_a?(SFloat)
131
- r = rand(x_train.shape[0])
132
- x[i, true] = x_train[r, true]
133
- y[i, true] = y_train[r, true]
134
- else
135
- r = rand(x_train.length)
136
- x[i, true] = SFloat.cast(x_train[r])
137
- y[i, true] = SFloat.cast(y_train[r])
127
+ if x_train.is_a?(SFloat)
128
+ indexes = (0...x_train.shape[0]).to_a.sample(@batch_size)
129
+ x = x_train[indexes, true]
130
+ y = y_train[indexes, true]
131
+ else
132
+ indexes = (0...x_train.length).to_a.sample(@batch_size)
133
+ x = SFloat.zeros(@batch_size, @num_nodes.first)
134
+ y = SFloat.zeros(@batch_size, @num_nodes.last)
135
+ @batch_size.times do |i|
136
+ x[i, true] = SFloat.cast(x_train[indexes[i]])
137
+ y[i, true] = SFloat.cast(y_train[indexes[i]])
138
138
  end
139
139
  end
140
140
  x, y = block.call(x, y) if block
@@ -415,14 +415,12 @@ class NN::BatchNorm
415
415
  end
416
416
 
417
417
  def forward(x)
418
- @x = x
419
418
  @mean = x.mean(0)
420
419
  @xc = x - @mean
421
420
  @var = (@xc ** 2).mean(0)
422
421
  @std = NMath.sqrt(@var + 1e-7)
423
422
  @xn = @xc / @std
424
- out = @nn.gammas[@index] * @xn + @nn.betas[@index]
425
- out.reshape(*@x.shape)
423
+ @nn.gammas[@index] * @xn + @nn.betas[@index]
426
424
  end
427
425
 
428
426
  def backward(dout)
@@ -434,7 +432,6 @@ class NN::BatchNorm
434
432
  dvar = 0.5 * dstd / @std
435
433
  dxc += (2.0 / @nn.batch_size) * @xc * dvar
436
434
  dmean = dxc.sum(0)
437
- dx = dxc - dmean / @nn.batch_size
438
- dx.reshape(*@x.shape)
435
+ dxc - dmean / @nn.batch_size
439
436
  end
440
437
  end
data/nn.gemspec CHANGED
File without changes
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: 2.2.0
4
+ version: 2.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-06-17 00:00:00.000000000 Z
11
+ date: 2018-06-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -93,7 +93,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
93
93
  version: '0'
94
94
  requirements: []
95
95
  rubyforge_project:
96
- rubygems_version: 2.7.3
96
+ rubygems_version: 2.7.6
97
97
  signing_key:
98
98
  specification_version: 4
99
99
  summary: Ruby用ニューラルネットワークライブラリ