ruby-dnn 0.12.4 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,8 +3,19 @@ module DNN
3
3
 
4
4
  # Super class of all optimizer classes.
5
5
  class Optimizer
6
+ attr_reader :status
6
7
  attr_accessor :clip_norm
7
8
 
9
+ def self.load(dumped)
10
+ opt = Utils.hash_to_obj(dumped[:hash])
11
+ dumped[:status].each do |key, state|
12
+ state = state.clone
13
+ opt.status[key] = state
14
+ opt.instance_variable_set("@#{key}", state)
15
+ end
16
+ opt
17
+ end
18
+
8
19
  # @param [Float | NilClass] clip_norm Gradient clip norm.
9
20
  def initialize(clip_norm: nil)
10
21
  @clip_norm = clip_norm
@@ -22,6 +33,10 @@ module DNN
22
33
  end
23
34
  end
24
35
 
36
+ def dump
37
+ { hash: to_hash, status: @status }
38
+ end
39
+
25
40
  def to_hash(merge_hash = nil)
26
41
  hash = { class: self.class.name, clip_norm: @clip_norm }
27
42
  hash.merge!(merge_hash) if merge_hash
@@ -59,6 +74,7 @@ module DNN
59
74
  @lr = lr
60
75
  @momentum = momentum
61
76
  @v = {}
77
+ @status = { v: @v }
62
78
  end
63
79
 
64
80
  def to_hash
@@ -69,9 +85,9 @@ module DNN
69
85
  params.each do |param|
70
86
  amount = param.grad * @lr
71
87
  if @momentum > 0
72
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
73
- amount += @momentum * @v[param]
74
- @v[param] = amount
88
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
89
+ amount += @momentum * @v[param.name]
90
+ @v[param.name] = amount
75
91
  end
76
92
  param.data -= amount
77
93
  end
@@ -79,33 +95,17 @@ module DNN
79
95
  end
80
96
 
81
97
 
82
- class Nesterov < Optimizer
83
- attr_accessor :lr
84
- attr_accessor :momentum
85
-
86
- def self.from_hash(hash)
87
- self.new(hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm])
88
- end
89
-
90
- # @param [Float] lr Learning rate.
91
- # @param [Float] momentum Momentum coefficient.
98
+ class Nesterov < SGD
92
99
  def initialize(lr = 0.01, momentum: 0.9, clip_norm: nil)
93
- super(clip_norm: clip_norm)
94
- @lr = lr
95
- @momentum = momentum
96
- @v = {}
97
- end
98
-
99
- def to_hash
100
- super(lr: @lr, momentum: @momentum)
100
+ super(lr, momentum: momentum, clip_norm: clip_norm)
101
101
  end
102
102
 
103
103
  private def update_params(params)
104
104
  params.each do |param|
105
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
105
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
106
106
  amount = param.grad * @lr
107
- @v[param] = @v[param] * @momentum - amount
108
- param.data = (param.data + @momentum ** 2 * @v[param]) - (1 + @momentum) * amount
107
+ @v[param.name] = @v[param.name] * @momentum - amount
108
+ param.data = (param.data + @momentum ** 2 * @v[param.name]) - (1 + @momentum) * amount
109
109
  end
110
110
  end
111
111
  end
@@ -126,13 +126,14 @@ module DNN
126
126
  @lr = lr
127
127
  @eps = eps
128
128
  @g = {}
129
+ @status = { g: @g }
129
130
  end
130
131
 
131
132
  private def update_params(params)
132
133
  params.each do |param|
133
- @g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
134
- @g[param] += param.grad ** 2
135
- param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
134
+ @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
135
+ @g[param.name] += param.grad ** 2
136
+ param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad
136
137
  end
137
138
  end
138
139
 
@@ -160,6 +161,7 @@ module DNN
160
161
  @alpha = alpha
161
162
  @eps = eps
162
163
  @g = {}
164
+ @status = { g: @g }
163
165
  end
164
166
 
165
167
  def to_hash
@@ -168,9 +170,9 @@ module DNN
168
170
 
169
171
  private def update_params(params)
170
172
  params.each do |param|
171
- @g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
172
- @g[param] = @alpha * @g[param] + (1 - @alpha) * param.grad ** 2
173
- param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
173
+ @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
174
+ @g[param.name] = @alpha * @g[param.name] + (1 - @alpha) * param.grad ** 2
175
+ param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad
174
176
  end
