ruby-dnn 1.1.5 → 1.1.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/examples/judge-number/README.md +29 -0
- data/examples/judge-number/capture.PNG +0 -0
- data/examples/judge-number/convnet8.rb +70 -0
- data/examples/judge-number/make_weights.rb +5 -0
- data/examples/judge-number/mnist_predict.rb +20 -0
- data/examples/judge-number/mnist_train.rb +19 -0
- data/examples/judge-number/public/httpRequest.js +44 -0
- data/examples/judge-number/public/judgeNumber.js +61 -0
- data/examples/judge-number/server.rb +19 -0
- data/examples/judge-number/trained_mnist_params.marshal +0 -0
- data/examples/judge-number/views/index.erb +7 -0
- data/lib/dnn/core/losses.rb +0 -6
- data/lib/dnn/core/models.rb +9 -0
- data/lib/dnn/image.rb +13 -0
- data/lib/dnn/version.rb +1 -1
- metadata +13 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6b3228eaf257e50c55bdd51b348e71dcc45d74a7cc14668231c4a5e1c9fed318
|
4
|
+
data.tar.gz: ae3a1217cb1aa0a3d0f50ad2b042fd255385a713c7787ad318f34410170dbb83
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 977ddba86307314f9af523643e151f397705600f2aa244980dfefbda8af7c74114dd4b17d515c64d30e412d8e6aa3de8a127ba085dd8666e03a99ac2cb1fd573
|
7
|
+
data.tar.gz: 80e1d0a963e066aa74f0e9e9e786113ea8505112990de217b6d3a7b7b9f60b06f5ee777e94fb95fb7f006fd92f7ef4b11b3aaa87f605ae66115c30d007055657
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# Prepare
|
2
|
+
This example use to sinatra.
|
3
|
+
|
4
|
+
```
|
5
|
+
$ gem install sinatra
|
6
|
+
$ gem install sinatra-contrib
|
7
|
+
```
|
8
|
+
|
9
|
+
# Let's try
|
10
|
+
This example prepared weights that have already been trained.
|
11
|
+
If you want to try it immediately, skip steps (1) and (2).
|
12
|
+
|
13
|
+
### (1) Training MNIST
|
14
|
+
```
|
15
|
+
$ ruby mnist_train.rb
|
16
|
+
```
|
17
|
+
|
18
|
+
### (2) Make weights
|
19
|
+
```
|
20
|
+
$ ruby make_weights.rb
|
21
|
+
```
|
22
|
+
|
23
|
+
### (3) Launch sinatra server
|
24
|
+
```
|
25
|
+
$ ruby server.rb
|
26
|
+
```
|
27
|
+
|
28
|
+
### (4) Access 127.0.0.1:4567 with your browser
|
29
|
+
![](capture.PNG)
|
Binary file
|
@@ -0,0 +1,70 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "numo/linalg/autoloader"
|
3
|
+
|
4
|
+
include DNN::Models
|
5
|
+
include DNN::Layers
|
6
|
+
include DNN::Optimizers
|
7
|
+
include DNN::Losses
|
8
|
+
|
9
|
+
class ConvNet < Model
|
10
|
+
def self.create(input_shape)
|
11
|
+
convnet = ConvNet.new(input_shape, 32)
|
12
|
+
convnet.setup(Adam.new, SoftmaxCrossEntropy.new)
|
13
|
+
convnet
|
14
|
+
end
|
15
|
+
|
16
|
+
def initialize(input_shape, base_filter_size)
|
17
|
+
super()
|
18
|
+
@input_shape = input_shape
|
19
|
+
@cv1 = Conv2D.new(base_filter_size, 3, padding: true)
|
20
|
+
@cv2 = Conv2D.new(base_filter_size, 3, padding: true)
|
21
|
+
@cv3 = Conv2D.new(base_filter_size * 2, 3, padding: true)
|
22
|
+
@cv4 = Conv2D.new(base_filter_size * 2, 3, padding: true)
|
23
|
+
@cv5 = Conv2D.new(base_filter_size * 4, 3, padding: true)
|
24
|
+
@cv6 = Conv2D.new(base_filter_size * 4, 3, padding: true)
|
25
|
+
@bn1 = BatchNormalization.new
|
26
|
+
@bn2 = BatchNormalization.new
|
27
|
+
@bn3 = BatchNormalization.new
|
28
|
+
@bn4 = BatchNormalization.new
|
29
|
+
@d1 = Dense.new(512)
|
30
|
+
@d2 = Dense.new(10)
|
31
|
+
end
|
32
|
+
|
33
|
+
def forward(x)
|
34
|
+
x = InputLayer.new(@input_shape).(x)
|
35
|
+
|
36
|
+
x = @cv1.(x)
|
37
|
+
x = ReLU.(x)
|
38
|
+
x = Dropout.(x, 0.25)
|
39
|
+
|
40
|
+
x = @cv2.(x)
|
41
|
+
x = @bn1.(x)
|
42
|
+
x = ReLU.(x)
|
43
|
+
x = MaxPool2D.(x, 2)
|
44
|
+
|
45
|
+
x = @cv3.(x)
|
46
|
+
x = ReLU.(x)
|
47
|
+
x = Dropout.(x, 0.25)
|
48
|
+
|
49
|
+
x = @cv4.(x)
|
50
|
+
x = @bn2.(x)
|
51
|
+
x = ReLU.(x)
|
52
|
+
x = MaxPool2D.(x, 2)
|
53
|
+
|
54
|
+
x = @cv5.(x)
|
55
|
+
x = ReLU.(x)
|
56
|
+
x = Dropout.(x, 0.25)
|
57
|
+
|
58
|
+
x = @cv6.(x)
|
59
|
+
x = @bn3.(x)
|
60
|
+
x = ReLU.(x)
|
61
|
+
x = MaxPool2D.(x, 2)
|
62
|
+
|
63
|
+
x = Flatten.(x)
|
64
|
+
x = @d1.(x)
|
65
|
+
x = @bn4.(x)
|
66
|
+
x = ReLU.(x)
|
67
|
+
x = @d2.(x)
|
68
|
+
x
|
69
|
+
end
|
70
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/image"
|
3
|
+
require_relative "convnet8"
|
4
|
+
|
5
|
+
def load_model
|
6
|
+
return if $model
|
7
|
+
$model = ConvNet.create([28, 28, 1])
|
8
|
+
$model.predict1(Numo::SFloat.zeros(28, 28, 1))
|
9
|
+
$model.load_params("trained_mnist_params.marshal")
|
10
|
+
end
|
11
|
+
|
12
|
+
def mnist_predict(img, width, height)
|
13
|
+
load_model
|
14
|
+
img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
|
15
|
+
img = img[true, true, 0...DNN::Image::RGB]
|
16
|
+
img = DNN::Image.to_gray_scale(img)
|
17
|
+
x = Numo::SFloat.cast(img) / 255
|
18
|
+
out = $model.predict1(x)
|
19
|
+
out.to_a.map { |v| v.round(4) * 100 }
|
20
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/datasets/mnist"
|
3
|
+
require_relative "convnet8"
|
4
|
+
|
5
|
+
include DNN::Callbacks
|
6
|
+
|
7
|
+
x_train, y_train = DNN::MNIST.load_train
|
8
|
+
x_test, y_test = DNN::MNIST.load_test
|
9
|
+
|
10
|
+
x_train = Numo::SFloat.cast(x_train) / 255
|
11
|
+
x_test = Numo::SFloat.cast(x_test) / 255
|
12
|
+
|
13
|
+
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
14
|
+
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
15
|
+
|
16
|
+
model = ConvNet.create([28, 28, 1])
|
17
|
+
model.add_callback(CheckPoint.new("trained/trained_mnist", interval: 5))
|
18
|
+
|
19
|
+
model.train(x_train, y_train, 20, batch_size: 128, test: [x_test, y_test])
|
@@ -0,0 +1,44 @@
|
|
1
|
+
class HttpRequest {
|
2
|
+
static get(path, responseCallback) {
|
3
|
+
const req = new HttpRequest(path, "GET", responseCallback);
|
4
|
+
req.send();
|
5
|
+
return req;
|
6
|
+
}
|
7
|
+
|
8
|
+
static post(path, params, responseCallback) {
|
9
|
+
const req = new HttpRequest(path, "POST", responseCallback);
|
10
|
+
req.send(params);
|
11
|
+
return req;
|
12
|
+
}
|
13
|
+
|
14
|
+
constructor(path, method, responseCallback) {
|
15
|
+
this._path = path;
|
16
|
+
this._method = method;
|
17
|
+
this._responseCallback = responseCallback;
|
18
|
+
}
|
19
|
+
|
20
|
+
send(params = null) {
|
21
|
+
const xhr = new XMLHttpRequest();
|
22
|
+
xhr.open(this._method, this._path);
|
23
|
+
let json = null;
|
24
|
+
if (params) json = JSON.stringify(params);
|
25
|
+
xhr.addEventListener("load", (e) => {
|
26
|
+
const res = {
|
27
|
+
response: xhr.response,
|
28
|
+
event: e
|
29
|
+
};
|
30
|
+
this._responseCallback(res);
|
31
|
+
});
|
32
|
+
xhr.send(json);
|
33
|
+
}
|
34
|
+
}
|
35
|
+
|
36
|
+
class Base64 {
|
37
|
+
static encode(obj) {
|
38
|
+
if (typeof(obj) === "string") {
|
39
|
+
return btoa(obj);
|
40
|
+
} else if (obj instanceof Uint8Array || obj instanceof Uint8ClampedArray) {
|
41
|
+
return btoa(String.fromCharCode(...obj));
|
42
|
+
}
|
43
|
+
}
|
44
|
+
}
|
@@ -0,0 +1,61 @@
|
|
1
|
+
const drawCanvas = document.getElementById("draw");
|
2
|
+
const viewCanvas = document.getElementById("view");
|
3
|
+
|
4
|
+
const drawContext = drawCanvas.getContext("2d");
|
5
|
+
drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
6
|
+
const viewContext = viewCanvas.getContext("2d");
|
7
|
+
viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
8
|
+
|
9
|
+
const judgeButton = document.getElementById("judge");
|
10
|
+
const clearButton = document.getElementById("clear");
|
11
|
+
|
12
|
+
const resultArea = document.getElementById("result");
|
13
|
+
|
14
|
+
const updateResult = (classification) => {
|
15
|
+
let str = "";
|
16
|
+
for(let i = 0; i <= 9; i++){
|
17
|
+
str += `${i}: ${classification[i]}%<br>`;
|
18
|
+
}
|
19
|
+
resultArea.innerHTML = str;
|
20
|
+
};
|
21
|
+
|
22
|
+
judgeButton.addEventListener("click", () =>{
|
23
|
+
viewContext.drawImage(drawCanvas, 0, 0, viewCanvas.width, viewCanvas.height);
|
24
|
+
const data = viewContext.getImageData(0, 0, viewCanvas.width, viewCanvas.height).data;
|
25
|
+
params = {
|
26
|
+
img: Base64.encode(data),
|
27
|
+
width: viewCanvas.width,
|
28
|
+
height: viewCanvas.height,
|
29
|
+
}
|
30
|
+
HttpRequest.post("/predict", params, (res) => {
|
31
|
+
updateResult(JSON.parse(res.response));
|
32
|
+
});
|
33
|
+
});
|
34
|
+
|
35
|
+
clearButton.addEventListener("click", () =>{
|
36
|
+
drawContext.fillStyle = "black";
|
37
|
+
drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
38
|
+
viewContext.fillStyle = "black";
|
39
|
+
viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
40
|
+
result.innerHTML = "";
|
41
|
+
});
|
42
|
+
|
43
|
+
let mouseDown = false;
|
44
|
+
|
45
|
+
window.addEventListener("mousedown", e =>{
|
46
|
+
mouseDown = true;
|
47
|
+
});
|
48
|
+
|
49
|
+
window.addEventListener("mouseup", e =>{
|
50
|
+
mouseDown = false;
|
51
|
+
});
|
52
|
+
|
53
|
+
drawCanvas.addEventListener("mousemove", e =>{
|
54
|
+
if(mouseDown){
|
55
|
+
let rect = e.target.getBoundingClientRect();
|
56
|
+
let x = e.clientX - 10 - rect.left;
|
57
|
+
let y = e.clientY - 10 - rect.top;
|
58
|
+
drawContext.fillStyle = "white";
|
59
|
+
drawContext.fillRect(x, y, 20, 20);
|
60
|
+
}
|
61
|
+
});
|
@@ -0,0 +1,19 @@
|
|
1
|
+
require "sinatra"
|
2
|
+
require "sinatra/reloader"
|
3
|
+
require "json"
|
4
|
+
require "base64"
|
5
|
+
require_relative "mnist_predict"
|
6
|
+
|
7
|
+
get "/" do
|
8
|
+
erb :index
|
9
|
+
end
|
10
|
+
|
11
|
+
post "/predict" do
|
12
|
+
json = request.body.read
|
13
|
+
params = JSON.parse(json, symbolize_names: true)
|
14
|
+
img = Base64.decode64(params[:img])
|
15
|
+
width = params[:width].to_i
|
16
|
+
height = params[:height].to_i
|
17
|
+
result = mnist_predict(img, width, height)
|
18
|
+
JSON.dump(result)
|
19
|
+
end
|
Binary file
|
@@ -0,0 +1,7 @@
|
|
1
|
+
<canvas id="draw" width=256 height=256></canvas>
|
2
|
+
<canvas id="view" width=28 height=28></canvas>
|
3
|
+
<button id="judge">Judge</button>
|
4
|
+
<button id="clear">Clear</button>
|
5
|
+
<p id="result"></p>
|
6
|
+
<script src="judgeNumber.js"></script>
|
7
|
+
<script src="httpRequest.js"></script>
|
data/lib/dnn/core/losses.rb
CHANGED
@@ -42,12 +42,6 @@ module DNN
|
|
42
42
|
loss
|
43
43
|
end
|
44
44
|
|
45
|
-
def regularizers_backward(layers)
|
46
|
-
layers.select { |layer| layer.respond_to?(:regularizers) }.each do |layer|
|
47
|
-
layer.regularizers.each(&:backward)
|
48
|
-
end
|
49
|
-
end
|
50
|
-
|
51
45
|
def to_hash(merge_hash = nil)
|
52
46
|
hash = { class: self.class.name }
|
53
47
|
hash.merge!(merge_hash) if merge_hash
|
data/lib/dnn/core/models.rb
CHANGED
@@ -474,6 +474,15 @@ module DNN
|
|
474
474
|
@callbacks << callback
|
475
475
|
end
|
476
476
|
|
477
|
+
# Add lambda callback.
|
478
|
+
# @param [Symbol] event Event to execute callback.
|
479
|
+
# @yield Register the contents of the callback.
|
480
|
+
def add_lambda_callback(event, &block)
|
481
|
+
callback = Callbacks::LambdaCallback.new(event, &block)
|
482
|
+
callback.model = self
|
483
|
+
@callbacks << callback
|
484
|
+
end
|
485
|
+
|
477
486
|
# Clear the callback function registered for each event.
|
478
487
|
def clear_callbacks
|
479
488
|
@callbacks = []
|
data/lib/dnn/image.rb
CHANGED
@@ -64,6 +64,19 @@ module DNN
|
|
64
64
|
raise ImageWriteError, "Image write failed." if res == 0
|
65
65
|
end
|
66
66
|
|
67
|
+
# Create an image from binary.
|
68
|
+
# @param [String] bin binary data.
|
69
|
+
# @param [Integer] height Image height.
|
70
|
+
# @param [Integer] width Image width.
|
71
|
+
# @param [Integer] channel Image channel.
|
72
|
+
def self.from_binary(bin, height, width, channel = DNN::Image::RGB)
|
73
|
+
expected_size = height * width * channel
|
74
|
+
unless bin.size == expected_size
|
75
|
+
raise ImageError, "binary size is #{bin.size}, but expected binary size is #{expected_size}"
|
76
|
+
end
|
77
|
+
Numo::UInt8.from_binary(bin).reshape(height, width, channel)
|
78
|
+
end
|
79
|
+
|
67
80
|
# Resize the image.
|
68
81
|
# @param [Numo::UInt8] img Image to resize.
|
69
82
|
# @param [Integer] out_height Image height to resize.
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.1.
|
4
|
+
version: 1.1.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-03
|
11
|
+
date: 2020-05-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -121,6 +121,17 @@ files:
|
|
121
121
|
- examples/dcgan/imgen.rb
|
122
122
|
- examples/dcgan/train.rb
|
123
123
|
- examples/iris_example.rb
|
124
|
+
- examples/judge-number/README.md
|
125
|
+
- examples/judge-number/capture.PNG
|
126
|
+
- examples/judge-number/convnet8.rb
|
127
|
+
- examples/judge-number/make_weights.rb
|
128
|
+
- examples/judge-number/mnist_predict.rb
|
129
|
+
- examples/judge-number/mnist_train.rb
|
130
|
+
- examples/judge-number/public/httpRequest.js
|
131
|
+
- examples/judge-number/public/judgeNumber.js
|
132
|
+
- examples/judge-number/server.rb
|
133
|
+
- examples/judge-number/trained_mnist_params.marshal
|
134
|
+
- examples/judge-number/views/index.erb
|
124
135
|
- examples/mnist_conv2d_example.rb
|
125
136
|
- examples/mnist_define_by_run.rb
|
126
137
|
- examples/mnist_example.rb
|