ruby-dnn 1.1.4 → 1.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/.travis.yml +2 -1
- data/README.md +39 -22
- data/examples/api-examples/early_stopping_example.rb +6 -6
- data/examples/api-examples/initializer_example.rb +6 -6
- data/examples/api-examples/regularizer_example.rb +6 -6
- data/examples/api-examples/save_example.rb +6 -6
- data/examples/dcgan/dcgan.rb +27 -27
- data/examples/judge-number/README.md +29 -0
- data/examples/judge-number/capture.PNG +0 -0
- data/examples/judge-number/convnet8.rb +70 -0
- data/examples/judge-number/make_weights.rb +5 -0
- data/examples/judge-number/mnist_predict.rb +20 -0
- data/examples/judge-number/mnist_train.rb +19 -0
- data/examples/judge-number/public/httpRequest.js +44 -0
- data/examples/judge-number/public/judgeNumber.js +61 -0
- data/examples/judge-number/server.rb +19 -0
- data/examples/judge-number/trained_mnist_params.marshal +0 -0
- data/examples/judge-number/views/index.erb +7 -0
- data/examples/mnist_conv2d_example.rb +3 -3
- data/examples/mnist_define_by_run.rb +7 -7
- data/examples/mnist_gpu.rb +47 -0
- data/examples/mnist_lstm_example.rb +1 -1
- data/examples/pix2pix/dcgan.rb +54 -66
- data/examples/pix2pix/train.rb +2 -2
- data/examples/vae.rb +13 -13
- data/img/cart-pole.gif +0 -0
- data/img/cycle-gan.PNG +0 -0
- data/img/facade-pix2pix.png +0 -0
- data/lib/dnn.rb +24 -3
- data/lib/dnn/core/callbacks.rb +6 -4
- data/lib/dnn/core/layers/basic_layers.rb +40 -22
- data/lib/dnn/core/layers/cnn_layers.rb +33 -5
- data/lib/dnn/core/layers/math_layers.rb +17 -9
- data/lib/dnn/core/layers/merge_layers.rb +2 -26
- data/lib/dnn/core/layers/split_layers.rb +39 -0
- data/lib/dnn/core/link.rb +14 -33
- data/lib/dnn/core/losses.rb +6 -12
- data/lib/dnn/core/models.rb +77 -10
- data/lib/dnn/core/optimizers.rb +8 -1
- data/lib/dnn/core/utils.rb +23 -0
- data/lib/dnn/image.rb +48 -0
- data/lib/dnn/version.rb +1 -1
- data/ruby-dnn.gemspec +2 -15
- metadata +40 -20
- data/bin/console +0 -14
- data/bin/setup +0 -8
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
module DNN
|
|
2
|
+
module Layers
|
|
3
|
+
|
|
4
|
+
class Split < Layer
|
|
5
|
+
include LayerNode
|
|
6
|
+
|
|
7
|
+
attr_reader :axis
|
|
8
|
+
attr_reader :dim
|
|
9
|
+
|
|
10
|
+
def initialize(axis: 1, dim: nil)
|
|
11
|
+
super()
|
|
12
|
+
raise DNNError, "dim is nil" if dim == nil
|
|
13
|
+
@axis = axis
|
|
14
|
+
@dim = dim
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def forward_node(x)
|
|
18
|
+
x1_dim = @dim
|
|
19
|
+
x2_dim = x.shape[@axis] - @dim
|
|
20
|
+
y1, y2others = x.split([x1_dim, x1_dim + x2_dim], axis: @axis)
|
|
21
|
+
y2 = y2others.is_a?(Array) ? y2others[0].concatenate(y2others[1..-1], axis: @axis) : y2others
|
|
22
|
+
[y1, y2]
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def backward_node(dy1, dy2)
|
|
26
|
+
dy1.concatenate(dy2, axis: @axis)
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
def to_hash
|
|
30
|
+
super(axis: @axis, dim: @dim)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def load_hash(hash)
|
|
34
|
+
initialize(axis: hash[:axis], dim: hash[:dim])
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
end
|
|
39
|
+
end
|
data/lib/dnn/core/link.rb
CHANGED
|
@@ -1,57 +1,38 @@
|
|
|
1
1
|
module DNN
|
|
2
2
|
class Link
|
|
3
|
-
attr_accessor :
|
|
3
|
+
attr_accessor :prevs
|
|
4
4
|
attr_accessor :next
|
|
5
5
|
attr_accessor :layer_node
|
|
6
|
+
attr_reader :num_outputs
|
|
6
7
|
|
|
7
|
-
def initialize(
|
|
8
|
-
@
|
|
9
|
-
@layer_node = layer_node
|
|
10
|
-
@next = nil
|
|
11
|
-
end
|
|
12
|
-
|
|
13
|
-
def forward(x)
|
|
14
|
-
x = @layer_node.(x)
|
|
15
|
-
@next ? @next.forward(x) : x
|
|
16
|
-
end
|
|
17
|
-
|
|
18
|
-
def backward(dy = Xumo::SFloat[1])
|
|
19
|
-
dy = @layer_node.backward_node(dy)
|
|
20
|
-
@prev&.backward(dy)
|
|
21
|
-
end
|
|
22
|
-
end
|
|
23
|
-
|
|
24
|
-
class TwoInputLink
|
|
25
|
-
attr_accessor :prev1
|
|
26
|
-
attr_accessor :prev2
|
|
27
|
-
attr_accessor :next
|
|
28
|
-
attr_accessor :layer_node
|
|
29
|
-
|
|
30
|
-
def initialize(prev1 = nil, prev2 = nil, layer_node = nil)
|
|
31
|
-
@prev1 = prev1
|
|
32
|
-
@prev2 = prev2
|
|
8
|
+
def initialize(prevs: nil, layer_node: nil, num_outputs: 1)
|
|
9
|
+
@prevs = prevs
|
|
33
10
|
@layer_node = layer_node
|
|
11
|
+
@num_outputs = num_outputs
|
|
34
12
|
@next = nil
|
|
35
13
|
@hold = []
|
|
36
14
|
end
|
|
37
15
|
|
|
38
16
|
def forward(x)
|
|
39
17
|
@hold << x
|
|
40
|
-
return if @hold.length <
|
|
18
|
+
return if @hold.length < @prevs.length
|
|
41
19
|
x = @layer_node.(*@hold)
|
|
42
20
|
@hold = []
|
|
43
21
|
@next ? @next.forward(x) : x
|
|
44
22
|
end
|
|
45
23
|
|
|
46
24
|
def backward(dy = Xumo::SFloat[1])
|
|
47
|
-
|
|
25
|
+
@hold << dy
|
|
26
|
+
return if @hold.length < @num_outputs
|
|
27
|
+
dys = @layer_node.backward_node(*@hold)
|
|
28
|
+
@hold = []
|
|
48
29
|
if dys.is_a?(Array)
|
|
49
|
-
|
|
30
|
+
dys.each.with_index do |dy, i|
|
|
31
|
+
@prevs[i]&.backward(dy)
|
|
32
|
+
end
|
|
50
33
|
else
|
|
51
|
-
|
|
34
|
+
@prevs.first&.backward(dys)
|
|
52
35
|
end
|
|
53
|
-
@prev1&.backward(dy1)
|
|
54
|
-
@prev2&.backward(dy2) if dy2
|
|
55
36
|
end
|
|
56
37
|
end
|
|
57
38
|
end
|
data/lib/dnn/core/losses.rb
CHANGED
|
@@ -42,12 +42,6 @@ module DNN
|
|
|
42
42
|
loss
|
|
43
43
|
end
|
|
44
44
|
|
|
45
|
-
def regularizers_backward(layers)
|
|
46
|
-
layers.select { |layer| layer.respond_to?(:regularizers) }.each do |layer|
|
|
47
|
-
layer.regularizers.each(&:backward)
|
|
48
|
-
end
|
|
49
|
-
end
|
|
50
|
-
|
|
51
45
|
def to_hash(merge_hash = nil)
|
|
52
46
|
hash = { class: self.class.name }
|
|
53
47
|
hash.merge!(merge_hash) if merge_hash
|
|
@@ -68,7 +62,7 @@ module DNN
|
|
|
68
62
|
end
|
|
69
63
|
|
|
70
64
|
class MeanSquaredError < Loss
|
|
71
|
-
include Layers::
|
|
65
|
+
include Layers::LayerNode
|
|
72
66
|
|
|
73
67
|
def forward_node(y, t)
|
|
74
68
|
@y = y
|
|
@@ -82,7 +76,7 @@ module DNN
|
|
|
82
76
|
end
|
|
83
77
|
|
|
84
78
|
class MeanAbsoluteError < Loss
|
|
85
|
-
include Layers::
|
|
79
|
+
include Layers::LayerNode
|
|
86
80
|
|
|
87
81
|
def forward_node(y, t)
|
|
88
82
|
@y = y
|
|
@@ -99,7 +93,7 @@ module DNN
|
|
|
99
93
|
end
|
|
100
94
|
|
|
101
95
|
class Hinge < Loss
|
|
102
|
-
include Layers::
|
|
96
|
+
include Layers::LayerNode
|
|
103
97
|
|
|
104
98
|
def forward_node(y, t)
|
|
105
99
|
@t = t
|
|
@@ -115,7 +109,7 @@ module DNN
|
|
|
115
109
|
end
|
|
116
110
|
|
|
117
111
|
class HuberLoss < Loss
|
|
118
|
-
include Layers::
|
|
112
|
+
include Layers::LayerNode
|
|
119
113
|
|
|
120
114
|
def forward_node(y, t)
|
|
121
115
|
@y = y
|
|
@@ -135,7 +129,7 @@ module DNN
|
|
|
135
129
|
end
|
|
136
130
|
|
|
137
131
|
class SoftmaxCrossEntropy < Loss
|
|
138
|
-
include Layers::
|
|
132
|
+
include Layers::LayerNode
|
|
139
133
|
|
|
140
134
|
attr_accessor :eps
|
|
141
135
|
|
|
@@ -172,7 +166,7 @@ module DNN
|
|
|
172
166
|
end
|
|
173
167
|
|
|
174
168
|
class SigmoidCrossEntropy < Loss
|
|
175
|
-
include Layers::
|
|
169
|
+
include Layers::LayerNode
|
|
176
170
|
|
|
177
171
|
attr_accessor :eps
|
|
178
172
|
|
data/lib/dnn/core/models.rb
CHANGED
|
@@ -230,6 +230,7 @@ module DNN
|
|
|
230
230
|
puts "【 epoch #{epoch}/#{epochs} 】" if verbose
|
|
231
231
|
|
|
232
232
|
train_iterator.foreach(batch_size) do |x_batch, y_batch, index|
|
|
233
|
+
@last_log[:step] = index
|
|
233
234
|
train_step_met = train_step(x_batch, y_batch)
|
|
234
235
|
num_trained_datas = (index + 1) * batch_size
|
|
235
236
|
num_trained_datas = num_trained_datas > num_train_datas ? num_train_datas : num_trained_datas
|
|
@@ -305,13 +306,13 @@ module DNN
|
|
|
305
306
|
loss_opt[:layers] = layers if i == 0
|
|
306
307
|
loss_opt[:loss_weight] = @loss_weights[i] if @loss_weights
|
|
307
308
|
loss = @loss_func[i].loss(out, Tensor.convert(y[i]), **loss_opt)
|
|
308
|
-
loss_data << loss.data
|
|
309
|
+
loss_data << Utils.to_f(loss.data)
|
|
309
310
|
loss.link.backward(Xumo::SFloat.ones(y[i][0...1, false].shape[0], 1))
|
|
310
311
|
end
|
|
311
312
|
else
|
|
312
313
|
out = output_tensors
|
|
313
314
|
loss = @loss_func.loss(out, Tensor.convert(y), layers: layers)
|
|
314
|
-
loss_data = loss.data
|
|
315
|
+
loss_data = Utils.to_f(loss.data)
|
|
315
316
|
loss.link.backward(Xumo::SFloat.ones(y[0...1, false].shape[0], 1))
|
|
316
317
|
end
|
|
317
318
|
@optimizer.update(get_all_trainable_params)
|
|
@@ -392,13 +393,13 @@ module DNN
|
|
|
392
393
|
output_tensors.each.with_index do |out, i|
|
|
393
394
|
correct << accuracy(out.data, y[i]) if accuracy
|
|
394
395
|
loss = @loss_func[i].(out, Tensor.convert(y[i]))
|
|
395
|
-
loss_data << loss.data
|
|
396
|
+
loss_data << Utils.to_f(loss.data)
|
|
396
397
|
end
|
|
397
398
|
else
|
|
398
399
|
out = output_tensors
|
|
399
400
|
correct = accuracy(out.data, y) if accuracy
|
|
400
401
|
loss = @loss_func.(out, Tensor.convert(y))
|
|
401
|
-
loss_data = loss.data
|
|
402
|
+
loss_data = Utils.to_f(loss.data)
|
|
402
403
|
end
|
|
403
404
|
call_callbacks(:after_test_on_batch)
|
|
404
405
|
[correct, loss_data]
|
|
@@ -441,8 +442,9 @@ module DNN
|
|
|
441
442
|
ys = []
|
|
442
443
|
ary_output_tensors.each.with_index do |out, i|
|
|
443
444
|
y = out.data
|
|
444
|
-
|
|
445
|
-
|
|
445
|
+
lf = lfs[i]
|
|
446
|
+
if use_loss_activation && lf && lf.class.respond_to?(:activation)
|
|
447
|
+
y = lf.class.activation(y)
|
|
446
448
|
end
|
|
447
449
|
ys << y
|
|
448
450
|
end
|
|
@@ -458,7 +460,12 @@ module DNN
|
|
|
458
460
|
else
|
|
459
461
|
x.reshape(1, *x.shape)
|
|
460
462
|
end
|
|
461
|
-
predict(input, use_loss_activation: use_loss_activation)
|
|
463
|
+
y = predict(input, use_loss_activation: use_loss_activation)
|
|
464
|
+
if y.is_a?(Array)
|
|
465
|
+
y.map { |v| v[0, false] }
|
|
466
|
+
else
|
|
467
|
+
y[0, false]
|
|
468
|
+
end
|
|
462
469
|
end
|
|
463
470
|
|
|
464
471
|
# Add callback function.
|
|
@@ -468,6 +475,15 @@ module DNN
|
|
|
468
475
|
@callbacks << callback
|
|
469
476
|
end
|
|
470
477
|
|
|
478
|
+
# Add lambda callback.
|
|
479
|
+
# @param [Symbol] event Event to execute callback.
|
|
480
|
+
# @yield Register the contents of the callback.
|
|
481
|
+
def add_lambda_callback(event, &block)
|
|
482
|
+
callback = Callbacks::LambdaCallback.new(event, &block)
|
|
483
|
+
callback.model = self
|
|
484
|
+
@callbacks << callback
|
|
485
|
+
end
|
|
486
|
+
|
|
471
487
|
# Clear the callback function registered for each event.
|
|
472
488
|
def clear_callbacks
|
|
473
489
|
@callbacks = []
|
|
@@ -526,7 +542,7 @@ module DNN
|
|
|
526
542
|
@loss_func.each do |lf|
|
|
527
543
|
lf.clean
|
|
528
544
|
end
|
|
529
|
-
|
|
545
|
+
elsif @loss_func.is_a?(Losses::Loss)
|
|
530
546
|
@loss_func.clean
|
|
531
547
|
end
|
|
532
548
|
@layers_cache = nil
|
|
@@ -552,6 +568,56 @@ module DNN
|
|
|
552
568
|
end
|
|
553
569
|
end
|
|
554
570
|
|
|
571
|
+
# Convert the parameters of model and optimizer for cpu.
|
|
572
|
+
# @return [DNN::Models::Model] Return self.
|
|
573
|
+
def to_cpu
|
|
574
|
+
params_data = get_all_params_data
|
|
575
|
+
clean_layers
|
|
576
|
+
set_all_params_data(params_data)
|
|
577
|
+
trainable_layers.each do |layer|
|
|
578
|
+
layer.get_params.each do |key, param|
|
|
579
|
+
data = param.data
|
|
580
|
+
if DNN.use_cumo? && data.is_a?(Cumo::NArray)
|
|
581
|
+
param.data = Utils.cumo2numo(data)
|
|
582
|
+
end
|
|
583
|
+
end
|
|
584
|
+
end
|
|
585
|
+
@optimizer.status.each do |key, state|
|
|
586
|
+
next unless state
|
|
587
|
+
state.each do |param, data|
|
|
588
|
+
if DNN.use_cumo? && data.is_a?(Cumo::NArray)
|
|
589
|
+
state[param] = Utils.cumo2numo(data)
|
|
590
|
+
end
|
|
591
|
+
end
|
|
592
|
+
end
|
|
593
|
+
self
|
|
594
|
+
end
|
|
595
|
+
|
|
596
|
+
# Convert the parameters of model and optimizer for gpu.
|
|
597
|
+
# @return [DNN::Models::Model] Return self.
|
|
598
|
+
def to_gpu
|
|
599
|
+
params_data = get_all_params_data
|
|
600
|
+
clean_layers
|
|
601
|
+
set_all_params_data(params_data)
|
|
602
|
+
trainable_layers.each do |layer|
|
|
603
|
+
layer.get_params.each do |(key, param)|
|
|
604
|
+
data = param.data
|
|
605
|
+
if DNN.use_cumo? && data.is_a?(Numo::NArray)
|
|
606
|
+
param.data = Utils.numo2cumo(data)
|
|
607
|
+
end
|
|
608
|
+
end
|
|
609
|
+
end
|
|
610
|
+
@optimizer.status.each do |(key, state)|
|
|
611
|
+
next unless state
|
|
612
|
+
state.each do |(param, data)|
|
|
613
|
+
if DNN.use_cumo? && data.is_a?(Numo::NArray)
|
|
614
|
+
state[param] = Utils.numo2cumo(data)
|
|
615
|
+
end
|
|
616
|
+
end
|
|
617
|
+
end
|
|
618
|
+
self
|
|
619
|
+
end
|
|
620
|
+
|
|
555
621
|
private
|
|
556
622
|
|
|
557
623
|
def get_all_trainable_params
|
|
@@ -569,10 +635,10 @@ module DNN
|
|
|
569
635
|
def metrics_to_str(mertics)
|
|
570
636
|
mertics.map { |key, values|
|
|
571
637
|
str_values = if values.is_a?(Array)
|
|
572
|
-
values_fmt = values.map { |v| sprintf('%.4f', v) }
|
|
638
|
+
values_fmt = values.map { |v| sprintf('%.4f', Utils.to_f(v)) }
|
|
573
639
|
"[#{values_fmt.join(", ")}]"
|
|
574
640
|
else
|
|
575
|
-
sprintf('%.4f', values)
|
|
641
|
+
sprintf('%.4f', Utils.to_f(values))
|
|
576
642
|
end
|
|
577
643
|
"#{key}: #{str_values}"
|
|
578
644
|
}.join(", ")
|
|
@@ -641,6 +707,7 @@ module DNN
|
|
|
641
707
|
raise TypeError, "layer: #{layer.class.name} is not an instance of the DNN::Layers::Layer class or DNN::Models::Chain class."
|
|
642
708
|
end
|
|
643
709
|
@stack.insert(index, layer)
|
|
710
|
+
self
|
|
644
711
|
end
|
|
645
712
|
|
|
646
713
|
# Remove layer to the model.
|
data/lib/dnn/core/optimizers.rb
CHANGED
|
@@ -3,6 +3,7 @@ module DNN
|
|
|
3
3
|
|
|
4
4
|
# Super class of all optimizer classes.
|
|
5
5
|
class Optimizer
|
|
6
|
+
attr_reader :status
|
|
6
7
|
attr_accessor :clip_norm
|
|
7
8
|
|
|
8
9
|
def self.from_hash(hash)
|
|
@@ -47,7 +48,7 @@ module DNN
|
|
|
47
48
|
end
|
|
48
49
|
|
|
49
50
|
private def clip_grads(params)
|
|
50
|
-
norm = Math.sqrt(params.reduce(0) { |total, param| total + (param.grad**2).sum })
|
|
51
|
+
norm = Math.sqrt(params.reduce(0) { |total, param| total + (param.grad**2).sum.to_f })
|
|
51
52
|
return if norm <= @clip_norm
|
|
52
53
|
rate = @clip_norm / (norm + 1e-7)
|
|
53
54
|
params.each do |param|
|
|
@@ -71,6 +72,7 @@ module DNN
|
|
|
71
72
|
@lr = lr
|
|
72
73
|
@momentum = momentum
|
|
73
74
|
@v = {}
|
|
75
|
+
@status = { v: @v }
|
|
74
76
|
end
|
|
75
77
|
|
|
76
78
|
def to_hash
|
|
@@ -120,6 +122,7 @@ module DNN
|
|
|
120
122
|
@lr = lr
|
|
121
123
|
@eps = eps
|
|
122
124
|
@g = {}
|
|
125
|
+
@status = { g: @g }
|
|
123
126
|
end
|
|
124
127
|
|
|
125
128
|
private def update_params(params)
|
|
@@ -153,6 +156,7 @@ module DNN
|
|
|
153
156
|
@alpha = alpha
|
|
154
157
|
@eps = eps
|
|
155
158
|
@g = {}
|
|
159
|
+
@status = { g: @g }
|
|
156
160
|
end
|
|
157
161
|
|
|
158
162
|
def to_hash
|
|
@@ -184,6 +188,7 @@ module DNN
|
|
|
184
188
|
@eps = eps
|
|
185
189
|
@h = {}
|
|
186
190
|
@s = {}
|
|
191
|
+
@status = { h: @h, s: @s }
|
|
187
192
|
end
|
|
188
193
|
|
|
189
194
|
def to_hash
|
|
@@ -221,6 +226,7 @@ module DNN
|
|
|
221
226
|
@eps = eps
|
|
222
227
|
@m = {}
|
|
223
228
|
@v = {}
|
|
229
|
+
@status = { m: @m, v: @v }
|
|
224
230
|
end
|
|
225
231
|
|
|
226
232
|
def to_hash
|
|
@@ -265,6 +271,7 @@ module DNN
|
|
|
265
271
|
@m = {}
|
|
266
272
|
@v = {}
|
|
267
273
|
@s = amsgrad ? {} : nil
|
|
274
|
+
@status = { m: @m, v: @v, s: @s }
|
|
268
275
|
end
|
|
269
276
|
|
|
270
277
|
def to_hash
|
data/lib/dnn/core/utils.rb
CHANGED
|
@@ -43,5 +43,28 @@ module DNN
|
|
|
43
43
|
def self.numerical_grad(x, func)
|
|
44
44
|
(func.(x + 1e-7) - func.(x)) / 1e-7
|
|
45
45
|
end
|
|
46
|
+
|
|
47
|
+
# Convert numo to cumo.
|
|
48
|
+
def self.numo2cumo(na)
|
|
49
|
+
b = na.to_binary
|
|
50
|
+
ca = Cumo::SFloat.from_binary(b)
|
|
51
|
+
ca.reshape(*na.shape)
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# Convert cumo to numo.
|
|
55
|
+
def self.cumo2numo(ca)
|
|
56
|
+
b = ca.to_binary
|
|
57
|
+
na = Numo::SFloat.from_binary(b)
|
|
58
|
+
na.reshape(*ca.shape)
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# Force convert to Float.
|
|
62
|
+
def self.to_f(x)
|
|
63
|
+
if x.is_a?(Xumo::NArray)
|
|
64
|
+
x[0].to_f
|
|
65
|
+
else
|
|
66
|
+
x.to_f
|
|
67
|
+
end
|
|
68
|
+
end
|
|
46
69
|
end
|
|
47
70
|
end
|
data/lib/dnn/image.rb
CHANGED
|
@@ -64,6 +64,19 @@ module DNN
|
|
|
64
64
|
raise ImageWriteError, "Image write failed." if res == 0
|
|
65
65
|
end
|
|
66
66
|
|
|
67
|
+
# Create an image from binary.
|
|
68
|
+
# @param [String] bin binary data.
|
|
69
|
+
# @param [Integer] height Image height.
|
|
70
|
+
# @param [Integer] width Image width.
|
|
71
|
+
# @param [Integer] channel Image channel.
|
|
72
|
+
def self.from_binary(bin, height, width, channel = DNN::Image::RGB)
|
|
73
|
+
expected_size = height * width * channel
|
|
74
|
+
unless bin.size == expected_size
|
|
75
|
+
raise ImageError, "binary size is #{bin.size}, but expected binary size is #{expected_size}"
|
|
76
|
+
end
|
|
77
|
+
Numo::UInt8.from_binary(bin).reshape(height, width, channel)
|
|
78
|
+
end
|
|
79
|
+
|
|
67
80
|
# Resize the image.
|
|
68
81
|
# @param [Numo::UInt8] img Image to resize.
|
|
69
82
|
# @param [Integer] out_height Image height to resize.
|
|
@@ -101,6 +114,41 @@ module DNN
|
|
|
101
114
|
Numo::UInt8.cast(x)
|
|
102
115
|
end
|
|
103
116
|
|
|
117
|
+
# Image convert image channel to RGB.
|
|
118
|
+
# @param [Numo::UInt8] img Image to RGB.
|
|
119
|
+
def self.to_rgb(img)
|
|
120
|
+
img_check(img)
|
|
121
|
+
case img.shape[2]
|
|
122
|
+
when 1
|
|
123
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
|
|
124
|
+
when 2
|
|
125
|
+
img = img[true, true, 0...1]
|
|
126
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2)
|
|
127
|
+
when 4
|
|
128
|
+
return img[true, true, 0...3].clone
|
|
129
|
+
end
|
|
130
|
+
img
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
# Image convert image channel to RGBA.
|
|
134
|
+
# @param [Numo::UInt8] img Image to RGBA.
|
|
135
|
+
def self.to_rgba(img)
|
|
136
|
+
img_check(img)
|
|
137
|
+
case img.shape[2]
|
|
138
|
+
when 1
|
|
139
|
+
alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
|
|
140
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
|
|
141
|
+
when 2
|
|
142
|
+
alpha = img[true, true, 1...2]
|
|
143
|
+
img = img[true, true, 0...1]
|
|
144
|
+
return img.concatenate(img, axis: 2).concatenate(img, axis: 2).concatenate(alpha, axis: 2)
|
|
145
|
+
when 3
|
|
146
|
+
alpha = Numo::UInt8.new(*img.shape[0..1], 1).fill(255)
|
|
147
|
+
return img.concatenate(alpha, axis: 2)
|
|
148
|
+
end
|
|
149
|
+
img
|
|
150
|
+
end
|
|
151
|
+
|
|
104
152
|
private_class_method def self.img_check(img)
|
|
105
153
|
raise TypeError, "img: #{img.class} is not an instance of the Numo::UInt8 class." unless img.is_a?(Numo::UInt8)
|
|
106
154
|
if img.shape.length != 3
|