175
177
  end
176
178
  end
@@ -192,6 +194,7 @@ module DNN
192
194
  @eps = eps
193
195
  @h = {}
194
196
  @s = {}
197
+ @status = { h: @h, s: @s }
195
198
  end
196
199
 
197
200
  def to_hash
@@ -200,11 +203,11 @@ module DNN
200
203
 
201
204
  private def update_params(params)
202
205
  params.each do |param|
203
- @h[param] ||= Xumo::SFloat.zeros(*param.data.shape)
204
- @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
205
- @h[param] = @rho * @h[param] + (1 - @rho) * param.grad ** 2
206
- v = (Xumo::NMath.sqrt(@s[param] + @eps) / Xumo::NMath.sqrt(@h[param] + @eps)) * param.grad
207
- @s[param] = @rho * @s[param] + (1 - @rho) * v ** 2
206
+ @h[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
207
+ @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
208
+ @h[param.name] = @rho * @h[param.name] + (1 - @rho) * param.grad ** 2
209
+ v = (Xumo::NMath.sqrt(@s[param.name] + @eps) / Xumo::NMath.sqrt(@h[param.name] + @eps)) * param.grad
210
+ @s[param.name] = @rho * @s[param.name] + (1 - @rho) * v ** 2
208
211
  param.data -= v
209
212
  end
210
213
  end
@@ -230,6 +233,7 @@ module DNN
230
233
  @eps = eps
231
234
  @m = {}
232
235
  @v = {}
236
+ @status = { m: @m, v: @v }
233
237
  end
234
238
 
235
239
  def to_hash
@@ -238,11 +242,11 @@ module DNN
238
242
 
239
243
  private def update_params(params)
240
244
  params.each do |param|
241
- @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
242
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
243
- @m[param] = @alpha * @m[param] + (1 - @alpha) * param.grad
244
- @v[param] = @alpha * @v[param] + (1 - @alpha) * param.grad ** 2
245
- param.data -= (@lr / Xumo::NMath.sqrt(@v[param] - @m[param] ** 2 + @eps)) * param.grad
245
+ @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
246
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
247
+ @m[param.name] = @alpha * @m[param.name] + (1 - @alpha) * param.grad
248
+ @v[param.name] = @alpha * @v[param.name] + (1 - @alpha) * param.grad ** 2
249
+ param.data -= (@lr / Xumo::NMath.sqrt(@v[param.name] - @m[param.name] ** 2 + @eps)) * param.grad
246
250
  end
247
251
  end
248
252
  end
@@ -275,7 +279,8 @@ module DNN
275
279
  @t = 0
276
280
  @m = {}
277
281
  @v = {}
278
- @s = {} if amsgrad
282
+ @s = amsgrad ? {} : nil
283
+ @status = { t: @t, m: @m, v: @v, s: @s }
279
284
  end
280
285
 
281
286
  def to_hash
@@ -289,16 +294,16 @@ module DNN
289
294
  @t += 1
290
295
  lr = @alpha * Math.sqrt(1 - @beta2 ** @t) / (1 - @beta1 ** @t)
291
296
  params.each do |param|
292
- @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
293
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
294
- @m[param] += (1 - @beta1) * (param.grad - @m[param])
295
- @v[param] += (1 - @beta2) * (param.grad ** 2 - @v[param])
297
+ @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
298
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
299
+ @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name])
300
+ @v[param.name] += (1 - @beta2) * (param.grad ** 2 - @v[param.name])
296
301
  if @amsgrad
297
- @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
298
- @s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
299
- param.data -= lr * @m[param] / Xumo::NMath.sqrt(@s[param] + @eps)
302
+ @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
303
+ @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name])
304
+ param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@s[param.name] + @eps)
300
305
  else
301
- param.data -= lr * @m[param] / Xumo::NMath.sqrt(@v[param] + @eps)
306
+ param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@v[param.name] + @eps)
302
307
  end
303
308
  end
304
309
  end
@@ -336,16 +341,16 @@ module DNN
336
341
  lower_bound = final_lr * (1 - 1 / (@gamma * @t + 1))
337
342
  upper_bound = final_lr * (1 + 1 / (@gamma * @t))
338
343
  params.each do |param|
