ruby-dnn 0.10.4 → 0.12.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/.travis.yml +1 -2
  3. data/README.md +33 -6
  4. data/examples/cifar100_example.rb +3 -3
  5. data/examples/cifar10_example.rb +3 -3
  6. data/examples/dcgan/dcgan.rb +112 -0
  7. data/examples/dcgan/imgen.rb +20 -0
  8. data/examples/dcgan/train.rb +41 -0
  9. data/examples/iris_example.rb +3 -6
  10. data/examples/mnist_conv2d_example.rb +5 -5
  11. data/examples/mnist_define_by_run.rb +52 -0
  12. data/examples/mnist_example.rb +3 -3
  13. data/examples/mnist_lstm_example.rb +3 -3
  14. data/examples/xor_example.rb +4 -5
  15. data/ext/rb_stb_image/rb_stb_image.c +103 -0
  16. data/lib/dnn.rb +10 -10
  17. data/lib/dnn/cifar10.rb +1 -1
  18. data/lib/dnn/cifar100.rb +1 -1
  19. data/lib/dnn/core/activations.rb +21 -22
  20. data/lib/dnn/core/cnn_layers.rb +94 -111
  21. data/lib/dnn/core/embedding.rb +30 -9
  22. data/lib/dnn/core/initializers.rb +31 -21
  23. data/lib/dnn/core/iterator.rb +52 -0
  24. data/lib/dnn/core/layers.rb +99 -66
  25. data/lib/dnn/core/link.rb +24 -0
  26. data/lib/dnn/core/losses.rb +69 -59
  27. data/lib/dnn/core/merge_layers.rb +71 -0
  28. data/lib/dnn/core/models.rb +393 -0
  29. data/lib/dnn/core/normalizations.rb +27 -14
  30. data/lib/dnn/core/optimizers.rb +212 -134
  31. data/lib/dnn/core/param.rb +8 -6
  32. data/lib/dnn/core/regularizers.rb +10 -7
  33. data/lib/dnn/core/rnn_layers.rb +78 -85
  34. data/lib/dnn/core/utils.rb +6 -3
  35. data/lib/dnn/downloader.rb +3 -3
  36. data/lib/dnn/fashion-mnist.rb +89 -0
  37. data/lib/dnn/image.rb +57 -18
  38. data/lib/dnn/iris.rb +1 -3
  39. data/lib/dnn/mnist.rb +38 -34
  40. data/lib/dnn/version.rb +1 -1
  41. data/third_party/stb_image.h +16 -4
  42. data/third_party/stb_image_resize.h +2630 -0
  43. data/third_party/stb_image_write.h +4 -7
  44. metadata +12 -4
  45. data/lib/dnn/core/dataset.rb +0 -34
  46. data/lib/dnn/core/model.rb +0 -440
@@ -1,9 +1,11 @@
1
- class DNN::Param
2
- attr_accessor :data
3
- attr_accessor :grad
1
+ module DNN
2
+ class Param
3
+ attr_accessor :data
4
+ attr_accessor :grad
4
5
 
5
- def initialize(data = nil, grad = nil)
6
- @data = data
7
- @grad = grad
6
+ def initialize(data = nil, grad = nil)
7
+ @data = data
8
+ @grad = grad
9
+ end
8
10
  end
9
11
  end
@@ -13,7 +13,7 @@ module DNN
13
13
  end
14
14
 
15
15
  def to_hash(merge_hash)
16
- hash = {class: self.class.name}
16
+ hash = { class: self.class.name }
17
17
  hash.merge!(merge_hash)
18
18
  hash
19
19
  end
@@ -23,9 +23,10 @@ module DNN
23
23
  attr_accessor :l1_lambda
24
24
 
25
25
  def self.from_hash(hash)
26
- L1.new(hash[:l1_lambda])
26
+ self.new(hash[:l1_lambda])
27
27
  end
28
28
 
29
+ # @param [Float] l1_lambda L1 regularizer coefficient.
29
30
  def initialize(l1_lambda = 0.01)
30
31
  @l1_lambda = l1_lambda
31
32
  end
@@ -50,15 +51,16 @@ module DNN
50
51
  attr_accessor :l2_lambda
51
52
 
52
53
  def self.from_hash(hash)
53
- L2.new(hash[:l2_lambda])
54
+ self.new(hash[:l2_lambda])
54
55
  end
