ruby-dnn 0.16.2 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +22 -0
  3. data/examples/api-examples/early_stopping_example.rb +1 -1
  4. data/examples/api-examples/initializer_example.rb +1 -1
  5. data/examples/api-examples/regularizer_example.rb +1 -1
  6. data/examples/dcgan/dcgan.rb +10 -3
  7. data/examples/pix2pix/dcgan.rb +4 -0
  8. data/examples/pix2pix/train.rb +5 -2
  9. data/examples/vae.rb +0 -6
  10. data/lib/dnn/core/callbacks.rb +7 -3
  11. data/lib/dnn/core/error.rb +2 -2
  12. data/lib/dnn/core/initializers.rb +5 -5
  13. data/lib/dnn/core/iterator.rb +4 -1
  14. data/lib/dnn/core/layers/basic_layers.rb +42 -65
  15. data/lib/dnn/core/layers/cnn_layers.rb +34 -35
  16. data/lib/dnn/core/layers/embedding.rb +3 -24
  17. data/lib/dnn/core/layers/math_layers.rb +12 -0
  18. data/lib/dnn/core/layers/merge_layers.rb +13 -13
  19. data/lib/dnn/core/layers/normalizations.rb +4 -4
  20. data/lib/dnn/core/layers/rnn_layers.rb +46 -46
  21. data/lib/dnn/core/link.rb +8 -8
  22. data/lib/dnn/core/losses.rb +10 -20
  23. data/lib/dnn/core/models.rb +23 -46
  24. data/lib/dnn/core/monkey_patch.rb +10 -0
  25. data/lib/dnn/core/optimizers.rb +1 -2
  26. data/lib/dnn/core/param.rb +2 -2
  27. data/lib/dnn/core/regularizers.rb +1 -1
  28. data/lib/dnn/core/savers.rb +2 -2
  29. data/lib/dnn/core/tensor.rb +1 -1
  30. data/lib/dnn/datasets/cifar10.rb +1 -1
  31. data/lib/dnn/datasets/cifar100.rb +1 -1
  32. data/lib/dnn/datasets/downloader.rb +1 -1
  33. data/lib/dnn/datasets/fashion-mnist.rb +1 -1
  34. data/lib/dnn/datasets/iris.rb +1 -1
  35. data/lib/dnn/datasets/mnist.rb +1 -1
  36. data/lib/dnn/datasets/stl-10.rb +2 -2
  37. data/lib/dnn/version.rb +1 -1
  38. metadata +2 -2
@@ -1,15 +1,15 @@
1
1
  module DNN
2
2
  class Link
3
3
  attr_accessor :prev
4
- attr_accessor :layer
4
+ attr_accessor :layer_node
5
5
 
6
- def initialize(prev = nil, layer = nil)
6
+ def initialize(prev = nil, layer_node = nil)
7
7
  @prev = prev
8
- @layer = layer
8
+ @layer_node = layer_node
9
9
  end
10
10
 
11
11
  def backward(dy = Numo::SFloat[1])
12
- dy = @layer.backward(dy)
12
+ dy = @layer_node.backward_node(dy)
13
13
  @prev&.backward(dy)
14
14
  end
15
15
  end
@@ -17,16 +17,16 @@ module DNN
17
17
  class TwoInputLink
18
18
  attr_accessor :prev1
19
19
  attr_accessor :prev2
20
- attr_accessor :layer
20
+ attr_accessor :layer_node
21
21
 
22
- def initialize(prev1 = nil, prev2 = nil, layer = nil)
22
+ def initialize(prev1 = nil, prev2 = nil, layer_node = nil)
23
23
  @prev1 = prev1
24
24
  @prev2 = prev2
25
- @layer = layer
25
+ @layer_node = layer_node
26
26
  end
27
27
 
28
28
  def backward(dy = Numo::SFloat[1])
29
- dys = @layer.backward(dy)
29
+ dys = @layer_node.backward_node(dy)
30
30
  if dys.is_a?(Array)
31
31
  dy1, dy2 = *dys
32
32
  else
