ruby-dnn 1.1.3 → 1.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/ext/rb_stb_image/rb_stb_image.c +5 -0
- data/lib/dnn/core/layers/cnn_layers.rb +1 -1
- data/lib/dnn/core/layers/math_layers.rb +33 -17
- data/lib/dnn/keras-model-convertor.rb +2 -6
- data/lib/dnn/version.rb +1 -1
- data/ruby-dnn.gemspec +1 -1
- metadata +6 -6
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d33f1d5472229d184630d37063d4c42888230044a9e1b069035a144e8aae0964
|
4
|
+
data.tar.gz: 530d31b5fc5073fa02253eb0f7d7b78a007ccaed16ea4a49aba56a87174afb59
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: acf104aec7d661da52623930dbd8a60e105aff6b0c7fd4bdc675f507a50f55372451d7c4066992d13a390629d89ec7ba42d7c316376797e439e6b399d702e5aa
|
7
|
+
data.tar.gz: 1e338be765295d9c9827dc82c70e7af79473386be1574ae647ef2568c24be9f9142d94104b62726bda800aff064c37ede794066cfb098a40744dd24a2e7addb4
|
@@ -5,6 +5,11 @@
|
|
5
5
|
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
6
6
|
#define STB_IMAGE_RESIZE_IMPLEMENTATION
|
7
7
|
|
8
|
+
#if defined(_WIN32) || defined(_WIN64)
|
9
|
+
#define STBI_WINDOWS_UTF8
|
10
|
+
#define STBIW_WINDOWS_UTF8
|
11
|
+
#endif
|
12
|
+
|
8
13
|
#include "../../third_party/stb_image.h"
|
9
14
|
#include "../../third_party/stb_image_write.h"
|
10
15
|
#include "../../third_party/stb_image_resize.h"
|
@@ -1,7 +1,9 @@
|
|
1
1
|
module DNN
|
2
2
|
module Layers
|
3
3
|
module MathUtils
|
4
|
-
|
4
|
+
module_function
|
5
|
+
|
6
|
+
def align_ndim(shape1, shape2)
|
5
7
|
if shape1.length < shape2.length
|
6
8
|
shape2.length.times do |axis|
|
7
9
|
unless shape1[axis] == shape2[axis]
|
@@ -18,7 +20,7 @@ module DNN
|
|
18
20
|
[shape1, shape2]
|
19
21
|
end
|
20
22
|
|
21
|
-
def
|
23
|
+
def broadcast_to(x, target_shape)
|
22
24
|
return x if x.shape == target_shape
|
23
25
|
x_shape, target_shape = align_ndim(x.shape, target_shape)
|
24
26
|
x = x.reshape(*x_shape)
|
@@ -33,7 +35,7 @@ module DNN
|
|
33
35
|
x
|
34
36
|
end
|
35
37
|
|
36
|
-
def
|
38
|
+
def sum_to(x, target_shape)
|
37
39
|
return x if x.shape == target_shape
|
38
40
|
x_shape, target_shape = align_ndim(x.shape, target_shape)
|
39
41
|
x = x.reshape(*x_shape)
|
@@ -192,9 +194,13 @@ module DNN
|
|
192
194
|
class Sum < Layer
|
193
195
|
include LayerNode
|
194
196
|
|
195
|
-
|
197
|
+
attr_reader :axis
|
198
|
+
attr_reader :keepdims
|
199
|
+
|
200
|
+
def initialize(axis: 0, keepdims: true)
|
196
201
|
super()
|
197
202
|
@axis = axis
|
203
|
+
@keepdims = keepdims
|
198
204
|
end
|
199
205
|
|
200
206
|
def forward_node(x)
|
@@ -204,21 +210,28 @@ module DNN
|
|
204
210
|
end
|
205
211
|
|
206
212
|
def backward_node(dy)
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
+
MathUtils.broadcast_to(dy, @x_shape)
|
214
|
+
end
|
215
|
+
|
216
|
+
def to_hash
|
217
|
+
super(axis: @axis, keepdims: @keepdims)
|
218
|
+
end
|
219
|
+
|
220
|
+
def load_hash(hash)
|
221
|
+
initialize(axis: hash[:axis], keepdims: hash[:keepdims])
|
213
222
|
end
|
214
223
|
end
|
215
224
|
|
216
225
|
class Mean < Layer
|
217
226
|
include LayerNode
|
218
227
|
|
219
|
-
|
228
|
+
attr_reader :axis
|
229
|
+
attr_reader :keepdims
|
230
|
+
|
231
|
+
def initialize(axis: 0, keepdims: true)
|
220
232
|
super()
|
221
233
|
@axis = axis
|
234
|
+
@keepdims = keepdims
|
222
235
|
end
|
223
236
|
|
224
237
|
def forward_node(x)
|
@@ -228,12 +241,15 @@ module DNN
|
|
228
241
|
end
|
229
242
|
|
230
243
|
def backward_node(dy)
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
244
|
+
MathUtils.broadcast_to(dy, @x_shape) / @dim
|
245
|
+
end
|
246
|
+
|
247
|
+
def to_hash
|
248
|
+
super(axis: @axis, keepdims: @keepdims)
|
249
|
+
end
|
250
|
+
|
251
|
+
def load_hash(hash)
|
252
|
+
initialize(axis: hash[:axis], keepdims: hash[:keepdims])
|
237
253
|
end
|
238
254
|
end
|
239
255
|
|
@@ -9,11 +9,7 @@ require_relative "numo2numpy"
|
|
9
9
|
|
10
10
|
include PyCall::Import
|
11
11
|
|
12
|
-
pyimport :numpy, as: :np
|
13
12
|
pyimport :keras
|
14
|
-
pyfrom :"keras.models", import: :Sequential
|
15
|
-
pyfrom :"keras.layers", import: [:Dense, :Dropout, :Conv2D, :Activation, :MaxPooling2D, :Flatten]
|
16
|
-
pyfrom :"keras.layers.normalization", import: :BatchNormalization
|
17
13
|
|
18
14
|
module DNN
|
19
15
|
module Layers
|
@@ -185,9 +181,9 @@ class KerasModelConvertor
|
|
185
181
|
conv2d_t.filters = Numpy.to_na(k_conv2d_t.get_weights[0])
|
186
182
|
conv2d_t.bias.data = Numpy.to_na(k_conv2d_t.get_weights[1])
|
187
183
|
returns = [conv2d_t]
|
188
|
-
unless
|
184
|
+
unless k_conv2d_t.get_config[:activation] == "linear"
|
189
185
|
input_shape, output_shape = get_k_layer_shape(k_conv2d)
|
190
|
-
returns << activation_to_dnn_layer(
|
186
|
+
returns << activation_to_dnn_layer(k_conv2d_t.get_config[:activation], output_shape)
|
191
187
|
end
|
192
188
|
returns
|
193
189
|
end
|
data/lib/dnn/version.rb
CHANGED
data/ruby-dnn.gemspec
CHANGED
@@ -40,6 +40,6 @@ Gem::Specification.new do |spec|
|
|
40
40
|
spec.require_paths = ["lib"]
|
41
41
|
|
42
42
|
spec.add_development_dependency "bundler", "~> 1.16"
|
43
|
-
spec.add_development_dependency "rake", "
|
43
|
+
spec.add_development_dependency "rake", ">= 12.3.3"
|
44
44
|
spec.add_development_dependency "minitest", "~> 5.0"
|
45
45
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.1.
|
4
|
+
version: 1.1.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-02-
|
11
|
+
date: 2020-02-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -70,16 +70,16 @@ dependencies:
|
|
70
70
|
name: rake
|
71
71
|
requirement: !ruby/object:Gem::Requirement
|
72
72
|
requirements:
|
73
|
-
- - "
|
73
|
+
- - ">="
|
74
74
|
- !ruby/object:Gem::Version
|
75
|
-
version:
|
75
|
+
version: 12.3.3
|
76
76
|
type: :development
|
77
77
|
prerelease: false
|
78
78
|
version_requirements: !ruby/object:Gem::Requirement
|
79
79
|
requirements:
|
80
|
-
- - "
|
80
|
+
- - ">="
|
81
81
|
- !ruby/object:Gem::Version
|
82
|
-
version:
|
82
|
+
version: 12.3.3
|
83
83
|
- !ruby/object:Gem::Dependency
|
84
84
|
name: minitest
|
85
85
|
requirement: !ruby/object:Gem::Requirement
|