ruby-dnn 1.2.0 → 1.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/examples/judge-number/mnist_predict.rb +1 -1
- data/examples/pix2pix/dcgan.rb +0 -71
- data/lib/dnn/core/layers/basic_layers.rb +3 -4
- data/lib/dnn/image.rb +35 -0
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8475052a51f8b81f176c8b3a77c953d116fb1098b7f24fe8a2ad9135b2f2a33b
|
4
|
+
data.tar.gz: f291c99be16b9a35fee59a09cefda3b9aece1bbaef5d6b482fad6d238043e68b
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: '08f061510f73297b12b53cd3da5748dd47e67605d1b40c35388779a94e023e37450d2c9ebd9ec1bf0970c1580c7481f46c6cef05d54af5748fe175522b198c5e'
|
7
|
+
data.tar.gz: fdfe99ac9b5a63e4914aaab4f71a37e56b097c6b00560245a82438c6fd70ab658683de95c1ddef0283f4cc2f0b9e261dda798158cd374cf19cc17966ed338b16
|
@@ -12,7 +12,7 @@ end
|
|
12
12
|
def mnist_predict(img, width, height)
|
13
13
|
load_model
|
14
14
|
img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
|
15
|
-
img =
|
15
|
+
img = DNN::Image.to_rgb(img)
|
16
16
|
img = DNN::Image.to_gray_scale(img)
|
17
17
|
x = Numo::SFloat.cast(img) / 255
|
18
18
|
out = $model.predict1(x)
|
data/examples/pix2pix/dcgan.rb
CHANGED
@@ -70,77 +70,6 @@ class Generator < Model
|
|
70
70
|
end
|
71
71
|
end
|
72
72
|
|
73
|
-
class Discriminator < Model
|
74
|
-
def initialize(gen_input_shape, gen_output_shape)
|
75
|
-
super()
|
76
|
-
@gen_input_shape = gen_input_shape
|
77
|
-
@gen_output_shape = gen_output_shape
|
78
|
-
@l1_1 = Conv2D.new(32, 4, padding: true)
|
79
|
-
@l1_2 = Conv2D.new(32, 4, padding: true)
|
80
|
-
@l2 = Conv2D.new(32, 4, strides: 2, padding: true)
|
81
|
-
@l3 = Conv2D.new(32, 4, padding: true)
|
82
|
-
@l4 = Conv2D.new(64, 4, strides: 2, padding: true)
|
83
|
-
@l5 = Conv2D.new(64, 4, padding: true)
|
84
|
-
@l6 = Dense.new(1024)
|
85
|
-
@l7 = Dense.new(1)
|
86
|
-
@bn1 = BatchNormalization.new
|
87
|
-
@bn2 = BatchNormalization.new
|
88
|
-
@bn3 = BatchNormalization.new
|
89
|
-
@bn4 = BatchNormalization.new
|
90
|
-
@bn5 = BatchNormalization.new
|
91
|
-
@bn6 = BatchNormalization.new
|
92
|
-
end
|
93
|
-
|
94
|
-
def forward(inputs)
|
95
|
-
input, images = *inputs
|
96
|
-
x = InputLayer.new(@gen_input_shape).(input)
|
97
|
-
x = @l1_1.(x)
|
98
|
-
x = @bn1.(x)
|
99
|
-
x1 = LeakyReLU.(x, 0.2)
|
100
|
-
|
101
|
-
x = InputLayer.new(@gen_output_shape).(images)
|
102
|
-
x = @l1_2.(x)
|
103
|
-
x = @bn2.(x)
|
104
|
-
x2 = LeakyReLU.(x, 0.2)
|
105
|
-
|
106
|
-
x = Concatenate.(x1, x2)
|
107
|
-
x = @l2.(x)
|
108
|
-
x = @bn3.(x)
|
109
|
-
x = LeakyReLU.(x, 0.2)
|
110
|
-
|
111
|
-
x = @l3.(x)
|
112
|
-
x = @bn4.(x)
|
113
|
-
x = LeakyReLU.(x, 0.2)
|
114
|
-
|
115
|
-
x = @l4.(x)
|
116
|
-
x = @bn5.(x)
|
117
|
-
x = LeakyReLU.(x, 0.2)
|
118
|
-
|
119
|
-
x = @l5.(x)
|
120
|
-
x = @bn6.(x)
|
121
|
-
x = LeakyReLU.(x, 0.2)
|
122
|
-
|
123
|
-
x = Flatten.(x)
|
124
|
-
x = @l6.(x)
|
125
|
-
x = LeakyReLU.(x, 0.2)
|
126
|
-
|
127
|
-
x = @l7.(x)
|
128
|
-
x
|
129
|
-
end
|
130
|
-
|
131
|
-
def enable_training
|
132
|
-
trainable_layers.each do |layer|
|
133
|
-
layer.trainable = true
|
134
|
-
end
|
135
|
-
end
|
136
|
-
|
137
|
-
def disable_training
|
138
|
-
trainable_layers.each do |layer|
|
139
|
-
layer.trainable = false
|
140
|
-
end
|
141
|
-
end
|
142
|
-
end
|
143
|
-
|
144
73
|
class Discriminator < Model
|
145
74
|
def initialize(gen_input_shape, gen_output_shape, base_num_filters)
|
146
75
|
super()
|
@@ -436,8 +436,8 @@ module DNN
|
|
436
436
|
def forward_node(x)
|
437
437
|
if DNN.learning_phase
|
438
438
|
Xumo::SFloat.srand(@rnd.rand(1 << 31))
|
439
|
-
@mask = Xumo::SFloat.new(*x.shape).rand
|
440
|
-
x
|
439
|
+
@mask = Xumo::SFloat.cast(Xumo::SFloat.new(*x.shape).rand >= @dropout_ratio)
|
440
|
+
x = x * @mask
|
441
441
|
elsif @use_scale
|
442
442
|
x *= (1 - @dropout_ratio)
|
443
443
|
end
|
@@ -445,8 +445,7 @@ module DNN
|
|
445
445
|
end
|
446
446
|
|
447
447
|
def backward_node(dy)
|
448
|
-
dy
|
449
|
-
dy
|
448
|
+
dy * @mask
|
450
449
|
end
|
451
450
|
|
452
451
|
def to_hash
|
data/lib/dnn/image.rb
CHANGED
@@ -114,6 +114,41 @@ module DNN
|
|
114
114
|
Numo::UInt8.cast(x)
|
115
115
|
end
|
116
116
|
|
117
|
+
# Image convert image channel to RGB.
|
118
|
+
# @param [Numo::UInt8] img Image to RGB.
|
119
|
+
def self.to_rgb(img)
|
120
|
+
img_check(img)
|
121
|
+
case img.shape[2]
|
122
|
+
when 1
|
123
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
|
124
|
+
when 2
|
125
|
+
img = img[true, true, 0...1]
|
126
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
|
127
|
+
when 4
|
128
|
+
return img[true, true, 0...3].clone
|
129
|
+
end
|
130
|
+
img
|
131
|
+
end
|
132
|
+
|
133
|
+
# Image convert image channel to RGBA.
|
134
|
+
# @param [Numo::UInt8] img Image to RGBA.
|
135
|
+
def self.to_rgba(img)
|
136
|
+
img_check(img)
|
137
|
+
case img.shape[2]
|
138
|
+
when 1
|
139
|
+
alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
|
140
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
|
141
|
+
when 2
|
142
|
+
alpha = img[true, true, 1...2]
|
143
|
+
img = img[true, true, 0...1]
|
144
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
|
145
|
+
when 3
|
146
|
+
alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
|
147
|
+
return img.concatenate(alpha, axis: 2)
|
148
|
+
end
|
149
|
+
img
|
150
|
+
end
|
151
|
+
|
117
152
|
private_class_method def self.img_check(img)
|
118
153
|
raise TypeError, "img: #{img.class} is not an instance of the Numo::UInt8 class." unless img.is_a?(Numo::UInt8)
|
119
154
|
if img.shape.length != 3
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.2.
|
4
|
+
version: 1.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-05-
|
11
|
+
date: 2020-05-24 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|