nn 2.2.0 → 2.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +0 -0
- data/LICENSE.txt +0 -0
- data/README.md +0 -0
- data/Rakefile +0 -0
- data/document.txt +3 -2
- data/lib/nn.rb +14 -17
- data/nn.gemspec +0 -0
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: dae759f927b41bad2f29c7d4739cf198cf0080e7054d2e8276d518db94bb33f3
|
4
|
+
data.tar.gz: 99594955a3e1d8801a2415413387cc9ca86b068a25e3f079ce15d19226318952
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 2612677d78320574f714002006b8be85f5abe5d0b0eb107dbb83d8be67eebc99f077c62771325f4d06e372ba54e4a60f9f1c7edbb5bea3214a610790502e9b0e
|
7
|
+
data.tar.gz: 36268d75d8b43bf9d3fc7442d4d3b15ca93f45bb585e0fd73f6967586ba576ebbadbeed957ef59c731ce76daf40159441c6684c26238e3e6e53a1ca659e7fc7f
|
data/Gemfile
CHANGED
File without changes
|
data/LICENSE.txt
CHANGED
File without changes
|
data/README.md
CHANGED
File without changes
|
data/Rakefile
CHANGED
File without changes
|
data/document.txt
CHANGED
data/lib/nn.rb
CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
|
|
2
2
|
require "json"
|
3
3
|
|
4
4
|
class NN
|
5
|
-
VERSION = "2.
|
5
|
+
VERSION = "2.3"
|
6
6
|
|
7
7
|
include Numo
|
8
8
|
|
@@ -124,17 +124,17 @@ class NN
|
|
124
124
|
end
|
125
125
|
|
126
126
|
def learn(x_train, y_train, &block)
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
x[i, true] = SFloat.cast(x_train[
|
137
|
-
y[i, true] = SFloat.cast(y_train[
|
127
|
+
if x_train.is_a?(SFloat)
|
128
|
+
indexes = (0...x_train.shape[0]).to_a.sample(@batch_size)
|
129
|
+
x = x_train[indexes, true]
|
130
|
+
y = y_train[indexes, true]
|
131
|
+
else
|
132
|
+
indexes = (0...x_train.length).to_a.sample(@batch_size)
|
133
|
+
x = SFloat.zeros(@batch_size, @num_nodes.first)
|
134
|
+
y = SFloat.zeros(@batch_size, @num_nodes.last)
|
135
|
+
@batch_size.times do |i|
|
136
|
+
x[i, true] = SFloat.cast(x_train[indexes[i]])
|
137
|
+
y[i, true] = SFloat.cast(y_train[indexes[i]])
|
138
138
|
end
|
139
139
|
end
|
140
140
|
x, y = block.call(x, y) if block
|
@@ -415,14 +415,12 @@ class NN::BatchNorm
|
|
415
415
|
end
|
416
416
|
|
417
417
|
def forward(x)
|
418
|
-
@x = x
|
419
418
|
@mean = x.mean(0)
|
420
419
|
@xc = x - @mean
|
421
420
|
@var = (@xc ** 2).mean(0)
|
422
421
|
@std = NMath.sqrt(@var + 1e-7)
|
423
422
|
@xn = @xc / @std
|
424
|
-
|
425
|
-
out.reshape(*@x.shape)
|
423
|
+
@nn.gammas[@index] * @xn + @nn.betas[@index]
|
426
424
|
end
|
427
425
|
|
428
426
|
def backward(dout)
|
@@ -434,7 +432,6 @@ class NN::BatchNorm
|
|
434
432
|
dvar = 0.5 * dstd / @std
|
435
433
|
dxc += (2.0 / @nn.batch_size) * @xc * dvar
|
436
434
|
dmean = dxc.sum(0)
|
437
|
-
|
438
|
-
dx.reshape(*@x.shape)
|
435
|
+
dxc - dmean / @nn.batch_size
|
439
436
|
end
|
440
437
|
end
|
data/nn.gemspec
CHANGED
File without changes
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: nn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 2.
|
4
|
+
version: 2.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-06-
|
11
|
+
date: 2018-06-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -93,7 +93,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
93
93
|
version: '0'
|
94
94
|
requirements: []
|
95
95
|
rubyforge_project:
|
96
|
-
rubygems_version: 2.7.
|
96
|
+
rubygems_version: 2.7.6
|
97
97
|
signing_key:
|
98
98
|
specification_version: 4
|
99
99
|
summary: Ruby用ニューラルネットワークライブラリ
|