339
- @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
340
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
341
- @m[param] += (1 - @beta1) * (param.grad - @m[param])
342
- @v[param] += (1 - @beta2) * (param.grad ** 2 - @v[param])
344
+ @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
345
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
346
+ @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name])
347
+ @v[param.name] += (1 - @beta2) * (param.grad ** 2 - @v[param.name])
343
348
  if @amsgrad
344
- @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
345
- @s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
346
- param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param]) + @eps), lower_bound, upper_bound) * @m[param]
349
+ @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
350
+ @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name])
351
+ param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name]
347
352
  else
348
- param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param]) + @eps), lower_bound, upper_bound) * @m[param]
353
+ param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name]
349
354
  end
350
355
  end
351
356
  end
@@ -1,11 +1,13 @@
1
1
  module DNN
2
2
  class Param
3
+ attr_accessor :name
3
4
  attr_accessor :data
4
5
  attr_accessor :grad
5
6
 
6
7
  def initialize(data = nil, grad = nil)
7
8
  @data = data
8
9
  @grad = grad
10
+ @name = nil
9
11
  end
10
12
  end
11
13
  end
@@ -0,0 +1,138 @@
1
+ require "zlib"
2
+ require "json"
3
+ require "base64"
4
+
5
+ module DNN
6
+ module Loaders
7
+
8
+ class Loader
9
+ def initialize(model)
10
+ @model = model
11
+ end
12
+
13
+ def load(file_name)
14
+ load_bin(File.binread(file_name))
15
+ end
16
+
17
+ private
18
+
19
+ def load_bin(bin)
20
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'load_bin'")
21
+ end
22
+
23
+ def set_all_params_data(params_data)
24
+ all_params = @model.has_param_layers.uniq.map { |layer|
25
+ layer.get_params.values
26
+ }.flatten
27
+ all_params.each do |param|
28
+ param.data = params_data[param.name]
29
+ end
30
+ end
31
+ end
32
+
33
+
34
+ class MarshalLoader < Loader
35
+ private def load_bin(bin)
36
+ data = Marshal.load(Zlib::Inflate.inflate(bin))
37
+ opt = Optimizers::Optimizer.load(data[:optimizer])
38
+ loss_func = Utils.hash_to_obj(data[:loss_func])
39
+ @model.setup(opt, loss_func)
40
+ @model.predict1(Xumo::SFloat.zeros(*data[:input_shape]))
41
+ set_all_params_data(data[:params])
42
+ end
43
+ end
44
+
45
+ class JSONLoader < Loader
46
+ private
47
+
48
+ def load_bin(bin)
49
+ data = JSON.parse(bin, symbolize_names: true)
50
+ opt = Utils.hash_to_obj(data[:optimizer])
51
+ loss_func = Utils.hash_to_obj(data[:loss_func])
52
+ @model.setup(opt, loss_func)
53
+ @model.predict1(Xumo::SFloat.zeros(*data[:input_shape]))
54
+ base64_to_params_data(data[:params])
55
+ end
56
+
57
+ def base64_to_params_data(base64_params_data)
58
+ params_data = base64_params_data.map { |key, (shape, base64_data)|
59
+ bin = Base64.decode64(base64_data)
60
+ [key, Xumo::SFloat.from_binary(bin).reshape(*shape)]
61
+ }.to_h
62
+ set_all_params_data(params_data)
63
+ end
64
+ end
65
+
66
+ end
67
+
68
+
69
+ module Savers
70
+
71
+ class Saver
72
+ def initialize(model)
73
+ @model = model
74
+ end
75
+
76
+ def save(file_name)
77
+ bin = dump_bin
78
+ begin
79
+ File.binwrite(file_name, bin)
80
+ rescue Errno::ENOENT
81
+ dir_name = file_name.match(%r`(.*)/.+$`)[1]
82
+ Dir.mkdir(dir_name)
83
+ File.binwrite(file_name, bin)
84
+ end
85
+ end
86
+
87
+ private
88
+
89
+ def dump_bin
90
+ raise NotImplementedError.new("Class '#{self.class.name}' has implement method 'dump_bin'")
91
+ end
92
+
93
+ def get_all_params_data
94
+ all_params = @model.has_param_layers.uniq.map { |layer|
95
+ layer.get_params.values
96
+ }.flatten
97
+ all_params.map { |param| [param.name, param.data] }.to_h
98
+ end
99
+ end
100
+
101
+
102
+ class MarshalSaver < Saver
103
+ def initialize(model, include_optimizer: true)
104
+ super(model)
105
+ @include_optimizer = include_optimizer
106
+ end
107
+
108
+ private def dump_bin
109
+ opt = @include_optimizer ? @model.optimizer.dump : @model.optimizer.class.new.dump
110
+ data = {
111
+ version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape, params: get_all_params_data,
112
+ optimizer: opt, loss_func: @model.loss_func.to_hash
113
+ }
114
+ Zlib::Deflate.deflate(Marshal.dump(data))
115
+ end
116
+ end
117
+
118
+ class JSONSaver < Saver
119
+ private
120
+
121
+ def dump_bin
122
+ data = {
123
+ version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape, params: params_data_to_base64,
124
+ optimizer: @model.optimizer.to_hash, loss_func: @model.loss_func.to_hash
125
+ }
126
+ JSON.dump(data)
127
+ end
128
+
129
+ def params_data_to_base64
130
+ get_all_params_data.map { |key, data|
131
+ base64_data = Base64.encode64(data.to_binary)
132
+ [key, [data.shape, base64_data]]
133
+ }.to_h
134
+ end
135
+ end
136
+
137
+ end
138
+ end
@@ -1,6 +1,6 @@
1
1
  require "zlib"