55
56
 
57
+ # @param [Float] l2_lambda L2 regularizer coefficient.
56
58
  def initialize(l2_lambda = 0.01)
57
59
  @l2_lambda = l2_lambda
58
60
  end
59
61
 
60
62
  def forward(x)
61
- x + 0.5 * @l2_lambda * (@param.data**2).sum
63
+ x + 0.5 * @l2_lambda * (@param.data ** 2).sum
62
64
  end
63
65
 
64
66
  def backward
@@ -75,9 +77,11 @@ module DNN
75
77
  attr_accessor :l2_lambda
76
78
 
77
79
  def self.from_hash(hash)
78
- L1L2.new(hash[:l1_lambda], hash[:l2_lambda])
80
+ self.new(hash[:l1_lambda], hash[:l2_lambda])
79
81
  end
80
82
 
83
+ # @param [Float] l1_lambda L1 regularizer coefficient.
84
+ # @param [Float] l2_lambda L2 regularizer coefficient.
81
85
  def initialize(l1_lambda = 0.01, l2_lambda = 0.01)
82
86
  @l1_lambda = l1_lambda
83
87
  @l2_lambda = l2_lambda
@@ -85,7 +89,7 @@ module DNN
85
89
 
86
90
  def forward(x)
87
91
  l1 = @l1_lambda * @param.data.abs.sum
88
- l2 = 0.5 * @l2_lambda * (@param.data**2).sum
92
+ l2 = 0.5 * @l2_lambda * (@param.data ** 2).sum
89
93
  x + l1 + l2
90
94
  end
91
95
 
@@ -99,7 +103,6 @@ module DNN
99
103
  def to_hash
100
104
  super(l1_lambda: l1_lambda, l2_lambda: l2_lambda)
101
105
  end
102
-
103
106
  end
104
107
 
105
108
  end
@@ -3,25 +3,38 @@ module DNN
3
3
 
4
4
  # Super class of all RNN classes.
5
5
  class RNN < Connection
6
- include Initializers
7
-
8
- # @return [Integer] number of nodes.
9
6
  attr_reader :num_nodes
10
- # @return [Bool] Maintain state between batches.
7
+ attr_reader :recurrent_weight
8
+ attr_reader :hidden
11
9
  attr_reader :stateful
12
- # @return [Bool] Set the false, only the last of each cell of RNN is left.
13
10
  attr_reader :return_sequences
14
- # @return [DNN::Initializers::Initializer] Recurrent weight initializer.
15
11
  attr_reader :recurrent_weight_initializer
16
- # @return [DNN::Regularizers::Regularizer] Recurrent weight regularization.
17
12
  attr_reader :recurrent_weight_regularizer
18
13
 
