ruby-dnn 1.1.5 → 1.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/examples/judge-number/README.md +29 -0
- data/examples/judge-number/capture.PNG +0 -0
- data/examples/judge-number/convnet8.rb +70 -0
- data/examples/judge-number/make_weights.rb +5 -0
- data/examples/judge-number/mnist_predict.rb +20 -0
- data/examples/judge-number/mnist_train.rb +19 -0
- data/examples/judge-number/public/httpRequest.js +44 -0
- data/examples/judge-number/public/judgeNumber.js +61 -0
- data/examples/judge-number/server.rb +19 -0
- data/examples/judge-number/trained_mnist_params.marshal +0 -0
- data/examples/judge-number/views/index.erb +7 -0
- data/lib/dnn/core/losses.rb +0 -6
- data/lib/dnn/core/models.rb +9 -0
- data/lib/dnn/image.rb +13 -0
- data/lib/dnn/version.rb +1 -1
- metadata +13 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6b3228eaf257e50c55bdd51b348e71dcc45d74a7cc14668231c4a5e1c9fed318
|
4
|
+
data.tar.gz: ae3a1217cb1aa0a3d0f50ad2b042fd255385a713c7787ad318f34410170dbb83
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 977ddba86307314f9af523643e151f397705600f2aa244980dfefbda8af7c74114dd4b17d515c64d30e412d8e6aa3de8a127ba085dd8666e03a99ac2cb1fd573
|
7
|
+
data.tar.gz: 80e1d0a963e066aa74f0e9e9e786113ea8505112990de217b6d3a7b7b9f60b06f5ee777e94fb95fb7f006fd92f7ef4b11b3aaa87f605ae66115c30d007055657
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# Prepare
|
2
|
+
This example use to sinatra.
|
3
|
+
|
4
|
+
```
|
5
|
+
$ gem install sinatra
|
6
|
+
$ gem install sinatra-contrib
|
7
|
+
```
|
8
|
+
|
9
|
+
# Let's try
|
10
|
+
This example prepared weights that have already been trained.
|
11
|
+
If you want to try it immediately, skip steps (1) and (2).
|
12
|
+
|
13
|
+
### (1) Training MNIST
|
14
|
+
```
|
15
|
+
$ ruby mnist_train.rb
|
16
|
+
```
|
17
|
+
|
18
|
+
### (2) Make weights
|
19
|
+
```
|
20
|
+
$ ruby make_weights.rb
|
21
|
+
```
|
22
|
+
|
23
|
+
### (3) Launch sinatra server
|
24
|
+
```
|
25
|
+
$ ruby server.rb
|
26
|
+
```
|
27
|
+
|
28
|
+
### (4) Access 127.0.0.1:4567 with your browser
|
29
|
+

|
Binary file
|
@@ -0,0 +1,70 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "numo/linalg/autoloader"
|
3
|
+
|
4
|
+
include DNN::Models
|
5
|
+
include DNN::Layers
|
6
|
+
include DNN::Optimizers
|
7
|
+
include DNN::Losses
|
8
|
+
|
9
|
+
class ConvNet < Model
|
10
|
+
def self.create(input_shape)
|
11
|
+
convnet = ConvNet.new(input_shape, 32)
|
12
|
+
convnet.setup(Adam.new, SoftmaxCrossEntropy.new)
|
13
|
+
convnet
|
14
|
+
end
|
15
|
+
|
16
|
+
def initialize(input_shape, base_filter_size)
|
17
|
+
super()
|
18
|
+
@input_shape = input_shape
|
19
|
+
@cv1 = Conv2D.new(base_filter_size, 3, padding: true)
|
20
|
+
@cv2 = Conv2D.new(base_filter_size, 3, padding: true)
|
21
|
+
@cv3 = Conv2D.new(base_filter_size * 2, 3, padding: true)
|
22
|
+
@cv4 = Conv2D.new(base_filter_size * 2, 3, padding: true)
|
23
|
+
@cv5 = Conv2D.new(base_filter_size * 4, 3, padding: true)
|
24
|
+
@cv6 = Conv2D.new(base_filter_size * 4, 3, padding: true)
|
25
|
+
@bn1 = BatchNormalization.new
|
26
|
+
@bn2 = BatchNormalization.new
|
27
|
+
@bn3 = BatchNormalization.new
|
28
|
+
@bn4 = BatchNormalization.new
|
29
|
+
@d1 = Dense.new(512)
|
30
|
+
@d2 = Dense.new(10)
|
31
|
+
end
|
32
|
+
|
33
|
+
def forward(x)
|
34
|
+
x = InputLayer.new(@input_shape).(x)
|
35
|
+
|
36
|
+
x = @cv1.(x)
|
37
|
+
x = ReLU.(x)
|
38
|
+
x = Dropout.(x, 0.25)
|
39
|
+
|
40
|
+
x = @cv2.(x)
|
41
|
+
x = @bn1.(x)
|
42
|
+
x = ReLU.(x)
|
43
|
+
x = MaxPool2D.(x, 2)
|
44
|
+
|
45
|
+
x = @cv3.(x)
|
46
|
+
x = ReLU.(x)
|
47
|
+
x = Dropout.(x, 0.25)
|
48
|
+
|
49
|
+
x = @cv4.(x)
|
50
|
+
x = @bn2.(x)
|
51
|
+
x = ReLU.(x)
|
52
|
+
x = MaxPool2D.(x, 2)
|
53
|
+
|
54
|
+
x = @cv5.(x)
|
55
|
+
x = ReLU.(x)
|
56
|
+
x = Dropout.(x, 0.25)
|
57
|
+
|
58
|
+
x = @cv6.(x)
|
59
|
+
x = @bn3.(x)
|
60
|
+
x = ReLU.(x)
|
61
|
+
x = MaxPool2D.(x, 2)
|
62
|
+
|
63
|
+
x = Flatten.(x)
|
64
|
+
x = @d1.(x)
|
65
|
+
x = @bn4.(x)
|
66
|
+
x = ReLU.(x)
|
67
|
+
x = @d2.(x)
|
68
|
+
x
|
69
|
+
end
|
70
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/image"
|
3
|
+
require_relative "convnet8"
|
4
|
+
|
5
|
+
def load_model
|
6
|
+
return if $model
|
7
|
+
$model = ConvNet.create([28, 28, 1])
|
8
|
+
$model.predict1(Numo::SFloat.zeros(28, 28, 1))
|
9
|
+
$model.load_params("trained_mnist_params.marshal")
|
10
|
+
end
|
11
|
+
|
12
|
+
def mnist_predict(img, width, height)
|
13
|
+
load_model
|
14
|
+
img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
|
15
|
+
img = img[true, true, 0...DNN::Image::RGB]
|
16
|
+
img = DNN::Image.to_gray_scale(img)
|
17
|
+
x = Numo::SFloat.cast(img) / 255
|
18
|
+
out = $model.predict1(x)
|
19
|
+
out.to_a.map { |v| v.round(4) * 100 }
|
20
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
require "dnn"
|
2
|
+
require "dnn/datasets/mnist"
|
3
|
+
require_relative "convnet8"
|
4
|
+
|
5
|
+
include DNN::Callbacks
|
6
|
+
|
7
|
+
x_train, y_train = DNN::MNIST.load_train
|
8
|
+
x_test, y_test = DNN::MNIST.load_test
|
9
|
+
|
10
|
+
x_train = Numo::SFloat.cast(x_train) / 255
|
11
|
+
x_test = Numo::SFloat.cast(x_test) / 255
|
12
|
+
|
13
|
+
y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
|
14
|
+
y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
|
15
|
+
|
16
|
+
model = ConvNet.create([28, 28, 1])
|
17
|
+
model.add_callback(CheckPoint.new("trained/trained_mnist", interval: 5))
|
18
|
+
|
19
|
+
model.train(x_train, y_train, 20, batch_size: 128, test: [x_test, y_test])
|
@@ -0,0 +1,44 @@
|
|
1
|
+
class HttpRequest {
|
2
|
+
static get(path, responseCallback) {
|
3
|
+
const req = new HttpRequest(path, "GET", responseCallback);
|
4
|
+
req.send();
|
5
|
+
return req;
|
6
|
+
}
|
7
|
+
|
8
|
+
static post(path, params, responseCallback) {
|
9
|
+
const req = new HttpRequest(path, "POST", responseCallback);
|
10
|
+
req.send(params);
|
11
|
+
return req;
|
12
|
+
}
|
13
|
+
|
14
|
+
constructor(path, method, responseCallback) {
|
15
|
+
this._path = path;
|
16
|
+
this._method = method;
|
17
|
+
this._responseCallback = responseCallback;
|
18
|
+
}
|
19
|
+
|
20
|
+
send(params = null) {
|
21
|
+
const xhr = new XMLHttpRequest();
|
22
|
+
xhr.open(this._method, this._path);
|
23
|
+
let json = null;
|
24
|
+
if (params) json = JSON.stringify(params);
|
25
|
+
xhr.addEventListener("load", (e) => {
|
26
|
+
const res = {
|
27
|
+
response: xhr.response,
|
28
|
+
event: e
|
29
|
+
};
|
30
|
+
this._responseCallback(res);
|
31
|
+
});
|
32
|
+
xhr.send(json);
|
33
|
+
}
|
34
|
+
}
|
35
|
+
|
36
|
+
class Base64 {
|
37
|
+
static encode(obj) {
|
38
|
+
if (typeof(obj) === "string") {
|
39
|
+
return btoa(obj);
|
40
|
+
} else if (obj instanceof Uint8Array || obj instanceof Uint8ClampedArray) {
|
41
|
+
return btoa(String.fromCharCode(...obj));
|
42
|
+
}
|
43
|
+
}
|
44
|
+
}
|
@@ -0,0 +1,61 @@
|
|
1
|
+
const drawCanvas = document.getElementById("draw");
|
2
|
+
const viewCanvas = document.getElementById("view");
|
3
|
+
|
4
|
+
const drawContext = drawCanvas.getContext("2d");
|
5
|
+
drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
6
|
+
const viewContext = viewCanvas.getContext("2d");
|
7
|
+
viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
8
|
+
|
9
|
+
const judgeButton = document.getElementById("judge");
|
10
|
+
const clearButton = document.getElementById("clear");
|
11
|
+
|
12
|
+
const resultArea = document.getElementById("result");
|
13
|
+
|
14
|
+
const updateResult = (classification) => {
|
15
|
+
let str = "";
|
16
|
+
for(let i = 0; i <= 9; i++){
|
17
|
+
str += `${i}: ${classification[i]}%<br>`;
|
18
|
+
}
|
19
|
+
resultArea.innerHTML = str;
|
20
|
+
};
|
21
|
+
|
22
|
+
judgeButton.addEventListener("click", () =>{
|
23
|
+
viewContext.drawImage(drawCanvas, 0, 0, viewCanvas.width, viewCanvas.height);
|
24
|
+
const data = viewContext.getImageData(0, 0, viewCanvas.width, viewCanvas.height).data;
|
25
|
+
params = {
|
26
|
+
img: Base64.encode(data),
|
27
|
+
width: viewCanvas.width,
|
28
|
+
height: viewCanvas.height,
|
29
|
+
}
|
30
|
+
HttpRequest.post("/predict", params, (res) => {
|
31
|
+
updateResult(JSON.parse(res.response));
|
32
|
+
});
|
33
|
+
});
|
34
|
+
|
35
|
+
clearButton.addEventListener("click", () =>{
|
36
|
+
drawContext.fillStyle = "black";
|
37
|
+
drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
38
|
+
viewContext.fillStyle = "black";
|
39
|
+
viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
|
40
|
+
result.innerHTML = "";
|
41
|
+
});
|
42
|
+
|
43
|
+
let mouseDown = false;
|
44
|
+
|
45
|
+
window.addEventListener("mousedown", e =>{
|
46
|
+
mouseDown = true;
|
47
|
+
});
|
48
|
+
|
49
|
+
window.addEventListener("mouseup", e =>{
|
50
|
+
mouseDown = false;
|
51
|
+
});
|
52
|
+
|
53
|
+
drawCanvas.addEventListener("mousemove", e =>{
|
54
|
+
if(mouseDown){
|
55
|
+
let rect = e.target.getBoundingClientRect();
|
56
|
+
let x = e.clientX - 10 - rect.left;
|
57
|
+
let y = e.clientY - 10 - rect.top;
|
58
|
+
drawContext.fillStyle = "white";
|
59
|
+
drawContext.fillRect(x, y, 20, 20);
|
60
|
+
}
|
61
|
+
});
|
@@ -0,0 +1,19 @@
|
|
1
|
+
require "sinatra"
|
2
|
+
require "sinatra/reloader"
|
3
|
+
require "json"
|
4
|
+
require "base64"
|
5
|
+
require_relative "mnist_predict"
|
6
|
+
|
7
|
+
get "/" do
|
8
|
+
erb :index
|
9
|
+
end
|
10
|
+
|
11
|
+
post "/predict" do
|
12
|
+
json = request.body.read
|
13
|
+
params = JSON.parse(json, symbolize_names: true)
|
14
|
+
img = Base64.decode64(params[:img])
|
15
|
+
width = params[:width].to_i
|
16
|
+
height = params[:height].to_i
|
17
|
+
result = mnist_predict(img, width, height)
|
18
|
+
JSON.dump(result)
|
19
|
+
end
|
Binary file
|
@@ -0,0 +1,7 @@
|
|
1
|
+
<canvas id="draw" width=256 height=256></canvas>
|
2
|
+
<canvas id="view" width=28 height=28></canvas>
|
3
|
+
<button id="judge">Judge</button>
|
4
|
+
<button id="clear">Clear</button>
|
5
|
+
<p id="result"></p>
|
6
|
+
<script src="judgeNumber.js"></script>
|
7
|
+
<script src="httpRequest.js"></script>
|
data/lib/dnn/core/losses.rb
CHANGED
@@ -42,12 +42,6 @@ module DNN
|
|
42
42
|
loss
|
43
43
|
end
|
44
44
|
|
45
|
-
def regularizers_backward(layers)
|
46
|
-
layers.select { |layer| layer.respond_to?(:regularizers) }.each do |layer|
|
47
|
-
layer.regularizers.each(&:backward)
|
48
|
-
end
|
49
|
-
end
|
50
|
-
|
51
45
|
def to_hash(merge_hash = nil)
|
52
46
|
hash = { class: self.class.name }
|
53
47
|
hash.merge!(merge_hash) if merge_hash
|
data/lib/dnn/core/models.rb
CHANGED
@@ -474,6 +474,15 @@ module DNN
|
|
474
474
|
@callbacks << callback
|
475
475
|
end
|
476
476
|
|
477
|
+
# Add lambda callback.
|
478
|
+
# @param [Symbol] event Event to execute callback.
|
479
|
+
# @yield Register the contents of the callback.
|
480
|
+
def add_lambda_callback(event, &block)
|
481
|
+
callback = Callbacks::LambdaCallback.new(event, &block)
|
482
|
+
callback.model = self
|
483
|
+
@callbacks << callback
|
484
|
+
end
|
485
|
+
|
477
486
|
# Clear the callback function registered for each event.
|
478
487
|
def clear_callbacks
|
479
488
|
@callbacks = []
|
data/lib/dnn/image.rb
CHANGED
@@ -64,6 +64,19 @@ module DNN
|
|
64
64
|
raise ImageWriteError, "Image write failed." if res == 0
|
65
65
|
end
|
66
66
|
|
67
|
+
# Create an image from binary.
|
68
|
+
# @param [String] bin binary data.
|
69
|
+
# @param [Integer] height Image height.
|
70
|
+
# @param [Integer] width Image width.
|
71
|
+
# @param [Integer] channel Image channel.
|
72
|
+
def self.from_binary(bin, height, width, channel = DNN::Image::RGB)
|
73
|
+
expected_size = height * width * channel
|
74
|
+
unless bin.size == expected_size
|
75
|
+
raise ImageError, "binary size is #{bin.size}, but expected binary size is #{expected_size}"
|
76
|
+
end
|
77
|
+
Numo::UInt8.from_binary(bin).reshape(height, width, channel)
|
78
|
+
end
|
79
|
+
|
67
80
|
# Resize the image.
|
68
81
|
# @param [Numo::UInt8] img Image to resize.
|
69
82
|
# @param [Integer] out_height Image height to resize.
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.1.
|
4
|
+
version: 1.1.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-03
|
11
|
+
date: 2020-05-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -121,6 +121,17 @@ files:
|
|
121
121
|
- examples/dcgan/imgen.rb
|
122
122
|
- examples/dcgan/train.rb
|
123
123
|
- examples/iris_example.rb
|
124
|
+
- examples/judge-number/README.md
|
125
|
+
- examples/judge-number/capture.PNG
|
126
|
+
- examples/judge-number/convnet8.rb
|
127
|
+
- examples/judge-number/make_weights.rb
|
128
|
+
- examples/judge-number/mnist_predict.rb
|
129
|
+
- examples/judge-number/mnist_train.rb
|
130
|
+
- examples/judge-number/public/httpRequest.js
|
131
|
+
- examples/judge-number/public/judgeNumber.js
|
132
|
+
- examples/judge-number/server.rb
|
133
|
+
- examples/judge-number/trained_mnist_params.marshal
|
134
|
+
- examples/judge-number/views/index.erb
|
124
135
|
- examples/mnist_conv2d_example.rb
|
125
136
|
- examples/mnist_define_by_run.rb
|
126
137
|
- examples/mnist_example.rb
|