2
2
  require "archive/tar/minitar"
3
- require_relative "../../ext/cifar_loader/cifar_loader"
3
+ require_relative "../../../ext/cifar_loader/cifar_loader"
4
4
  require_relative "downloader"
5
5
 
6
6
  URL_CIFAR10 = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
@@ -1,6 +1,6 @@
1
1
  require "zlib"
2
2
  require "archive/tar/minitar"
3
- require_relative "../../ext/cifar_loader/cifar_loader"
3
+ require_relative "../../../ext/cifar_loader/cifar_loader"
4
4
  require_relative "downloader"
5
5
 
6
6
  URL_CIFAR100 = "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
File without changes
@@ -1,5 +1,5 @@
1
1
  require "zlib"
2
- require_relative "core/error"
2
+ require_relative "../core/error"
3
3
  require_relative "downloader"
4
4
  require_relative "mnist"
5
5
 
File without changes
@@ -1,5 +1,5 @@
1
1
  require "zlib"
2
- require_relative "core/error"
2
+ require_relative "../core/error"
3
3
  require_relative "downloader"
4
4
 
5
5
  module DNN
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.12.4"
2
+ VERSION = "0.13.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.12.4
4
+ version: 0.13.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-09-08 00:00:00.000000000 Z
11
+ date: 2019-09-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -114,12 +114,11 @@ files:
114
114
  - ext/rb_stb_image/extconf.rb
115
115
  - ext/rb_stb_image/rb_stb_image.c
116
116
  - lib/dnn.rb
117
- - lib/dnn/cifar10.rb
118
- - lib/dnn/cifar100.rb
119
117
  - lib/dnn/core/activations.rb
120
118
  - lib/dnn/core/cnn_layers.rb
121
119
  - lib/dnn/core/embedding.rb
122
120
  - lib/dnn/core/error.rb
121
+ - lib/dnn/core/global.rb
123
122
  - lib/dnn/core/initializers.rb
124
123
  - lib/dnn/core/iterator.rb
125
124
  - lib/dnn/core/layers.rb
@@ -132,12 +131,15 @@ files:
132
131
  - lib/dnn/core/param.rb
133
132
  - lib/dnn/core/regularizers.rb
134
133
  - lib/dnn/core/rnn_layers.rb
134
+ - lib/dnn/core/savers.rb
135
135
  - lib/dnn/core/utils.rb
136
- - lib/dnn/downloader.rb
137
- - lib/dnn/fashion-mnist.rb
136
+ - lib/dnn/datasets/cifar10.rb
137
+ - lib/dnn/datasets/cifar100.rb
138
+ - lib/dnn/datasets/downloader.rb
139
+ - lib/dnn/datasets/fashion-mnist.rb
140
+ - lib/dnn/datasets/iris.rb
141
+ - lib/dnn/datasets/mnist.rb
138
142
  - lib/dnn/image.rb
139
- - lib/dnn/iris.rb
140
- - lib/dnn/mnist.rb
141
143
  - lib/dnn/version.rb
142
144
  - ruby-dnn.gemspec
143
145
  - third_party/stb_image.h