@@ -10,7 +10,7 @@ module DNN
10
10
  return nil unless hash
11
11
  loss_class = DNN.const_get(hash[:class])
12
12
  loss = loss_class.allocate
13
- raise DNN_Error, "#{loss.class} is not an instance of #{self} class." unless loss.is_a?(self)
13
+ raise DNNError, "#{loss.class} is not an instance of #{self} class." unless loss.is_a?(self)
14
14
  loss.load_hash(hash)
15
15
  loss
16
16
  end
@@ -21,7 +21,7 @@ module DNN
21
21
 
22
22
  def loss(y, t, layers = nil)
23
23
  unless y.shape == t.shape
24
- raise DNN_ShapeError, "The shape of y does not match the t shape. y shape is #{y.shape}, but t shape is #{t.shape}."
24
+ raise DNNShapeError, "The shape of y does not match the t shape. y shape is #{y.shape}, but t shape is #{t.shape}."
25
25
  end
26
26
  loss = call(y, t)
27
27
  loss = regularizers_forward(loss, layers) if layers
@@ -76,7 +76,7 @@ module DNN
76
76
  end
77
77
 
78
78
  def backward_node(d)
79
- (@y - @t) / @y.shape[0]
79
+ d * (@y - @t) / @y.shape[0]
80
80
  end
81
81
  end
82
82
 
@@ -93,7 +93,7 @@ module DNN
93
93
  dy = (@y - @t)
94
94
  dy[dy >= 0] = 1
95
95
  dy[dy < 0] = -1
96
- dy / @y.shape[0]
96
+ d * dy / @y.shape[0]
97
97
  end
98
98
  end
99
99
 
@@ -109,7 +109,7 @@ module DNN
109
109
  def backward_node(d)
110
110
  a = Xumo::SFloat.ones(*@a.shape)
111
111
  a[@a <= 0] = 0
112
- (a * -@t) / a.shape[0]
112
+ d * (a * -@t) / a.shape[0]
113
113
  end
114
114
  end
115
115
 
@@ -119,8 +119,8 @@ module DNN
119
119
  def forward_node(y, t)
120
120
  @y = y
121
121
  @t = t
122
- loss_l1_value = loss_l1(y, t)
123
- @loss_value = loss_l1_value > 1 ? loss_l1_value : loss_l2(y, t)
122
+ loss_l1_value = (y - t).abs.mean(0).sum
123
+ @loss_value = loss_l1_value > 1 ? loss_l1_value : 0.5 * ((y - t)**2).mean(0).sum
124
124
  end
125
125
 
126
126
  def backward_node(d)
@@ -129,17 +129,7 @@ module DNN
129
129
  dy[dy >= 0] = 1
130
130
  dy[dy < 0] = -1
131
131
  end
132
- dy / @y.shape[0]
133
- end
134
-
135
- private
136
-
137
- def loss_l1(y, t)
138
- (y - t).abs.mean(0).sum
139
- end
140
-
141
- def loss_l2(y, t)
142
- 0.5 * ((y - t)**2).mean(0).sum
132
+ d * dy / @y.shape[0]
143
133
  end
144
134
  end
145
135
 
@@ -168,7 +158,7 @@ module DNN
168
158
  end
169
159
 
170
160
  def backward_node(d)
171
- (@x - @t) / @x.shape[0]
161
+ d * (@x - @t) / @x.shape[0]
172
162
  end
173
163
 
174
164
  def to_hash
@@ -205,7 +195,7 @@ module DNN
205
195
  end
206
196
 
207
197
  def backward_node(d)
208
- (@x - @t) / @x.shape[0]
198
+ d * (@x - @t) / @x.shape[0]
209
199
  end
210
200
 
211
201
  def to_hash
@@ -39,23 +39,28 @@ module DNN
39
39
  end
40
40
 
41
41
  class Chain
42
+ def initialize
43
+ @layers_cache = nil
44
+ end
45
+
42
46
  # Forward propagation.
43
- # @param [Tensor] input_tensor Input tensor.
47
+ # @param [Tensor] input_tensors Input tensors.
44
48
  # @return [Tensor] Output tensor.
