ruby-dnn 0.10.4 → 0.12.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.travis.yml +1 -2
- data/README.md +33 -6
- data/examples/cifar100_example.rb +3 -3
- data/examples/cifar10_example.rb +3 -3
- data/examples/dcgan/dcgan.rb +112 -0
- data/examples/dcgan/imgen.rb +20 -0
- data/examples/dcgan/train.rb +41 -0
- data/examples/iris_example.rb +3 -6
- data/examples/mnist_conv2d_example.rb +5 -5
- data/examples/mnist_define_by_run.rb +52 -0
- data/examples/mnist_example.rb +3 -3
- data/examples/mnist_lstm_example.rb +3 -3
- data/examples/xor_example.rb +4 -5
- data/ext/rb_stb_image/rb_stb_image.c +103 -0
- data/lib/dnn.rb +10 -10
- data/lib/dnn/cifar10.rb +1 -1
- data/lib/dnn/cifar100.rb +1 -1
- data/lib/dnn/core/activations.rb +21 -22
- data/lib/dnn/core/cnn_layers.rb +94 -111
- data/lib/dnn/core/embedding.rb +30 -9
- data/lib/dnn/core/initializers.rb +31 -21
- data/lib/dnn/core/iterator.rb +52 -0
- data/lib/dnn/core/layers.rb +99 -66
- data/lib/dnn/core/link.rb +24 -0
- data/lib/dnn/core/losses.rb +69 -59
- data/lib/dnn/core/merge_layers.rb +71 -0
- data/lib/dnn/core/models.rb +393 -0
- data/lib/dnn/core/normalizations.rb +27 -14
- data/lib/dnn/core/optimizers.rb +212 -134
- data/lib/dnn/core/param.rb +8 -6
- data/lib/dnn/core/regularizers.rb +10 -7
- data/lib/dnn/core/rnn_layers.rb +78 -85
- data/lib/dnn/core/utils.rb +6 -3
- data/lib/dnn/downloader.rb +3 -3
- data/lib/dnn/fashion-mnist.rb +89 -0
- data/lib/dnn/image.rb +57 -18
- data/lib/dnn/iris.rb +1 -3
- data/lib/dnn/mnist.rb +38 -34
- data/lib/dnn/version.rb +1 -1
- data/third_party/stb_image.h +16 -4
- data/third_party/stb_image_resize.h +2630 -0
- data/third_party/stb_image_write.h +4 -7
- metadata +12 -4
- data/lib/dnn/core/dataset.rb +0 -34
- data/lib/dnn/core/model.rb +0 -440
data/lib/dnn/core/param.rb
CHANGED
@@ -1,9 +1,11 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
1
|
+
module DNN
|
2
|
+
class Param
|
3
|
+
attr_accessor :data
|
4
|
+
attr_accessor :grad
|
4
5
|
|
5
|
-
|
6
|
-
|
7
|
-
|
6
|
+
def initialize(data = nil, grad = nil)
|
7
|
+
@data = data
|
8
|
+
@grad = grad
|
9
|
+
end
|
8
10
|
end
|
9
11
|
end
|
@@ -13,7 +13,7 @@ module DNN
|
|
13
13
|
end
|
14
14
|
|
15
15
|
def to_hash(merge_hash)
|
16
|
-
hash = {class: self.class.name}
|
16
|
+
hash = { class: self.class.name }
|
17
17
|
hash.merge!(merge_hash)
|
18
18
|
hash
|
19
19
|
end
|
@@ -23,9 +23,10 @@ module DNN
|
|
23
23
|
attr_accessor :l1_lambda
|
24
24
|
|
25
25
|
def self.from_hash(hash)
|
26
|
-
|
26
|
+
self.new(hash[:l1_lambda])
|
27
27
|
end
|
28
28
|
|
29
|
+
# @param [Float] l1_lambda L1 regularizer coefficient.
|
29
30
|
def initialize(l1_lambda = 0.01)
|
30
31
|
@l1_lambda = l1_lambda
|
31
32
|
end
|
@@ -50,15 +51,16 @@ module DNN
|
|
50
51
|
attr_accessor :l2_lambda
|
51
52
|
|
52
53
|
def self.from_hash(hash)
|
53
|
-
|
54
|
+
self.new(hash[:l2_lambda])
|
54
55
|
end
|
55
56
|
|
57
|
+
# @param [Float] l2_lambda L2 regularizer coefficient.
|
56
58
|
def initialize(l2_lambda = 0.01)
|
57
59
|
@l2_lambda = l2_lambda
|
58
60
|
end
|
59
61
|
|
60
62
|
def forward(x)
|
61
|
-
x + 0.5 * @l2_lambda * (@param.data**2).sum
|
63
|
+
x + 0.5 * @l2_lambda * (@param.data ** 2).sum
|
62
64
|
end
|
63
65
|
|
64
66
|
def backward
|
@@ -75,9 +77,11 @@ module DNN
|
|
75
77
|
attr_accessor :l2_lambda
|
76
78
|
|
77
79
|
def self.from_hash(hash)
|
78
|
-
|
80
|
+
self.new(hash[:l1_lambda], hash[:l2_lambda])
|
79
81
|
end
|
80
82
|
|
83
|
+
# @param [Float] l1_lambda L1 regularizer coefficient.
|
84
|
+
# @param [Float] l2_lambda L2 regularizer coefficient.
|
81
85
|
def initialize(l1_lambda = 0.01, l2_lambda = 0.01)
|
82
86
|
@l1_lambda = l1_lambda
|
83
87
|
@l2_lambda = l2_lambda
|
@@ -85,7 +89,7 @@ module DNN
|
|
85
89
|
|
86
90
|
def forward(x)
|
87
91
|
l1 = @l1_lambda * @param.data.abs.sum
|
88
|
-
l2 = 0.5 * @l2_lambda * (@param.data**2).sum
|
92
|
+
l2 = 0.5 * @l2_lambda * (@param.data ** 2).sum
|
89
93
|
x + l1 + l2
|
90
94
|
end
|
91
95
|
|
@@ -99,7 +103,6 @@ module DNN
|
|
99
103
|
def to_hash
|
100
104
|
super(l1_lambda: l1_lambda, l2_lambda: l2_lambda)
|
101
105
|
end
|
102
|
-
|
103
106
|
end
|
104
107
|
|
105
108
|
end
|
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -3,25 +3,38 @@ module DNN
|
|
3
3
|
|
4
4
|
# Super class of all RNN classes.
|
5
5
|
class RNN < Connection
|
6
|
-
include Initializers
|
7
|
-
|
8
|
-
# @return [Integer] number of nodes.
|
9
6
|
attr_reader :num_nodes
|
10
|
-
|
7
|
+
attr_reader :recurrent_weight
|
8
|
+
attr_reader :hidden
|
11
9
|
attr_reader :stateful
|
12
|
-
# @return [Bool] Set the false, only the last of each cell of RNN is left.
|
13
10
|
attr_reader :return_sequences
|
14
|
-
# @return [DNN::Initializers::Initializer] Recurrent weight initializer.
|
15
11
|
attr_reader :recurrent_weight_initializer
|
16
|
-
# @return [DNN::Regularizers::Regularizer] Recurrent weight regularization.
|
17
12
|
attr_reader :recurrent_weight_regularizer
|
18
13
|
|
14
|
+
def self.from_hash(hash)
|
15
|
+
self.new(hash[:num_nodes],
|
16
|
+
stateful: hash[:stateful],
|
17
|
+
return_sequences: hash[:return_sequences],
|
18
|
+
weight_initializer: Utils.hash_to_obj(hash[:weight_initializer]),
|
19
|
+
recurrent_weight_initializer: Utils.hash_to_obj(hash[:recurrent_weight_initializer]),
|
20
|
+
bias_initializer: Utils.hash_to_obj(hash[:bias_initializer]),
|
21
|
+
weight_regularizer: Utils.hash_to_obj(hash[:weight_regularizer]),
|
22
|
+
recurrent_weight_regularizer: Utils.hash_to_obj(hash[:recurrent_weight_regularizer]),
|
23
|
+
bias_regularizer: Utils.hash_to_obj(hash[:bias_regularizer]),
|
24
|
+
use_bias: hash[:use_bias])
|
25
|
+
end
|
26
|
+
|
27
|
+
# @param [Integer] num_nodes Number of nodes.
|
28
|
+
# @param [Boolean] stateful Maintain state between batches.
|
29
|
+
# @param [Boolean] return_sequences Set the false, only the last of each cell of RNN is left.
|
30
|
+
# @param [DNN::Initializers::Initializer] recurrent_weight_initializer Recurrent weight initializer.
|
31
|
+
# @param [DNN::Regularizers::Regularizer | NilClass] recurrent_weight_regularizer Recurrent weight regularizer.
|
19
32
|
def initialize(num_nodes,
|
20
33
|
stateful: false,
|
21
34
|
return_sequences: true,
|
22
|
-
weight_initializer: RandomNormal.new,
|
23
|
-
recurrent_weight_initializer: RandomNormal.new,
|
24
|
-
bias_initializer: Zeros.new,
|
35
|
+
weight_initializer: Initializers::RandomNormal.new,
|
36
|
+
recurrent_weight_initializer: Initializers::RandomNormal.new,
|
37
|
+
bias_initializer: Initializers::Zeros.new,
|
25
38
|
weight_regularizer: nil,
|
26
39
|
recurrent_weight_regularizer: nil,
|
27
40
|
bias_regularizer: nil,
|
@@ -32,8 +45,8 @@ module DNN
|
|
32
45
|
@stateful = stateful
|
33
46
|
@return_sequences = return_sequences
|
34
47
|
@layers = []
|
35
|
-
@hidden =
|
36
|
-
@
|
48
|
+
@hidden = Param.new
|
49
|
+
@recurrent_weight = Param.new(nil, 0)
|
37
50
|
@recurrent_weight_initializer = recurrent_weight_initializer
|
38
51
|
@recurrent_weight_regularizer = recurrent_weight_regularizer
|
39
52
|
end
|
@@ -49,7 +62,7 @@ module DNN
|
|
49
62
|
def forward(xs)
|
50
63
|
@xs_shape = xs.shape
|
51
64
|
hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
|
52
|
-
h =
|
65
|
+
h = @stateful && @hidden.data ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
|
53
66
|
xs.shape[1].times do |t|
|
54
67
|
x = xs[true, t, false]
|
55
68
|
@layers[t].trainable = @trainable
|
@@ -92,6 +105,10 @@ module DNN
|
|
92
105
|
super(hash)
|
93
106
|
end
|
94
107
|
|
108
|
+
def get_params
|
109
|
+
{ weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden }
|
110
|
+
end
|
111
|
+
|
95
112
|
# Reset the state of RNN.
|
96
113
|
def reset_state
|
97
114
|
@hidden.data = @hidden.data.fill(0) if @hidden.data
|
@@ -113,7 +130,7 @@ module DNN
|
|
113
130
|
end
|
114
131
|
|
115
132
|
|
116
|
-
class
|
133
|
+
class SimpleRNNDense
|
117
134
|
attr_accessor :trainable
|
118
135
|
|
119
136
|
def initialize(weight, recurrent_weight, bias, activation)
|
@@ -147,32 +164,30 @@ module DNN
|
|
147
164
|
|
148
165
|
|
149
166
|
class SimpleRNN < RNN
|
150
|
-
include Activations
|
151
|
-
|
152
167
|
attr_reader :activation
|
153
|
-
|
168
|
+
|
154
169
|
def self.from_hash(hash)
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
simple_rnn
|
170
|
+
self.new(hash[:num_nodes],
|
171
|
+
stateful: hash[:stateful],
|
172
|
+
return_sequences: hash[:return_sequences],
|
173
|
+
activation: Utils.hash_to_obj(hash[:activation]),
|
174
|
+
weight_initializer: Utils.hash_to_obj(hash[:weight_initializer]),
|
175
|
+
recurrent_weight_initializer: Utils.hash_to_obj(hash[:recurrent_weight_initializer]),
|
176
|
+
bias_initializer: Utils.hash_to_obj(hash[:bias_initializer]),
|
177
|
+
weight_regularizer: Utils.hash_to_obj(hash[:weight_regularizer]),
|
178
|
+
recurrent_weight_regularizer: Utils.hash_to_obj(hash[:recurrent_weight_regularizer]),
|
179
|
+
bias_regularizer: Utils.hash_to_obj(hash[:bias_regularizer]),
|
180
|
+
use_bias: hash[:use_bias])
|
167
181
|
end
|
168
182
|
|
183
|
+
# @param [DNN::Layers::Layer] activation Activation function to use in a recurrent network.
|
169
184
|
def initialize(num_nodes,
|
170
185
|
stateful: false,
|
171
186
|
return_sequences: true,
|
172
|
-
activation: Tanh.new,
|
173
|
-
weight_initializer: RandomNormal.new,
|
174
|
-
recurrent_weight_initializer: RandomNormal.new,
|
175
|
-
bias_initializer: Zeros.new,
|
187
|
+
activation: Activations::Tanh.new,
|
188
|
+
weight_initializer: Initializers::RandomNormal.new,
|
189
|
+
recurrent_weight_initializer: Initializers::RandomNormal.new,
|
190
|
+
bias_initializer: Initializers::Zeros.new,
|
176
191
|
weight_regularizer: nil,
|
177
192
|
recurrent_weight_regularizer: nil,
|
178
193
|
bias_regularizer: nil,
|
@@ -197,29 +212,29 @@ module DNN
|
|
197
212
|
@recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
|
198
213
|
@bias.data = Xumo::SFloat.new(@num_nodes) if @bias
|
199
214
|
init_weight_and_bias
|
200
|
-
@time_length.times do
|
201
|
-
@layers <<
|
215
|
+
@time_length.times do
|
216
|
+
@layers << SimpleRNNDense.new(@weight, @recurrent_weight, @bias, @activation)
|
202
217
|
end
|
203
218
|
end
|
204
219
|
|
205
220
|
def to_hash
|
206
|
-
super(
|
221
|
+
super(activation: @activation.to_hash)
|
207
222
|
end
|
208
223
|
end
|
209
224
|
|
210
225
|
|
211
|
-
class
|
226
|
+
class LSTMDense
|
212
227
|
attr_accessor :trainable
|
213
228
|
|
214
229
|
def initialize(weight, recurrent_weight, bias)
|
215
230
|
@weight = weight
|
216
231
|
@recurrent_weight = recurrent_weight
|
217
232
|
@bias = bias
|
218
|
-
@tanh = Tanh.new
|
219
|
-
@g_tanh = Tanh.new
|
220
|
-
@forget_sigmoid = Sigmoid.new
|
221
|
-
@in_sigmoid = Sigmoid.new
|
222
|
-
@out_sigmoid = Sigmoid.new
|
233
|
+
@tanh = Activations::Tanh.new
|
234
|
+
@g_tanh = Activations::Tanh.new
|
235
|
+
@forget_sigmoid = Activations::Sigmoid.new
|
236
|
+
@in_sigmoid = Activations::Sigmoid.new
|
237
|
+
@out_sigmoid = Activations::Sigmoid.new
|
223
238
|
@trainable = true
|
224
239
|
end
|
225
240
|
|
@@ -267,32 +282,20 @@ module DNN
|
|
267
282
|
|
268
283
|
|
269
284
|
class LSTM < RNN
|
270
|
-
|
271
|
-
lstm = self.new(hash[:num_nodes],
|
272
|
-
stateful: hash[:stateful],
|
273
|
-
return_sequences: hash[:return_sequences],
|
274
|
-
weight_initializer: Utils.from_hash(hash[:weight_initializer]),
|
275
|
-
recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
|
276
|
-
bias_initializer: Utils.from_hash(hash[:bias_initializer]),
|
277
|
-
weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
|
278
|
-
recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
|
279
|
-
bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
|
280
|
-
use_bias: hash[:use_bias])
|
281
|
-
lstm
|
282
|
-
end
|
285
|
+
attr_reader :cell
|
283
286
|
|
284
287
|
def initialize(num_nodes,
|
285
288
|
stateful: false,
|
286
289
|
return_sequences: true,
|
287
|
-
weight_initializer: RandomNormal.new,
|
288
|
-
recurrent_weight_initializer: RandomNormal.new,
|
289
|
-
bias_initializer: Zeros.new,
|
290
|
+
weight_initializer: Initializers::RandomNormal.new,
|
291
|
+
recurrent_weight_initializer: Initializers::RandomNormal.new,
|
292
|
+
bias_initializer: Initializers::Zeros.new,
|
290
293
|
weight_regularizer: nil,
|
291
294
|
recurrent_weight_regularizer: nil,
|
292
295
|
bias_regularizer: nil,
|
293
296
|
use_bias: true)
|
294
297
|
super
|
295
|
-
@cell =
|
298
|
+
@cell = Param.new
|
296
299
|
end
|
297
300
|
|
298
301
|
def build(input_shape)
|
@@ -302,8 +305,8 @@ module DNN
|
|
302
305
|
@recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
|
303
306
|
@bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
|
304
307
|
init_weight_and_bias
|
305
|
-
@time_length.times do
|
306
|
-
@layers <<
|
308
|
+
@time_length.times do
|
309
|
+
@layers << LSTMDense.new(@weight, @recurrent_weight, @bias)
|
307
310
|
end
|
308
311
|
end
|
309
312
|
|
@@ -350,19 +353,23 @@ module DNN
|
|
350
353
|
super()
|
351
354
|
@cell.data = @cell.data.fill(0) if @cell.data
|
352
355
|
end
|
356
|
+
|
357
|
+
def get_params
|
358
|
+
{ weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden, cell: @cell }
|
359
|
+
end
|
353
360
|
end
|
354
361
|
|
355
362
|
|
356
|
-
class
|
363
|
+
class GRUDense
|
357
364
|
attr_accessor :trainable
|
358
365
|
|
359
366
|
def initialize(weight, recurrent_weight, bias)
|
360
367
|
@weight = weight
|
361
368
|
@recurrent_weight = recurrent_weight
|
362
369
|
@bias = bias
|
363
|
-
@update_sigmoid = Sigmoid.new
|
364
|
-
@reset_sigmoid = Sigmoid.new
|
365
|
-
@tanh = Tanh.new
|
370
|
+
@update_sigmoid = Activations::Sigmoid.new
|
371
|
+
@reset_sigmoid = Activations::Sigmoid.new
|
372
|
+
@tanh = Activations::Tanh.new
|
366
373
|
@trainable = true
|
367
374
|
end
|
368
375
|
|
@@ -423,33 +430,19 @@ module DNN
|
|
423
430
|
|
424
431
|
|
425
432
|
class GRU < RNN
|
426
|
-
def self.from_hash(hash)
|
427
|
-
gru = self.new(hash[:num_nodes],
|
428
|
-
stateful: hash[:stateful],
|
429
|
-
return_sequences: hash[:return_sequences],
|
430
|
-
weight_initializer: Utils.from_hash(hash[:weight_initializer]),
|
431
|
-
recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
|
432
|
-
bias_initializer: Utils.from_hash(hash[:bias_initializer]),
|
433
|
-
weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
|
434
|
-
recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
|
435
|
-
bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
|
436
|
-
use_bias: hash[:use_bias])
|
437
|
-
gru
|
438
|
-
end
|
439
|
-
|
440
433
|
def initialize(num_nodes,
|
441
434
|
stateful: false,
|
442
435
|
return_sequences: true,
|
443
|
-
weight_initializer: RandomNormal.new,
|
444
|
-
recurrent_weight_initializer: RandomNormal.new,
|
445
|
-
bias_initializer: Zeros.new,
|
436
|
+
weight_initializer: Initializers::RandomNormal.new,
|
437
|
+
recurrent_weight_initializer: Initializers::RandomNormal.new,
|
438
|
+
bias_initializer: Initializers::Zeros.new,
|
446
439
|
weight_regularizer: nil,
|
447
440
|
recurrent_weight_regularizer: nil,
|
448
441
|
bias_regularizer: nil,
|
449
442
|
use_bias: true)
|
450
443
|
super
|
451
444
|
end
|
452
|
-
|
445
|
+
|
453
446
|
def build(input_shape)
|
454
447
|
super
|
455
448
|
num_prev_nodes = @input_shape[1]
|
@@ -457,8 +450,8 @@ module DNN
|
|
457
450
|
@recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
|
458
451
|
@bias.data = Xumo::SFloat.new(@num_nodes * 3) if @bias
|
459
452
|
init_weight_and_bias
|
460
|
-
@time_length.times do
|
461
|
-
@layers <<
|
453
|
+
@time_length.times do
|
454
|
+
@layers << GRUDense.new(@weight, @recurrent_weight, @bias)
|
462
455
|
end
|
463
456
|
end
|
464
457
|
end
|
data/lib/dnn/core/utils.rb
CHANGED
@@ -2,6 +2,9 @@ module DNN
|
|
2
2
|
# This module provides utility functions.
|
3
3
|
module Utils
|
4
4
|
# Categorize labels into "num_classes" classes.
|
5
|
+
# @param [Numo::SFloat] y Label data.
|
6
|
+
# @param [Numo::SFloat] num_classes Number of classes to classify.
|
7
|
+
# @param [Class] narray_type Type of Numo::Narray data after classification.
|
5
8
|
def self.to_categorical(y, num_classes, narray_type = nil)
|
6
9
|
narray_type ||= y.class
|
7
10
|
y2 = narray_type.zeros(y.shape[0], num_classes)
|
@@ -12,7 +15,7 @@ module DNN
|
|
12
15
|
end
|
13
16
|
|
14
17
|
# Convert hash to an object.
|
15
|
-
def self.
|
18
|
+
def self.hash_to_obj(hash)
|
16
19
|
return nil if hash == nil
|
17
20
|
dnn_class = DNN.const_get(hash[:class])
|
18
21
|
if dnn_class.respond_to?(:from_hash)
|
@@ -23,12 +26,12 @@ module DNN
|
|
23
26
|
|
24
27
|
# Return the result of the sigmoid function.
|
25
28
|
def self.sigmoid(x)
|
26
|
-
Sigmoid.new.forward(x)
|
29
|
+
Activations::Sigmoid.new.forward(x)
|
27
30
|
end
|
28
31
|
|
29
32
|
# Return the result of the softmax function.
|
30
33
|
def self.softmax(x)
|
31
|
-
SoftmaxCrossEntropy.softmax(x)
|
34
|
+
Losses::SoftmaxCrossEntropy.softmax(x)
|
32
35
|
end
|
33
36
|
end
|
34
37
|
end
|
data/lib/dnn/downloader.rb
CHANGED
@@ -11,8 +11,8 @@ module DNN
|
|
11
11
|
dir_path = "#{__dir__}/downloads"
|
12
12
|
end
|
13
13
|
Downloader.new(url).download(dir_path)
|
14
|
-
rescue =>
|
15
|
-
raise DNN_DownloadError.new(
|
14
|
+
rescue => e
|
15
|
+
raise DNN_DownloadError.new(e.message)
|
16
16
|
end
|
17
17
|
|
18
18
|
def initialize(url)
|
@@ -42,7 +42,7 @@ module DNN
|
|
42
42
|
end
|
43
43
|
puts ""
|
44
44
|
end
|
45
|
-
file_name = @path.match(%r
|
45
|
+
file_name = @path.match(%r`.*/(.+)`)[1]
|
46
46
|
File.binwrite("#{dir_path}/#{file_name}", buf)
|
47
47
|
end
|
48
48
|
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
require "zlib"
|
2
|
+
require_relative "core/error"
|
3
|
+
require_relative "downloader"
|
4
|
+
require_relative "mnist"
|
5
|
+
|
6
|
+
module DNN
|
7
|
+
module FashionMNIST
|
8
|
+
class DNN_MNIST_LoadError < DNN_Error; end
|
9
|
+
|
10
|
+
URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
|
11
|
+
|
12
|
+
TRAIN_IMAGES_FILE_NAME = "train-images-idx3-ubyte.gz"
|
13
|
+
TRAIN_LABELS_FILE_NAME = "train-labels-idx1-ubyte.gz"
|
14
|
+
TEST_IMAGES_FILE_NAME = "t10k-images-idx3-ubyte.gz"
|
15
|
+
TEST_LABELS_FILE_NAME = "t10k-labels-idx1-ubyte.gz"
|
16
|
+
|
17
|
+
URL_TRAIN_IMAGES = URL_BASE + TRAIN_IMAGES_FILE_NAME
|
18
|
+
URL_TRAIN_LABELS = URL_BASE + TRAIN_LABELS_FILE_NAME
|
19
|
+
URL_TEST_IMAGES = URL_BASE + TEST_IMAGES_FILE_NAME
|
20
|
+
URL_TEST_LABELS = URL_BASE + TEST_LABELS_FILE_NAME
|
21
|
+
|
22
|
+
def self.downloads
|
23
|
+
Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
|
24
|
+
Dir.mkdir(mnist_dir) unless Dir.exist?(mnist_dir)
|
25
|
+
Downloader.download(URL_TRAIN_IMAGES, mnist_dir) unless File.exist?(get_file_path(TRAIN_IMAGES_FILE_NAME))
|
26
|
+
Downloader.download(URL_TRAIN_LABELS, mnist_dir) unless File.exist?(get_file_path(TRAIN_LABELS_FILE_NAME))
|
27
|
+
Downloader.download(URL_TEST_IMAGES, mnist_dir) unless File.exist?(get_file_path(TEST_IMAGES_FILE_NAME))
|
28
|
+
Downloader.download(URL_TEST_LABELS, mnist_dir) unless File.exist?(get_file_path(TEST_LABELS_FILE_NAME))
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.load_train
|
32
|
+
downloads
|
33
|
+
train_images_file_path = get_file_path(TRAIN_IMAGES_FILE_NAME)
|
34
|
+
train_labels_file_path = get_file_path(TRAIN_LABELS_FILE_NAME)
|
35
|
+
unless File.exist?(train_images_file_path)
|
36
|
+
raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_path}" is not found.`)
|
37
|
+
end
|
38
|
+
unless File.exist?(train_labels_file_path)
|
39
|
+
raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_path}" is not found.`)
|
40
|
+
end
|
41
|
+
images = load_images(train_images_file_path)
|
42
|
+
labels = load_labels(train_labels_file_path)
|
43
|
+
[images, labels]
|
44
|
+
end
|
45
|
+
|
46
|
+
def self.load_test
|
47
|
+
downloads
|
48
|
+
test_images_file_path = get_file_path(TEST_IMAGES_FILE_NAME)
|
49
|
+
test_labels_file_path = get_file_path(TEST_LABELS_FILE_NAME)
|
50
|
+
unless File.exist?(test_images_file_path)
|
51
|
+
raise DNN_MNIST_LoadError.new(%`file "#{test_images_file_path}" is not found.`)
|
52
|
+
end
|
53
|
+
unless File.exist?(test_labels_file_path)
|
54
|
+
raise DNN_MNIST_LoadError.new(%`file "#{test_labels_file_path}" is not found.`)
|
55
|
+
end
|
56
|
+
images = load_images(test_images_file_path)
|
57
|
+
labels = load_labels(test_labels_file_path)
|
58
|
+
[images, labels]
|
59
|
+
end
|
60
|
+
|
61
|
+
private_class_method def self.load_images(file_name)
|
62
|
+
images = nil
|
63
|
+
Zlib::GzipReader.open(file_name) do |f|
|
64
|
+
magic, num_images = f.read(8).unpack("N2")
|
65
|
+
rows, cols = f.read(8).unpack("N2")
|
66
|
+
images = Numo::UInt8.from_binary(f.read)
|
67
|
+
images = images.reshape(num_images, cols, rows, 1)
|
68
|
+
end
|
69
|
+
images
|
70
|
+
end
|
71
|
+
|
72
|
+
private_class_method def self.load_labels(file_name)
|
73
|
+
labels = nil
|
74
|
+
Zlib::GzipReader.open(file_name) do |f|
|
75
|
+
magic, num_labels = f.read(8).unpack("N2")
|
76
|
+
labels = Numo::UInt8.from_binary(f.read)
|
77
|
+
end
|
78
|
+
labels
|
79
|
+
end
|
80
|
+
|
81
|
+
private_class_method def self.mnist_dir
|
82
|
+
"#{__dir__}/downloads/fashion-mnist"
|
83
|
+
end
|
84
|
+
|
85
|
+
private_class_method def self.get_file_path(file_name)
|
86
|
+
mnist_dir + "/" + file_name
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|