ruby-dnn 0.10.4 → 0.12.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.travis.yml +1 -2
- data/README.md +33 -6
- data/examples/cifar100_example.rb +3 -3
- data/examples/cifar10_example.rb +3 -3
- data/examples/dcgan/dcgan.rb +112 -0
- data/examples/dcgan/imgen.rb +20 -0
- data/examples/dcgan/train.rb +41 -0
- data/examples/iris_example.rb +3 -6
- data/examples/mnist_conv2d_example.rb +5 -5
- data/examples/mnist_define_by_run.rb +52 -0
- data/examples/mnist_example.rb +3 -3
- data/examples/mnist_lstm_example.rb +3 -3
- data/examples/xor_example.rb +4 -5
- data/ext/rb_stb_image/rb_stb_image.c +103 -0
- data/lib/dnn.rb +10 -10
- data/lib/dnn/cifar10.rb +1 -1
- data/lib/dnn/cifar100.rb +1 -1
- data/lib/dnn/core/activations.rb +21 -22
- data/lib/dnn/core/cnn_layers.rb +94 -111
- data/lib/dnn/core/embedding.rb +30 -9
- data/lib/dnn/core/initializers.rb +31 -21
- data/lib/dnn/core/iterator.rb +52 -0
- data/lib/dnn/core/layers.rb +99 -66
- data/lib/dnn/core/link.rb +24 -0
- data/lib/dnn/core/losses.rb +69 -59
- data/lib/dnn/core/merge_layers.rb +71 -0
- data/lib/dnn/core/models.rb +393 -0
- data/lib/dnn/core/normalizations.rb +27 -14
- data/lib/dnn/core/optimizers.rb +212 -134
- data/lib/dnn/core/param.rb +8 -6
- data/lib/dnn/core/regularizers.rb +10 -7
- data/lib/dnn/core/rnn_layers.rb +78 -85
- data/lib/dnn/core/utils.rb +6 -3
- data/lib/dnn/downloader.rb +3 -3
- data/lib/dnn/fashion-mnist.rb +89 -0
- data/lib/dnn/image.rb +57 -18
- data/lib/dnn/iris.rb +1 -3
- data/lib/dnn/mnist.rb +38 -34
- data/lib/dnn/version.rb +1 -1
- data/third_party/stb_image.h +16 -4
- data/third_party/stb_image_resize.h +2630 -0
- data/third_party/stb_image_write.h +4 -7
- metadata +12 -4
- data/lib/dnn/core/dataset.rb +0 -34
- data/lib/dnn/core/model.rb +0 -440
data/lib/dnn/core/param.rb
CHANGED
@@ -1,9 +1,11 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
1
|
+
module DNN
|
2
|
+
class Param
|
3
|
+
attr_accessor :data
|
4
|
+
attr_accessor :grad
|
4
5
|
|
5
|
-
|
6
|
-
|
7
|
-
|
6
|
+
def initialize(data = nil, grad = nil)
|
7
|
+
@data = data
|
8
|
+
@grad = grad
|
9
|
+
end
|
8
10
|
end
|
9
11
|
end
|
@@ -13,7 +13,7 @@ module DNN
|
|
13
13
|
end
|
14
14
|
|
15
15
|
def to_hash(merge_hash)
|
16
|
-
hash = {class: self.class.name}
|
16
|
+
hash = { class: self.class.name }
|
17
17
|
hash.merge!(merge_hash)
|
18
18
|
hash
|
19
19
|
end
|
@@ -23,9 +23,10 @@ module DNN
|
|
23
23
|
attr_accessor :l1_lambda
|
24
24
|
|
25
25
|
def self.from_hash(hash)
|
26
|
-
|
26
|
+
self.new(hash[:l1_lambda])
|
27
27
|
end
|
28
28
|
|
29
|
+
# @param [Float] l1_lambda L1 regularizer coefficient.
|
29
30
|
def initialize(l1_lambda = 0.01)
|
30
31
|
@l1_lambda = l1_lambda
|
31
32
|
end
|
@@ -50,15 +51,16 @@ module DNN
|
|
50
51
|
attr_accessor :l2_lambda
|
51
52
|
|
52
53
|
def self.from_hash(hash)
|
53
|
-
|
54
|
+
self.new(hash[:l2_lambda])
|
54
55
|
end
|
55
56
|
|
57
|
+
# @param [Float] l2_lambda L2 regularizer coefficient.
|
56
58
|
def initialize(l2_lambda = 0.01)
|
57
59
|
@l2_lambda = l2_lambda
|
58
60
|
end
|
59
61
|
|
60
62
|
def forward(x)
|
61
|
-
x + 0.5 * @l2_lambda * (@param.data**2).sum
|
63
|
+
x + 0.5 * @l2_lambda * (@param.data ** 2).sum
|
62
64
|
end
|
63
65
|
|
64
66
|
def backward
|
@@ -75,9 +77,11 @@ module DNN
|
|
75
77
|
attr_accessor :l2_lambda
|
76
78
|
|
77
79
|
def self.from_hash(hash)
|
78
|
-
|
80
|
+
self.new(hash[:l1_lambda], hash[:l2_lambda])
|
79
81
|
end
|
80
82
|
|
83
|
+
# @param [Float] l1_lambda L1 regularizer coefficient.
|
84
|
+
# @param [Float] l2_lambda L2 regularizer coefficient.
|
81
85
|
def initialize(l1_lambda = 0.01, l2_lambda = 0.01)
|
82
86
|
@l1_lambda = l1_lambda
|
83
87
|
@l2_lambda = l2_lambda
|
@@ -85,7 +89,7 @@ module DNN
|
|
85
89
|
|
86
90
|
def forward(x)
|
87
91
|
l1 = @l1_lambda * @param.data.abs.sum
|
88
|
-
l2 = 0.5 * @l2_lambda * (@param.data**2).sum
|
92
|
+
l2 = 0.5 * @l2_lambda * (@param.data ** 2).sum
|
89
93
|
x + l1 + l2
|
90
94
|
end
|
91
95
|
|
@@ -99,7 +103,6 @@ module DNN
|
|
99
103
|
def to_hash
|
100
104
|
super(l1_lambda: l1_lambda, l2_lambda: l2_lambda)
|
101
105
|
end
|
102
|
-
|
103
106
|
end
|
104
107
|
|
105
108
|
end
|
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -3,25 +3,38 @@ module DNN
|
|
3
3
|
|
4
4
|
# Super class of all RNN classes.
|
5
5
|
class RNN < Connection
|
6
|
-
include Initializers
|
7
|
-
|
8
|
-
# @return [Integer] number of nodes.
|
9
6
|
attr_reader :num_nodes
|
10
|
-
|
7
|
+
attr_reader :recurrent_weight
|
8
|
+
attr_reader :hidden
|
11
9
|
attr_reader :stateful
|
12
|
-
# @return [Bool] Set the false, only the last of each cell of RNN is left.
|
13
10
|
attr_reader :return_sequences
|
14
|
-
# @return [DNN::Initializers::Initializer] Recurrent weight initializer.
|
15
11
|
attr_reader :recurrent_weight_initializer
|
16
|
-
# @return [DNN::Regularizers::Regularizer] Recurrent weight regularization.
|
17
12
|
attr_reader :recurrent_weight_regularizer
|
18
13
|
|
14
|
+
def self.from_hash(hash)
|
15
|
+
self.new(hash[:num_nodes],
|
16
|
+
stateful: hash[:stateful],
|
17
|
+
return_sequences: hash[:return_sequences],
|
18
|
+
weight_initializer: Utils.hash_to_obj(hash[:weight_initializer]),
|
19
|
+
recurrent_weight_initializer: Utils.hash_to_obj(hash[:recurrent_weight_initializer]),
|
20
|
+
bias_initializer: Utils.hash_to_obj(hash[:bias_initializer]),
|
21
|
+
weight_regularizer: Utils.hash_to_obj(hash[:weight_regularizer]),
|
22
|
+
recurrent_weight_regularizer: Utils.hash_to_obj(hash[:recurrent_weight_regularizer]),
|
23
|
+
bias_regularizer: Utils.hash_to_obj(hash[:bias_regularizer]),
|
24
|
+
use_bias: hash[:use_bias])
|
25
|
+
end
|
26
|
+
|
27
|
+
# @param [Integer] num_nodes Number of nodes.
|
28
|
+
# @param [Boolean] stateful Maintain state between batches.
|
29
|
+
# @param [Boolean] return_sequences Set the false, only the last of each cell of RNN is left.
|
30
|
+
# @param [DNN::Initializers::Initializer] recurrent_weight_initializer Recurrent weight initializer.
|
31
|
+
# @param [DNN::Regularizers::Regularizer | NilClass] recurrent_weight_regularizer Recurrent weight regularizer.
|
19
32
|
def initialize(num_nodes,
|
20
33
|
stateful: false,
|
21
34
|
return_sequences: true,
|
22
|
-
weight_initializer: RandomNormal.new,
|
23
|
-
recurrent_weight_initializer: RandomNormal.new,
|
24
|
-
bias_initializer: Zeros.new,
|
35
|
+
weight_initializer: Initializers::RandomNormal.new,
|
36
|
+
recurrent_weight_initializer: Initializers::RandomNormal.new,
|
37
|
+
bias_initializer: Initializers::Zeros.new,
|
25
38
|
weight_regularizer: nil,
|
26
39
|
recurrent_weight_regularizer: nil,
|
27
40
|
bias_regularizer: nil,
|
@@ -32,8 +45,8 @@ module DNN
|
|
32
45
|
@stateful = stateful
|
33
46
|
@return_sequences = return_sequences
|
34
47
|
@layers = []
|
35
|
-
@hidden =
|
36
|
-
@
|
48
|
+
@hidden = Param.new
|
49
|
+
@recurrent_weight = Param.new(nil, 0)
|
37
50
|
@recurrent_weight_initializer = recurrent_weight_initializer
|
38
51
|
@recurrent_weight_regularizer = recurrent_weight_regularizer
|
39
52
|
end
|
@@ -49,7 +62,7 @@ module DNN
|
|
49
62
|
def forward(xs)
|
50
63
|
@xs_shape = xs.shape
|
51
64
|
hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
|
52
|
-
h =
|
65
|
+
h = @stateful && @hidden.data ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
|
53
66
|
xs.shape[1].times do |t|
|
54
67
|
x = xs[true, t, false]
|
55
68
|
@layers[t].trainable = @trainable
|
@@ -92,6 +105,10 @@ module DNN
|
|
92
105
|
super(hash)
|
93
106
|
end
|
94
107
|
|
108
|
+
def get_params
|
109
|
+
{ weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden }
|
110
|
+
end
|
111
|
+
|
95
112
|
# Reset the state of RNN.
|
96
113
|
def reset_state
|
97
114
|
@hidden.data = @hidden.data.fill(0) if @hidden.data
|
@@ -113,7 +130,7 @@ module DNN
|
|
113
130
|
end
|
114
131
|
|
115
132
|
|
116
|
-
class
|
133
|
+
class SimpleRNNDense
|
117
134
|
attr_accessor :trainable
|
118
135
|
|
119
136
|
def initialize(weight, recurrent_weight, bias, activation)
|
@@ -147,32 +164,30 @@ module DNN
|
|
147
164
|
|
148
165
|
|
149
166
|
class SimpleRNN < RNN
|
150
|
-
include Activations
|
151
|
-
|
152
167
|
attr_reader :activation
|
153
|
-
|
168
|
+
|
154
169
|
def self.from_hash(hash)
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
simple_rnn
|
170
|
+
self.new(hash[:num_nodes],
|
171
|
+
stateful: hash[:stateful],
|
172
|
+
return_sequences: hash[:return_sequences],
|
173
|
+
activation: Utils.hash_to_obj(hash[:activation]),
|
174
|
+
weight_initializer: Utils.hash_to_obj(hash[:weight_initializer]),
|
175
|
+
recurrent_weight_initializer: Utils.hash_to_obj(hash[:recurrent_weight_initializer]),
|
176
|
+
bias_initializer: Utils.hash_to_obj(hash[:bias_initializer]),
|
177
|
+
weight_regularizer: Utils.hash_to_obj(hash[:weight_regularizer]),
|
178
|
+
recurrent_weight_regularizer: Utils.hash_to_obj(hash[:recurrent_weight_regularizer]),
|
179
|
+
bias_regularizer: Utils.hash_to_obj(hash[:bias_regularizer]),
|
180
|
+
use_bias: hash[:use_bias])
|
167
181
|
end
|
168
182
|
|
183
|
+
# @param [DNN::Layers::Layer] activation Activation function to use in a recurrent network.
|
169
184
|
def initialize(num_nodes,
|
170
185
|
stateful: false,
|
171
186
|
return_sequences: true,
|
172
|
-
activation: Tanh.new,
|
173
|
-
weight_initializer: RandomNormal.new,
|
174
|
-
recurrent_weight_initializer: RandomNormal.new,
|
175
|
-
bias_initializer: Zeros.new,
|
187
|
+
activation: Activations::Tanh.new,
|
188
|
+
weight_initializer: Initializers::RandomNormal.new,
|
189
|
+
recurrent_weight_initializer: Initializers::RandomNormal.new,
|
190
|
+
bias_initializer: Initializers::Zeros.new,
|
176
191
|
weight_regularizer: nil,
|
177
192
|
recurrent_weight_regularizer: nil,
|
178
193
|
bias_regularizer: nil,
|
@@ -197,29 +212,29 @@ module DNN
|
|
197
212
|
@recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
|
198
213
|
@bias.data = Xumo::SFloat.new(@num_nodes) if @bias
|
199
214
|
init_weight_and_bias
|
200
|
-
@time_length.times do
|
201
|
-
@layers <<
|
215
|
+
@time_length.times do
|
216
|
+
@layers << SimpleRNNDense.new(@weight, @recurrent_weight, @bias, @activation)
|
202
217
|
end
|
203
218
|
end
|
204
219
|
|
205
220
|
def to_hash
|
206
|
-
super(
|
221
|
+
super(activation: @activation.to_hash)
|
207
222
|
end
|
208
223
|
end
|
209
224
|
|
210
225
|
|
211
|
-
class
|
226
|
+
class LSTMDense
|
212
227
|
attr_accessor :trainable
|
213
228
|
|
214
229
|
def initialize(weight, recurrent_weight, bias)
|
215
230
|
@weight = weight
|
216
231
|
@recurrent_weight = recurrent_weight
|
217
232
|
@bias = bias
|
218
|
-
@tanh = Tanh.new
|
219
|
-
@g_tanh = Tanh.new
|
220
|
-
@forget_sigmoid = Sigmoid.new
|
221
|
-
@in_sigmoid = Sigmoid.new
|
222
|
-
@out_sigmoid = Sigmoid.new
|
233
|
+
@tanh = Activations::Tanh.new
|
234
|
+
@g_tanh = Activations::Tanh.new
|
235
|
+
@forget_sigmoid = Activations::Sigmoid.new
|
236
|
+
@in_sigmoid = Activations::Sigmoid.new
|
237
|
+
@out_sigmoid = Activations::Sigmoid.new
|
223
238
|
@trainable = true
|
224
239
|
end
|
225
240
|
|
@@ -267,32 +282,20 @@ module DNN
|
|
267
282
|
|
268
283
|
|
269
284
|
class LSTM < RNN
|
270
|
-
|
271
|
-
lstm = self.new(hash[:num_nodes],
|
272
|
-
stateful: hash[:stateful],
|
273
|
-
return_sequences: hash[:return_sequences],
|
274
|
-
weight_initializer: Utils.from_hash(hash[:weight_initializer]),
|
275
|
-
recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
|
276
|
-
bias_initializer: Utils.from_hash(hash[:bias_initializer]),
|
277
|
-
weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
|
278
|
-
recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
|
279
|
-
bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
|
280
|
-
use_bias: hash[:use_bias])
|
281
|
-
lstm
|
282
|
-
end
|
285
|
+
attr_reader :cell
|
283
286
|
|
284
287
|
def initialize(num_nodes,
|
285
288
|
stateful: false,
|
286
289
|
return_sequences: true,
|
287
|
-
weight_initializer: RandomNormal.new,
|
288
|
-
recurrent_weight_initializer: RandomNormal.new,
|
289
|
-
bias_initializer: Zeros.new,
|
290
|
+
weight_initializer: Initializers::RandomNormal.new,
|
291
|
+
recurrent_weight_initializer: Initializers::RandomNormal.new,
|
292
|
+
bias_initializer: Initializers::Zeros.new,
|
290
293
|
weight_regularizer: nil,
|
291
294
|
recurrent_weight_regularizer: nil,
|
292
295
|
bias_regularizer: nil,
|
293
296
|
use_bias: true)
|
294
297
|
super
|
295
|
-
@cell =
|
298
|
+
@cell = Param.new
|
296
299
|
end
|
297
300
|
|
298
301
|
def build(input_shape)
|
@@ -302,8 +305,8 @@ module DNN
|
|
302
305
|
@recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
|
303
306
|
@bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
|
304
307
|
init_weight_and_bias
|
305
|
-
@time_length.times do
|
306
|
-
@layers <<
|
308
|
+
@time_length.times do
|
309
|
+
@layers << LSTMDense.new(@weight, @recurrent_weight, @bias)
|
307
310
|
end
|
308
311
|
end
|
309
312
|
|
@@ -350,19 +353,23 @@ module DNN
|
|
350
353
|
super()
|
351
354
|
@cell.data = @cell.data.fill(0) if @cell.data
|
352
355
|
end
|
356
|
+
|
357
|
+
def get_params
|
358
|
+
{ weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden, cell: @cell }
|
359
|
+
end
|
353
360
|
end
|
354
361
|
|
355
362
|
|
356
|
-
class
|
363
|
+
class GRUDense
|
357
364
|
attr_accessor :trainable
|
358
365
|
|
359
366
|
def initialize(weight, recurrent_weight, bias)
|
360
367
|
@weight = weight
|
361
368
|
@recurrent_weight = recurrent_weight
|
362
369
|
@bias = bias
|
363
|
-
@update_sigmoid = Sigmoid.new
|
364
|
-
@reset_sigmoid = Sigmoid.new
|
365
|
-
@tanh = Tanh.new
|
370
|
+
@update_sigmoid = Activations::Sigmoid.new
|
371
|
+
@reset_sigmoid = Activations::Sigmoid.new
|
372
|
+
@tanh = Activations::Tanh.new
|
366
373
|
@trainable = true
|
367
374
|
end
|
368
375
|
|
@@ -423,33 +430,19 @@ module DNN
|
|
423
430
|
|
424
431
|
|
425
432
|
class GRU < RNN
|
426
|
-
def self.from_hash(hash)
|
427
|
-
gru = self.new(hash[:num_nodes],
|
428
|
-
stateful: hash[:stateful],
|
429
|
-
return_sequences: hash[:return_sequences],
|
430
|
-
weight_initializer: Utils.from_hash(hash[:weight_initializer]),
|
431
|
-
recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
|
432
|
-
bias_initializer: Utils.from_hash(hash[:bias_initializer]),
|
433
|
-
weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
|
434
|
-
recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
|
435
|
-
bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
|
436
|
-
use_bias: hash[:use_bias])
|
437
|
-
gru
|
438
|
-
end
|
439
|
-
|
440
433
|
def initialize(num_nodes,
|
441
434
|
stateful: false,
|
442
435
|
return_sequences: true,
|
443
|
-
weight_initializer: RandomNormal.new,
|
444
|
-
recurrent_weight_initializer: RandomNormal.new,
|
445
|
-
bias_initializer: Zeros.new,
|
436
|
+
weight_initializer: Initializers::RandomNormal.new,
|
437
|
+
recurrent_weight_initializer: Initializers::RandomNormal.new,
|
438
|
+
bias_initializer: Initializers::Zeros.new,
|
446
439
|
weight_regularizer: nil,
|
447
440
|
recurrent_weight_regularizer: nil,
|
448
441
|
bias_regularizer: nil,
|
449
442
|
use_bias: true)
|
450
443
|
super
|
451
444
|
end
|
452
|
-
|
445
|
+
|
453
446
|
def build(input_shape)
|
454
447
|
super
|
455
448
|
num_prev_nodes = @input_shape[1]
|
@@ -457,8 +450,8 @@ module DNN
|
|
457
450
|
@recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
|
458
451
|
@bias.data = Xumo::SFloat.new(@num_nodes * 3) if @bias
|
459
452
|
init_weight_and_bias
|
460
|
-
@time_length.times do
|
461
|
-
@layers <<
|
453
|
+
@time_length.times do
|
454
|
+
@layers << GRUDense.new(@weight, @recurrent_weight, @bias)
|
462
455
|
end
|
463
456
|
end
|
464
457
|
end
|
data/lib/dnn/core/utils.rb
CHANGED
@@ -2,6 +2,9 @@ module DNN
|
|
2
2
|
# This module provides utility functions.
|
3
3
|
module Utils
|
4
4
|
# Categorize labels into "num_classes" classes.
|
5
|
+
# @param [Numo::SFloat] y Label data.
|
6
|
+
# @param [Numo::SFloat] num_classes Number of classes to classify.
|
7
|
+
# @param [Class] narray_type Type of Numo::Narray data after classification.
|
5
8
|
def self.to_categorical(y, num_classes, narray_type = nil)
|
6
9
|
narray_type ||= y.class
|
7
10
|
y2 = narray_type.zeros(y.shape[0], num_classes)
|
@@ -12,7 +15,7 @@ module DNN
|
|
12
15
|
end
|
13
16
|
|
14
17
|
# Convert hash to an object.
|
15
|
-
def self.
|
18
|
+
def self.hash_to_obj(hash)
|
16
19
|
return nil if hash == nil
|
17
20
|
dnn_class = DNN.const_get(hash[:class])
|
18
21
|
if dnn_class.respond_to?(:from_hash)
|
@@ -23,12 +26,12 @@ module DNN
|
|
23
26
|
|
24
27
|
# Return the result of the sigmoid function.
|
25
28
|
def self.sigmoid(x)
|
26
|
-
Sigmoid.new.forward(x)
|
29
|
+
Activations::Sigmoid.new.forward(x)
|
27
30
|
end
|
28
31
|
|
29
32
|
# Return the result of the softmax function.
|
30
33
|
def self.softmax(x)
|
31
|
-
SoftmaxCrossEntropy.softmax(x)
|
34
|
+
Losses::SoftmaxCrossEntropy.softmax(x)
|
32
35
|
end
|
33
36
|
end
|
34
37
|
end
|
data/lib/dnn/downloader.rb
CHANGED
@@ -11,8 +11,8 @@ module DNN
|
|
11
11
|
dir_path = "#{__dir__}/downloads"
|
12
12
|
end
|
13
13
|
Downloader.new(url).download(dir_path)
|
14
|
-
rescue =>
|
15
|
-
raise DNN_DownloadError.new(
|
14
|
+
rescue => e
|
15
|
+
raise DNN_DownloadError.new(e.message)
|
16
16
|
end
|
17
17
|
|
18
18
|
def initialize(url)
|
@@ -42,7 +42,7 @@ module DNN
|
|
42
42
|
end
|
43
43
|
puts ""
|
44
44
|
end
|
45
|
-
file_name = @path.match(%r
|
45
|
+
file_name = @path.match(%r`.*/(.+)`)[1]
|
46
46
|
File.binwrite("#{dir_path}/#{file_name}", buf)
|
47
47
|
end
|
48
48
|
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
require "zlib"
|
2
|
+
require_relative "core/error"
|
3
|
+
require_relative "downloader"
|
4
|
+
require_relative "mnist"
|
5
|
+
|
6
|
+
module DNN
|
7
|
+
module FashionMNIST
|
8
|
+
class DNN_MNIST_LoadError < DNN_Error; end
|
9
|
+
|
10
|
+
URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
|
11
|
+
|
12
|
+
TRAIN_IMAGES_FILE_NAME = "train-images-idx3-ubyte.gz"
|
13
|
+
TRAIN_LABELS_FILE_NAME = "train-labels-idx1-ubyte.gz"
|
14
|
+
TEST_IMAGES_FILE_NAME = "t10k-images-idx3-ubyte.gz"
|
15
|
+
TEST_LABELS_FILE_NAME = "t10k-labels-idx1-ubyte.gz"
|
16
|
+
|
17
|
+
URL_TRAIN_IMAGES = URL_BASE + TRAIN_IMAGES_FILE_NAME
|
18
|
+
URL_TRAIN_LABELS = URL_BASE + TRAIN_LABELS_FILE_NAME
|
19
|
+
URL_TEST_IMAGES = URL_BASE + TEST_IMAGES_FILE_NAME
|
20
|
+
URL_TEST_LABELS = URL_BASE + TEST_LABELS_FILE_NAME
|
21
|
+
|
22
|
+
def self.downloads
|
23
|
+
Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
|
24
|
+
Dir.mkdir(mnist_dir) unless Dir.exist?(mnist_dir)
|
25
|
+
Downloader.download(URL_TRAIN_IMAGES, mnist_dir) unless File.exist?(get_file_path(TRAIN_IMAGES_FILE_NAME))
|
26
|
+
Downloader.download(URL_TRAIN_LABELS, mnist_dir) unless File.exist?(get_file_path(TRAIN_LABELS_FILE_NAME))
|
27
|
+
Downloader.download(URL_TEST_IMAGES, mnist_dir) unless File.exist?(get_file_path(TEST_IMAGES_FILE_NAME))
|
28
|
+
Downloader.download(URL_TEST_LABELS, mnist_dir) unless File.exist?(get_file_path(TEST_LABELS_FILE_NAME))
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.load_train
|
32
|
+
downloads
|
33
|
+
train_images_file_path = get_file_path(TRAIN_IMAGES_FILE_NAME)
|
34
|
+
train_labels_file_path = get_file_path(TRAIN_LABELS_FILE_NAME)
|
35
|
+
unless File.exist?(train_images_file_path)
|
36
|
+
raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_path}" is not found.`)
|
37
|
+
end
|
38
|
+
unless File.exist?(train_labels_file_path)
|
39
|
+
raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_path}" is not found.`)
|
40
|
+
end
|
41
|
+
images = load_images(train_images_file_path)
|
42
|
+
labels = load_labels(train_labels_file_path)
|
43
|
+
[images, labels]
|
44
|
+
end
|
45
|
+
|
46
|
+
def self.load_test
|
47
|
+
downloads
|
48
|
+
test_images_file_path = get_file_path(TEST_IMAGES_FILE_NAME)
|
49
|
+
test_labels_file_path = get_file_path(TEST_LABELS_FILE_NAME)
|
50
|
+
unless File.exist?(test_images_file_path)
|
51
|
+
raise DNN_MNIST_LoadError.new(%`file "#{test_images_file_path}" is not found.`)
|
52
|
+
end
|
53
|
+
unless File.exist?(test_labels_file_path)
|
54
|
+
raise DNN_MNIST_LoadError.new(%`file "#{test_labels_file_path}" is not found.`)
|
55
|
+
end
|
56
|
+
images = load_images(test_images_file_path)
|
57
|
+
labels = load_labels(test_labels_file_path)
|
58
|
+
[images, labels]
|
59
|
+
end
|
60
|
+
|
61
|
+
private_class_method def self.load_images(file_name)
|
62
|
+
images = nil
|
63
|
+
Zlib::GzipReader.open(file_name) do |f|
|
64
|
+
magic, num_images = f.read(8).unpack("N2")
|
65
|
+
rows, cols = f.read(8).unpack("N2")
|
66
|
+
images = Numo::UInt8.from_binary(f.read)
|
67
|
+
images = images.reshape(num_images, cols, rows, 1)
|
68
|
+
end
|
69
|
+
images
|
70
|
+
end
|
71
|
+
|
72
|
+
private_class_method def self.load_labels(file_name)
|
73
|
+
labels = nil
|
74
|
+
Zlib::GzipReader.open(file_name) do |f|
|
75
|
+
magic, num_labels = f.read(8).unpack("N2")
|
76
|
+
labels = Numo::UInt8.from_binary(f.read)
|
77
|
+
end
|
78
|
+
labels
|
79
|
+
end
|
80
|
+
|
81
|
+
private_class_method def self.mnist_dir
|
82
|
+
"#{__dir__}/downloads/fashion-mnist"
|
83
|
+
end
|
84
|
+
|
85
|
+
private_class_method def self.get_file_path(file_name)
|
86
|
+
mnist_dir + "/" + file_name
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|