45
- def forward(input_tensor)
49
+ def forward(input_tensors)
46
50
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
47
51
  end
48
52
 
49
53
  # Forward propagation and create a link.
50
- # @param [Tensor] input_tensor Input tensor.
54
+ # @param [Tensor | Array] input_tensors Input tensors.
51
55
  # @return [Tensor] Output tensor.
52
- def call(input_tensor)
53
- forward(input_tensor)
56
+ def call(input_tensors)
57
+ forward(input_tensors)
54
58
  end
55
59
 
56
60
  # Get the all layers.
57
61
  # @return [Array] All layers array.
58
62
  def layers
63
+ return @layers_cache if @layers_cache
59
64
  layers_array = []
60
65
  instance_variables.sort.each do |ivar|
61
66
  obj = instance_variable_get(ivar)
@@ -65,39 +70,7 @@ module DNN
65
70
  layers_array.concat(obj.layers)
66
71
  end
67
72
  end
68
- layers_array
69
- end
70
-
71
- def to_hash
72
- layers_hash = { class: self.class.name }
73
- instance_variables.sort.each do |ivar|
74
- obj = instance_variable_get(ivar)
75
- if obj.is_a?(Layers::Layer) || obj.is_a?(Chain)
76
- layers_hash[ivar] = obj.to_hash
77
- elsif obj.is_a?(LayersList)
78
- layers_hash[ivar] = obj.to_hash_list
79
- end
80
- end
81
- layers_hash
82
- end
83
-
84
- def load_hash(layers_hash)
85
- instance_variables.sort.each do |ivar|
86
- hash_or_array = layers_hash[ivar]
87
- if hash_or_array.is_a?(Array)
88
- instance_variable_set(ivar, LayersList.from_hash_list(hash_or_array))
89
- elsif hash_or_array.is_a?(Hash)
90
- obj_class = DNN.const_get(hash_or_array[:class])
91
- obj = obj_class.allocate
92
- if obj.is_a?(Chain)
93
- obj = obj_class.new
94
- obj.load_hash(hash_or_array)
95
- instance_variable_set(ivar, obj)
96
- else
97
- instance_variable_set(ivar, Layers::Layer.from_hash(hash_or_array))
98
- end
99
- end
100
- end
73
+ @layers_cache = layers_array
101
74
  end
102
75
  end
103
76
 
@@ -118,17 +91,16 @@ module DNN
118
91
  end
119
92
 
120
93
  def initialize
94
+ super
121
95
  @optimizer = nil
122
96
  @loss_func = nil
123
97
  @built = false
124
98
  @callbacks = []
125
- @layers_cache = nil
126
99
  @last_log = {}
127
100
  end
128
101
 
129
- def call(inputs)
130
- @layers_cache = nil
131
- output_tensor = forward(inputs)
102
+ def call(input_tensors)
103
+ output_tensor = forward(input_tensors)
132
104
  @built = true unless @built
133
105
  output_tensor
134
106
  end
@@ -187,8 +159,8 @@ module DNN
187
159
  initial_epoch: 1,
188
160
  test: nil,
189
161
  verbose: true)
190
- raise DNN_Error, "The model is not optimizer setup complete." unless @optimizer
191
- raise DNN_Error, "The model is not loss_func setup complete." unless @loss_func
162
+ raise DNNError, "The model is not optimizer setup complete." unless @optimizer
163
+ raise DNNError, "The model is not loss_func setup complete." unless @loss_func
192
164
 
193
165
  num_train_datas = train_iterator.num_datas
194
166
  num_train_datas = num_train_datas / batch_size * batch_size if train_iterator.last_round_down
@@ -255,8 +227,8 @@ module DNN
255
227
  # @param [Numo::SFloat] y Output training data.
256
228
  # @return [Float | Numo::SFloat] Return loss value in the form of Float or Numo::SFloat.
257
229
  def train_on_batch(x, y)
