ruby-dnn 1.1.5 → 1.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 55dea04f1d2a6bb4806c3f029086474b46b8a225f7168e4c72af92e0f7d69f71
4
- data.tar.gz: 2522778ffabbce31315b48ad3abfedd2dea2e9dba0d09af4e9700cb9c588393b
3
+ metadata.gz: 6b3228eaf257e50c55bdd51b348e71dcc45d74a7cc14668231c4a5e1c9fed318
4
+ data.tar.gz: ae3a1217cb1aa0a3d0f50ad2b042fd255385a713c7787ad318f34410170dbb83
5
5
  SHA512:
6
- metadata.gz: 05056e7619f52dd8efac34c6aeae1e1652e3635257baf302f15908131d6908d7ee49e892040514b0ea3011c46148eb3c0fedd3857b85ab0cf271c521bb907349
7
- data.tar.gz: 41828ada6a07129fdfc4ed5aff6af9c6dc7d934dd0f39eec8792d77a7b2fc5769b2f3d278ced6b5def360db58a7a34c014a0876e65ca7525d35a67219285aab7
6
+ metadata.gz: 977ddba86307314f9af523643e151f397705600f2aa244980dfefbda8af7c74114dd4b17d515c64d30e412d8e6aa3de8a127ba085dd8666e03a99ac2cb1fd573
7
+ data.tar.gz: 80e1d0a963e066aa74f0e9e9e786113ea8505112990de217b6d3a7b7b9f60b06f5ee777e94fb95fb7f006fd92f7ef4b11b3aaa87f605ae66115c30d007055657
@@ -0,0 +1,29 @@
1
+ # Prepare
2
+ This example use to sinatra.
3
+
4
+ ```
5
+ $ gem install sinatra
6
+ $ gem install sinatra-contrib
7
+ ```
8
+
9
+ # Let's try
10
+ This example prepared weights that have already been trained.
11
+ If you want to try it immediately, skip steps (1) and (2).
12
+
13
+ ### (1) Training MNIST
14
+ ```
15
+ $ ruby mnist_train.rb
16
+ ```
17
+
18
+ ### (2) Make weights
19
+ ```
20
+ $ ruby make_weights.rb
21
+ ```
22
+
23
+ ### (3) Launch sinatra server
24
+ ```
25
+ $ ruby server.rb
26
+ ```
27
+
28
+ ### (4) Access 127.0.0.1:4567 with your browser
29
+ ![](capture.PNG)
@@ -0,0 +1,70 @@
1
+ require "dnn"
2
+ require "numo/linalg/autoloader"
3
+
4
+ include DNN::Models
5
+ include DNN::Layers
6
+ include DNN::Optimizers
7
+ include DNN::Losses
8
+
9
+ class ConvNet < Model
10
+ def self.create(input_shape)
11
+ convnet = ConvNet.new(input_shape, 32)
12
+ convnet.setup(Adam.new, SoftmaxCrossEntropy.new)
13
+ convnet
14
+ end
15
+
16
+ def initialize(input_shape, base_filter_size)
17
+ super()
18
+ @input_shape = input_shape
19
+ @cv1 = Conv2D.new(base_filter_size, 3, padding: true)
20
+ @cv2 = Conv2D.new(base_filter_size, 3, padding: true)
21
+ @cv3 = Conv2D.new(base_filter_size * 2, 3, padding: true)
22
+ @cv4 = Conv2D.new(base_filter_size * 2, 3, padding: true)
23
+ @cv5 = Conv2D.new(base_filter_size * 4, 3, padding: true)
24
+ @cv6 = Conv2D.new(base_filter_size * 4, 3, padding: true)
25
+ @bn1 = BatchNormalization.new
26
+ @bn2 = BatchNormalization.new
27
+ @bn3 = BatchNormalization.new
28
+ @bn4 = BatchNormalization.new
29
+ @d1 = Dense.new(512)
30
+ @d2 = Dense.new(10)
31
+ end
32
+
33
+ def forward(x)
34
+ x = InputLayer.new(@input_shape).(x)
35
+
36
+ x = @cv1.(x)
37
+ x = ReLU.(x)
38
+ x = Dropout.(x, 0.25)
39
+
40
+ x = @cv2.(x)
41
+ x = @bn1.(x)
42
+ x = ReLU.(x)
43
+ x = MaxPool2D.(x, 2)
44
+
45
+ x = @cv3.(x)
46
+ x = ReLU.(x)
47
+ x = Dropout.(x, 0.25)
48
+
49
+ x = @cv4.(x)
50
+ x = @bn2.(x)
51
+ x = ReLU.(x)
52
+ x = MaxPool2D.(x, 2)
53
+
54
+ x = @cv5.(x)
55
+ x = ReLU.(x)
56
+ x = Dropout.(x, 0.25)
57
+
58
+ x = @cv6.(x)
59
+ x = @bn3.(x)
60
+ x = ReLU.(x)
61
+ x = MaxPool2D.(x, 2)
62
+
63
+ x = Flatten.(x)
64
+ x = @d1.(x)
65
+ x = @bn4.(x)
66
+ x = ReLU.(x)
67
+ x = @d2.(x)
68
+ x
69
+ end
70
+ end
@@ -0,0 +1,5 @@
1
+ require "dnn"
2
+ require_relative "convnet8"
3
+
4
+ model = ConvNet.load("trained_mnist_epoch20.marshal")
5
+ model.save_params("trained_mnist_params.marshal")
@@ -0,0 +1,20 @@
1
+ require "dnn"
2
+ require "dnn/image"
3
+ require_relative "convnet8"
4
+
5
+ def load_model
6
+ return if $model
7
+ $model = ConvNet.create([28, 28, 1])
8
+ $model.predict1(Numo::SFloat.zeros(28, 28, 1))
9
+ $model.load_params("trained_mnist_params.marshal")
10
+ end
11
+
12
+ def mnist_predict(img, width, height)
13
+ load_model
14
+ img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
15
+ img = img[true, true, 0...DNN::Image::RGB]
16
+ img = DNN::Image.to_gray_scale(img)
17
+ x = Numo::SFloat.cast(img) / 255
18
+ out = $model.predict1(x)
19
+ out.to_a.map { |v| v.round(4) * 100 }
20
+ end
@@ -0,0 +1,19 @@
1
+ require "dnn"
2
+ require "dnn/datasets/mnist"
3
+ require_relative "convnet8"
4
+
5
+ include DNN::Callbacks
6
+
7
+ x_train, y_train = DNN::MNIST.load_train
8
+ x_test, y_test = DNN::MNIST.load_test
9
+
10
+ x_train = Numo::SFloat.cast(x_train) / 255
11
+ x_test = Numo::SFloat.cast(x_test) / 255
12
+
13
+ y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat)
14
+ y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat)
15
+
16
+ model = ConvNet.create([28, 28, 1])
17
+ model.add_callback(CheckPoint.new("trained/trained_mnist", interval: 5))
18
+
19
+ model.train(x_train, y_train, 20, batch_size: 128, test: [x_test, y_test])
@@ -0,0 +1,44 @@
1
+ class HttpRequest {
2
+ static get(path, responseCallback) {
3
+ const req = new HttpRequest(path, "GET", responseCallback);
4
+ req.send();
5
+ return req;
6
+ }
7
+
8
+ static post(path, params, responseCallback) {
9
+ const req = new HttpRequest(path, "POST", responseCallback);
10
+ req.send(params);
11
+ return req;
12
+ }
13
+
14
+ constructor(path, method, responseCallback) {
15
+ this._path = path;
16
+ this._method = method;
17
+ this._responseCallback = responseCallback;
18
+ }
19
+
20
+ send(params = null) {
21
+ const xhr = new XMLHttpRequest();
22
+ xhr.open(this._method, this._path);
23
+ let json = null;
24
+ if (params) json = JSON.stringify(params);
25
+ xhr.addEventListener("load", (e) => {
26
+ const res = {
27
+ response: xhr.response,
28
+ event: e
29
+ };
30
+ this._responseCallback(res);
31
+ });
32
+ xhr.send(json);
33
+ }
34
+ }
35
+
36
+ class Base64 {
37
+ static encode(obj) {
38
+ if (typeof(obj) === "string") {
39
+ return btoa(obj);
40
+ } else if (obj instanceof Uint8Array || obj instanceof Uint8ClampedArray) {
41
+ return btoa(String.fromCharCode(...obj));
42
+ }
43
+ }
44
+ }
@@ -0,0 +1,61 @@
1
+ const drawCanvas = document.getElementById("draw");
2
+ const viewCanvas = document.getElementById("view");
3
+
4
+ const drawContext = drawCanvas.getContext("2d");
5
+ drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
6
+ const viewContext = viewCanvas.getContext("2d");
7
+ viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
8
+
9
+ const judgeButton = document.getElementById("judge");
10
+ const clearButton = document.getElementById("clear");
11
+
12
+ const resultArea = document.getElementById("result");
13
+
14
+ const updateResult = (classification) => {
15
+ let str = "";
16
+ for(let i = 0; i <= 9; i++){
17
+ str += `${i}: ${classification[i]}%<br>`;
18
+ }
19
+ resultArea.innerHTML = str;
20
+ };
21
+
22
+ judgeButton.addEventListener("click", () =>{
23
+ viewContext.drawImage(drawCanvas, 0, 0, viewCanvas.width, viewCanvas.height);
24
+ const data = viewContext.getImageData(0, 0, viewCanvas.width, viewCanvas.height).data;
25
+ params = {
26
+ img: Base64.encode(data),
27
+ width: viewCanvas.width,
28
+ height: viewCanvas.height,
29
+ }
30
+ HttpRequest.post("/predict", params, (res) => {
31
+ updateResult(JSON.parse(res.response));
32
+ });
33
+ });
34
+
35
+ clearButton.addEventListener("click", () =>{
36
+ drawContext.fillStyle = "black";
37
+ drawContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
38
+ viewContext.fillStyle = "black";
39
+ viewContext.fillRect(0, 0, drawCanvas.width, drawCanvas.height);
40
+ result.innerHTML = "";
41
+ });
42
+
43
+ let mouseDown = false;
44
+
45
+ window.addEventListener("mousedown", e =>{
46
+ mouseDown = true;
47
+ });
48
+
49
+ window.addEventListener("mouseup", e =>{
50
+ mouseDown = false;
51
+ });
52
+
53
+ drawCanvas.addEventListener("mousemove", e =>{
54
+ if(mouseDown){
55
+ let rect = e.target.getBoundingClientRect();
56
+ let x = e.clientX - 10 - rect.left;
57
+ let y = e.clientY - 10 - rect.top;
58
+ drawContext.fillStyle = "white";
59
+ drawContext.fillRect(x, y, 20, 20);
60
+ }
61
+ });
@@ -0,0 +1,19 @@
1
+ require "sinatra"
2
+ require "sinatra/reloader"
3
+ require "json"
4
+ require "base64"
5
+ require_relative "mnist_predict"
6
+
7
+ get "/" do
8
+ erb :index
9
+ end
10
+
11
+ post "/predict" do
12
+ json = request.body.read
13
+ params = JSON.parse(json, symbolize_names: true)
14
+ img = Base64.decode64(params[:img])
15
+ width = params[:width].to_i
16
+ height = params[:height].to_i
17
+ result = mnist_predict(img, width, height)
18
+ JSON.dump(result)
19
+ end
@@ -0,0 +1,7 @@
1
+ <canvas id="draw" width=256 height=256></canvas>
2
+ <canvas id="view" width=28 height=28></canvas>
3
+ <button id="judge">Judge</button>
4
+ <button id="clear">Clear</button>
5
+ <p id="result"></p>
6
+ <script src="judgeNumber.js"></script>
7
+ <script src="httpRequest.js"></script>
@@ -42,12 +42,6 @@ module DNN
42
42
  loss
