ruby-dnn 0.10.4 → 0.12.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/.travis.yml +1 -2
  3. data/README.md +33 -6
  4. data/examples/cifar100_example.rb +3 -3
  5. data/examples/cifar10_example.rb +3 -3
  6. data/examples/dcgan/dcgan.rb +112 -0
  7. data/examples/dcgan/imgen.rb +20 -0
  8. data/examples/dcgan/train.rb +41 -0
  9. data/examples/iris_example.rb +3 -6
  10. data/examples/mnist_conv2d_example.rb +5 -5
  11. data/examples/mnist_define_by_run.rb +52 -0
  12. data/examples/mnist_example.rb +3 -3
  13. data/examples/mnist_lstm_example.rb +3 -3
  14. data/examples/xor_example.rb +4 -5
  15. data/ext/rb_stb_image/rb_stb_image.c +103 -0
  16. data/lib/dnn.rb +10 -10
  17. data/lib/dnn/cifar10.rb +1 -1
  18. data/lib/dnn/cifar100.rb +1 -1
  19. data/lib/dnn/core/activations.rb +21 -22
  20. data/lib/dnn/core/cnn_layers.rb +94 -111
  21. data/lib/dnn/core/embedding.rb +30 -9
  22. data/lib/dnn/core/initializers.rb +31 -21
  23. data/lib/dnn/core/iterator.rb +52 -0
  24. data/lib/dnn/core/layers.rb +99 -66
  25. data/lib/dnn/core/link.rb +24 -0
  26. data/lib/dnn/core/losses.rb +69 -59
  27. data/lib/dnn/core/merge_layers.rb +71 -0
  28. data/lib/dnn/core/models.rb +393 -0
  29. data/lib/dnn/core/normalizations.rb +27 -14
  30. data/lib/dnn/core/optimizers.rb +212 -134
  31. data/lib/dnn/core/param.rb +8 -6
  32. data/lib/dnn/core/regularizers.rb +10 -7
  33. data/lib/dnn/core/rnn_layers.rb +78 -85
  34. data/lib/dnn/core/utils.rb +6 -3
  35. data/lib/dnn/downloader.rb +3 -3
  36. data/lib/dnn/fashion-mnist.rb +89 -0
  37. data/lib/dnn/image.rb +57 -18
  38. data/lib/dnn/iris.rb +1 -3
  39. data/lib/dnn/mnist.rb +38 -34
  40. data/lib/dnn/version.rb +1 -1
  41. data/third_party/stb_image.h +16 -4
  42. data/third_party/stb_image_resize.h +2630 -0
  43. data/third_party/stb_image_write.h +4 -7
  44. metadata +12 -4
  45. data/lib/dnn/core/dataset.rb +0 -34
  46. data/lib/dnn/core/model.rb +0 -440
@@ -1,9 +1,11 @@
1
- class DNN::Param
2
- attr_accessor :data
3
- attr_accessor :grad
1
+ module DNN
2
+ class Param
3
+ attr_accessor :data
4
+ attr_accessor :grad
4
5
 
5
- def initialize(data = nil, grad = nil)
6
- @data = data
7
- @grad = grad
6
+ def initialize(data = nil, grad = nil)
7
+ @data = data
8
+ @grad = grad
9
+ end
8
10
  end
9
11
  end
@@ -13,7 +13,7 @@ module DNN
13
13
  end
14
14
 
15
15
  def to_hash(merge_hash)
16
- hash = {class: self.class.name}
16
+ hash = { class: self.class.name }
17
17
  hash.merge!(merge_hash)
18
18
  hash
19
19
  end
@@ -23,9 +23,10 @@ module DNN
23
23
  attr_accessor :l1_lambda
24
24
 
25
25
  def self.from_hash(hash)
26
- L1.new(hash[:l1_lambda])
26
+ self.new(hash[:l1_lambda])
27
27
  end
28
28
 
29
+ # @param [Float] l1_lambda L1 regularizer coefficient.
29
30
  def initialize(l1_lambda = 0.01)
30
31
  @l1_lambda = l1_lambda
31
32
  end
@@ -50,15 +51,16 @@ module DNN
50
51
  attr_accessor :l2_lambda
51
52
 
52
53
  def self.from_hash(hash)
53
- L2.new(hash[:l2_lambda])
54
+ self.new(hash[:l2_lambda])
54
55
  end
55
56
 