258
- raise DNN_Error, "The model is not optimizer setup complete." unless @optimizer
259
- raise DNN_Error, "The model is not loss_func setup complete." unless @loss_func
230
+ raise DNNError, "The model is not optimizer setup complete." unless @optimizer
231
+ raise DNNError, "The model is not loss_func setup complete." unless @loss_func
260
232
  check_xy_type(x, y)
261
233
  call_callbacks(:before_train_on_batch)
262
234
  DNN.learning_phase = true
@@ -419,12 +391,15 @@ module DNN
419
391
  @built
420
392
  end
421
393
 
394
+ # Clean all layers.
422
395
  def clean_layers
423
396
  layers.each(&:clean)
424
397
  @loss_func.clean
425
398
  @layers_cache = nil
426
399
  end
427
400
 
401
+ # Get parameter data of all layers.
402
+ # @return [Array] Parameter data.
428
403
  def get_all_params_data
429
404
  trainable_layers.map do |layer|
430
405
  layer.get_params.to_h do |key, param|
@@ -433,6 +408,8 @@ module DNN
433
408
  end
434
409
  end
435
410
 
411
+ # Set parameter data of all layers.
412
+ # @param [Array] params_data Parameter data obtained by get_all_params_data.
436
413
  def set_all_params_data(params_data)
437
414
  trainable_layers.each.with_index do |layer, i|
438
415
  params_data[i].each do |(key, data)|
@@ -73,3 +73,13 @@ class Float
73
73
  end
74
74
  end
75
75
  end
76
+
77
+ if RUBY_VERSION < "2.6.0"
78
+ class Hash
79
+ alias dnn__to_h to_h
80
+ def to_h(&block)
81
+ dnn__to_h unless block
82
+ map(&block).to_h
83
+ end
84
+ end
85
+ end
@@ -9,7 +9,7 @@ module DNN
9
9
  return nil unless hash
10
10
  optimizer_class = DNN.const_get(hash[:class])
11
11
  optimizer = optimizer_class.allocate
12
- raise DNN_Error, "#{optimizer.class} is not an instance of #{self} class." unless optimizer.is_a?(self)
12
+ raise DNNError, "#{optimizer.class} is not an instance of #{self} class." unless optimizer.is_a?(self)
13
13
  optimizer.load_hash(hash)
14
14
  optimizer
15
15
  end
@@ -49,7 +49,6 @@ module DNN
49
49
  private def clip_grads(params)
50
50
  norm = Math.sqrt(params.reduce(0) { |total, param| total + (param.grad**2).sum })
51
51
  return if norm <= @clip_norm
52
-
53
52
  rate = @clip_norm / (norm + 1e-7)
54
53
  params.each do |param|
55
54
  param.grad *= rate
@@ -18,7 +18,7 @@ module DNN
18
18
  elsif @data.shape == grad.shape[1..-1]
19
19
  @grad += grad.sum(0)
20
20
  else
21
- raise DNN_Error, "Shape is missmatch."
21
+ raise DNNError, "Shape is missmatch."
22
22
  end
23
23
  else
24
24
  @grad = Xumo::SFloat[0]
@@ -34,7 +34,7 @@ module DNN
34
34
  end
35
35
 
36
36
  def -@
37
- self * -1
37
+ Neg.(self)
38
38
  end
39
39
 
40
40
  def +(other)
@@ -8,7 +8,7 @@ module DNN
8
8
  return nil unless hash
9
9
  regularizer_class = DNN.const_get(hash[:class])
10
10
  regularizer = regularizer_class.allocate
11
- raise DNN_Error, "#{regularizer.class} is not an instance of #{self} class." unless regularizer.is_a?(self)
11
+ raise DNNError, "#{regularizer.class} is not an instance of #{self} class." unless regularizer.is_a?(self)
12
12
  regularizer.load_hash(hash)
13
13
  regularizer
14
14
  end
@@ -25,7 +25,7 @@ module DNN
25
25
  private def load_bin(bin)
26
26
  data = Marshal.load(Zlib::Inflate.inflate(bin))
27
27
  unless @model.class.name == data[:class]
28
- raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
28
+ raise DNNError, "Class name is mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
29
29
  end
30
30
  if data[:model]
31
31
  data[:model].instance_variables.each do |ivar|
