ruby-dnn 1.1.4 → 1.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/.travis.yml +2 -1
- data/README.md +39 -22
- data/examples/api-examples/early_stopping_example.rb +6 -6
- data/examples/api-examples/initializer_example.rb +6 -6
- data/examples/api-examples/regularizer_example.rb +6 -6
- data/examples/api-examples/save_example.rb +6 -6
- data/examples/dcgan/dcgan.rb +27 -27
- data/examples/judge-number/README.md +29 -0
- data/examples/judge-number/capture.PNG +0 -0
- data/examples/judge-number/convnet8.rb +70 -0
- data/examples/judge-number/make_weights.rb +5 -0
- data/examples/judge-number/mnist_predict.rb +20 -0
- data/examples/judge-number/mnist_train.rb +19 -0
- data/examples/judge-number/public/httpRequest.js +44 -0
- data/examples/judge-number/public/judgeNumber.js +61 -0
- data/examples/judge-number/server.rb +19 -0
- data/examples/judge-number/trained_mnist_params.marshal +0 -0
- data/examples/judge-number/views/index.erb +7 -0
- data/examples/mnist_conv2d_example.rb +3 -3
- data/examples/mnist_define_by_run.rb +7 -7
- data/examples/mnist_gpu.rb +47 -0
- data/examples/mnist_lstm_example.rb +1 -1
- data/examples/pix2pix/dcgan.rb +54 -66
- data/examples/pix2pix/train.rb +2 -2
- data/examples/vae.rb +13 -13
- data/img/cart-pole.gif +0 -0
- data/img/cycle-gan.PNG +0 -0
- data/img/facade-pix2pix.png +0 -0
- data/lib/dnn.rb +24 -3
- data/lib/dnn/core/callbacks.rb +6 -4
- data/lib/dnn/core/layers/basic_layers.rb +40 -22
- data/lib/dnn/core/layers/cnn_layers.rb +33 -5
- data/lib/dnn/core/layers/math_layers.rb +17 -9
- data/lib/dnn/core/layers/merge_layers.rb +2 -26
- data/lib/dnn/core/layers/split_layers.rb +39 -0
- data/lib/dnn/core/link.rb +14 -33
- data/lib/dnn/core/losses.rb +6 -12
- data/lib/dnn/core/models.rb +77 -10
- data/lib/dnn/core/optimizers.rb +8 -1
- data/lib/dnn/core/utils.rb +23 -0
- data/lib/dnn/image.rb +48 -0
- data/lib/dnn/version.rb +1 -1
- data/ruby-dnn.gemspec +2 -15
- metadata +40 -20
- data/bin/console +0 -14
- data/bin/setup +0 -8
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
require "dnn"
|
|
2
|
+
require "dnn/image"
|
|
3
|
+
require_relative "convnet8"
|
|
4
|
+
|
|
5
|
+
def load_model
|
|
6
|
+
return if $model
|
|
7
|
+
$model = ConvNet.create([28, 28, 1])
|
|
8
|
+
$model.predict1(Numo::SFloat.zeros(28, 28, 1))
|
|
9
|
+
$model.load_params("trained_mnist_params.marshal")
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def mnist_predict(img, width, height)
|
|
13
|
+
load_model
|
|
14
|
+
img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
|
|
15
|
+
img = DNN::Image.to_rgb(img)
|
|
16
|
+
img = DNN::Image.to_gray_scale(img)
|
|
17
|
+
x = Numo::SFloat.cast(img) / 255
|
|
18
|
+
out = $model.predict1(x)
|
|
19
|
+
out.to_a.map { |v| v.round(4) * 100 }
|
|
20
|
+
end
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
require "dnn"
|
|
2
|
+
require "dnn/datasets/mnist"
|
|
3
|
+
require_relative "convnet8"
|
|
4
|
+
|
|
5
|
+
include DNN::Callbacks
|
|
6
|
+
|
|
7
|
+
x_train, y_train = DNN::MNIST.load_train
|
|
8
|
+
x_test, y_test = DNN::MNIST.load_test
|
|
9
|
+
|
|
10
|
+
x_train = Numo::SFloat.cast(x_train) / 255
|
|
11
|
+
x_test = Numo::SFloat.cast(x_test) / 255
|
|
12
|
+
|
|
13
|
+
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
|
14
|
+
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
|
15
|
+
|
|
16
|
+
model = ConvNet.create([28, 28, 1])
|
|
17
|
+
model.add_callback(CheckPoint.new("trained/trained_mnist", interval: 5))
|
|
18
|
+
|
|
19
|
+
model.train(x_train, y_train, 20, batch_size: 128, test: [x_test, y_test])
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
class HttpRequest {
|
|
2
|
+
static get(path, responseCallback) {
|
|
3
|
+
const req = new HttpRequest(path, "GET", responseCallback);
|
|
4
|
+
req.send();
|
|
5
|
+
return req;
|
|
6
|
+
}
|
|
7
|
+
|
|
8
|
+
static post(path, params, responseCallback) {
|
|
9
|
+
const req = new HttpRequest(path, "POST", responseCallback);
|
|
10
|
+
req.send(params);
|
|
11
|
+
return req;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
constructor(path, method, responseCallback) {
|
|
15
|
+
this._path = path;
|
|
16
|
+
this._method = method;
|
|
17
|
+
this._responseCallback = responseCallback;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
send(params = null) {
|
|
21
|
+
const xhr = new XMLHttpRequest();
|
|
22
|
+
xhr.open(this._method, this._path);
|
|
23
|
+
let json = null;
|
|
24
|
+
if (params) json = JSON.stringify(params);
|
|
25
|
+
xhr.addEventListener("load", (e) => {
|
|
26
|
+
const res = {
|
|
27
|
+
response: xhr.response,
|
|
28
|
+
event: e
|
|
29
|
+
};
|
|
30
|
+
this._responseCallback(res);
|
|
31
|
+
});
|
|
32
|
+
xhr.send(json);
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
class Base64 {
|
|
37
|
+
static encode(obj) {
|
|
38
|
+
if (typeof(obj) === "string") {
|
|
39
|
+
return btoa(obj);
|
|
40
|
+
} else if (obj instanceof Uint8Array || obj instanceof Uint8ClampedArray) {
|
|
41
|
+
return btoa(String.fromCharCode(...obj));
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
}
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
const drawCanvas = document.getElementById("draw");
|
|
2
|
+
const viewCanvas = document.getElementById("view");
|
|
3
|
+
|
|
4
|
+
const drawContext = drawCanvas.getContext("2d");
|
|
5
|
+
drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
|
6
|
+
const viewContext = viewCanvas.getContext("2d");
|
|
7
|
+
viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
|
8
|
+
|
|
9
|
+
const judgeButton = document.getElementById("judge");
|
|
10
|
+
const clearButton = document.getElementById("clear");
|
|
11
|
+
|
|
12
|
+
const resultArea = document.getElementById("result");
|
|
13
|
+
|
|
14
|
+
const updateResult = (classification) => {
|
|
15
|
+
let str = "";
|
|
16
|
+
for(let i = 0; i <= 9; i++){
|
|
17
|
+
str += `${i}: ${classification[i]}%<br>`;
|
|
18
|
+
}
|
|
19
|
+
resultArea.innerHTML = str;
|
|
20
|
+
};
|
|
21
|
+
|
|
22
|
+
judgeButton.addEventListener("click", () =>{
|
|
23
|
+
viewContext.drawImage(drawCanvas, 0, 0, viewCanvas.width, viewCanvas.height);
|
|
24
|
+
const data = viewContext.getImageData(0, 0, viewCanvas.width, viewCanvas.height).data;
|
|
25
|
+
params = {
|
|
26
|
+
img: Base64.encode(data),
|
|
27
|
+
width: viewCanvas.width,
|
|
28
|
+
height: viewCanvas.height,
|
|
29
|
+
}
|
|
30
|
+
HttpRequest.post("/predict", params, (res) => {
|
|
31
|
+
updateResult(JSON.parse(res.response));
|
|
32
|
+
});
|
|
33
|
+
});
|
|
34
|
+
|
|
35
|
+
clearButton.addEventListener("click", () =>{
|
|
36
|
+
drawContext.fillStyle = "black";
|
|
37
|
+
drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
|
38
|
+
viewContext.fillStyle = "black";
|
|
39
|
+
viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
|
40
|
+
result.innerHTML = "";
|
|
41
|
+
});
|
|
42
|
+
|
|
43
|
+
let mouseDown = false;
|
|
44
|
+
|
|
45
|
+
window.addEventListener("mousedown", e =>{
|
|
46
|
+
mouseDown = true;
|
|
47
|
+
});
|
|
48
|
+
|
|
49
|
+
window.addEventListener("mouseup", e =>{
|
|
50
|
+
mouseDown = false;
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
drawCanvas.addEventListener("mousemove", e =>{
|
|
54
|
+
if(mouseDown){
|
|
55
|
+
let rect = e.target.getBoundingClientRect();
|
|
56
|
+
let x = e.clientX - 10 - rect.left;
|
|
57
|
+
let y = e.clientY - 10 - rect.top;
|
|
58
|
+
drawContext.fillStyle = "white";
|
|
59
|
+
drawContext.fillRect(x, y, 20, 20);
|
|
60
|
+
}
|
|
61
|
+
});
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
require "sinatra"
|
|
2
|
+
require "sinatra/reloader"
|
|
3
|
+
require "json"
|
|
4
|
+
require "base64"
|
|
5
|
+
require_relative "mnist_predict"
|
|
6
|
+
|
|
7
|
+
get "/" do
|
|
8
|
+
erb :index
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
post "/predict" do
|
|
12
|
+
json = request.body.read
|
|
13
|
+
params = JSON.parse(json, symbolize_names: true)
|
|
14
|
+
img = Base64.decode64(params[:img])
|
|
15
|
+
width = params[:width].to_i
|
|
16
|
+
height = params[:height].to_i
|
|
17
|
+
result = mnist_predict(img, width, height)
|
|
18
|
+
JSON.dump(result)
|
|
19
|
+
end
|
|
Binary file
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
<canvas id="draw" width=256 height=256></canvas>
|
|
2
|
+
<canvas id="view" width=28 height=28></canvas>
|
|
3
|
+
<button id="judge">Judge</button>
|
|
4
|
+
<button id="clear">Clear</button>
|
|
5
|
+
<p id="result"></p>
|
|
6
|
+
<script src="judgeNumber.js"></script>
|
|
7
|
+
<script src="httpRequest.js"></script>
|
|
@@ -21,13 +21,13 @@ model = Sequential.new
|
|
|
21
21
|
|
|
22
22
|
model << InputLayer.new([28, 28, 1])
|
|
23
23
|
|
|
24
|
-
model << Conv2D.new(16,
|
|
24
|
+
model << Conv2D.new(16, 3)
|
|
25
25
|
model << BatchNormalization.new
|
|
26
26
|
model << ReLU.new
|
|
27
27
|
|
|
28
28
|
model << MaxPool2D.new(2)
|
|
29
29
|
|
|
30
|
-
model << Conv2D.new(32,
|
|
30
|
+
model << Conv2D.new(32, 3)
|
|
31
31
|
model << BatchNormalization.new
|
|
32
32
|
model << ReLU.new
|
|
33
33
|
|
|
@@ -42,7 +42,7 @@ model << Dense.new(10)
|
|
|
42
42
|
|
|
43
43
|
model.setup(Adam.new, SoftmaxCrossEntropy.new)
|
|
44
44
|
|
|
45
|
-
model.train(x_train, y_train, 10, batch_size:
|
|
45
|
+
model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
|
|
46
46
|
|
|
47
47
|
accuracy, loss = model.evaluate(x_test, y_test)
|
|
48
48
|
puts "accuracy: #{accuracy}"
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
require "dnn"
|
|
2
2
|
require "dnn/datasets/mnist"
|
|
3
3
|
# If you use numo/linalg then please uncomment out.
|
|
4
|
-
require "numo/linalg/autoloader"
|
|
4
|
+
# require "numo/linalg/autoloader"
|
|
5
5
|
|
|
6
6
|
include DNN::Models
|
|
7
7
|
include DNN::Layers
|
|
@@ -23,18 +23,18 @@ y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
|
|
23
23
|
class MLP < Model
|
|
24
24
|
def initialize
|
|
25
25
|
super
|
|
26
|
-
@
|
|
27
|
-
@
|
|
28
|
-
@
|
|
26
|
+
@d1 = Dense.new(256)
|
|
27
|
+
@d2 = Dense.new(256)
|
|
28
|
+
@d3 = Dense.new(10)
|
|
29
29
|
end
|
|
30
30
|
|
|
31
31
|
def forward(x)
|
|
32
32
|
x = InputLayer.new(784).(x)
|
|
33
|
-
x = @
|
|
33
|
+
x = @d1.(x)
|
|
34
34
|
x = ReLU.(x)
|
|
35
|
-
x = @
|
|
35
|
+
x = @d2.(x)
|
|
36
36
|
x = ReLU.(x)
|
|
37
|
-
x = @
|
|
37
|
+
x = @d3.(x)
|
|
38
38
|
x
|
|
39
39
|
end
|
|
40
40
|
end
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
require "cumo/narray"
|
|
2
|
+
require "dnn"
|
|
3
|
+
require "dnn/datasets/mnist"
|
|
4
|
+
|
|
5
|
+
include DNN::Models
|
|
6
|
+
include DNN::Layers
|
|
7
|
+
include DNN::Optimizers
|
|
8
|
+
include DNN::Losses
|
|
9
|
+
|
|
10
|
+
x_train, y_train = DNN::MNIST.load_train
|
|
11
|
+
x_test, y_test = DNN::MNIST.load_test
|
|
12
|
+
|
|
13
|
+
x_train = x_train.reshape(x_train.shape[0], 784)
|
|
14
|
+
x_test = x_test.reshape(x_test.shape[0], 784)
|
|
15
|
+
|
|
16
|
+
x_train = Numo::SFloat.cast(x_train) / 255
|
|
17
|
+
x_test = Numo::SFloat.cast(x_test) / 255
|
|
18
|
+
|
|
19
|
+
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
|
20
|
+
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
|
21
|
+
|
|
22
|
+
if DNN.use_cumo?
|
|
23
|
+
x_train = DNN::Utils.numo2cumo(x_train)
|
|
24
|
+
y_train = DNN::Utils.numo2cumo(y_train)
|
|
25
|
+
x_test = DNN::Utils.numo2cumo(x_test)
|
|
26
|
+
y_test = DNN::Utils.numo2cumo(y_test)
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
model = Sequential.new
|
|
30
|
+
|
|
31
|
+
model << InputLayer.new(784)
|
|
32
|
+
|
|
33
|
+
model << Dense.new(256)
|
|
34
|
+
model << ReLU.new
|
|
35
|
+
|
|
36
|
+
model << Dense.new(256)
|
|
37
|
+
model << ReLU.new
|
|
38
|
+
|
|
39
|
+
model << Dense.new(10)
|
|
40
|
+
|
|
41
|
+
model.setup(Adam.new, SoftmaxCrossEntropy.new)
|
|
42
|
+
|
|
43
|
+
model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
|
|
44
|
+
|
|
45
|
+
accuracy, loss = model.evaluate(x_test, y_test)
|
|
46
|
+
puts "accuracy: #{accuracy}"
|
|
47
|
+
puts "loss: #{loss}"
|
|
@@ -31,7 +31,7 @@ model << Dense.new(10)
|
|
|
31
31
|
|
|
32
32
|
model.setup(Adam.new, SoftmaxCrossEntropy.new)
|
|
33
33
|
|
|
34
|
-
model.train(x_train, y_train, 10, batch_size:
|
|
34
|
+
model.train(x_train, y_train, 10, batch_size: 128, test: [x_test, y_test])
|
|
35
35
|
|
|
36
36
|
accuracy, loss = model.evaluate(x_test, y_test)
|
|
37
37
|
puts "accuracy: #{accuracy}"
|
data/examples/pix2pix/dcgan.rb
CHANGED
|
@@ -2,20 +2,19 @@ include DNN::Models
|
|
|
2
2
|
include DNN::Layers
|
|
3
3
|
|
|
4
4
|
class Generator < Model
|
|
5
|
-
def initialize(input_shape)
|
|
5
|
+
def initialize(input_shape, base_num_filters)
|
|
6
6
|
super()
|
|
7
7
|
@input_shape = input_shape
|
|
8
|
-
@
|
|
9
|
-
@
|
|
10
|
-
@
|
|
11
|
-
@
|
|
12
|
-
@
|
|
13
|
-
@
|
|
14
|
-
@
|
|
15
|
-
@
|
|
16
|
-
@
|
|
17
|
-
@
|
|
18
|
-
@l11 = Conv2D.new(3, 4, padding: true)
|
|
8
|
+
@cv1 = Conv2D.new(base_num_filters, 4, padding: true)
|
|
9
|
+
@cv2 = Conv2D.new(base_num_filters, 4, strides: 2, padding: true)
|
|
10
|
+
@cv3 = Conv2D.new(base_num_filters * 2, 4, padding: true)
|
|
11
|
+
@cv4 = Conv2D.new(base_num_filters * 2, 4, strides: 2, padding: true)
|
|
12
|
+
@cv5 = Conv2D.new(base_num_filters * 2, 4, padding: true)
|
|
13
|
+
@cv6 = Conv2D.new(base_num_filters, 4, padding: true)
|
|
14
|
+
@cv7 = Conv2D.new(base_num_filters, 4, padding: true)
|
|
15
|
+
@cv8 = Conv2D.new(3, 4, padding: true)
|
|
16
|
+
@cvt1 = Conv2DTranspose.new(base_num_filters * 2, 4, strides: 2, padding: true)
|
|
17
|
+
@cvt2 = Conv2DTranspose.new(base_num_filters, 4, strides: 2, padding: true)
|
|
19
18
|
@bn1 = BatchNormalization.new
|
|
20
19
|
@bn2 = BatchNormalization.new
|
|
21
20
|
@bn3 = BatchNormalization.new
|
|
@@ -24,113 +23,102 @@ class Generator < Model
|
|
|
24
23
|
@bn6 = BatchNormalization.new
|
|
25
24
|
@bn7 = BatchNormalization.new
|
|
26
25
|
@bn8 = BatchNormalization.new
|
|
27
|
-
@bn9 = BatchNormalization.new
|
|
28
26
|
end
|
|
29
27
|
|
|
30
28
|
def forward(x)
|
|
31
29
|
input = InputLayer.new(@input_shape).(x)
|
|
32
|
-
x = @
|
|
30
|
+
x = @cv1.(input)
|
|
33
31
|
x = @bn1.(x)
|
|
34
|
-
h1 =
|
|
32
|
+
h1 = LeakyReLU.(x, 0.2)
|
|
35
33
|
|
|
36
|
-
x = @
|
|
34
|
+
x = @cv2.(h1)
|
|
37
35
|
x = @bn2.(x)
|
|
38
|
-
x =
|
|
36
|
+
x = LeakyReLU.(x, 0.2)
|
|
39
37
|
|
|
40
|
-
x = @
|
|
38
|
+
x = @cv3.(x)
|
|
41
39
|
x = @bn3.(x)
|
|
42
|
-
h2 =
|
|
40
|
+
h2 = LeakyReLU.(x, 0.2)
|
|
43
41
|
|
|
44
|
-
x = @
|
|
42
|
+
x = @cv4.(h2)
|
|
45
43
|
x = @bn4.(x)
|
|
46
|
-
x =
|
|
44
|
+
x = LeakyReLU.(x, 0.2)
|
|
47
45
|
|
|
48
|
-
x = @
|
|
46
|
+
x = @cv5.(x)
|
|
49
47
|
x = @bn5.(x)
|
|
50
|
-
x =
|
|
48
|
+
x = LeakyReLU.(x, 0.2)
|
|
51
49
|
|
|
52
|
-
x = @
|
|
50
|
+
x = @cvt1.(x)
|
|
53
51
|
x = @bn6.(x)
|
|
54
|
-
x =
|
|
52
|
+
x = LeakyReLU.(x, 0.2)
|
|
53
|
+
x = Concatenate.(x, h2, axis: 3)
|
|
55
54
|
|
|
56
|
-
x = @
|
|
55
|
+
x = @cv6.(x)
|
|
57
56
|
x = @bn7.(x)
|
|
58
|
-
x =
|
|
59
|
-
x = Concatenate.(x, h2, axis: 3)
|
|
57
|
+
x = LeakyReLU.(x, 0.2)
|
|
60
58
|
|
|
61
|
-
x = @
|
|
59
|
+
x = @cvt2.(x)
|
|
62
60
|
x = @bn8.(x)
|
|
63
|
-
x =
|
|
64
|
-
|
|
65
|
-
x = @l9.(x)
|
|
66
|
-
x = @bn9.(x)
|
|
67
|
-
x = ReLU.(x)
|
|
61
|
+
x = LeakyReLU.(x, 0.2)
|
|
68
62
|
x = Concatenate.(x, h1, axis: 3)
|
|
69
63
|
|
|
70
|
-
x = @
|
|
71
|
-
x =
|
|
64
|
+
x = @cv7.(x)
|
|
65
|
+
x = LeakyReLU.(x, 0.2)
|
|
72
66
|
|
|
73
|
-
x = @
|
|
67
|
+
x = @cv8.(x)
|
|
74
68
|
x = Tanh.(x)
|
|
75
69
|
x
|
|
76
70
|
end
|
|
77
71
|
end
|
|
78
72
|
|
|
79
73
|
class Discriminator < Model
|
|
80
|
-
def initialize(gen_input_shape, gen_output_shape)
|
|
74
|
+
def initialize(gen_input_shape, gen_output_shape, base_num_filters)
|
|
81
75
|
super()
|
|
82
76
|
@gen_input_shape = gen_input_shape
|
|
83
77
|
@gen_output_shape = gen_output_shape
|
|
84
|
-
@
|
|
85
|
-
@
|
|
86
|
-
@
|
|
87
|
-
@
|
|
88
|
-
@
|
|
89
|
-
@
|
|
90
|
-
@
|
|
91
|
-
@
|
|
92
|
-
@
|
|
78
|
+
@cv1_1 = Conv2D.new(base_num_filters, 4, padding: true)
|
|
79
|
+
@cv1_2 = Conv2D.new(base_num_filters, 4, padding: true)
|
|
80
|
+
@cv2 = Conv2D.new(base_num_filters, 4, strides: 2, padding: true)
|
|
81
|
+
@cv3 = Conv2D.new(base_num_filters * 2, 4, padding: true)
|
|
82
|
+
@cv4 = Conv2D.new(base_num_filters * 2, 4, strides: 2, padding: true)
|
|
83
|
+
@d1 = Dense.new(1024)
|
|
84
|
+
@d2 = Dense.new(1)
|
|
85
|
+
@bn1_1 = BatchNormalization.new
|
|
86
|
+
@bn1_2 = BatchNormalization.new
|
|
93
87
|
@bn2 = BatchNormalization.new
|
|
94
88
|
@bn3 = BatchNormalization.new
|
|
95
89
|
@bn4 = BatchNormalization.new
|
|
96
|
-
@bn5 = BatchNormalization.new
|
|
97
|
-
@bn6 = BatchNormalization.new
|
|
98
90
|
end
|
|
99
91
|
|
|
100
92
|
def forward(inputs)
|
|
101
93
|
input, images = *inputs
|
|
102
94
|
x = InputLayer.new(@gen_input_shape).(input)
|
|
103
|
-
x = @
|
|
104
|
-
x = @
|
|
95
|
+
x = @cv1_1.(x)
|
|
96
|
+
x = @bn1_1.(x)
|
|
105
97
|
x1 = LeakyReLU.(x, 0.2)
|
|
106
98
|
|
|
107
99
|
x = InputLayer.new(@gen_output_shape).(images)
|
|
108
|
-
x = @
|
|
109
|
-
x = @
|
|
100
|
+
x = @cv1_2.(x)
|
|
101
|
+
x = @bn1_2.(x)
|
|
110
102
|
x2 = LeakyReLU.(x, 0.2)
|
|
111
103
|
|
|
112
104
|
x = Concatenate.(x1, x2)
|
|
113
|
-
x = @
|
|
114
|
-
x = @
|
|
115
|
-
x = LeakyReLU.(x, 0.2)
|
|
116
|
-
|
|
117
|
-
x = @l3.(x)
|
|
118
|
-
x = @bn4.(x)
|
|
105
|
+
x = @cv2.(x)
|
|
106
|
+
x = @bn2.(x)
|
|
119
107
|
x = LeakyReLU.(x, 0.2)
|
|
120
108
|
|
|
121
|
-
x = @
|
|
122
|
-
x = @
|
|
109
|
+
x = @cv3.(x)
|
|
110
|
+
x = @bn3.(x)
|
|
123
111
|
x = LeakyReLU.(x, 0.2)
|
|
124
112
|
|
|
125
|
-
x = @
|
|
126
|
-
x = @
|
|
113
|
+
x = @cv4.(x)
|
|
114
|
+
x = @bn4.(x)
|
|
127
115
|
x = LeakyReLU.(x, 0.2)
|
|
128
116
|
|
|
129
117
|
x = Flatten.(x)
|
|
130
|
-
x = @
|
|
118
|
+
x = @d1.(x)
|
|
131
119
|
x = LeakyReLU.(x, 0.2)
|
|
132
120
|
|
|
133
|
-
x = @
|
|
121
|
+
x = @d2.(x)
|
|
134
122
|
x
|
|
135
123
|
end
|
|
136
124
|
|
|
@@ -139,7 +127,7 @@ class Discriminator < Model
|
|
|
139
127
|
layer.trainable = true
|
|
140
128
|
end
|
|
141
129
|
end
|
|
142
|
-
|
|
130
|
+
|
|
143
131
|
def disable_training
|
|
144
132
|
trainable_layers.each do |layer|
|
|
145
133
|
layer.trainable = false
|