43
43
  end
44
44
 
45
- def regularizers_backward(layers)
46
- layers.select { |layer| layer.respond_to?(:regularizers) }.each do |layer|
47
- layer.regularizers.each(&:backward)
48
- end
49
- end
50
-
51
45
  def to_hash(merge_hash = nil)
52
46
  hash = { class: self.class.name }
53
47
  hash.merge!(merge_hash) if merge_hash
@@ -474,6 +474,15 @@ module DNN
474
474
  @callbacks << callback
475
475
  end
476
476
 
477
+ # Add lambda callback.
478
+ # @param [Symbol] event Event to execute callback.
479
+ # @yield Register the contents of the callback.
480
+ def add_lambda_callback(event, &block)
481
+ callback = Callbacks::LambdaCallback.new(event, &block)
482
+ callback.model = self
483
+ @callbacks << callback
484
+ end
485
+
477
486
  # Clear the callback function registered for each event.
478
487
  def clear_callbacks
479
488
  @callbacks = []
@@ -64,6 +64,19 @@ module DNN
64
64
  raise ImageWriteError, "Image write failed." if res == 0
65
65
  end
66
66
 
67
+ # Create an image from binary.
68
+ # @param [String] bin binary data.
69
+ # @param [Integer] height Image height.
70
+ # @param [Integer] width Image width.
71
+ # @param [Integer] channel Image channel.
72
+ def self.from_binary(bin, height, width, channel = DNN::Image::RGB)
73
+ expected_size = height * width * channel
74
+ unless bin.size == expected_size
75
+ raise ImageError, "binary size is #{bin.size}, but expected binary size is #{expected_size}"
76
+ end
77
+ Numo::UInt8.from_binary(bin).reshape(height, width, channel)
78
+ end
79
+
67
80
  # Resize the image.
68
81
  # @param [Numo::UInt8] img Image to resize.
69
82
  # @param [Integer] out_height Image height to resize.
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "1.1.5"
2
+ VERSION = "1.1.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.5
4
+ version: 1.1.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-03-20 00:00:00.000000000 Z
11
+ date: 2020-05-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -121,6 +121,17 @@ files:
121
121
  - examples/dcgan/imgen.rb
122
122
  - examples/dcgan/train.rb
123
123
  - examples/iris_example.rb
124
+ - examples/judge-number/README.md
125
+ - examples/judge-number/capture.PNG
126
+ - examples/judge-number/convnet8.rb
127
+ - examples/judge-number/make_weights.rb
128
+ - examples/judge-number/mnist_predict.rb
129
+ - examples/judge-number/mnist_train.rb
130
+ - examples/judge-number/public/httpRequest.js
131
+ - examples/judge-number/public/judgeNumber.js
132
+ - examples/judge-number/server.rb
133
+ - examples/judge-number/trained_mnist_params.marshal
134
+ - examples/judge-number/views/index.erb
124
135
  - examples/mnist_conv2d_example.rb
125
136
  - examples/mnist_define_by_run.rb
126
137
  - examples/mnist_example.rb