@@ -43,7 +43,7 @@ module DNN
43
43
  def load_bin(bin)
44
44
  data = JSON.parse(bin, symbolize_names: true)
45
45
  unless @model.class.name == data[:class]
46
- raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
46
+ raise DNNError, "Class name is mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
47
47
  end
48
48
  set_all_params_base64_data(data[:params])
49
49
  end
@@ -25,7 +25,7 @@ module DNN
25
25
  end
26
26
 
27
27
  def -@
28
- self * -1
28
+ Neg.(self)
29
29
  end
30
30
 
31
31
  def +(other)
@@ -7,7 +7,7 @@ DIR_CIFAR10 = "cifar-10-batches-bin"
7
7
 
8
8
  module DNN
9
9
  module CIFAR10
10
- class DNN_CIFAR10_LoadError < DNN_Error; end
10
+ class DNN_CIFAR10_LoadError < DNNError; end
11
11
 
12
12
  def self.downloads
13
13
  return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_CIFAR10)
@@ -7,7 +7,7 @@ DIR_CIFAR100 = "cifar-100-binary"
7
7
 
8
8
  module DNN
9
9
  module CIFAR100
10
- class DNN_CIFAR100_LoadError < DNN_Error; end
10
+ class DNN_CIFAR100_LoadError < DNNError; end
11
11
 
12
12
  def self.downloads
13
13
  return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_CIFAR100)
@@ -3,7 +3,7 @@ require "net/http"
3
3
  module DNN
4
4
  DOWNLOADS_PATH = ENV["RUBY_DNN_DOWNLOADS_PATH"] || __dir__
5
5
 
6
- class DNN_DownloadError < DNN_Error; end
6
+ class DNN_DownloadError < DNNError; end
7
7
 
8
8
  class Downloader
9
9
  def self.download(url, dir_path = nil)
@@ -5,7 +5,7 @@ require_relative "mnist"
5
5
 
6
6
  module DNN
7
7
  module FashionMNIST
8
- class DNN_MNIST_LoadError < DNN_Error; end
8
+ class DNN_MNIST_LoadError < DNNError; end
9
9
 
10
10
  URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
11
11
 
@@ -2,7 +2,7 @@ require "csv"
2
2
  require_relative "downloader"
3
3
 
4
4
  module DNN
5
- class DNN_Iris_LoadError < DNN_Error; end
5
+ class DNN_Iris_LoadError < DNNError; end
6
6
 
7
7
  module Iris
8
8
  URL_CSV = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
@@ -4,7 +4,7 @@ require_relative "downloader"
4
4
 
5
5
  module DNN
6
6
  module MNIST
7
- class DNN_MNIST_LoadError < DNN_Error; end
7
+ class DNN_MNIST_LoadError < DNNError; end
8
8
 
9
9
  URL_BASE = "http://yann.lecun.com/exdb/mnist/"
10
10
 
@@ -7,7 +7,7 @@ DIR_STL10 = "stl10_binary"
7
7
 
8
8
  module DNN
9
9
  module STL10
10
- class DNN_STL10_LoadError < DNN_Error; end
10
+ class DNN_STL10_LoadError < DNNError; end
11
11
 
12
12
  def self.downloads
13
13
  return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_STL10)
@@ -49,7 +49,7 @@ module DNN
49
49
  end
50
50
 
51
51
  def self.load_unlabeled(range = 0...100000)
52
- raise DNN_Error, "Range must between 0 and 100000. (But the end is excluded)" unless range.begin >= 0 && range.end <= 100000
52
+ raise DNNError, "Range must between 0 and 100000. (But the end is excluded)" unless range.begin >= 0 && range.end <= 100000
53
53
  downloads
54
54
  x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/unlabeled_X.bin"
55
55
  raise DNN_STL10_LoadError, %`file "#{x_fname}" is not found.` unless File.exist?(x_fname)
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.16.2"
2
+ VERSION = "1.0.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.16.2
4
+ version: 1.0.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-01-11 00:00:00.000000000 Z
11
+ date: 2020-01-13 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray