ruby-dnn 1.1.5 → 1.1.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 55dea04f1d2a6bb4806c3f029086474b46b8a225f7168e4c72af92e0f7d69f71
4
- data.tar.gz: 2522778ffabbce31315b48ad3abfedd2dea2e9dba0d09af4e9700cb9c588393b
3
+ metadata.gz: 6b3228eaf257e50c55bdd51b348e71dcc45d74a7cc14668231c4a5e1c9fed318
4
+ data.tar.gz: ae3a1217cb1aa0a3d0f50ad2b042fd255385a713c7787ad318f34410170dbb83
5
5
  SHA512:
6
- metadata.gz: 05056e7619f52dd8efac34c6aeae1e1652e3635257baf302f15908131d6908d7ee49e892040514b0ea3011c46148eb3c0fedd3857b85ab0cf271c521bb907349
7
- data.tar.gz: 41828ada6a07129fdfc4ed5aff6af9c6dc7d934dd0f39eec8792d77a7b2fc5769b2f3d278ced6b5def360db58a7a34c014a0876e65ca7525d35a67219285aab7
6
+ metadata.gz: 977ddba86307314f9af523643e151f397705600f2aa244980dfefbda8af7c74114dd4b17d515c64d30e412d8e6aa3de8a127ba085dd8666e03a99ac2cb1fd573
7
+ data.tar.gz: 80e1d0a963e066aa74f0e9e9e786113ea8505112990de217b6d3a7b7b9f60b06f5ee777e94fb95fb7f006fd92f7ef4b11b3aaa87f605ae66115c30d007055657
@@ -0,0 +1,29 @@
1
+ # Prepare
2
+ This example use to sinatra.
3
+
4
+ ```
5
+ $ gem install sinatra
6
+ $ gem install sinatra-contrib
7
+ ```
8
+
9
+ # Let's try
10
+ This example prepared weights that have already been trained.
11
+ If you want to try it immediately, skip steps (1) and (2).
12
+
13
+ ### (1) Training MNIST
14
+ ```
15
+ $ ruby mnist_train.rb
16
+ ```
17
+
18
+ ### (2) Make weights
19
+ ```
20
+ $ ruby make_weights.rb
21
+ ```
22
+
23
+ ### (3) Launch sinatra server
24
+ ```
25
+ $ ruby server.rb
26
+ ```
27
+
28
+ ### (4) Access 127.0.0.1:4567 with your browser
29
+ ![](capture.PNG)
@@ -0,0 +1,70 @@
1
+ require "dnn"
2
+ require "numo/linalg/autoloader"
3
+
4
+ include DNN::Models
5
+ include DNN::Layers
6
+ include DNN::Optimizers
7
+ include DNN::Losses
8
+
9
+ class ConvNet < Model
10
+ def self.create(input_shape)
11
+ convnet = ConvNet.new(input_shape, 32)
12
+ convnet.setup(Adam.new, SoftmaxCrossEntropy.new)
13
+ convnet
14
+ end
15
+
16
+ def initialize(input_shape, base_filter_size)
17
+ super()
18
+ @input_shape = input_shape
19
+ @cv1 = Conv2D.new(base_filter_size, 3, padding: true)
20
+ @cv2 = Conv2D.new(base_filter_size, 3, padding: true)
21
+ @cv3 = Conv2D.new(base_filter_size * 2, 3, padding: true)
22
+ @cv4 = Conv2D.new(base_filter_size * 2, 3, padding: true)
23
+ @cv5 = Conv2D.new(base_filter_size * 4, 3, padding: true)
24
+ @cv6 = Conv2D.new(base_filter_size * 4, 3, padding: true)
25
+ @bn1 = BatchNormalization.new
26
+ @bn2 = BatchNormalization.new
27
+ @bn3 = BatchNormalization.new
28
+ @bn4 = BatchNormalization.new
29
+ @d1 = Dense.new(512)
30
+ @d2 = Dense.new(10)
31
+ end
32
+
33
+ def forward(x)
34
+ x = InputLayer.new(@input_shape).(x)
35
+
36
+ x = @cv1.(x)
37
+ x = ReLU.(x)
38
+ x = Dropout.(x, 0.25)
39
+
40
+ x = @cv2.(x)
41
+ x = @bn1.(x)
42
+ x = ReLU.(x)
43
+ x = MaxPool2D.(x, 2)
44
+
45
+ x = @cv3.(x)
46
+ x = ReLU.(x)
47
+ x = Dropout.(x, 0.25)
48
+
49
+ x = @cv4.(x)
50
+ x = @bn2.(x)
51
+ x = ReLU.(x)
52
+ x = MaxPool2D.(x, 2)
53
+
54
+ x = @cv5.(x)
55
+ x = ReLU.(x)
56
+ x = Dropout.(x, 0.25)
57
+
58
+ x = @cv6.(x)
59
+ x = @bn3.(x)
60
+ x = ReLU.(x)
61
+ x = MaxPool2D.(x, 2)
62
+
63
+ x = Flatten.(x)
64
+ x = @d1.(x)
65
+ x = @bn4.(x)
66
+ x = ReLU.(x)
67
+ x = @d2.(x)
68
+ x
69
+ end
70
+ end
@@ -0,0 +1,5 @@
1
+ require "dnn"
2
+ require_relative "convnet8"
3
+
4
+ model = ConvNet.load("trained_mnist_epoch20.marshal")
5
+ model.save_params("trained_mnist_params.marshal")
@@ -0,0 +1,20 @@
1
+ require "dnn"
2
+ require "dnn/image"
3
+ require_relative "convnet8"
4
+
5
+ def load_model
6
+ return if $model
7
+ $model = ConvNet.create([28, 28, 1])
8
+ $model.predict1(Numo::SFloat.zeros(28, 28, 1))
9
+ $model.load_params("trained_mnist_params.marshal")
10
+ end
11
+
12
+ def mnist_predict(img, width, height)
13
+ load_model
14
+ img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
15
+ img = img[true, true, 0...DNN::Image::RGB]
16
+ img = DNN::Image.to_gray_scale(img)
17
+ x = Numo::SFloat.cast(img) / 255
18
+ out = $model.predict1(x)
19
+ out.to_a.map { |v| v.round(4) * 100 }
20
+ end
@@ -0,0 +1,19 @@
1
+ require "dnn"
2
+ require "dnn/datasets/mnist"
3
+ require_relative "convnet8"
4
+
5
+ include DNN::Callbacks
6
+
7
+ x_train, y_train = DNN::MNIST.load_train
8
+ x_test, y_test = DNN::MNIST.load_test
9
+
10
+ x_train = Numo::SFloat.cast(x_train) / 255
11
+ x_test = Numo::SFloat.cast(x_test) / 255
12
+
13
+ y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
14
+ y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
15
+
16
+ model = ConvNet.create([28, 28, 1])
17
+ model.add_callback(CheckPoint.new("trained/trained_mnist", interval: 5))
18
+
19
+ model.train(x_train, y_train, 20, batch_size: 128, test: [x_test, y_test])
@@ -0,0 +1,44 @@
1
+ class HttpRequest {
2
+ static get(path, responseCallback) {
3
+ const req = new HttpRequest(path, "GET", responseCallback);
4
+ req.send();
5
+ return req;
6
+ }
7
+
8
+ static post(path, params, responseCallback) {
9
+ const req = new HttpRequest(path, "POST", responseCallback);
10
+ req.send(params);
11
+ return req;
12
+ }
13
+
14
+ constructor(path, method, responseCallback) {
15
+ this._path = path;
16
+ this._method = method;
17
+ this._responseCallback = responseCallback;
18
+ }
19
+
20
+ send(params = null) {
21
+ const xhr = new XMLHttpRequest();
22
+ xhr.open(this._method, this._path);
23
+ let json = null;
24
+ if (params) json = JSON.stringify(params);
25
+ xhr.addEventListener("load", (e) => {
26
+ const res = {
27
+ response: xhr.response,
28
+ event: e
29
+ };
30
+ this._responseCallback(res);
31
+ });
32
+ xhr.send(json);
33
+ }
34
+ }
35
+
36
+ class Base64 {
37
+ static encode(obj) {
38
+ if (typeof(obj) === "string") {
39
+ return btoa(obj);
40
+ } else if (obj instanceof Uint8Array || obj instanceof Uint8ClampedArray) {
41
+ return btoa(String.fromCharCode(...obj));
42
+ }
43
+ }
44
+ }
@@ -0,0 +1,61 @@
1
+ const drawCanvas = document.getElementById("draw");
2
+ const viewCanvas = document.getElementById("view");
3
+
4
+ const drawContext = drawCanvas.getContext("2d");
5
+ drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
6
+ const viewContext = viewCanvas.getContext("2d");
7
+ viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
8
+
9
+ const judgeButton = document.getElementById("judge");
10
+ const clearButton = document.getElementById("clear");
11
+
12
+ const resultArea = document.getElementById("result");
13
+
14
+ const updateResult = (classification) => {
15
+ let str = "";
16
+ for(let i = 0; i <= 9; i++){
17
+ str += `${i}: ${classification[i]}%<br>`;
18
+ }
19
+ resultArea.innerHTML = str;
20
+ };
21
+
22
+ judgeButton.addEventListener("click", () =>{
23
+ viewContext.drawImage(drawCanvas, 0, 0, viewCanvas.width, viewCanvas.height);
24
+ const data = viewContext.getImageData(0, 0, viewCanvas.width, viewCanvas.height).data;
25
+ params = {
26
+ img: Base64.encode(data),
27
+ width: viewCanvas.width,
28
+ height: viewCanvas.height,
29
+ }
30
+ HttpRequest.post("/predict", params, (res) => {
31
+ updateResult(JSON.parse(res.response));
32
+ });
33
+ });
34
+
35
+ clearButton.addEventListener("click", () =>{
36
+ drawContext.fillStyle = "black";
37
+ drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
38
+ viewContext.fillStyle = "black";
39
+ viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
40
+ result.innerHTML = "";
41
+ });
42
+
43
+ let mouseDown = false;
44
+
45
+ window.addEventListener("mousedown", e =>{
46
+ mouseDown = true;
47
+ });
48
+
49
+ window.addEventListener("mouseup", e =>{
50
+ mouseDown = false;
51
+ });
52
+
53
+ drawCanvas.addEventListener("mousemove", e =>{
54
+ if(mouseDown){
55
+ let rect = e.target.getBoundingClientRect();
56
+ let x = e.clientX - 10 - rect.left;
57
+ let y = e.clientY - 10 - rect.top;
58
+ drawContext.fillStyle = "white";
59
+ drawContext.fillRect(x, y, 20, 20);
60
+ }
61
+ });
@@ -0,0 +1,19 @@
1
+ require "sinatra"
2
+ require "sinatra/reloader"
3
+ require "json"
4
+ require "base64"
5
+ require_relative "mnist_predict"
6
+
7
+ get "/" do
8
+ erb :index
9
+ end
10
+
11
+ post "/predict" do
12
+ json = request.body.read
13
+ params = JSON.parse(json, symbolize_names: true)
14
+ img = Base64.decode64(params[:img])
15
+ width = params[:width].to_i
16
+ height = params[:height].to_i
17
+ result = mnist_predict(img, width, height)
18
+ JSON.dump(result)
19
+ end
@@ -0,0 +1,7 @@
1
+ <canvas id="draw" width=256 height=256></canvas>
2
+ <canvas id="view" width=28 height=28></canvas>
3
+ <button id="judge">Judge</button>
4
+ <button id="clear">Clear</button>
5
+ <p id="result"></p>
6
+ <script src="judgeNumber.js"></script>
7
+ <script src="httpRequest.js"></script>
@@ -42,12 +42,6 @@ module DNN
42
42
  loss
43
43
  end
44
44
 
45
- def regularizers_backward(layers)
46
- layers.select { |layer| layer.respond_to?(:regularizers) }.each do |layer|
47
- layer.regularizers.each(&:backward)
48
- end
49
- end
50
-
51
45
  def to_hash(merge_hash = nil)
52
46
  hash = { class: self.class.name }
53
47
  hash.merge!(merge_hash) if merge_hash
@@ -474,6 +474,15 @@ module DNN
474
474
  @callbacks << callback
475
475
  end
476
476
 
477
+ # Add lambda callback.
478
+ # @param [Symbol] event Event to execute callback.
479
+ # @yield Register the contents of the callback.
480
+ def add_lambda_callback(event, &block)
481
+ callback = Callbacks::LambdaCallback.new(event, &block)
482
+ callback.model = self
483
+ @callbacks << callback
484
+ end
485
+
477
486
  # Clear the callback function registered for each event.
478
487
  def clear_callbacks
479
488
  @callbacks = []
@@ -64,6 +64,19 @@ module DNN
64
64
  raise ImageWriteError, "Image write failed." if res == 0
65
65
  end
66
66
 
67
+ # Create an image from binary.
68
+ # @param [String] bin binary data.
69
+ # @param [Integer] height Image height.
70
+ # @param [Integer] width Image width.
71
+ # @param [Integer] channel Image channel.
72
+ def self.from_binary(bin, height, width, channel = DNN::Image::RGB)
73
+ expected_size = height * width * channel
74
+ unless bin.size == expected_size
75
+ raise ImageError, "binary size is #{bin.size}, but expected binary size is #{expected_size}"
76
+ end
77
+ Numo::UInt8.from_binary(bin).reshape(height, width, channel)
78
+ end
79
+
67
80
  # Resize the image.
68
81
  # @param [Numo::UInt8] img Image to resize.
69
82
  # @param [Integer] out_height Image height to resize.
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "1.1.5"
2
+ VERSION = "1.1.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.5
4
+ version: 1.1.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-03-20 00:00:00.000000000 Z
11
+ date: 2020-05-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -121,6 +121,17 @@ files:
121
121
  - examples/dcgan/imgen.rb
122
122
  - examples/dcgan/train.rb
123
123
  - examples/iris_example.rb
124
+ - examples/judge-number/README.md
125
+ - examples/judge-number/capture.PNG
126
+ - examples/judge-number/convnet8.rb
127
+ - examples/judge-number/make_weights.rb
128
+ - examples/judge-number/mnist_predict.rb
129
+ - examples/judge-number/mnist_train.rb
130
+ - examples/judge-number/public/httpRequest.js
131
+ - examples/judge-number/public/judgeNumber.js
132
+ - examples/judge-number/server.rb
133
+ - examples/judge-number/trained_mnist_params.marshal
134
+ - examples/judge-number/views/index.erb
124
135
  - examples/mnist_conv2d_example.rb
125
136
  - examples/mnist_define_by_run.rb
126
137
  - examples/mnist_example.rb