14
+ def self.from_hash(hash)
15
+ self.new(hash[:num_nodes],
16
+ stateful: hash[:stateful],
17
+ return_sequences: hash[:return_sequences],
18
+ weight_initializer: Utils.hash_to_obj(hash[:weight_initializer]),
19
+ recurrent_weight_initializer: Utils.hash_to_obj(hash[:recurrent_weight_initializer]),
20
+ bias_initializer: Utils.hash_to_obj(hash[:bias_initializer]),
21
+ weight_regularizer: Utils.hash_to_obj(hash[:weight_regularizer]),
22
+ recurrent_weight_regularizer: Utils.hash_to_obj(hash[:recurrent_weight_regularizer]),
23
+ bias_regularizer: Utils.hash_to_obj(hash[:bias_regularizer]),
24
+ use_bias: hash[:use_bias])
25
+ end
26
+
27
+ # @param [Integer] num_nodes Number of nodes.
28
+ # @param [Boolean] stateful Maintain state between batches.
29
+ # @param [Boolean] return_sequences Set the false, only the last of each cell of RNN is left.
30
+ # @param [DNN::Initializers::Initializer] recurrent_weight_initializer Recurrent weight initializer.
31
+ # @param [DNN::Regularizers::Regularizer | NilClass] recurrent_weight_regularizer Recurrent weight regularizer.
19
32
  def initialize(num_nodes,
20
33
  stateful: false,
21
34
  return_sequences: true,
22
- weight_initializer: RandomNormal.new,
23
- recurrent_weight_initializer: RandomNormal.new,
24
- bias_initializer: Zeros.new,
35
+ weight_initializer: Initializers::RandomNormal.new,
36
+ recurrent_weight_initializer: Initializers::RandomNormal.new,
37
+ bias_initializer: Initializers::Zeros.new,
25
38
  weight_regularizer: nil,
26
39
  recurrent_weight_regularizer: nil,
27
40
  bias_regularizer: nil,
@@ -32,8 +45,8 @@ module DNN
32
45
  @stateful = stateful
33
46
  @return_sequences = return_sequences
34
47
  @layers = []
35
- @hidden = @params[:hidden] = Param.new
36
- @params[:recurrent_weight] = @recurrent_weight = Param.new(nil, 0)
48
+ @hidden = Param.new
49
+ @recurrent_weight = Param.new(nil, 0)
37
50
  @recurrent_weight_initializer = recurrent_weight_initializer
38
51
  @recurrent_weight_regularizer = recurrent_weight_regularizer
39
52
  end
@@ -49,7 +62,7 @@ module DNN
49
62
  def forward(xs)
50
63
  @xs_shape = xs.shape
51
64
  hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
52
- h = (@stateful && @hidden.data) ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
65
+ h = @stateful && @hidden.data ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
53
66
  xs.shape[1].times do |t|
54
67
  x = xs[true, t, false]
55
68
  @layers[t].trainable = @trainable
@@ -92,6 +105,10 @@ module DNN
92
105
  super(hash)
93
106
  end
94
107
 
108
+ def get_params
109
+ { weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden }
110
+ end
111
+
95
112
  # Reset the state of RNN.
96
113
  def reset_state
97
114
  @hidden.data = @hidden.data.fill(0) if @hidden.data
@@ -113,7 +130,7 @@ module DNN
113
130
  end
114
131
 
115
132
 
116
- class SimpleRNN_Dense
133
+ class SimpleRNNDense
117
134
  attr_accessor :trainable
118
135
 
119
136
  def initialize(weight, recurrent_weight, bias, activation)
@@ -147,32 +164,30 @@ module DNN
147
164
 
148
165
 
149
166
  class SimpleRNN < RNN
150
- include Activations
151
-
152
167
  attr_reader :activation
153
-
168
+
154
169
  def self.from_hash(hash)
155
- simple_rnn = self.new(hash[:num_nodes],
156
- stateful: hash[:stateful],
157
- return_sequences: hash[:return_sequences],
158
- activation: Utils.from_hash(hash[:activation]),
159
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
160
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
161
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
162
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
163
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
164
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
165
- use_bias: hash[:use_bias])
166
- simple_rnn
170
+ self.new(hash[:num_nodes],
171
+ stateful: hash[:stateful],
172
+ return_sequences: hash[:return_sequences],
173
+ activation: Utils.hash_to_obj(hash[:activation]),
174
+ weight_initializer: Utils.hash_to_obj(hash[:weight_initializer]),
175
+ recurrent_weight_initializer: Utils.hash_to_obj(hash[:recurrent_weight_initializer]),
176
+ bias_initializer: Utils.hash_to_obj(hash[:bias_initializer]),
177
+ weight_regularizer: Utils.hash_to_obj(hash[:weight_regularizer]),
178
+ recurrent_weight_regularizer: Utils.hash_to_obj(hash[:recurrent_weight_regularizer]),
179
+ bias_regularizer: Utils.hash_to_obj(hash[:bias_regularizer]),
180
+ use_bias: hash[:use_bias])
167
181
  end
168
182
 
183
+ # @param [DNN::Layers::Layer] activation Activation function to use in a recurrent network.
169
184
  def initialize(num_nodes,
170
185
  stateful: false,
171
186
  return_sequences: true,
172
- activation: Tanh.new,
173
- weight_initializer: RandomNormal.new,
174
- recurrent_weight_initializer: RandomNormal.new,
175
- bias_initializer: Zeros.new,
187
+ activation: Activations::Tanh.new,
188
+ weight_initializer: Initializers::RandomNormal.new,
189
+ recurrent_weight_initializer: Initializers::RandomNormal.new,
190
+ bias_initializer: Initializers::Zeros.new,
176
191
  weight_regularizer: nil,
177
192
  recurrent_weight_regularizer: nil,
178
193
  bias_regularizer: nil,
@@ -197,29 +212,29 @@ module DNN
197
212
  @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
198
213
  @bias.data = Xumo::SFloat.new(@num_nodes) if @bias
199
214
  init_weight_and_bias
200
- @time_length.times do |t|
201
- @layers << SimpleRNN_Dense.new(@weight, @recurrent_weight, @bias, @activation)
215
+ @time_length.times do
216
+ @layers << SimpleRNNDense.new(@weight, @recurrent_weight, @bias, @activation)
202
217
  end
203
218
  end
204
219
 
205
220
  def to_hash
206
- super({activation: @activation.to_hash})
221
+ super(activation: @activation.to_hash)
207
222
  end
208
223
  end
209
224
 
210
225
 
211
- class LSTM_Dense
226
+ class LSTMDense
212
227
  attr_accessor :trainable
213
228
 
214
229
  def initialize(weight, recurrent_weight, bias)
215
230
  @weight = weight
216
231
  @recurrent_weight = recurrent_weight
217
232
  @bias = bias
218
- @tanh = Tanh.new
219
- @g_tanh = Tanh.new
220
- @forget_sigmoid = Sigmoid.new
221
- @in_sigmoid = Sigmoid.new
222
- @out_sigmoid = Sigmoid.new
233
+ @tanh = Activations::Tanh.new
234
+ @g_tanh = Activations::Tanh.new
235
+ @forget_sigmoid = Activations::Sigmoid.new
236
+ @in_sigmoid = Activations::Sigmoid.new
237
+ @out_sigmoid = Activations::Sigmoid.new
223
238
  @trainable = true
224
239
  end
225
240
 
@@ -267,32 +282,20 @@ module DNN
267
282
 
268
283
 
269
284
  class LSTM < RNN
270
- def self.from_hash(hash)
271
- lstm = self.new(hash[:num_nodes],
272
- stateful: hash[:stateful],
273
- return_sequences: hash[:return_sequences],
274
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
275
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
276
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
277
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
278
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
279
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
280
- use_bias: hash[:use_bias])
281
- lstm
282
- end
285
+ attr_reader :cell
283
286
 
284
287
  def initialize(num_nodes,
285
288
  stateful: false,
286
289
  return_sequences: true,
287
- weight_initializer: RandomNormal.new,
288
- recurrent_weight_initializer: RandomNormal.new,
289
- bias_initializer: Zeros.new,
290
+ weight_initializer: Initializers::RandomNormal.new,
291
+ recurrent_weight_initializer: Initializers::RandomNormal.new,
292
+ bias_initializer: Initializers::Zeros.new,
290
293
  weight_regularizer: nil,
291
294
  recurrent_weight_regularizer: nil,
292
295
  bias_regularizer: nil,
293
296
  use_bias: true)
294
297
  super
295
- @cell = @params[:cell] = Param.new
298
+ @cell = Param.new
296
299
  end
297
300
 
298
301
  def build(input_shape)
@@ -302,8 +305,8 @@ module DNN
302
305
  @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
303
306
  @bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
304
307
  init_weight_and_bias
305
- @time_length.times do |t|
306
- @layers << LSTM_Dense.new(@weight, @recurrent_weight, @bias)
308
+ @time_length.times do
309
+ @layers << LSTMDense.new(@weight, @recurrent_weight, @bias)
307
310
  end
308
311
  end
309
312
 
@@ -350,19 +353,23 @@ module DNN
350
353
  super()
351
354
  @cell.data = @cell.data.fill(0) if @cell.data
352
355
  end
356
+
357
+ def get_params
358
+ { weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden, cell: @cell }
359
+ end
353
360
  end
354
361
 
355
362
 
356
- class GRU_Dense
363
+ class GRUDense
357
364
  attr_accessor :trainable
358
365
 
359
366
  def initialize(weight, recurrent_weight, bias)
360
367
  @weight = weight
361
368
  @recurrent_weight = recurrent_weight
362
369
  @bias = bias
363
- @update_sigmoid = Sigmoid.new
364
- @reset_sigmoid = Sigmoid.new
365
- @tanh = Tanh.new
370
+ @update_sigmoid = Activations::Sigmoid.new
371
+ @reset_sigmoid = Activations::Sigmoid.new
372
+ @tanh = Activations::Tanh.new
366
373
  @trainable = true
367
374
  end
368
375
 
@@ -423,33 +430,19 @@ module DNN
423
430
 
424
431
 
425
432
  class GRU < RNN
426
- def self.from_hash(hash)
427
- gru = self.new(hash[:num_nodes],
428
- stateful: hash[:stateful],
429
- return_sequences: hash[:return_sequences],
430
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
431
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
432
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
433
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
434
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
435
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
436
- use_bias: hash[:use_bias])
437
- gru
438
- end
439
-
440
433
  def initialize(num_nodes,
441
434
  stateful: false,
442
435
  return_sequences: true,
443
- weight_initializer: RandomNormal.new,
444
- recurrent_weight_initializer: RandomNormal.new,
445
- bias_initializer: Zeros.new,
436
+ weight_initializer: Initializers::RandomNormal.new,
437
+ recurrent_weight_initializer: Initializers::RandomNormal.new,
438
+ bias_initializer: Initializers::Zeros.new,
446
439
  weight_regularizer: nil,
447
440
  recurrent_weight_regularizer: nil,
448
441
  bias_regularizer: nil,
449
442
  use_bias: true)
450
443
  super
451
444
  end
452
-
445
+
453
446
  def build(input_shape)
454
447
  super
455
448
  num_prev_nodes = @input_shape[1]
@@ -457,8 +450,8 @@ module DNN
457
450
  @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
458
451
  @bias.data = Xumo::SFloat.new(@num_nodes * 3) if @bias
459
452
  init_weight_and_bias
460
- @time_length.times do |t|
461
- @layers << GRU_Dense.new(@weight, @recurrent_weight, @bias)
453
+ @time_length.times do
454
+ @layers << GRUDense.new(@weight, @recurrent_weight, @bias)
462
455
  end
463
456
  end
464
457
  end
@@ -2,6 +2,9 @@ module DNN
2
2
  # This module provides utility functions.
3
3
  module Utils
4
4
  # Categorize labels into "num_classes" classes.
5
+ # @param [Numo::SFloat] y Label data.
6
+ # @param [Numo::SFloat] num_classes Number of classes to classify.
7
+ # @param [Class] narray_type Type of Numo::Narray data after classification.
5
8
  def self.to_categorical(y, num_classes, narray_type = nil)
6
9
  narray_type ||= y.class
7
10
  y2 = narray_type.zeros(y.shape[0], num_classes)
@@ -12,7 +15,7 @@ module DNN
12
15
  end
13
16
 
14
17
  # Convert hash to an object.
15
- def self.from_hash(hash)
18
+ def self.hash_to_obj(hash)
16
19
  return nil if hash == nil
17
20
  dnn_class = DNN.const_get(hash[:class])
18
21
  if dnn_class.respond_to?(:from_hash)
@@ -23,12 +26,12 @@ module DNN
23
26
 
24
27
  # Return the result of the sigmoid function.
25
28
  def self.sigmoid(x)
26
- Sigmoid.new.forward(x)
29
+ Activations::Sigmoid.new.forward(x)
27
30
  end
28
31
 
29
32
  # Return the result of the softmax function.
30
33
  def self.softmax(x)
31
- SoftmaxCrossEntropy.softmax(x)
34
+ Losses::SoftmaxCrossEntropy.softmax(x)
32
35
  end
33
36
  end
34
37
  end
@@ -11,8 +11,8 @@ module DNN
11
11
  dir_path = "#{__dir__}/downloads"
12
12
  end
13
13
  Downloader.new(url).download(dir_path)
14
- rescue => ex
15
- raise DNN_DownloadError.new(ex.message)
14
+ rescue => e
15
+ raise DNN_DownloadError.new(e.message)
16
16
  end
17
17
 
18
18
  def initialize(url)
@@ -42,7 +42,7 @@ module DNN
42
42
  end
43
43
  puts ""
44
44
  end
45
- file_name = @path.match(%r`.+/(.+)`)[1]
45
+ file_name = @path.match(%r`.*/(.+)`)[1]
46
46
  File.binwrite("#{dir_path}/#{file_name}", buf)
47
47
  end
48
48
  end
@@ -0,0 +1,89 @@
1
+ require "zlib"
2
+ require_relative "core/error"
3
+ require_relative "downloader"
4
+ require_relative "mnist"
5
+
6
+ module DNN
7
+ module FashionMNIST
8
+ class DNN_MNIST_LoadError < DNN_Error; end
9
+
10
+ URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
11
+
12
+ TRAIN_IMAGES_FILE_NAME = "train-images-idx3-ubyte.gz"
13
+ TRAIN_LABELS_FILE_NAME = "train-labels-idx1-ubyte.gz"
14
+ TEST_IMAGES_FILE_NAME = "t10k-images-idx3-ubyte.gz"
15
+ TEST_LABELS_FILE_NAME = "t10k-labels-idx1-ubyte.gz"
16
+
17
+ URL_TRAIN_IMAGES = URL_BASE + TRAIN_IMAGES_FILE_NAME
18
+ URL_TRAIN_LABELS = URL_BASE + TRAIN_LABELS_FILE_NAME
19
+ URL_TEST_IMAGES = URL_BASE + TEST_IMAGES_FILE_NAME
20
+ URL_TEST_LABELS = URL_BASE + TEST_LABELS_FILE_NAME
21
+
22
+ def self.downloads
23
+ Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
24
+ Dir.mkdir(mnist_dir) unless Dir.exist?(mnist_dir)
25
+ Downloader.download(URL_TRAIN_IMAGES, mnist_dir) unless File.exist?(get_file_path(TRAIN_IMAGES_FILE_NAME))
26
+ Downloader.download(URL_TRAIN_LABELS, mnist_dir) unless File.exist?(get_file_path(TRAIN_LABELS_FILE_NAME))
27
+ Downloader.download(URL_TEST_IMAGES, mnist_dir) unless File.exist?(get_file_path(TEST_IMAGES_FILE_NAME))
28
+ Downloader.download(URL_TEST_LABELS, mnist_dir) unless File.exist?(get_file_path(TEST_LABELS_FILE_NAME))
29
+ end
30
+
31
+ def self.load_train
32
+ downloads
33
+ train_images_file_path = get_file_path(TRAIN_IMAGES_FILE_NAME)
34
+ train_labels_file_path = get_file_path(TRAIN_LABELS_FILE_NAME)
35
+ unless File.exist?(train_images_file_path)
36
+ raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_path}" is not found.`)
37
+ end
38
+ unless File.exist?(train_labels_file_path)
39
+ raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_path}" is not found.`)
40
+ end
41
+ images = load_images(train_images_file_path)
42
+ labels = load_labels(train_labels_file_path)
43
+ [images, labels]
44
+ end
45
+
46
+ def self.load_test
47
+ downloads
48
+ test_images_file_path = get_file_path(TEST_IMAGES_FILE_NAME)
49
+ test_labels_file_path = get_file_path(TEST_LABELS_FILE_NAME)
50
+ unless File.exist?(test_images_file_path)
51
+ raise DNN_MNIST_LoadError.new(%`file "#{test_images_file_path}" is not found.`)
52
+ end
53
+ unless File.exist?(test_labels_file_path)
54
+ raise DNN_MNIST_LoadError.new(%`file "#{test_labels_file_path}" is not found.`)
55
+ end
56
+ images = load_images(test_images_file_path)
57
+ labels = load_labels(test_labels_file_path)
58
+ [images, labels]
59
+ end
60
+
61
+ private_class_method def self.load_images(file_name)
62
+ images = nil
63
+ Zlib::GzipReader.open(file_name) do |f|
64
+ magic, num_images = f.read(8).unpack("N2")
65
+ rows, cols = f.read(8).unpack("N2")
66
+ images = Numo::UInt8.from_binary(f.read)
67
+ images = images.reshape(num_images, cols, rows, 1)
68
+ end
69
+ images
70
+ end
71
+
72
+ private_class_method def self.load_labels(file_name)
73
+ labels = nil
74
+ Zlib::GzipReader.open(file_name) do |f|
75
+ magic, num_labels = f.read(8).unpack("N2")
76
+ labels = Numo::UInt8.from_binary(f.read)
77
+ end
78
+ labels
79
+ end
80
+
81
+ private_class_method def self.mnist_dir
82
+ "#{__dir__}/downloads/fashion-mnist"
83
+ end
84
+
85
+ private_class_method def self.get_file_path(file_name)
86
+ mnist_dir + "/" + file_name
87
+ end
88
+ end
89
+ end