nn 2.1.0 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (4) hide show
  1. checksums.yaml +4 -4
  2. data/document.txt +1 -0
  3. data/lib/nn.rb +5 -6
  4. metadata +2 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8e26103663135846646308af8fa48851d0f56d34122bfdb6172e1a8bd4675bdf
4
- data.tar.gz: cffd4270bae567540673a1ec327ea416b14c3794172b7db3c589c0053bf15466
3
+ metadata.gz: a22a2db0abf0ff95fda30fcfdc92ec12847bd033b544cdb87ce007f69cb9f760
4
+ data.tar.gz: 91528c11fa35c99e0a3b0d84b863266a0107e810d8493d0489f6c7c372483fbe
5
5
  SHA512:
6
- metadata.gz: abaa7793eb7e44cd44d9f49f63d7cd252f01e11e4c41d7e709cf6c8124b923d300ff6c2e2b07f410f4dbeaaa8f53a82f876e22582681ec4178364110128726ab
7
- data.tar.gz: c8f830238afd0a767f5cc34a7be600c2fc481c591f8f8f433c13d0ad9647011e49f3b0ff4b87d5784a0c8ce2c01552445dc1540e252321cd856eea23c4a6f84d
6
+ metadata.gz: d6cf94cf86c9c0e160f369466a146b18b5114bafff755ed6ff4cb88a98898727fb67a94ce04a117b3d90761028be57777e7fb7effdca01a023586ba6a84cd1f1
7
+ data.tar.gz: 27cf471cece4749dea22e2a470d5e342fd33de689ff38191cfc0445a6df3655b0da5f91552766fbd2db10e0ba3bd7a17997a863a2811db2b937bddaea5dc5a60
@@ -152,3 +152,4 @@ http://d.hatena.ne.jp/n_shuyo/20090913/mnist
152
152
  2018/5/16 バージョン2.0公開
153
153
  2018/6/10 バージョン2.0.1公開
154
154
  2018/6/10 バージョン2.1.0公開
155
+ 2018/6/10 バージョン2.2.0公開
data/lib/nn.rb CHANGED
@@ -2,7 +2,7 @@ require "numo/narray"
2
2
  require "json"
3
3
 
4
4
  class NN
5
- VERSION = "2.1"
5
+ VERSION = "2.2"
6
6
 
7
7
  include Numo
8
8
 
@@ -297,13 +297,12 @@ class NN::Affine
297
297
  end
298
298
 
299
299
  def backward(dout)
300
- x = @x.reshape(*@x.shape, 1)
301
- @d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1])).mean(0)
300
+ @d_weight = @x.transpose.dot(dout)
302
301
  if @nn.weight_decay > 0
303
302
  dridge = @nn.weight_decay * @nn.weights[@index]
304
303
  @d_weight += dridge
305
304
  end
306
- @d_bias = dout.mean
305
+ @d_bias = dout.sum(0)
307
306
  dout.dot(@nn.weights[@index].transpose)
308
307
  end
309
308
  end
@@ -347,7 +346,7 @@ class NN::Identity
347
346
  end
348
347
 
349
348
  def backward(y)
350
- @out - y
349
+ (@out - y) / @nn.batch_size
351
350
  end
352
351
 
353
352
  def loss(y)
@@ -369,7 +368,7 @@ class NN::Softmax
369
368
  end
370
369
 
371
370
  def backward(y)
372
- @out - y
371
+ (@out - y) / @nn.batch_size
373
372
  end
374
373
 
375
374
  def loss(y)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: nn
3
3
  version: !ruby/object:Gem::Version
4
- version: 2.1.0
4
+ version: 2.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-06-10 00:00:00.000000000 Z
11
+ date: 2018-06-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray