ruby-dnn 0.14.3 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,11 @@
1
1
  module DNN
2
2
  class Param
3
- attr_accessor :name
4
3
  attr_accessor :data
5
4
  attr_accessor :grad
6
5
 
7
6
  def initialize(data = nil, grad = nil)
8
7
  @data = data
9
8
  @grad = grad
10
- @name = nil
11
9
  end
12
10
  end
13
11
  end
@@ -19,15 +19,6 @@ module DNN
19
19
  def load_bin(bin)
20
20
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'load_bin'"
21
21
  end
22
-
23
- def set_all_params_data(params_data)
24
- all_params = @model.has_param_layers.map { |layer|
25
- layer.get_params.values
26
- }.flatten
27
- all_params.each do |param|
28
- param.data = params_data[param.name]
29
- end
30
- end
31
22
  end
32
23
 
33
24
  class MarshalLoader < Loader
@@ -36,12 +27,13 @@ module DNN
36
27
  unless @model.class.name == data[:class]
37
28
  raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
38
29
  end
39
- opt = Optimizers::Optimizer.load(data[:optimizer])
40
- loss_func = Losses::Loss.from_hash(data[:loss_func])
41
- @model.setup(opt, loss_func)
42
- @model.instance_variable_set(:@built, false)
43
- @model.predict1(Xumo::SFloat.zeros(*data[:input_shape]))
44
- set_all_params_data(data[:params])
30
+ if data[:model]
31
+ data[:model].instance_variables.each do |ivar|
32
+ obj = data[:model].instance_variable_get(ivar)
33
+ @model.instance_variable_set(ivar, obj)
34
+ end
35
+ end
36
+ @model.set_all_params_data(data[:params])
45
37
  end
46
38
  end
47
39
 
@@ -50,20 +42,20 @@ module DNN
50
42
 
51
43
  def load_bin(bin)
52
44
  data = JSON.parse(bin, symbolize_names: true)
53
- opt = Optimizers::Optimizer.from_hash(data[:optimizer])
54
- loss_func = Losses::Loss.from_hash(data[:loss_func])
55
- @model.setup(opt, loss_func)
56
- @model.instance_variable_set(:@built, false)
57
- @model.predict1(Xumo::SFloat.zeros(*data[:input_shape]))
58
- base64_to_params_data(data[:params])
45
+ unless @model.class.name == data[:class]
46
+ raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
47
+ end
48
+ set_all_params_base64_data(data[:params])
59
49
  end
60
50
 
61
- def base64_to_params_data(base64_params_data)
62
- params_data = base64_params_data.to_h do |key, (shape, base64_data)|
63
- bin = Base64.decode64(base64_data)
64
- [key, Xumo::SFloat.from_binary(bin).reshape(*shape)]
51
+ def set_all_params_base64_data(params_data)
52
+ @model.trainable_layers.each.with_index do |layer, i|
53
+ params_data[i].each do |(key, (shape, base64_data))|
54
+ bin = Base64.decode64(base64_data)
55
+ data = Xumo::SFloat.from_binary(bin).reshape(*shape)
56
+ layer.get_params[key].data = data
57
+ end
65
58
  end
66
- set_all_params_data(params_data)
67
59
  end
68
60
  end
69
61
 
@@ -92,28 +84,28 @@ module DNN
92
84
  def dump_bin
93
85
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'dump_bin'"
94
86
  end
95
-
96
- def get_all_params_data
97
- all_params = @model.has_param_layers.map { |layer|
98
- layer.get_params.values
99
- }.flatten
100
- all_params.to_h { |param| [param.name, param.data] }
101
- end
102
87
  end
103
88
 
104
89
  class MarshalSaver < Saver
105
- def initialize(model, include_optimizer: true)
90
+ def initialize(model, include_model: true)
106
91
  super(model)
107
- @include_optimizer = include_optimizer
92
+ @include_model = include_model
108
93
  end
109
94
 
110
95
  private def dump_bin
111
- require_status = @include_optimizer ? true : false
112
- data = {
113
- version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape, params: get_all_params_data,
114
- optimizer: @model.optimizer.dump(require_status), loss_func: @model.loss_func.to_hash
115
- }
116
- Zlib::Deflate.deflate(Marshal.dump(data))
96
+ params_data = @model.get_all_params_data
97
+ if @include_model
98
+ @model.clean_layers
99
+ data = {
100
+ version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape,
101
+ params: params_data, model: @model
102
+ }
103
+ else
104
+ data = { version: VERSION, class: @model.class.name, params: params_data }
105
+ end
106
+ bin = Zlib::Deflate.deflate(Marshal.dump(data))
107
+ @model.set_all_params_data(params_data) if @include_model
108
+ bin
117
109
  end
118
110
  end
119
111
 
@@ -121,17 +113,16 @@ module DNN
121
113
  private
122
114
 
123
115
  def dump_bin
124
- data = {
125
- version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape, params: params_data_to_base64,
126
- optimizer: @model.optimizer.to_hash, loss_func: @model.loss_func.to_hash
127
- }
116
+ data = { version: VERSION, class: @model.class.name, params: get_all_params_base64_data }
128
117
  JSON.dump(data)
129
118
  end
130
119
 
131
- def params_data_to_base64
132
- get_all_params_data.to_h do |key, data|
133
- base64_data = Base64.encode64(data.to_binary)
134
- [key, [data.shape, base64_data]]
120
+ def get_all_params_base64_data
121
+ @model.trainable_layers.map do |layer|
122
+ layer.get_params.to_h do |key, param|
123
+ base64_data = Base64.encode64(param.data.to_binary)
124
+ [key, [param.data.shape, base64_data]]
125
+ end
135
126
  end
136
127
  end
137
128
  end
@@ -0,0 +1,65 @@
1
+ require "zlib"
2
+ require "archive/tar/minitar"
3
+ require_relative "downloader"
4
+
5
+ URL_STL10 = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
6
+ DIR_STL10 = "stl10_binary"
7
+
8
+ module DNN
9
+ module STL10
10
+ class DNN_STL10_LoadError < DNN_Error; end
11
+
12
+ def self.downloads
13
+ return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_STL10)
14
+ Downloader.download(URL_STL10)
15
+ stl10_binary_file_name = DOWNLOADS_PATH + "/downloads/" + URL_STL10.match(%r`.+/(.+)`)[1]
16
+ begin
17
+ Zlib::GzipReader.open(stl10_binary_file_name) do |gz|
18
+ Archive::Tar::Minitar.unpack(gz, DOWNLOADS_PATH + "/downloads")
19
+ end
20
+ ensure
21
+ File.unlink(stl10_binary_file_name)
22
+ end
23
+ end
24
+
25
+ def self.load_train
26
+ downloads
27
+ x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_X.bin"
28
+ raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
29
+ y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_y.bin"
30
+ raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
31
+ x_bin = File.binread(x_fname)
32
+ y_bin = File.binread(y_fname)
33
+ x_train = Numo::UInt8.from_binary(x_bin).reshape(5000, 3, 96, 96).transpose(0, 3, 2, 1).clone
34
+ y_train = Numo::UInt8.from_binary(y_bin)
35
+ [x_train, y_train]
36
+ end
37
+
38
+ def self.load_test
39
+ downloads
40
+ x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_X.bin"
41
+ raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
42
+ y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_y.bin"
43
+ raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
44
+ x_bin = File.binread(x_fname)
45
+ y_bin = File.binread(y_fname)
46
+ x_test = Numo::UInt8.from_binary(x_bin).reshape(8000, 3, 96, 96).transpose(0, 3, 2, 1).clone
47
+ y_test = Numo::UInt8.from_binary(y_bin)
48
+ [x_test, y_test]
49
+ end
50
+
51
+ def self.load_unlabeled(range = 0...100000)
52
+ unless 0 <= range.begin && range.end <= 100000
53
+ raise DNN_Error, "Range must between 0 and 100000. (But the end is excluded)"
54
+ end
55
+ downloads
56
+ x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/unlabeled_X.bin"
57
+ raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
58
+ num_datas = range.end - range.begin
59
+ length = num_datas * 3 * 96 * 96
60
+ ofs = range.begin * 3 * 96 * 96
61
+ x_bin = File.binread(x_fname, length, ofs)
62
+ Numo::UInt8.from_binary(x_bin).reshape(num_datas, 3, 96, 96).transpose(0, 3, 2, 1).clone
63
+ end
64
+ end
65
+ end
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.14.3"
2
+ VERSION = "0.15.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.14.3
4
+ version: 0.15.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-11-03 00:00:00.000000000 Z
11
+ date: 2019-11-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -118,24 +118,24 @@ files:
118
118
  - ext/rb_stb_image/extconf.rb
119
119
  - ext/rb_stb_image/rb_stb_image.c
120
120
  - lib/dnn.rb
121
- - lib/dnn/core/activations.rb
122
121
  - lib/dnn/core/callbacks.rb
123
- - lib/dnn/core/cnn_layers.rb
124
- - lib/dnn/core/embedding.rb
125
122
  - lib/dnn/core/error.rb
126
123
  - lib/dnn/core/global.rb
127
124
  - lib/dnn/core/initializers.rb
128
125
  - lib/dnn/core/iterator.rb
129
- - lib/dnn/core/layers.rb
126
+ - lib/dnn/core/layers/activations.rb
127
+ - lib/dnn/core/layers/basic_layers.rb
128
+ - lib/dnn/core/layers/cnn_layers.rb
129
+ - lib/dnn/core/layers/embedding.rb
130
+ - lib/dnn/core/layers/merge_layers.rb
131
+ - lib/dnn/core/layers/normalizations.rb
132
+ - lib/dnn/core/layers/rnn_layers.rb
130
133
  - lib/dnn/core/link.rb
131
134
  - lib/dnn/core/losses.rb
132
- - lib/dnn/core/merge_layers.rb
133
135
  - lib/dnn/core/models.rb
134
- - lib/dnn/core/normalizations.rb
135
136
  - lib/dnn/core/optimizers.rb
136
137
  - lib/dnn/core/param.rb
137
138
  - lib/dnn/core/regularizers.rb
138
- - lib/dnn/core/rnn_layers.rb
139
139
  - lib/dnn/core/savers.rb
140
140
  - lib/dnn/core/tensor.rb
141
141
  - lib/dnn/core/utils.rb
@@ -145,6 +145,7 @@ files:
145
145
  - lib/dnn/datasets/fashion-mnist.rb
146
146
  - lib/dnn/datasets/iris.rb
147
147
  - lib/dnn/datasets/mnist.rb
148
+ - lib/dnn/datasets/stl-10.rb
148
149
  - lib/dnn/image.rb
149
150
  - lib/dnn/version.rb
150
151
  - ruby-dnn.gemspec