ruby-dnn 1.2.0 → 1.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/examples/judge-number/mnist_predict.rb +1 -1
- data/examples/pix2pix/dcgan.rb +0 -71
- data/lib/dnn/core/layers/basic_layers.rb +3 -4
- data/lib/dnn/image.rb +35 -0
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8475052a51f8b81f176c8b3a77c953d116fb1098b7f24fe8a2ad9135b2f2a33b
|
4
|
+
data.tar.gz: f291c99be16b9a35fee59a09cefda3b9aece1bbaef5d6b482fad6d238043e68b
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: '08f061510f73297b12b53cd3da5748dd47e67605d1b40c35388779a94e023e37450d2c9ebd9ec1bf0970c1580c7481f46c6cef05d54af5748fe175522b198c5e'
|
7
|
+
data.tar.gz: fdfe99ac9b5a63e4914aaab4f71a37e56b097c6b00560245a82438c6fd70ab658683de95c1ddef0283f4cc2f0b9e261dda798158cd374cf19cc17966ed338b16
|
@@ -12,7 +12,7 @@ end
|
|
12
12
|
def mnist_predict(img, width, height)
|
13
13
|
load_model
|
14
14
|
img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
|
15
|
-
img =
|
15
|
+
img = DNN::Image.to_rgb(img)
|
16
16
|
img = DNN::Image.to_gray_scale(img)
|
17
17
|
x = Numo::SFloat.cast(img) / 255
|
18
18
|
out = $model.predict1(x)
|
data/examples/pix2pix/dcgan.rb
CHANGED
@@ -70,77 +70,6 @@ class Generator < Model
|
|
70
70
|
end
|
71
71
|
end
|
72
72
|
|
73
|
-
class Discriminator < Model
|
74
|
-
def initialize(gen_input_shape, gen_output_shape)
|
75
|
-
super()
|
76
|
-
@gen_input_shape = gen_input_shape
|
77
|
-
@gen_output_shape = gen_output_shape
|
78
|
-
@l1_1 = Conv2D.new(32, 4, padding: true)
|
79
|
-
@l1_2 = Conv2D.new(32, 4, padding: true)
|
80
|
-
@l2 = Conv2D.new(32, 4, strides: 2, padding: true)
|
81
|
-
@l3 = Conv2D.new(32, 4, padding: true)
|
82
|
-
@l4 = Conv2D.new(64, 4, strides: 2, padding: true)
|
83
|
-
@l5 = Conv2D.new(64, 4, padding: true)
|
84
|
-
@l6 = Dense.new(1024)
|
85
|
-
@l7 = Dense.new(1)
|
86
|
-
@bn1 = BatchNormalization.new
|
87
|
-
@bn2 = BatchNormalization.new
|
88
|
-
@bn3 = BatchNormalization.new
|
89
|
-
@bn4 = BatchNormalization.new
|
90
|
-
@bn5 = BatchNormalization.new
|
91
|
-
@bn6 = BatchNormalization.new
|
92
|
-
end
|
93
|
-
|
94
|
-
def forward(inputs)
|
95
|
-
input, images = *inputs
|
96
|
-
x = InputLayer.new(@gen_input_shape).(input)
|
97
|
-
x = @l1_1.(x)
|
98
|
-
x = @bn1.(x)
|
99
|
-
x1 = LeakyReLU.(x, 0.2)
|
100
|
-
|
101
|
-
x = InputLayer.new(@gen_output_shape).(images)
|
102
|
-
x = @l1_2.(x)
|
103
|
-
x = @bn2.(x)
|
104
|
-
x2 = LeakyReLU.(x, 0.2)
|
105
|
-
|
106
|
-
x = Concatenate.(x1, x2)
|
107
|
-
x = @l2.(x)
|
108
|
-
x = @bn3.(x)
|
109
|
-
x = LeakyReLU.(x, 0.2)
|
110
|
-
|
111
|
-
x = @l3.(x)
|
112
|
-
x = @bn4.(x)
|
113
|
-
x = LeakyReLU.(x, 0.2)
|
114
|
-
|
115
|
-
x = @l4.(x)
|
116
|
-
x = @bn5.(x)
|
117
|
-
x = LeakyReLU.(x, 0.2)
|
118
|
-
|
119
|
-
x = @l5.(x)
|
120
|
-
x = @bn6.(x)
|
121
|
-
x = LeakyReLU.(x, 0.2)
|
122
|
-
|
123
|
-
x = Flatten.(x)
|
124
|
-
x = @l6.(x)
|
125
|
-
x = LeakyReLU.(x, 0.2)
|
126
|
-
|
127
|
-
x = @l7.(x)
|
128
|
-
x
|
129
|
-
end
|
130
|
-
|
131
|
-
def enable_training
|
132
|
-
trainable_layers.each do |layer|
|
133
|
-
layer.trainable = true
|
134
|
-
end
|
135
|
-
end
|
136
|
-
|
137
|
-
def disable_training
|
138
|
-
trainable_layers.each do |layer|
|
139
|
-
layer.trainable = false
|
140
|
-
end
|
141
|
-
end
|
142
|
-
end
|
143
|
-
|
144
73
|
class Discriminator < Model
|
145
74
|
def initialize(gen_input_shape, gen_output_shape, base_num_filters)
|
146
75
|
super()
|
@@ -436,8 +436,8 @@ module DNN
|
|
436
436
|
def forward_node(x)
|
437
437
|
if DNN.learning_phase
|
438
438
|
Xumo::SFloat.srand(@rnd.rand(1 << 31))
|
439
|
-
@mask = Xumo::SFloat.new(*x.shape).rand
|
440
|
-
x
|
439
|
+
@mask = Xumo::SFloat.cast(Xumo::SFloat.new(*x.shape).rand >= @dropout_ratio)
|
440
|
+
x = x * @mask
|
441
441
|
elsif @use_scale
|
442
442
|
x *= (1 - @dropout_ratio)
|
443
443
|
end
|
@@ -445,8 +445,7 @@ module DNN
|
|
445
445
|
end
|
446
446
|
|
447
447
|
def backward_node(dy)
|
448
|
-
dy
|
449
|
-
dy
|
448
|
+
dy * @mask
|
450
449
|
end
|
451
450
|
|
452
451
|
def to_hash
|
data/lib/dnn/image.rb
CHANGED
@@ -114,6 +114,41 @@ module DNN
|
|
114
114
|
Numo::UInt8.cast(x)
|
115
115
|
end
|
116
116
|
|
117
|
+
# Image convert image channel to RGB.
|
118
|
+
# @param [Numo::UInt8] img Image to RGB.
|
119
|
+
def self.to_rgb(img)
|
120
|
+
img_check(img)
|
121
|
+
case img.shape[2]
|
122
|
+
when 1
|
123
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
|
124
|
+
when 2
|
125
|
+
img = img[true, true, 0...1]
|
126
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
|
127
|
+
when 4
|
128
|
+
return img[true, true, 0...3].clone
|
129
|
+
end
|
130
|
+
img
|
131
|
+
end
|
132
|
+
|
133
|
+
# Image convert image channel to RGBA.
|
134
|
+
# @param [Numo::UInt8] img Image to RGBA.
|
135
|
+
def self.to_rgba(img)
|
136
|
+
img_check(img)
|
137
|
+
case img.shape[2]
|
138
|
+
when 1
|
139
|
+
alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
|
140
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
|
141
|
+
when 2
|
142
|
+
alpha = img[true, true, 1...2]
|
143
|
+
img = img[true, true, 0...1]
|
144
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
|
145
|
+
when 3
|
146
|
+
alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
|
147
|
+
return img.concatenate(alpha, axis: 2)
|
148
|
+
end
|
149
|
+
img
|
150
|
+
end
|
151
|
+
|
117
152
|
private_class_method def self.img_check(img)
|
118
153
|
raise TypeError, "img: #{img.class} is not an instance of the Numo::UInt8 class." unless img.is_a?(Numo::UInt8)
|
119
154
|
if img.shape.length != 3
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-05-
|
11
|
+
date: 2020-05-24 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|