57
+ # @param [Float] l2_lambda L2 regularizer coefficient.
56
58
  def initialize(l2_lambda = 0.01)
57
59
  @l2_lambda = l2_lambda
58
60
  end
59
61
 
60
62
  def forward(x)
61
- x + 0.5 * @l2_lambda * (@param.data**2).sum
63
+ x + 0.5 * @l2_lambda * (@param.data ** 2).sum
62
64
  end
63
65
 
64
66
  def backward
@@ -75,9 +77,11 @@ module DNN
75
77
  attr_accessor :l2_lambda
76
78
 
77
79
  def self.from_hash(hash)
78
- L1L2.new(hash[:l1_lambda], hash[:l2_lambda])
80
+ self.new(hash[:l1_lambda], hash[:l2_lambda])
79
81
  end
80
82
 
83
+ # @param [Float] l1_lambda L1 regularizer coefficient.
84
+ # @param [Float] l2_lambda L2 regularizer coefficient.
81
85
  def initialize(l1_lambda = 0.01, l2_lambda = 0.01)
82
86
  @l1_lambda = l1_lambda
83
87
  @l2_lambda = l2_lambda
@@ -85,7 +89,7 @@ module DNN
85
89
 
86
90
  def forward(x)
87
91
  l1 = @l1_lambda * @param.data.abs.sum
88
- l2 = 0.5 * @l2_lambda * (@param.data**2).sum
92
+ l2 = 0.5 * @l2_lambda * (@param.data ** 2).sum
89
93
  x + l1 + l2
90
94
  end
91
95
 
@@ -99,7 +103,6 @@ module DNN
99
103
  def to_hash
100
104
  super(l1_lambda: l1_lambda, l2_lambda: l2_lambda)
101
105
  end
102
-
103
106
  end
104
107
 
105
108
  end
@@ -3,25 +3,38 @@ module DNN
3
3
 
4
4
  # Super class of all RNN classes.
5
5
  class RNN < Connection
6
- include Initializers
7
-
8
- # @return [Integer] number of nodes.
9
6
  attr_reader :num_nodes
10
- # @return [Bool] Maintain state between batches.
7
+ attr_reader :recurrent_weight
8
+ attr_reader :hidden
11
9
  attr_reader :stateful
12
- # @return [Bool] Set the false, only the last of each cell of RNN is left.
13
10
  attr_reader :return_sequences
14
- # @return [DNN::Initializers::Initializer] Recurrent weight initializer.
15
11
  attr_reader :recurrent_weight_initializer
16
- # @return [DNN::Regularizers::Regularizer] Recurrent weight regularization.
17
12
  attr_reader :recurrent_weight_regularizer
18
13
 
14
+ def self.from_hash(hash)
15
+ self.new(hash[:num_nodes],
16
+ stateful: hash[:stateful],
17
+ return_sequences: hash[:return_sequences],
18
+ weight_initializer: Utils.hash_to_obj(hash[:weight_initializer]),
19
+ recurrent_weight_initializer: Utils.hash_to_obj(hash[:recurrent_weight_initializer]),
20
+ bias_initializer: Utils.hash_to_obj(hash[:bias_initializer]),
21
+ weight_regularizer: Utils.hash_to_obj(hash[:weight_regularizer]),
22
+ recurrent_weight_regularizer: Utils.hash_to_obj(hash[:recurrent_weight_regularizer]),
23
+ bias_regularizer: Utils.hash_to_obj(hash[:bias_regularizer]),
24
+ use_bias: hash[:use_bias])
25
+ end
26
+
27
+ # @param [Integer] num_nodes Number of nodes.
28
+ # @param [Boolean] stateful Maintain state between batches.
29
+ # @param [Boolean] return_sequences Set the false, only the last of each cell of RNN is left.
30
+ # @param [DNN::Initializers::Initializer] recurrent_weight_initializer Recurrent weight initializer.
31
+ # @param [DNN::Regularizers::Regularizer | NilClass] recurrent_weight_regularizer Recurrent weight regularizer.
19
32
  def initialize(num_nodes,
20
33
  stateful: false,
21
34
  return_sequences: true,
22
- weight_initializer: RandomNormal.new,
23
- recurrent_weight_initializer: RandomNormal.new,
24
- bias_initializer: Zeros.new,
35
+ weight_initializer: Initializers::RandomNormal.new,
36
+ recurrent_weight_initializer: Initializers::RandomNormal.new,
37
+ bias_initializer: Initializers::Zeros.new,
25
38
  weight_regularizer: nil,
26
39
  recurrent_weight_regularizer: nil,
27
40
  bias_regularizer: nil,
@@ -32,8 +45,8 @@ module DNN
32
45
  @stateful = stateful
33
46
  @return_sequences = return_sequences
34
47
  @layers = []
35
- @hidden = @params[:hidden] = Param.new
36
- @params[:recurrent_weight] = @recurrent_weight = Param.new(nil, 0)
48
+ @hidden = Param.new
49
+ @recurrent_weight = Param.new(nil, 0)
37
50
  @recurrent_weight_initializer = recurrent_weight_initializer
38
51
  @recurrent_weight_regularizer = recurrent_weight_regularizer
39
52
  end
@@ -49,7 +62,7 @@ module DNN
49
62
  def forward(xs)
50
63
  @xs_shape = xs.shape
51
64
  hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes)
