ruby-dnn 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c62e4bbbd5aa1e89dbff3e112c161f14fef9b4165721fed81815b6103af184f5
4
- data.tar.gz: d5b46dfde07d33f9309adb4af6256fa404218b9ad1d1cfc82a79b6752fc4549e
3
+ metadata.gz: ccb761989aa380096c4b16db46da21a4fb74e95bb8b15ed4564f4da6818bf04f
4
+ data.tar.gz: 2764cc949b09a42b2bbb7ccee7504e62055b33318cf5f3beb4977eb50c9b86ff
5
5
  SHA512:
6
- metadata.gz: 49515dd6ed15f07b551b88d3b724c280f417a07b12c869974d5b4f34a2c2bc6ff4747ba6c853840588b365e3d77ee50cebded9d82f739e79142aebaf505313b5
7
- data.tar.gz: dbd7084f28a5a1c9b5675eadb9c20e76c42897ba32a20a552ed183ada75dc83aa7953718e853a9a658dcdb794317943c56f4e46ebeedf8d6a663eb02f90832c0
6
+ metadata.gz: da160fcb2a10367a916dbb020b10ecf6b7cc8eba0e1aa30d68e0f788c06a5a5ff29122a758308666384921f49866a493230aae527e343bec1b3470f6de849fa2
7
+ data.tar.gz: 6016b580542781fae7e1b07e2cb3aac219070a6428d0d02505fe833a1a54564947dbb196afeea76122fb964f761c7b23aca2513162d5b14be64113cddeac4f2b
data/README.md CHANGED
@@ -1,7 +1,7 @@
1
1
  # ruby-dnn
2
2
 
3
3
  ruby-dnn is a ruby deep learning library. This library supports full connected neural network and convolution neural network.
4
- Currently, you can get 99% accuracy with MNIST and 60% with CIFAR 10.
4
+ Currently, you can get 99% accuracy with MNIST and 74% with CIFAR 10.
5
5
 
6
6
  ## Installation
7
7
 
@@ -314,11 +314,16 @@ module DNN
314
314
  include Convert
315
315
 
316
316
  def initialize(pool_width, pool_height, strides: nil, padding: false)
317
+ super()
317
318
  @pool_width = pool_width
318
319
  @pool_height = pool_height
319
320
  @strides = strides ? strides : [@pool_width, @pool_height]
320
321
  @padding = padding
321
- end
322
+ end
323
+
324
+ def self.load_hash(hash)
325
+ MaxPool2D.new(hash[:pool_width], hash[:pool_height], strides: hash[:strides], padding: hash[:padding])
326
+ end
322
327
 
323
328
  def build(model)
324
329
  super
@@ -386,6 +391,7 @@ module DNN
386
391
  attr_reader :shape
387
392
 
388
393
  def initialize(shape)
394
+ super()
389
395
  @shape = shape
390
396
  @x_shape = nil
391
397
  end
@@ -421,16 +427,21 @@ module DNN
421
427
 
422
428
  class Dropout < Layer
423
429
  def initialize(dropout_ratio)
430
+ super()
424
431
  @dropout_ratio = dropout_ratio
425
432
  @mask = nil
426
433
  end
427
434
 
435
+ def self.load_hash(hash)
436
+ self.new(hash[:dropout_ratio])
437
+ end
438
+
428
439
  def self.load(hash)
429
440
  self.new(hash[:dropout_ratio])
430
441
  end
431
442
 
432
443
  def forward(x)
433
- if @model.training
444
+ if @model.training?
434
445
  @mask = SFloat.ones(*x.shape).rand < @dropout_ratio
435
446
  x[@mask] = 0
436
447
  else
@@ -443,17 +454,48 @@ module DNN
443
454
  dout[@mask] = 0 if @model.training
444
455
  dout
445
456
  end
457
+
458
+ def to_hash
459
+ {name: self.class.name, dropout_ratio: @dropout_ratio}
460
+ end
446
461
  end
447
462
 
448
463
 
449
464
  class BatchNormalization < HasParamLayer
465
+ def initialize(momentum: 0.9, running_mean: nil, running_var: nil)
466
+ super()
467
+ @momentum = momentum
468
+ @running_mean = running_mean
469
+ @running_var = running_var
470
+ end
471
+
472
+ def self.load_hash(hash)
473
+ running_mean = SFloat.cast(hash[:running_mean])
474
+ running_var = SFloat.cast(hash[:running_var])
475
+ self.new(momentum: hash[:momentum], running_mean: running_mean, running_var: running_var)
476
+ end
477
+
478
+ def build(model)
479
+ super
480
+ @running_mean ||= SFloat.zeros(*shape)
481
+ @running_var ||= SFloat.zeros(*shape)
482
+ end
483
+
450
484
  def forward(x)
451
- @mean = x.mean(0)
452
- @xc = x - @mean
453
- @var = (@xc**2).mean(0)
454
- @std = NMath.sqrt(@var + 1e-7)
455
- @xn = @xc / @std
456
- @params[:gamma] * @xn + @params[:beta]
485
+ if @model.training?
486
+ mean = x.mean(0)
487
+ @xc = x - mean
488
+ var = (@xc**2).mean(0)
489
+ @std = NMath.sqrt(var + 1e-7)
490
+ xn = @xc / @std
491
+ @xn = xn
492
+ @running_mean = @momentum * @running_mean + (1 - @momentum) * mean
493
+ @running_var = @momentum * @running_var + (1 - @momentum) * var
494
+ else
495
+ xc = x - @running_mean
496
+ xn = xc / NMath.sqrt(@running_var + 1e-7)
497
+ end
498
+ @params[:gamma] * xn + @params[:beta]
457
499
  end
458
500
 
459
501
  def backward(dout)
@@ -466,7 +508,16 @@ module DNN
466
508
  dvar = 0.5 * dstd / @std
467
509
  dxc += (2.0 / batch_size) * @xc * dvar
468
510
  dmean = dxc.sum(0)
469
- dxc - dmean / batch_size
511
+ dxc - dmean / batch_size
512
+ end
513
+
514
+ def to_hash
515
+ {
516
+ name: self.class.name,
517
+ momentum: @momentum,
518
+ running_mean: @running_mean.to_a,
519
+ running_var: @running_var.to_a,
520
+ }
470
521
  end
471
522
 
472
523
  private
@@ -8,7 +8,6 @@ module DNN
8
8
  attr_accessor :layers
9
9
  attr_reader :optimizer
10
10
  attr_reader :batch_size
11
- attr_reader :training
12
11
 
13
12
  def initialize
14
13
  @layers = []
@@ -86,6 +85,10 @@ module DNN
86
85
  def compiled?
87
86
  @compiled
88
87
  end
88
+
89
+ def training?
90
+ @training
91
+ end
89
92
 
90
93
  def train(x, y, epochs,
91
94
  batch_size: 1,
@@ -136,7 +139,13 @@ module DNN
136
139
  end
137
140
 
138
141
  def accurate(x, y, batch_size = nil, &batch_proc)
139
- @batch_size = batch_size if batch_size
142
+ unless batch_size
143
+ if @batch_size
144
+ batch_size = @batch_size >= x.shape[0] ? @batch_size : x.shape[0]
145
+ else
146
+ batch_size = 1
147
+ end
148
+ end
140
149
  correct = 0
141
150
  (x.shape[0].to_f / @batch_size).ceil.times do |i|
142
151
  x_batch = SFloat.zeros(@batch_size, *x.shape[1..-1])
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.2.0"
2
+ VERSION = "0.2.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-07-11 00:00:00.000000000 Z
11
+ date: 2018-07-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray