ruby-dnn 0.2.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +1 -1
- data/lib/dnn/core/layers.rb +60 -9
- data/lib/dnn/core/model.rb +11 -2
- data/lib/dnn/core/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ccb761989aa380096c4b16db46da21a4fb74e95bb8b15ed4564f4da6818bf04f
|
4
|
+
data.tar.gz: 2764cc949b09a42b2bbb7ccee7504e62055b33318cf5f3beb4977eb50c9b86ff
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: da160fcb2a10367a916dbb020b10ecf6b7cc8eba0e1aa30d68e0f788c06a5a5ff29122a758308666384921f49866a493230aae527e343bec1b3470f6de849fa2
|
7
|
+
data.tar.gz: 6016b580542781fae7e1b07e2cb3aac219070a6428d0d02505fe833a1a54564947dbb196afeea76122fb964f761c7b23aca2513162d5b14be64113cddeac4f2b
|
data/README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# ruby-dnn
|
2
2
|
|
3
3
|
ruby-dnn is a ruby deep learning library. This library supports full connected neural network and convolution neural network.
|
4
|
-
Currently, you can get 99% accuracy with MNIST and
|
4
|
+
Currently, you can get 99% accuracy with MNIST and 74% with CIFAR 10.
|
5
5
|
|
6
6
|
## Installation
|
7
7
|
|
data/lib/dnn/core/layers.rb
CHANGED
@@ -314,11 +314,16 @@ module DNN
|
|
314
314
|
include Convert
|
315
315
|
|
316
316
|
def initialize(pool_width, pool_height, strides: nil, padding: false)
|
317
|
+
super()
|
317
318
|
@pool_width = pool_width
|
318
319
|
@pool_height = pool_height
|
319
320
|
@strides = strides ? strides : [@pool_width, @pool_height]
|
320
321
|
@padding = padding
|
321
|
-
end
|
322
|
+
end
|
323
|
+
|
324
|
+
def self.load_hash(hash)
|
325
|
+
MaxPool2D.new(hash[:pool_width], hash[:pool_height], strides: hash[:strides], padding: hash[:padding])
|
326
|
+
end
|
322
327
|
|
323
328
|
def build(model)
|
324
329
|
super
|
@@ -386,6 +391,7 @@ module DNN
|
|
386
391
|
attr_reader :shape
|
387
392
|
|
388
393
|
def initialize(shape)
|
394
|
+
super()
|
389
395
|
@shape = shape
|
390
396
|
@x_shape = nil
|
391
397
|
end
|
@@ -421,16 +427,21 @@ module DNN
|
|
421
427
|
|
422
428
|
class Dropout < Layer
|
423
429
|
def initialize(dropout_ratio)
|
430
|
+
super()
|
424
431
|
@dropout_ratio = dropout_ratio
|
425
432
|
@mask = nil
|
426
433
|
end
|
427
434
|
|
435
|
+
def self.load_hash(hash)
|
436
|
+
self.new(hash[:dropout_ratio])
|
437
|
+
end
|
438
|
+
|
428
439
|
def self.load(hash)
|
429
440
|
self.new(hash[:dropout_ratio])
|
430
441
|
end
|
431
442
|
|
432
443
|
def forward(x)
|
433
|
-
if @model.training
|
444
|
+
if @model.training?
|
434
445
|
@mask = SFloat.ones(*x.shape).rand < @dropout_ratio
|
435
446
|
x[@mask] = 0
|
436
447
|
else
|
@@ -443,17 +454,48 @@ module DNN
|
|
443
454
|
dout[@mask] = 0 if @model.training
|
444
455
|
dout
|
445
456
|
end
|
457
|
+
|
458
|
+
def to_hash
|
459
|
+
{name: self.class.name, dropout_ratio: @dropout_ratio}
|
460
|
+
end
|
446
461
|
end
|
447
462
|
|
448
463
|
|
449
464
|
class BatchNormalization < HasParamLayer
|
465
|
+
def initialize(momentum: 0.9, running_mean: nil, running_var: nil)
|
466
|
+
super()
|
467
|
+
@momentum = momentum
|
468
|
+
@running_mean = running_mean
|
469
|
+
@running_var = running_var
|
470
|
+
end
|
471
|
+
|
472
|
+
def self.load_hash(hash)
|
473
|
+
running_mean = SFloat.cast(hash[:running_mean])
|
474
|
+
running_var = SFloat.cast(hash[:running_var])
|
475
|
+
self.new(momentum: hash[:momentum], running_mean: running_mean, running_var: running_var)
|
476
|
+
end
|
477
|
+
|
478
|
+
def build(model)
|
479
|
+
super
|
480
|
+
@running_mean ||= SFloat.zeros(*shape)
|
481
|
+
@running_var ||= SFloat.zeros(*shape)
|
482
|
+
end
|
483
|
+
|
450
484
|
def forward(x)
|
451
|
-
@
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
485
|
+
if @model.training?
|
486
|
+
mean = x.mean(0)
|
487
|
+
@xc = x - mean
|
488
|
+
var = (@xc**2).mean(0)
|
489
|
+
@std = NMath.sqrt(var + 1e-7)
|
490
|
+
xn = @xc / @std
|
491
|
+
@xn = xn
|
492
|
+
@running_mean = @momentum * @running_mean + (1 - @momentum) * mean
|
493
|
+
@running_var = @momentum * @running_var + (1 - @momentum) * var
|
494
|
+
else
|
495
|
+
xc = x - @running_mean
|
496
|
+
xn = xc / NMath.sqrt(@running_var + 1e-7)
|
497
|
+
end
|
498
|
+
@params[:gamma] * xn + @params[:beta]
|
457
499
|
end
|
458
500
|
|
459
501
|
def backward(dout)
|
@@ -466,7 +508,16 @@ module DNN
|
|
466
508
|
dvar = 0.5 * dstd / @std
|
467
509
|
dxc += (2.0 / batch_size) * @xc * dvar
|
468
510
|
dmean = dxc.sum(0)
|
469
|
-
dxc - dmean / batch_size
|
511
|
+
dxc - dmean / batch_size
|
512
|
+
end
|
513
|
+
|
514
|
+
def to_hash
|
515
|
+
{
|
516
|
+
name: self.class.name,
|
517
|
+
momentum: @momentum,
|
518
|
+
running_mean: @running_mean.to_a,
|
519
|
+
running_var: @running_var.to_a,
|
520
|
+
}
|
470
521
|
end
|
471
522
|
|
472
523
|
private
|
data/lib/dnn/core/model.rb
CHANGED
@@ -8,7 +8,6 @@ module DNN
|
|
8
8
|
attr_accessor :layers
|
9
9
|
attr_reader :optimizer
|
10
10
|
attr_reader :batch_size
|
11
|
-
attr_reader :training
|
12
11
|
|
13
12
|
def initialize
|
14
13
|
@layers = []
|
@@ -86,6 +85,10 @@ module DNN
|
|
86
85
|
def compiled?
|
87
86
|
@compiled
|
88
87
|
end
|
88
|
+
|
89
|
+
def training?
|
90
|
+
@training
|
91
|
+
end
|
89
92
|
|
90
93
|
def train(x, y, epochs,
|
91
94
|
batch_size: 1,
|
@@ -136,7 +139,13 @@ module DNN
|
|
136
139
|
end
|
137
140
|
|
138
141
|
def accurate(x, y, batch_size = nil, &batch_proc)
|
139
|
-
|
142
|
+
unless batch_size
|
143
|
+
if @batch_size
|
144
|
+
batch_size = @batch_size >= x.shape[0] ? @batch_size : x.shape[0]
|
145
|
+
else
|
146
|
+
batch_size = 1
|
147
|
+
end
|
148
|
+
end
|
140
149
|
correct = 0
|
141
150
|
(x.shape[0].to_f / @batch_size).ceil.times do |i|
|
142
151
|
x_batch = SFloat.zeros(@batch_size, *x.shape[1..-1])
|
data/lib/dnn/core/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-07-
|
11
|
+
date: 2018-07-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|