ruby-dnn 0.10.4 → 0.12.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/.travis.yml +1 -2
  3. data/README.md +33 -6
  4. data/examples/cifar100_example.rb +3 -3
  5. data/examples/cifar10_example.rb +3 -3
  6. data/examples/dcgan/dcgan.rb +112 -0
  7. data/examples/dcgan/imgen.rb +20 -0
  8. data/examples/dcgan/train.rb +41 -0
  9. data/examples/iris_example.rb +3 -6
  10. data/examples/mnist_conv2d_example.rb +5 -5
  11. data/examples/mnist_define_by_run.rb +52 -0
  12. data/examples/mnist_example.rb +3 -3
  13. data/examples/mnist_lstm_example.rb +3 -3
  14. data/examples/xor_example.rb +4 -5
  15. data/ext/rb_stb_image/rb_stb_image.c +103 -0
  16. data/lib/dnn.rb +10 -10
  17. data/lib/dnn/cifar10.rb +1 -1
  18. data/lib/dnn/cifar100.rb +1 -1
  19. data/lib/dnn/core/activations.rb +21 -22
  20. data/lib/dnn/core/cnn_layers.rb +94 -111
  21. data/lib/dnn/core/embedding.rb +30 -9
  22. data/lib/dnn/core/initializers.rb +31 -21
  23. data/lib/dnn/core/iterator.rb +52 -0
  24. data/lib/dnn/core/layers.rb +99 -66
  25. data/lib/dnn/core/link.rb +24 -0
  26. data/lib/dnn/core/losses.rb +69 -59
  27. data/lib/dnn/core/merge_layers.rb +71 -0
  28. data/lib/dnn/core/models.rb +393 -0
  29. data/lib/dnn/core/normalizations.rb +27 -14
  30. data/lib/dnn/core/optimizers.rb +212 -134
  31. data/lib/dnn/core/param.rb +8 -6
  32. data/lib/dnn/core/regularizers.rb +10 -7
  33. data/lib/dnn/core/rnn_layers.rb +78 -85
  34. data/lib/dnn/core/utils.rb +6 -3
  35. data/lib/dnn/downloader.rb +3 -3
  36. data/lib/dnn/fashion-mnist.rb +89 -0
  37. data/lib/dnn/image.rb +57 -18
  38. data/lib/dnn/iris.rb +1 -3
  39. data/lib/dnn/mnist.rb +38 -34
  40. data/lib/dnn/version.rb +1 -1
  41. data/third_party/stb_image.h +16 -4
  42. data/third_party/stb_image_resize.h +2630 -0
  43. data/third_party/stb_image_write.h +4 -7
  44. metadata +12 -4
  45. data/lib/dnn/core/dataset.rb +0 -34
  46. data/lib/dnn/core/model.rb +0 -440
