ruby-dnn 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c62e4bbbd5aa1e89dbff3e112c161f14fef9b4165721fed81815b6103af184f5
4
- data.tar.gz: d5b46dfde07d33f9309adb4af6256fa404218b9ad1d1cfc82a79b6752fc4549e
3
+ metadata.gz: ccb761989aa380096c4b16db46da21a4fb74e95bb8b15ed4564f4da6818bf04f
4
+ data.tar.gz: 2764cc949b09a42b2bbb7ccee7504e62055b33318cf5f3beb4977eb50c9b86ff
5
5
  SHA512:
6
- metadata.gz: 49515dd6ed15f07b551b88d3b724c280f417a07b12c869974d5b4f34a2c2bc6ff4747ba6c853840588b365e3d77ee50cebded9d82f739e79142aebaf505313b5
7
- data.tar.gz: dbd7084f28a5a1c9b5675eadb9c20e76c42897ba32a20a552ed183ada75dc83aa7953718e853a9a658dcdb794317943c56f4e46ebeedf8d6a663eb02f90832c0
6
+ metadata.gz: da160fcb2a10367a916dbb020b10ecf6b7cc8eba0e1aa30d68e0f788c06a5a5ff29122a758308666384921f49866a493230aae527e343bec1b3470f6de849fa2
7
+ data.tar.gz: 6016b580542781fae7e1b07e2cb3aac219070a6428d0d02505fe833a1a54564947dbb196afeea76122fb964f761c7b23aca2513162d5b14be64113cddeac4f2b
data/README.md CHANGED
@@ -1,7 +1,7 @@
1
1
  # ruby-dnn
2
2
 
3
3
  ruby-dnn is a ruby deep learning library. This library supports full connected neural network and convolution neural network.
4
- Currently, you can get 99% accuracy with MNIST and 60% with CIFAR 10.
4
+ Currently, you can get 99% accuracy with MNIST and 74% with CIFAR 10.
5
5
 
6
6
  ## Installation
7
7
 
@@ -314,11 +314,16 @@ module DNN
314
314
  include Convert
315
315
 
316
316
  def initialize(pool_width, pool_height, strides: nil, padding: false)
317
+ super()
317
318
  @pool_width = pool_width
318
319
  @pool_height = pool_height
319
320
  @strides = strides ? strides : [@pool_width, @pool_height]
320
321
  @padding = padding
321
- end
322
+ end
323
+
324
+ def self.load_hash(hash)
325
+ MaxPool2D.new(hash[:pool_width], hash[:pool_height], strides: hash[:strides], padding: hash[:padding])
326
+ end
322
327
 
323
328
  def build(model)
324
329
  super
@@ -386,6 +391,7 @@ module DNN
386
391
  attr_reader :shape
387
392
 
388
393
  def initialize(shape)
394
+ super()
389
395
  @shape = shape
390
396
  @x_shape = nil
391
397
  end
@@ -421,16 +427,21 @@ module DNN
421
427
 
422
428
  class Dropout < Layer
423
429
  def initialize(dropout_ratio)
430
+ super()
424
431
  @dropout_ratio = dropout_ratio
425
432
  @mask = nil
426
433
  end
427
434
 
435
+ def self.load_hash(hash)
436
+ self.new(hash[:dropout_ratio])
437
+ end
438
+
428
439
  def self.load(hash)
429
440
  self.new(hash[:dropout_ratio])
430
441
  end
431
442
 
432
443
  def forward(x)
433
- if @model.training
444
+ if @model.training?
434
445
  @mask = SFloat.ones(*x.shape).rand < @dropout_ratio
435
446
  x[@mask] = 0
436
447
  else
@@ -443,17 +454,48 @@ module DNN
443
454
  dout[@mask] = 0 if @model.training
444
455
  dout
445
456
  end
457
+
458
+ def to_hash
459
+ {name: self.class.name, dropout_ratio: @dropout_ratio}
460
+ end
446
461
  end
447
462
 
448
463
 
449
464
  class BatchNormalization < HasParamLayer
465
+ def initialize(momentum: 0.9, running_mean: nil, running_var: nil)
466
+ super()
467
+ @momentum = momentum
468
+ @running_mean = running_mean
469
+ @running_var = running_var
470
+ end
471
+
472
+ def self.load_hash(hash)
473
+ running_mean = SFloat.cast(hash[:running_mean])
474
+ running_var = SFloat.cast(hash[:running_var])
475
+ self.new(momentum: hash[:momentum], running_mean: running_mean, running_var: running_var)
476
+ end
477
+
478
+ def build(model)
479
+ super
480
+ @running_mean ||= SFloat.zeros(*shape)
481
+ @running_var ||= SFloat.zeros(*shape)
482
+ end
483
+
450
484
  def forward(x)
451
- @mean = x.mean(0)
452
- @xc = x - @mean
453
- @var = (@xc**2).mean(0)
454
- @std = NMath.sqrt(@var + 1e-7)
455
- @xn = @xc / @std
456
- @params[:gamma] * @xn + @params[:beta]
485
+ if @model.training?
486
+ mean = x.mean(0)
487
+ @xc = x - mean
488
+ var = (@xc**2).mean(0)
489
+ @std = NMath.sqrt(var + 1e-7)
490
+ xn = @xc / @std
491
+ @xn = xn
492
+ @running_mean = @momentum * @running_mean + (1 - @momentum) * mean
493
+ @running_var = @momentum * @running_var + (1 - @momentum) * var
494
+ else
495
+ xc = x - @running_mean
496
+ xn = xc / NMath.sqrt(@running_var + 1e-7)
497
+ end
498
+ @params[:gamma] * xn + @params[:beta]
457
499
  end
458
500
 
459
501
  def backward(dout)
@@ -466,7 +508,16 @@ module DNN
466
508
  dvar = 0.5 * dstd / @std
467
509
  dxc += (2.0 / batch_size) * @xc * dvar
468
510
  dmean = dxc.sum(0)
469
- dxc - dmean / batch_size
511
+ dxc - dmean / batch_size
512
+ end
513
+
514
+ def to_hash
515
+ {
516
+ name: self.class.name,
517
+ momentum: @momentum,
518
+ running_mean: @running_mean.to_a,
519
+ running_var: @running_var.to_a,
520
+ }
470
521
  end
471
522
 
472
523
  private
@@ -8,7 +8,6 @@ module DNN
8
8
  attr_accessor :layers
9
9
  attr_reader :optimizer
10
10
  attr_reader :batch_size
11
- attr_reader :training
12
11
 
13
12
  def initialize
14
13
  @layers = []
@@ -86,6 +85,10 @@ module DNN
86
85
  def compiled?
87
86
  @compiled
88
87
  end
88
+
89
+ def training?
90
+ @training
91
+ end
89
92
 
90
93
  def train(x, y, epochs,
91
94
  batch_size: 1,
@@ -136,7 +139,13 @@ module DNN
136
139
  end
137
140
 
138
141
  def accurate(x, y, batch_size = nil, &batch_proc)
139
- @batch_size = batch_size if batch_size
142
+ unless batch_size
143
+ if @batch_size
144
+ batch_size = @batch_size >= x.shape[0] ? @batch_size : x.shape[0]
145
+ else
146
+ batch_size = 1
147
+ end
148
+ end
140
149
  correct = 0
141
150
  (x.shape[0].to_f / @batch_size).ceil.times do |i|
142
151
  x_batch = SFloat.zeros(@batch_size, *x.shape[1..-1])
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.2.0"
2
+ VERSION = "0.2.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-07-11 00:00:00.000000000 Z
11
+ date: 2018-07-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray