ruby-dnn 0.14.3 → 0.15.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,13 +1,11 @@
1
1
  module DNN
2
2
  class Param
3
- attr_accessor :name
4
3
  attr_accessor :data
5
4
  attr_accessor :grad
6
5
 
7
6
  def initialize(data = nil, grad = nil)
8
7
  @data = data
9
8
  @grad = grad
10
- @name = nil
11
9
  end
12
10
  end
13
11
  end
@@ -19,15 +19,6 @@ module DNN
19
19
  def load_bin(bin)
20
20
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'load_bin'"
21
21
  end
22
-
23
- def set_all_params_data(params_data)
24
- all_params = @model.has_param_layers.map { |layer|
25
- layer.get_params.values
26
- }.flatten
27
- all_params.each do |param|
28
- param.data = params_data[param.name]
29
- end
30
- end
31
22
  end
32
23
 
33
24
  class MarshalLoader < Loader
@@ -36,12 +27,13 @@ module DNN
36
27
  unless @model.class.name == data[:class]
37
28
  raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
38
29
  end
39
- opt = Optimizers::Optimizer.load(data[:optimizer])
40
- loss_func = Losses::Loss.from_hash(data[:loss_func])
41
- @model.setup(opt, loss_func)
42
- @model.instance_variable_set(:@built, false)
43
- @model.predict1(Xumo::SFloat.zeros(*data[:input_shape]))
44
- set_all_params_data(data[:params])
30
+ if data[:model]
31
+ data[:model].instance_variables.each do |ivar|
32
+ obj = data[:model].instance_variable_get(ivar)
33
+ @model.instance_variable_set(ivar, obj)
34
+ end
35
+ end
36
+ @model.set_all_params_data(data[:params])
45
37
  end
46
38
  end
47
39
 
@@ -50,20 +42,20 @@ module DNN
50
42
 
51
43
  def load_bin(bin)
52
44
  data = JSON.parse(bin, symbolize_names: true)
53
- opt = Optimizers::Optimizer.from_hash(data[:optimizer])
54
- loss_func = Losses::Loss.from_hash(data[:loss_func])
55
- @model.setup(opt, loss_func)
56
- @model.instance_variable_set(:@built, false)
57
- @model.predict1(Xumo::SFloat.zeros(*data[:input_shape]))
58
- base64_to_params_data(data[:params])
45
+ unless @model.class.name == data[:class]
46
+ raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
47
+ end
48
+ set_all_params_base64_data(data[:params])
59
49
  end
60
50
 
61
- def base64_to_params_data(base64_params_data)
62
- params_data = base64_params_data.to_h do |key, (shape, base64_data)|
63
- bin = Base64.decode64(base64_data)
64
- [key, Xumo::SFloat.from_binary(bin).reshape(*shape)]
51
+ def set_all_params_base64_data(params_data)
52
+ @model.trainable_layers.each.with_index do |layer, i|
53
+ params_data[i].each do |(key, (shape, base64_data))|
54
+ bin = Base64.decode64(base64_data)
55
+ data = Xumo::SFloat.from_binary(bin).reshape(*shape)
56
+ layer.get_params[key].data = data
57
+ end
65
58
  end
66
- set_all_params_data(params_data)
67
59
  end
68
60
  end
69
61
 
@@ -92,28 +84,28 @@ module DNN
92
84
  def dump_bin
93
85
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'dump_bin'"
94
86
  end
95
-
96
- def get_all_params_data
97
- all_params = @model.has_param_layers.map { |layer|
98
- layer.get_params.values
99
- }.flatten
100
- all_params.to_h { |param| [param.name, param.data] }
101
- end
102
87
  end
103
88
 
104
89
  class MarshalSaver < Saver
105
- def initialize(model, include_optimizer: true)
90
+ def initialize(model, include_model: true)
106
91
  super(model)
107
- @include_optimizer = include_optimizer
92
+ @include_model = include_model
108
93
  end
109
94
 
110
95
  private def dump_bin
111
- require_status = @include_optimizer ? true : false
112
- data = {
113
- version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape, params: get_all_params_data,
114
- optimizer: @model.optimizer.dump(require_status), loss_func: @model.loss_func.to_hash
115
- }
116
- Zlib::Deflate.deflate(Marshal.dump(data))
96
+ params_data = @model.get_all_params_data
97
+ if @include_model
98
+ @model.clean_layers
99
+ data = {
100
+ version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape,
101
+ params: params_data, model: @model
102
+ }
103
+ else
104
+ data = { version: VERSION, class: @model.class.name, params: params_data }
105
+ end
106
+ bin = Zlib::Deflate.deflate(Marshal.dump(data))
107
+ @model.set_all_params_data(params_data) if @include_model
108
+ bin
117
109
  end
118
110
  end
119
111
 
@@ -121,17 +113,16 @@ module DNN
121
113
  private
122
114
 
123
115
  def dump_bin
124
- data = {
125
- version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape, params: params_data_to_base64,
126
- optimizer: @model.optimizer.to_hash, loss_func: @model.loss_func.to_hash
127
- }
116
+ data = { version: VERSION, class: @model.class.name, params: get_all_params_base64_data }
128
117
  JSON.dump(data)
129
118
  end
130
119
 
131
- def params_data_to_base64
132
- get_all_params_data.to_h do |key, data|
133
- base64_data = Base64.encode64(data.to_binary)
134
- [key, [data.shape, base64_data]]
120
+ def get_all_params_base64_data
121
+ @model.trainable_layers.map do |layer|
122
+ layer.get_params.to_h do |key, param|
123
+ base64_data = Base64.encode64(param.data.to_binary)
124
+ [key, [param.data.shape, base64_data]]
125
+ end
135
126
  end
136
127
  end
137
128
  end
@@ -0,0 +1,65 @@
1
+ require "zlib"
2
+ require "archive/tar/minitar"
3
+ require_relative "downloader"
4
+
5
+ URL_STL10 = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
6
+ DIR_STL10 = "stl10_binary"
7
+
8
+ module DNN
9
+ module STL10
10
+ class DNN_STL10_LoadError < DNN_Error; end
11
+
12
+ def self.downloads
13
+ return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_STL10)
14
+ Downloader.download(URL_STL10)
15
+ stl10_binary_file_name = DOWNLOADS_PATH + "/downloads/" + URL_STL10.match(%r`.+/(.+)`)[1]
16
+ begin
17
+ Zlib::GzipReader.open(stl10_binary_file_name) do |gz|
18
+ Archive::Tar::Minitar.unpack(gz, DOWNLOADS_PATH + "/downloads")
19
+ end
20
+ ensure
21
+ File.unlink(stl10_binary_file_name)
22
+ end
23
+ end
24
+
25
+ def self.load_train
26
+ downloads
27
+ x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_X.bin"
28
+ raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
29
+ y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_y.bin"
30
+ raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
31
+ x_bin = File.binread(x_fname)
32
+ y_bin = File.binread(y_fname)
33
+ x_train = Numo::UInt8.from_binary(x_bin).reshape(5000, 3, 96, 96).transpose(0, 3, 2, 1).clone
34
+ y_train = Numo::UInt8.from_binary(y_bin)
35
+ [x_train, y_train]
36
+ end
37
+
38
+ def self.load_test
39
+ downloads
40
+ x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_X.bin"
41
+ raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
42
+ y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_y.bin"
43
+ raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
44
+ x_bin = File.binread(x_fname)
45
+ y_bin = File.binread(y_fname)
46
+ x_test = Numo::UInt8.from_binary(x_bin).reshape(8000, 3, 96, 96).transpose(0, 3, 2, 1).clone
47
+ y_test = Numo::UInt8.from_binary(y_bin)
48
+ [x_test, y_test]
49
+ end
50
+
51
+ def self.load_unlabeled(range = 0...100000)
52
+ unless 0 <= range.begin && range.end <= 100000
53
+ raise DNN_Error, "Range must between 0 and 100000. (But the end is excluded)"
54
+ end
55
+ downloads
56
+ x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/unlabeled_X.bin"
57
+ raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
58
+ num_datas = range.end - range.begin
59
+ length = num_datas * 3 * 96 * 96
60
+ ofs = range.begin * 3 * 96 * 96
61
+ x_bin = File.binread(x_fname, length, ofs)
62
+ Numo::UInt8.from_binary(x_bin).reshape(num_datas, 3, 96, 96).transpose(0, 3, 2, 1).clone
63
+ end
64
+ end
65
+ end
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.14.3"
2
+ VERSION = "0.15.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.14.3
4
+ version: 0.15.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-11-03 00:00:00.000000000 Z
11
+ date: 2019-11-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -118,24 +118,24 @@ files:
118
118
  - ext/rb_stb_image/extconf.rb
119
119
  - ext/rb_stb_image/rb_stb_image.c
120
120
  - lib/dnn.rb
121
- - lib/dnn/core/activations.rb
122
121
  - lib/dnn/core/callbacks.rb
123
- - lib/dnn/core/cnn_layers.rb
124
- - lib/dnn/core/embedding.rb
125
122
  - lib/dnn/core/error.rb
126
123
  - lib/dnn/core/global.rb
127
124
  - lib/dnn/core/initializers.rb
128
125
  - lib/dnn/core/iterator.rb
129
- - lib/dnn/core/layers.rb
126
+ - lib/dnn/core/layers/activations.rb
127
+ - lib/dnn/core/layers/basic_layers.rb
128
+ - lib/dnn/core/layers/cnn_layers.rb
129
+ - lib/dnn/core/layers/embedding.rb
130
+ - lib/dnn/core/layers/merge_layers.rb
131
+ - lib/dnn/core/layers/normalizations.rb
132
+ - lib/dnn/core/layers/rnn_layers.rb
130
133
  - lib/dnn/core/link.rb
131
134
  - lib/dnn/core/losses.rb
132
- - lib/dnn/core/merge_layers.rb
133
135
  - lib/dnn/core/models.rb
134
- - lib/dnn/core/normalizations.rb
135
136
  - lib/dnn/core/optimizers.rb
136
137
  - lib/dnn/core/param.rb
137
138
  - lib/dnn/core/regularizers.rb
138
- - lib/dnn/core/rnn_layers.rb
139
139
  - lib/dnn/core/savers.rb
140
140
  - lib/dnn/core/tensor.rb
141
141
  - lib/dnn/core/utils.rb
@@ -145,6 +145,7 @@ files:
145
145
  - lib/dnn/datasets/fashion-mnist.rb
146
146
  - lib/dnn/datasets/iris.rb
147
147
  - lib/dnn/datasets/mnist.rb
148
+ - lib/dnn/datasets/stl-10.rb
148
149
  - lib/dnn/image.rb
149
150
  - lib/dnn/version.rb
150
151
  - ruby-dnn.gemspec