52
- h = (@stateful && @hidden.data) ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
65
+ h = @stateful && @hidden.data ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_nodes)
53
66
  xs.shape[1].times do |t|
54
67
  x = xs[true, t, false]
55
68
  @layers[t].trainable = @trainable
@@ -92,6 +105,10 @@ module DNN
92
105
  super(hash)
93
106
  end
94
107
 
108
+ def get_params
109
+ { weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden }
110
+ end
111
+
95
112
  # Reset the state of RNN.
96
113
  def reset_state
97
114
  @hidden.data = @hidden.data.fill(0) if @hidden.data
@@ -113,7 +130,7 @@ module DNN
113
130
  end
114
131
 
115
132
 
116
- class SimpleRNN_Dense
133
+ class SimpleRNNDense
117
134
  attr_accessor :trainable
118
135
 
119
136
  def initialize(weight, recurrent_weight, bias, activation)
@@ -147,32 +164,30 @@ module DNN
147
164
 
148
165
 
149
166
  class SimpleRNN < RNN
150
- include Activations
151
-
152
167
  attr_reader :activation
153
-
168
+
154
169
  def self.from_hash(hash)
155
- simple_rnn = self.new(hash[:num_nodes],
156
- stateful: hash[:stateful],
157
- return_sequences: hash[:return_sequences],
158
- activation: Utils.from_hash(hash[:activation]),
159
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
160
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
161
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
162
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
163
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
164
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
165
- use_bias: hash[:use_bias])
166
- simple_rnn
170
+ self.new(hash[:num_nodes],
171
+ stateful: hash[:stateful],
172
+ return_sequences: hash[:return_sequences],
173
+ activation: Utils.hash_to_obj(hash[:activation]),
174
+ weight_initializer: Utils.hash_to_obj(hash[:weight_initializer]),
175
+ recurrent_weight_initializer: Utils.hash_to_obj(hash[:recurrent_weight_initializer]),
176
+ bias_initializer: Utils.hash_to_obj(hash[:bias_initializer]),
177
+ weight_regularizer: Utils.hash_to_obj(hash[:weight_regularizer]),
178
+ recurrent_weight_regularizer: Utils.hash_to_obj(hash[:recurrent_weight_regularizer]),
179
+ bias_regularizer: Utils.hash_to_obj(hash[:bias_regularizer]),
180
+ use_bias: hash[:use_bias])
167
181
  end
168
182
 