@@ -1,4 +1,4 @@
1
- /* stb_image_write - v1.13 - public domain - http://nothings.org/stb/stb_image_write.h
1
+ /* stb_image_write - v1.13 - public domain - http://nothings.org/stb
2
2
  writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015
3
3
  no warranty implied; use at your own risk
4
4
 
@@ -10,11 +10,6 @@
10
10
 
11
11
  Will probably not work correctly with strict-aliasing optimizations.
12
12
 
13
- If using a modern Microsoft Compiler, non-safe versions of CRT calls may cause
14
- compilation warnings or even errors. To avoid this, also before #including,
15
-
16
- #define STBI_MSC_SECURE_CRT
17
-
18
13
  ABOUT:
19
14
 
20
15
  This header file is a library for writing images to C stdio or a callback.
@@ -873,7 +868,7 @@ STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, i
873
868
  unsigned int bitbuf=0;
874
869
  int i,j, bitcount=0;
875
870
  unsigned char *out = NULL;
876
- unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(char**));
871
+ unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**));
877
872
  if (hash_table == NULL)
878
873
  return NULL;
879
874
  if (quality < 5) quality = 5;
@@ -1535,6 +1530,8 @@ STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const
1535
1530
  #endif // STB_IMAGE_WRITE_IMPLEMENTATION
1536
1531
 
1537
1532
  /* Revision history
1533
+ 1.11 (2019-08-11)
1534
+
1538
1535
  1.10 (2019-02-07)
1539
1536
  support utf8 filenames in Windows; fix warnings and platform ifdefs
1540
1537
  1.09 (2018-02-11)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.10.4
4
+ version: 0.12.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-06-26 00:00:00.000000000 Z
11
+ date: 2019-09-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -100,8 +100,12 @@ files:
100
100
  - bin/setup
101
101
  - examples/cifar100_example.rb
102
102
  - examples/cifar10_example.rb
103
+ - examples/dcgan/dcgan.rb
104
+ - examples/dcgan/imgen.rb
105
+ - examples/dcgan/train.rb
103
106
  - examples/iris_example.rb
104
107
  - examples/mnist_conv2d_example.rb
108
+ - examples/mnist_define_by_run.rb
105
109
  - examples/mnist_example.rb
106
110
  - examples/mnist_lstm_example.rb
107
111
  - examples/xor_example.rb
@@ -114,13 +118,15 @@ files:
114
118
  - lib/dnn/cifar100.rb
115
119
  - lib/dnn/core/activations.rb
116
120
  - lib/dnn/core/cnn_layers.rb
117
- - lib/dnn/core/dataset.rb
118
121
  - lib/dnn/core/embedding.rb
119
122
  - lib/dnn/core/error.rb
120
123
  - lib/dnn/core/initializers.rb
124
+ - lib/dnn/core/iterator.rb
121
125
  - lib/dnn/core/layers.rb
126
+ - lib/dnn/core/link.rb
122
127
  - lib/dnn/core/losses.rb
123
- - lib/dnn/core/model.rb
128
+ - lib/dnn/core/merge_layers.rb
129
+ - lib/dnn/core/models.rb
124
130
  - lib/dnn/core/normalizations.rb
125
131
  - lib/dnn/core/optimizers.rb
126
132
  - lib/dnn/core/param.rb
@@ -128,12 +134,14 @@ files:
128
134
  - lib/dnn/core/rnn_layers.rb
129
135
  - lib/dnn/core/utils.rb
130
136
  - lib/dnn/downloader.rb
137
+ - lib/dnn/fashion-mnist.rb
131
138
  - lib/dnn/image.rb
132
139
  - lib/dnn/iris.rb
133
140
  - lib/dnn/mnist.rb
134
141
  - lib/dnn/version.rb
135
142
  - ruby-dnn.gemspec
136
143
  - third_party/stb_image.h
144
+ - third_party/stb_image_resize.h
137
145
  - third_party/stb_image_write.h
138
146
  homepage: https://github.com/unagiootoro/ruby-dnn.git
139
147
  licenses:
@@ -1,34 +0,0 @@
1
- # This class manages input datas and output datas together.
2
- class DNN::Dataset
3
- # @param [Numo::SFloat] x_datas input datas.
4
- # @param [Numo::SFloat] y_datas output datas.
5
- # @param [Bool] random Set true to return batches randomly. Setting false returns batches in order of index.
6
- def initialize(x_datas, y_datas, random = true)
7
- @x_datas = x_datas
8
- @y_datas = y_datas
9
- @random = random
10
- @num_datas = x_datas.shape[0]
11
- reset_indexs
12
- end
13
-
14
- # Return the next batch.
15
- # If the number of remaining data < batch size, if random = true, shuffle the data again and return a batch.
16
- # If random = false, all remaining data will be returned regardless of the batch size.
17
- def next_batch(batch_size)
18
- if @indexes.length < batch_size
19
- batch_indexes = @indexes unless @random
20
- reset_indexs
21
- batch_indexes = @indexes.shift(batch_size) if @random
22
- else
23
- batch_indexes = @indexes.shift(batch_size)
24
- end
25
- x_batch = @x_datas[batch_indexes, false]
26
- y_batch = @y_datas[batch_indexes, false]
27
- [x_batch, y_batch]
28
- end
29
-
30
- private def reset_indexs
31
- @indexes = @num_datas.times.to_a
32
- @indexes.shuffle! if @random
33
- end
34
- end
@@ -1,440 +0,0 @@
1
- require "zlib"
2
- require "json"
3
- require "base64"
4
-
5
- module DNN
6
-
7
- # This class deals with the model of the network.
8
- class Model
9
- # @return [Array] All layers possessed by the model.
10
- attr_accessor :layers
11
- # @return [Bool] Setting false prevents learning of parameters.
12
- attr_accessor :trainable
13
-
14
- # Load marshal model.
15
- # @param [String] file_name File name of marshal model to load.
16
- def self.load(file_name)
17
- Marshal.load(Zlib::Inflate.inflate(File.binread(file_name)))
18
- end
19
-
20
- # Load json model.
21
- # @param [String] json_str json string to load model.
22
- # @return [DNN::Model]
23
- def self.load_json(json_str)
24
- hash = JSON.parse(json_str, symbolize_names: true)
25
- model = self.from_hash(hash)
26
- model.compile(Utils.from_hash(hash[:optimizer]), Utils.from_hash(hash[:loss]))
27
- model
28
- end
29
-
30
- def self.from_hash(hash)
31
- model = self.new
32
- model.layers = hash[:layers].map { |hash_layer| Utils.from_hash(hash_layer) }
33
- model
34
- end
35
-
36
- def initialize
37
- @layers = []
38
- @trainable = true
39
- @optimizer = nil
40
- @compiled = false
41
- end
42
-
43
- # Load json model parameters.
44
- # @param [String] json_str json string to load model parameters.
45
- def load_json_params(json_str)
46
- hash = JSON.parse(json_str, symbolize_names: true)
47
- has_param_layers_params = hash[:params]
48
- has_param_layers_index = 0
49
- has_param_layers = get_all_layers.select { |layer| layer.is_a?(Layers::HasParamLayer) }
50
- has_param_layers.each do |layer|
51
- hash_params = has_param_layers_params[has_param_layers_index]
52
- hash_params.each do |key, (shape, base64_param)|
53
- bin = Base64.decode64(base64_param)
54
- data = Xumo::SFloat.from_binary(bin).reshape(*shape)
55
- layer.params[key].data = data
56
- end
57
- has_param_layers_index += 1
58
- end
59
- end
60
-
61
- # Save the model in marshal format.
62
- # @param [String] file_name name to save model.
63
- def save(file_name)
64
- bin = Zlib::Deflate.deflate(Marshal.dump(self))
65
- begin
66
- File.binwrite(file_name, bin)
67
- rescue Errno::ENOENT => ex
68
- dir_name = file_name.match(%r`(.*)/.+$`)[1]
69
- Dir.mkdir(dir_name)
70
- File.binwrite(file_name, bin)
71
- end
72
- end
73
-
74
- # Convert model to json string.
75
- # @return [String] json string.
76
- def to_json
77
- hash = self.to_hash
78
- hash[:version] = VERSION
79
- JSON.pretty_generate(hash)
80
- end
81
-
82
- # Convert model parameters to json string.
83
- # @return [String] json string.
84
- def params_to_json
85
- has_param_layers = get_all_layers.select { |layer| layer.is_a?(Layers::HasParamLayer) }
86
- has_param_layers_params = has_param_layers.map do |layer|
87
- layer.params.map { |key, param|
88
- base64_data = Base64.encode64(param.data.to_binary)
89
- [key, [param.data.shape, base64_data]]
90
- }.to_h
91
- end
92
- hash = {version: VERSION, params: has_param_layers_params}
93
- JSON.dump(hash)
94
- end
95
-
96
- # Add layer to the model.
97
- # @param [DNN::Layers::Layer] layer Layer to add to the model.
98
- # @return [DNN::Model] return self.
99
- def <<(layer)
100
- if !layer.is_a?(Layers::Layer) && !layer.is_a?(Model)
101
- raise TypeError.new("layer is not an instance of the DNN::Layers::Layer class or DNN::Model class.")
102
- end
103
- @layers << layer
104
- self
105
- end
106
-
107
- # Set optimizer and loss_func to model and build all layers.
108
- # @param [DNN::Optimizers::Optimizer] optimizer Optimizer to use for learning.
109
- # @param [DNN::Losses::Loss] loss_func Loss function to use for learning.
110
- def compile(optimizer, loss_func)
111
- raise DNN_Error.new("The model is already compiled.") if compiled?
112
- unless optimizer.is_a?(Optimizers::Optimizer)
113
- raise TypeError.new("optimizer:#{optimizer.class} is not an instance of DNN::Optimizers::Optimizer class.")
114
- end
115
- unless loss_func.is_a?(Losses::Loss)
116
- raise TypeError.new("loss_func:#{loss_func.class} is not an instance of DNN::Losses::Loss class.")
117
- end
118
- @compiled = true
119
- layers_check
120
- @optimizer = optimizer
121
- @loss_func = loss_func
122
- build
123
- end
124
-
125
- # Set optimizer and loss_func to model and recompile. But does not build layers.
126
- # @param [DNN::Optimizers::Optimizer] optimizer Optimizer to use for learning.
127
- # @param [DNN::Losses::Loss] loss_func Loss function to use for learning.
128
- def recompile(optimizer, loss_func)
129
- unless optimizer.is_a?(Optimizers::Optimizer)
130
- raise TypeError.new("optimizer:#{optimizer.class} is not an instance of DNN::Optimizers::Optimizer class.")
131
- end
132
- unless loss_func.is_a?(Losses::Loss)
133
- raise TypeError.new("loss_func:#{loss_func.class} is not an instance of DNN::Losses::Loss class.")
134
- end
135
- @compiled = true
136
- layers_check
137
- @optimizer = optimizer
138
- @loss_func = loss_func
139
- end
140
-
141
- def build(super_model = nil)
142
- @super_model = super_model
143
- shape = if super_model
144
- super_model.get_prev_layer(self).output_shape
145
- else
146
- @layers.first.build
147
- end
148
- layers = super_model ? @layers : @layers[1..-1]
149
- layers.each do |layer|
150
- if layer.is_a?(Model)
151
- layer.build(self)
152
- layer.recompile(@optimizer, @loss_func)
153
- else
154
- layer.build(shape)
155
- end
156
- shape = layer.output_shape
157
- end
158
- end
159
-
160
- # @return [Array] Return the input shape of the model.
161
- def input_shape
162
- @layers.first.input_shape
163
- end
164
-
165
- # @return [Array] Return the output shape of the model.
166
- def output_shape
167
- @layers.last.output_shape
168
- end
169
-
170
- # @return [DNN::Optimizers::Optimizer] optimizer Return the optimizer to use for learning.
171
- def optimizer
172
- raise DNN_Error.new("The model is not compiled.") unless compiled?
173
- @optimizer
174
- end
175
-
176
- # @return [DNN::Losses::Loss] loss Return the loss to use for learning.
177
- def loss_func
178
- raise DNN_Error.new("The model is not compiled.") unless compiled?
179
- @loss_func
180
- end
181
-
182
- # @return [Bool] Returns whether the model is learning.
183
- def compiled?
184
- @compiled
185
- end
186
-
187
- # Start training.
188
- # Compile the model before use this method.
189
- # @param [Numo::SFloat] x Input training data.
190
- # @param [Numo::SFloat] y Output training data.
191
- # @param [Integer] epochs Number of training.
192
- # @param [Integer] batch_size Batch size used for one training.
193
- # @param [Array or NilClass] test If you to test the model for every 1 epoch,
194
- # specify [x_test, y_test]. Don't test to the model, specify nil.
195
- # @param [Bool] verbose Set true to display the log. If false is set, the log is not displayed.
196
- # @param [Lambda] before_epoch_cbk Process performed before one training.
197
- # @param [Lambda] after_epoch_cbk Process performed after one training.
198
- # @param [Lambda] before_batch_cbk Set the proc to be performed before batch processing.
199
- # @param [Lambda] after_batch_cbk Set the proc to be performed after batch processing.
200
- def train(x, y, epochs,
201
- batch_size: 1,
202
- test: nil,
203
- verbose: true,
204
- before_epoch_cbk: nil,
205
- after_epoch_cbk: nil,
206
- before_batch_cbk: nil,
207
- after_batch_cbk: nil)
208
- raise DNN_Error.new("The model is not compiled.") unless compiled?
209
- check_xy_type(x, y)
210
- dataset = Dataset.new(x, y)
211
- num_train_datas = x.shape[0]
212
- (1..epochs).each do |epoch|
213
- before_epoch_cbk.call(epoch) if before_epoch_cbk
214
- puts "【 epoch #{epoch}/#{epochs} 】" if verbose
215
- (num_train_datas.to_f / batch_size).ceil.times do |index|
216
- x_batch, y_batch = dataset.next_batch(batch_size)
217
- loss_value = train_on_batch(x_batch, y_batch,
218
- before_batch_cbk: before_batch_cbk, after_batch_cbk: after_batch_cbk)
219
- if loss_value.is_a?(Numo::SFloat)
220
- loss_value = loss_value.mean
221
- elsif loss_value.nan?
222
- puts "\nloss is nan" if verbose
223
- return
224
- end
225
- num_trained_datas = (index + 1) * batch_size
226
- num_trained_datas = num_trained_datas > num_train_datas ? num_train_datas : num_trained_datas
227
- log = "\r"
228
- 40.times do |i|
229
- if i < num_trained_datas * 40 / num_train_datas
230
- log << "="
231
- elsif i == num_trained_datas * 40 / num_train_datas
232
- log << ">"
233
- else
234
- log << "_"
235
- end
236
- end
237
- log << " #{num_trained_datas}/#{num_train_datas} loss: #{sprintf('%.8f', loss_value)}"
238
- print log if verbose
239
- end
240
- if verbose && test
241
- acc, test_loss = accurate(test[0], test[1], batch_size,
242
- before_batch_cbk: before_batch_cbk, after_batch_cbk: after_batch_cbk)
243
- print " accurate: #{acc}, test loss: #{sprintf('%.8f', test_loss)}"
244
- end
245
- puts "" if verbose
246
- after_epoch_cbk.call(epoch) if after_epoch_cbk
247
- end
248
- end
249
-
250
- # Training once.
251
- # Compile the model before use this method.
252
- # @param [Numo::SFloat] x Input training data.
253
- # @param [Numo::SFloat] y Output training data.
254
- # @param [Lambda] before_batch_cbk Set the proc to be performed before batch processing.
255
- # @param [Lambda] after_batch_cbk Set the proc to be performed after batch processing.
256
- # @return [Float | Numo::SFloat] Return loss value in the form of Float or Numo::SFloat.
257
- def train_on_batch(x, y, before_batch_cbk: nil, after_batch_cbk: nil)
258
- raise DNN_Error.new("The model is not compiled.") unless compiled?
259
- check_xy_type(x, y)
260
- input_data_shape_check(x, y)
261
- x, y = before_batch_cbk.call(x, y, true) if before_batch_cbk
262
- x = forward(x, true)
263
- loss_value = @loss_func.forward(x, y, get_all_layers)
264
- dy = @loss_func.backward(y, get_all_layers)
265
- backward(dy)
266
- update
267
- after_batch_cbk.call(loss_value, true) if after_batch_cbk
268
- loss_value
269
- end
270
-
271
- # Evaluate model and get accurate of test data.
272
- # @param [Numo::SFloat] x Input test data.
273
- # @param [Numo::SFloat] y Output test data.
274
- # @param [Lambda] before_batch_cbk Set the proc to be performed before batch processing.
275
- # @param [Lambda] after_batch_cbk Set the proc to be performed after batch processing.
276
- # @return [Array] Returns the test data accurate and mean loss in the form [accurate, mean_loss].
277
- def accurate(x, y, batch_size = 100, before_batch_cbk: nil, after_batch_cbk: nil)
278
- check_xy_type(x, y)
279
- input_data_shape_check(x, y)
280
- batch_size = batch_size >= x.shape[0] ? x.shape[0] : batch_size
281
- dataset = Dataset.new(x, y, false)
282
- correct = 0
283
- sum_loss = 0
284
- max_iter = (x.shape[0].to_f / batch_size)
285
- max_iter.ceil.times do |i|
286
- x_batch, y_batch = dataset.next_batch(batch_size)
287
- x_batch, y_batch = before_batch_cbk.call(x_batch, y_batch, false) if before_batch_cbk
288
- x_batch = forward(x_batch, false)
289
- sigmoid = Sigmoid.new
290
- x_batch.shape[0].times do |j|
291
- if @layers.last.output_shape == [1]
292
- if @loss_func.is_a?(SigmoidCrossEntropy)
293
- correct += 1 if sigmoid.forward(x_batch[j, 0]).round == y_batch[j, 0].round
294
- else
295
- correct += 1 if x_batch[j, 0].round == y_batch[j, 0].round
296
- end
297
- else
298
- correct += 1 if x_batch[j, true].max_index == y_batch[j, true].max_index
299
- end
300
- end
301
- loss_value = @loss_func.forward(x_batch, y_batch, get_all_layers)
302
- after_batch_cbk.call(loss_value, false) if after_batch_cbk
303
- sum_loss += loss_value.is_a?(Xumo::SFloat) ? loss_value.mean : loss_value
304
- end
305
- mean_loss = sum_loss / max_iter
306
- [correct.to_f / x.shape[0], mean_loss]
307
- end
308
-
309
- # Predict data.
310
- # @param [Numo::SFloat] x Input data.
311
- def predict(x)
312
- check_xy_type(x)
313
- input_data_shape_check(x)
314
- forward(x, false)
315
- end
316
-
317
- # Predict one data.
318
- # @param [Numo::SFloat] x Input data. However, x is single data.
319
- def predict1(x)
320
- check_xy_type(x)
321
- predict(x.reshape(1, *x.shape))[0, false]
322
- end
323
-
324
- # Get loss value.
325
- # @param [Numo::SFloat] x Input data.
326
- # @param [Numo::SFloat] y Output data.
327
- # @return [Float | Numo::SFloat] Return loss value in the form of Float or Numo::SFloat.
328
- def loss(x, y)
329
- check_xy_type(x, y)
330
- input_data_shape_check(x, y)
331
- x = forward(x, false)
332
- @loss_func.forward(x, y, get_all_layers)
333
- end
334
-
335
- # @return [DNN::Model] Copy this model.
336
- def copy
337
- Marshal.load(Marshal.dump(self))
338
- end
339
-
340
- # Get the layer that the model has.
341
- def get_layer(*args)
342
- if args.length == 1
343
- index = args[0]
344
- @layers[index]
345
- else
346
- layer_class, index = args
347
- @layers.select { |layer| layer.is_a?(layer_class) }[index]
348
- end
349
- end
350
-
351
- # Get the all layers.
352
- # @return [Array] all layers array.
353
- def get_all_layers
354
- @layers.map { |layer|
355
- layer.is_a?(Model) ? layer.get_all_layers : layer
356
- }.flatten
357
- end
358
-
359
- def forward(x, learning_phase)
360
- @layers.each do |layer|
361
- x = if layer.is_a?(Model)
362
- layer.forward(x, learning_phase)
363
- else
364
- layer.learning_phase = learning_phase
365
- layer.forward(x)
366
- end
367
- end
368
- x
369
- end
370
-
371
- def backward(dy)
372
- @layers.reverse.each do |layer|
373
- dy = layer.backward(dy)
374
- end
375
- dy
376
- end
377
-
378
- def update
379
- return unless @trainable
380
- all_trainable_layers = @layers.map { |layer|
381
- if layer.is_a?(Model)
382
- layer.trainable ? layer.get_all_layers : nil
383
- else
384
- layer
385
- end
386
- }.flatten.compact.uniq
387
- @optimizer.update(all_trainable_layers)
388
- end
389
-
390
- def get_prev_layer(layer)
391
- layer_index = @layers.index(layer)
392
- prev_layer = if layer_index == 0
393
- if @super_model
394
- @super_model.layers[@super_model.layers.index(self) - 1]
395
- else
396
- self
397
- end
398
- else
399
- @layers[layer_index - 1]
400
- end
401
- if prev_layer.is_a?(Layers::Layer)
402
- prev_layer
403
- elsif prev_layer.is_a?(Model)
404
- prev_layer.layers.last
405
- end
406
- end
407
-
408
- def to_hash
409
- hash_layers = @layers.map { |layer| layer.to_hash }
410
- {class: Model.name, layers: hash_layers, optimizer: @optimizer.to_hash, loss: @loss_func.to_hash}
411
- end
412
-
413
- private
414
-
415
- def layers_check
416
- if !@layers.first.is_a?(Layers::InputLayer) && !@layers.first.is_a?(Layers::Embedding) && !@super_model
417
- raise TypeError.new("The first layer is not an InputLayer or Embedding.")
418
- end
419
- end
420
-
421
- def input_data_shape_check(x, y = nil)
422
- unless @layers.first.input_shape == x.shape[1..-1]
423
- raise DNN_ShapeError.new("The shape of x does not match the input shape. x shape is #{x.shape[1..-1]}, but input shape is #{@layers.first.input_shape}.")
424
- end
425
- if y && @layers.last.output_shape != y.shape[1..-1]
426
- raise DNN_ShapeError.new("The shape of y does not match the input shape. y shape is #{y.shape[1..-1]}, but output shape is #{@layers.last.output_shape}.")
427
- end
428
- end
429
-
430
- def check_xy_type(x, y = nil)
431
- unless x.is_a?(Xumo::SFloat)
432
- raise TypeError.new("x:#{x.class.name} is not an instance of #{Xumo::SFloat.name} class.")
433
- end
434
- if y && !y.is_a?(Xumo::SFloat)
435
- raise TypeError.new("y:#{y.class.name} is not an instance of #{Xumo::SFloat.name} class.")
436
- end
437
- end
438
- end
439
-
440
- end