ruby-dnn 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +1 -1
- data/lib/dnn/core/layers.rb +60 -9
- data/lib/dnn/core/model.rb +11 -2
- data/lib/dnn/core/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ccb761989aa380096c4b16db46da21a4fb74e95bb8b15ed4564f4da6818bf04f
|
4
|
+
data.tar.gz: 2764cc949b09a42b2bbb7ccee7504e62055b33318cf5f3beb4977eb50c9b86ff
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: da160fcb2a10367a916dbb020b10ecf6b7cc8eba0e1aa30d68e0f788c06a5a5ff29122a758308666384921f49866a493230aae527e343bec1b3470f6de849fa2
|
7
|
+
data.tar.gz: 6016b580542781fae7e1b07e2cb3aac219070a6428d0d02505fe833a1a54564947dbb196afeea76122fb964f761c7b23aca2513162d5b14be64113cddeac4f2b
|
data/README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# ruby-dnn
|
2
2
|
|
3
3
|
ruby-dnn is a ruby deep learning library. This library supports full connected neural network and convolution neural network.
|
4
|
-
Currently, you can get 99% accuracy with MNIST and
|
4
|
+
Currently, you can get 99% accuracy with MNIST and 74% with CIFAR 10.
|
5
5
|
|
6
6
|
## Installation
|
7
7
|
|
data/lib/dnn/core/layers.rb
CHANGED
@@ -314,11 +314,16 @@ module DNN
|
|
314
314
|
include Convert
|
315
315
|
|
316
316
|
def initialize(pool_width, pool_height, strides: nil, padding: false)
|
317
|
+
super()
|
317
318
|
@pool_width = pool_width
|
318
319
|
@pool_height = pool_height
|
319
320
|
@strides = strides ? strides : [@pool_width, @pool_height]
|
320
321
|
@padding = padding
|
321
|
-
end
|
322
|
+
end
|
323
|
+
|
324
|
+
def self.load_hash(hash)
|
325
|
+
MaxPool2D.new(hash[:pool_width], hash[:pool_height], strides: hash[:strides], padding: hash[:padding])
|
326
|
+
end
|
322
327
|
|
323
328
|
def build(model)
|
324
329
|
super
|
@@ -386,6 +391,7 @@ module DNN
|
|
386
391
|
attr_reader :shape
|
387
392
|
|
388
393
|
def initialize(shape)
|
394
|
+
super()
|
389
395
|
@shape = shape
|
390
396
|
@x_shape = nil
|
391
397
|
end
|
@@ -421,16 +427,21 @@ module DNN
|
|
421
427
|
|
422
428
|
class Dropout < Layer
|
423
429
|
def initialize(dropout_ratio)
|
430
|
+
super()
|
424
431
|
@dropout_ratio = dropout_ratio
|
425
432
|
@mask = nil
|
426
433
|
end
|
427
434
|
|
435
|
+
def self.load_hash(hash)
|
436
|
+
self.new(hash[:dropout_ratio])
|
437
|
+
end
|
438
|
+
|
428
439
|
def self.load(hash)
|
429
440
|
self.new(hash[:dropout_ratio])
|
430
441
|
end
|
431
442
|
|
432
443
|
def forward(x)
|
433
|
-
if @model.training
|
444
|
+
if @model.training?
|
434
445
|
@mask = SFloat.ones(*x.shape).rand < @dropout_ratio
|
435
446
|
x[@mask] = 0
|
436
447
|
else
|
@@ -443,17 +454,48 @@ module DNN
|
|
443
454
|
dout[@mask] = 0 if @model.training
|
444
455
|
dout
|
445
456
|
end
|
457
|
+
|
458
|
+
def to_hash
|
459
|
+
{name: self.class.name, dropout_ratio: @dropout_ratio}
|
460
|
+
end
|
446
461
|
end
|
447
462
|
|
448
463
|
|
449
464
|
class BatchNormalization < HasParamLayer
|
465
|
+
def initialize(momentum: 0.9, running_mean: nil, running_var: nil)
|
466
|
+
super()
|
467
|
+
@momentum = momentum
|
468
|
+
@running_mean = running_mean
|
469
|
+
@running_var = running_var
|
470
|
+
end
|
471
|
+
|
472
|
+
def self.load_hash(hash)
|
473
|
+
running_mean = SFloat.cast(hash[:running_mean])
|
474
|
+
running_var = SFloat.cast(hash[:running_var])
|
475
|
+
self.new(momentum: hash[:momentum], running_mean: running_mean, running_var: running_var)
|
476
|
+
end
|
477
|
+
|
478
|
+
def build(model)
|
479
|
+
super
|
480
|
+
@running_mean ||= SFloat.zeros(*shape)
|
481
|
+
@running_var ||= SFloat.zeros(*shape)
|
482
|
+
end
|
483
|
+
|
450
484
|
def forward(x)
|
451
|
-
@
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
485
|
+
if @model.training?
|
486
|
+
mean = x.mean(0)
|
487
|
+
@xc = x - mean
|
488
|
+
var = (@xc**2).mean(0)
|
489
|
+
@std = NMath.sqrt(var + 1e-7)
|
490
|
+
xn = @xc / @std
|
491
|
+
@xn = xn
|
492
|
+
@running_mean = @momentum * @running_mean + (1 - @momentum) * mean
|
493
|
+
@running_var = @momentum * @running_var + (1 - @momentum) * var
|
494
|
+
else
|
495
|
+
xc = x - @running_mean
|
496
|
+
xn = xc / NMath.sqrt(@running_var + 1e-7)
|
497
|
+
end
|
498
|
+
@params[:gamma] * xn + @params[:beta]
|
457
499
|
end
|
458
500
|
|
459
501
|
def backward(dout)
|
@@ -466,7 +508,16 @@ module DNN
|
|
466
508
|
dvar = 0.5 * dstd / @std
|
467
509
|
dxc += (2.0 / batch_size) * @xc * dvar
|
468
510
|
dmean = dxc.sum(0)
|
469
|
-
dxc - dmean / batch_size
|
511
|
+
dxc - dmean / batch_size
|
512
|
+
end
|
513
|
+
|
514
|
+
def to_hash
|
515
|
+
{
|
516
|
+
name: self.class.name,
|
517
|
+
momentum: @momentum,
|
518
|
+
running_mean: @running_mean.to_a,
|
519
|
+
running_var: @running_var.to_a,
|
520
|
+
}
|
470
521
|
end
|
471
522
|
|
472
523
|
private
|
data/lib/dnn/core/model.rb
CHANGED
@@ -8,7 +8,6 @@ module DNN
|
|
8
8
|
attr_accessor :layers
|
9
9
|
attr_reader :optimizer
|
10
10
|
attr_reader :batch_size
|
11
|
-
attr_reader :training
|
12
11
|
|
13
12
|
def initialize
|
14
13
|
@layers = []
|
@@ -86,6 +85,10 @@ module DNN
|
|
86
85
|
def compiled?
|
87
86
|
@compiled
|
88
87
|
end
|
88
|
+
|
89
|
+
def training?
|
90
|
+
@training
|
91
|
+
end
|
89
92
|
|
90
93
|
def train(x, y, epochs,
|
91
94
|
batch_size: 1,
|
@@ -136,7 +139,13 @@ module DNN
|
|
136
139
|
end
|
137
140
|
|
138
141
|
def accurate(x, y, batch_size = nil, &batch_proc)
|
139
|
-
|
142
|
+
unless batch_size
|
143
|
+
if @batch_size
|
144
|
+
batch_size = @batch_size >= x.shape[0] ? @batch_size : x.shape[0]
|
145
|
+
else
|
146
|
+
batch_size = 1
|
147
|
+
end
|
148
|
+
end
|
140
149
|
correct = 0
|
141
150
|
(x.shape[0].to_f / @batch_size).ceil.times do |i|
|
142
151
|
x_batch = SFloat.zeros(@batch_size, *x.shape[1..-1])
|
data/lib/dnn/core/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-07-
|
11
|
+
date: 2018-07-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|