183
+ # @param [DNN::Layers::Layer] activation Activation function to use in a recurrent network.
169
184
  def initialize(num_nodes,
170
185
  stateful: false,
171
186
  return_sequences: true,
172
- activation: Tanh.new,
173
- weight_initializer: RandomNormal.new,
174
- recurrent_weight_initializer: RandomNormal.new,
175
- bias_initializer: Zeros.new,
187
+ activation: Activations::Tanh.new,
188
+ weight_initializer: Initializers::RandomNormal.new,
189
+ recurrent_weight_initializer: Initializers::RandomNormal.new,
190
+ bias_initializer: Initializers::Zeros.new,
176
191
  weight_regularizer: nil,
177
192
  recurrent_weight_regularizer: nil,
178
193
  bias_regularizer: nil,
@@ -197,29 +212,29 @@ module DNN
197
212
  @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
198
213
  @bias.data = Xumo::SFloat.new(@num_nodes) if @bias
199
214
  init_weight_and_bias
200
- @time_length.times do |t|
201
- @layers << SimpleRNN_Dense.new(@weight, @recurrent_weight, @bias, @activation)
215
+ @time_length.times do
216
+ @layers << SimpleRNNDense.new(@weight, @recurrent_weight, @bias, @activation)
202
217
  end
203
218
  end
204
219
 
205
220
  def to_hash
206
- super({activation: @activation.to_hash})
221
+ super(activation: @activation.to_hash)
207
222
  end
208
223
  end
209
224
 
210
225
 
211
- class LSTM_Dense
226
+ class LSTMDense
212
227
  attr_accessor :trainable
213
228
 
214
229
  def initialize(weight, recurrent_weight, bias)
215
230
  @weight = weight
216
231
  @recurrent_weight = recurrent_weight
217
232
  @bias = bias
218
- @tanh = Tanh.new
219
- @g_tanh = Tanh.new
220
- @forget_sigmoid = Sigmoid.new
221
- @in_sigmoid = Sigmoid.new
222
- @out_sigmoid = Sigmoid.new
233
+ @tanh = Activations::Tanh.new
234
+ @g_tanh = Activations::Tanh.new
235
+ @forget_sigmoid = Activations::Sigmoid.new
236
+ @in_sigmoid = Activations::Sigmoid.new
237
+ @out_sigmoid = Activations::Sigmoid.new
223
238
  @trainable = true
224
239
  end
225
240
 
@@ -267,32 +282,20 @@ module DNN
267
282
 
268
283
 
269
284
  class LSTM < RNN
270
- def self.from_hash(hash)
271
- lstm = self.new(hash[:num_nodes],
272
- stateful: hash[:stateful],
273
- return_sequences: hash[:return_sequences],
274
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
275
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
276
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
277
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
278
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
279
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
280
- use_bias: hash[:use_bias])
281
- lstm
282
- end
285
+ attr_reader :cell
283
286
 
284
287
  def initialize(num_nodes,
285
288
  stateful: false,
286
289
  return_sequences: true,
287
- weight_initializer: RandomNormal.new,
288
- recurrent_weight_initializer: RandomNormal.new,
289
- bias_initializer: Zeros.new,
290
+ weight_initializer: Initializers::RandomNormal.new,
291
+ recurrent_weight_initializer: Initializers::RandomNormal.new,
292
+ bias_initializer: Initializers::Zeros.new,
290
293
  weight_regularizer: nil,
291
294
  recurrent_weight_regularizer: nil,
292
295
  bias_regularizer: nil,
293
296
  use_bias: true)
294
297
  super
295
- @cell = @params[:cell] = Param.new
298
+ @cell = Param.new
296
299
  end
297
300
 
298
301
  def build(input_shape)
@@ -302,8 +305,8 @@ module DNN
302
305
  @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
303
306
  @bias.data = Xumo::SFloat.new(@num_nodes * 4) if @bias
304
307
  init_weight_and_bias
305
- @time_length.times do |t|
306
- @layers << LSTM_Dense.new(@weight, @recurrent_weight, @bias)
308
+ @time_length.times do
309
+ @layers << LSTMDense.new(@weight, @recurrent_weight, @bias)
307
310
  end
308
311
  end
309
312
 
@@ -350,19 +353,23 @@ module DNN
350
353
  super()
351
354
  @cell.data = @cell.data.fill(0) if @cell.data
352
355
  end
356
+
357
+ def get_params
358
+ { weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden, cell: @cell }
359
+ end
353
360
  end
354
361
 
355
362
 
356
- class GRU_Dense
363
+ class GRUDense
357
364
  attr_accessor :trainable
358
365
 
359
366
  def initialize(weight, recurrent_weight, bias)
360
367
  @weight = weight
361
368
  @recurrent_weight = recurrent_weight
362
369
  @bias = bias
363
- @update_sigmoid = Sigmoid.new
364
- @reset_sigmoid = Sigmoid.new
365
- @tanh = Tanh.new
370
+ @update_sigmoid = Activations::Sigmoid.new
371
+ @reset_sigmoid = Activations::Sigmoid.new
372
+ @tanh = Activations::Tanh.new
366
373
  @trainable = true
367
374
  end
368
375
 
@@ -423,33 +430,19 @@ module DNN
423
430
 
424
431
 
425
432
  class GRU < RNN
426
- def self.from_hash(hash)
427
- gru = self.new(hash[:num_nodes],
428
- stateful: hash[:stateful],
429
- return_sequences: hash[:return_sequences],
430
- weight_initializer: Utils.from_hash(hash[:weight_initializer]),
431
- recurrent_weight_initializer: Utils.from_hash(hash[:recurrent_weight_initializer]),
432
- bias_initializer: Utils.from_hash(hash[:bias_initializer]),
433
- weight_regularizer: Utils.from_hash(hash[:weight_regularizer]),
434
- recurrent_weight_regularizer: Utils.from_hash(hash[:recurrent_weight_regularizer]),
435
- bias_regularizer: Utils.from_hash(hash[:bias_regularizer]),
436
- use_bias: hash[:use_bias])
437
- gru
438
- end
439
-
440
433
  def initialize(num_nodes,
441
434
  stateful: false,
442
435
  return_sequences: true,
443
- weight_initializer: RandomNormal.new,
444
- recurrent_weight_initializer: RandomNormal.new,
445
- bias_initializer: Zeros.new,
436
+ weight_initializer: Initializers::RandomNormal.new,
437
+ recurrent_weight_initializer: Initializers::RandomNormal.new,
438
+ bias_initializer: Initializers::Zeros.new,
446
439
  weight_regularizer: nil,
447
440
  recurrent_weight_regularizer: nil,
448
441
  bias_regularizer: nil,
449
442
  use_bias: true)
450
443
  super
451
444
  end
452
-
445
+
453
446
  def build(input_shape)
454
447
  super
455
448
  num_prev_nodes = @input_shape[1]
@@ -457,8 +450,8 @@ module DNN
457
450
  @recurrent_weight.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
458
451
  @bias.data = Xumo::SFloat.new(@num_nodes * 3) if @bias
459
452
  init_weight_and_bias
460
- @time_length.times do |t|
461
- @layers << GRU_Dense.new(@weight, @recurrent_weight, @bias)
453
+ @time_length.times do
454
+ @layers << GRUDense.new(@weight, @recurrent_weight, @bias)
462
455
  end
463
456
  end
464
457
  end
@@ -2,6 +2,9 @@ module DNN
2
2
  # This module provides utility functions.
3
3
  module Utils
4
4
  # Categorize labels into "num_classes" classes.
5
+ # @param [Numo::SFloat] y Label data.
6
+ # @param [Numo::SFloat] num_classes Number of classes to classify.
7
+ # @param [Class] narray_type Type of Numo::Narray data after classification.
5
8
  def self.to_categorical(y, num_classes, narray_type = nil)
6
9
  narray_type ||= y.class
7
10
  y2 = narray_type.zeros(y.shape[0], num_classes)
@@ -12,7 +15,7 @@ module DNN
12
15
  end
13
16
 
14
17
  # Convert hash to an object.
15
- def self.from_hash(hash)
18
+ def self.hash_to_obj(hash)
16
19
  return nil if hash == nil
17
20
  dnn_class = DNN.const_get(hash[:class])
18
21
  if dnn_class.respond_to?(:from_hash)
@@ -23,12 +26,12 @@ module DNN
23
26
 
24
27
  # Return the result of the sigmoid function.
25
28
  def self.sigmoid(x)
26
- Sigmoid.new.forward(x)
29
+ Activations::Sigmoid.new.forward(x)
27
30
  end
28
31
 
29
32
  # Return the result of the softmax function.
30
33
  def self.softmax(x)
31
- SoftmaxCrossEntropy.softmax(x)
34
+ Losses::SoftmaxCrossEntropy.softmax(x)
32
35
  end
33
36
  end
34
37
  end
@@ -11,8 +11,8 @@ module DNN
11
11
  dir_path = "#{__dir__}/downloads"
12
12
  end
13
13
  Downloader.new(url).download(dir_path)
14
- rescue => ex
15
- raise DNN_DownloadError.new(ex.message)
14
+ rescue => e
15
+ raise DNN_DownloadError.new(e.message)
16
16
  end
17
17
 
18
18
  def initialize(url)
@@ -42,7 +42,7 @@ module DNN
42
42
  end
43
43
  puts ""
44
44
  end
45
- file_name = @path.match(%r`.+/(.+)`)[1]
45
+ file_name = @path.match(%r`.*/(.+)`)[1]
46
46
  File.binwrite("#{dir_path}/#{file_name}", buf)
47
47
  end
48
48
  end
@@ -0,0 +1,89 @@
1
+ require "zlib"
2
+ require_relative "core/error"
3
+ require_relative "downloader"
4
+ require_relative "mnist"
5
+
6
+ module DNN
7
+ module FashionMNIST
8
+ class DNN_MNIST_LoadError < DNN_Error; end
9
+
10
+ URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
11
+
12
+ TRAIN_IMAGES_FILE_NAME = "train-images-idx3-ubyte.gz"
13
+ TRAIN_LABELS_FILE_NAME = "train-labels-idx1-ubyte.gz"
14
+ TEST_IMAGES_FILE_NAME = "t10k-images-idx3-ubyte.gz"
15
+ TEST_LABELS_FILE_NAME = "t10k-labels-idx1-ubyte.gz"
16
+
17
+ URL_TRAIN_IMAGES = URL_BASE + TRAIN_IMAGES_FILE_NAME
18
+ URL_TRAIN_LABELS = URL_BASE + TRAIN_LABELS_FILE_NAME
19
+ URL_TEST_IMAGES = URL_BASE + TEST_IMAGES_FILE_NAME
20
+ URL_TEST_LABELS = URL_BASE + TEST_LABELS_FILE_NAME
21
+
22
+ def self.downloads
23
+ Dir.mkdir("#{__dir__}/downloads") unless Dir.exist?("#{__dir__}/downloads")
24
+ Dir.mkdir(mnist_dir) unless Dir.exist?(mnist_dir)
25
+ Downloader.download(URL_TRAIN_IMAGES, mnist_dir) unless File.exist?(get_file_path(TRAIN_IMAGES_FILE_NAME))
26
+ Downloader.download(URL_TRAIN_LABELS, mnist_dir) unless File.exist?(get_file_path(TRAIN_LABELS_FILE_NAME))
27
+ Downloader.download(URL_TEST_IMAGES, mnist_dir) unless File.exist?(get_file_path(TEST_IMAGES_FILE_NAME))
28
+ Downloader.download(URL_TEST_LABELS, mnist_dir) unless File.exist?(get_file_path(TEST_LABELS_FILE_NAME))
29
+ end
30
+
31
+ def self.load_train
32
+ downloads
33
+ train_images_file_path = get_file_path(TRAIN_IMAGES_FILE_NAME)
34
+ train_labels_file_path = get_file_path(TRAIN_LABELS_FILE_NAME)
35
+ unless File.exist?(train_images_file_path)
36
+ raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_path}" is not found.`)
37
+ end
38
+ unless File.exist?(train_labels_file_path)
39
+ raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_path}" is not found.`)
40
+ end
41
+ images = load_images(train_images_file_path)
42
+ labels = load_labels(train_labels_file_path)
43
+ [images, labels]
44
+ end
45
+
46
+ def self.load_test
47
+ downloads
48
+ test_images_file_path = get_file_path(TEST_IMAGES_FILE_NAME)
49
+ test_labels_file_path = get_file_path(TEST_LABELS_FILE_NAME)
50
+ unless File.exist?(test_images_file_path)
51
+ raise DNN_MNIST_LoadError.new(%`file "#{test_images_file_path}" is not found.`)
52
+ end
53
+ unless File.exist?(test_labels_file_path)
54
+ raise DNN_MNIST_LoadError.new(%`file "#{test_labels_file_path}" is not found.`)
55
+ end
56
+ images = load_images(test_images_file_path)
57
+ labels = load_labels(test_labels_file_path)
58
+ [images, labels]
59
+ end
60
+
61
+ private_class_method def self.load_images(file_name)
62
+ images = nil
63
+ Zlib::GzipReader.open(file_name) do |f|
64
+ magic, num_images = f.read(8).unpack("N2")
65
+ rows, cols = f.read(8).unpack("N2")
66
+ images = Numo::UInt8.from_binary(f.read)
67
+ images = images.reshape(num_images, cols, rows, 1)
68
+ end
69
+ images
70
+ end
71
+
72
+ private_class_method def self.load_labels(file_name)
73
+ labels = nil
74
+ Zlib::GzipReader.open(file_name) do |f|
75
+ magic, num_labels = f.read(8).unpack("N2")
76
+ labels = Numo::UInt8.from_binary(f.read)
77
+ end
78
+ labels
79
+ end
80
+
81
+ private_class_method def self.mnist_dir
82
+ "#{__dir__}/downloads/fashion-mnist"
83
+ end
84
+
85
+ private_class_method def self.get_file_path(file_name)
86
+ mnist_dir + "/" + file_name
87
+ end
88
+ end
89
+ end