ruby-dnn 0.14.3 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +5 -3
- data/Rakefile +4 -2
- data/examples/api-examples/save_example.rb +7 -5
- data/examples/dcgan/imgen.rb +2 -7
- data/examples/dcgan/train.rb +0 -1
- data/lib/dnn.rb +10 -10
- data/lib/dnn/core/callbacks.rb +6 -2
- data/lib/dnn/core/iterator.rb +10 -2
- data/lib/dnn/core/{activations.rb → layers/activations.rb} +0 -0
- data/lib/dnn/core/{layers.rb → layers/basic_layers.rb} +31 -8
- data/lib/dnn/core/{cnn_layers.rb → layers/cnn_layers.rb} +0 -0
- data/lib/dnn/core/{embedding.rb → layers/embedding.rb} +5 -4
- data/lib/dnn/core/{merge_layers.rb → layers/merge_layers.rb} +1 -1
- data/lib/dnn/core/{normalizations.rb → layers/normalizations.rb} +9 -5
- data/lib/dnn/core/{rnn_layers.rb → layers/rnn_layers.rb} +25 -16
- data/lib/dnn/core/losses.rb +8 -0
- data/lib/dnn/core/models.rb +164 -68
- data/lib/dnn/core/optimizers.rb +49 -72
- data/lib/dnn/core/param.rb +0 -2
- data/lib/dnn/core/savers.rb +40 -49
- data/lib/dnn/datasets/stl-10.rb +65 -0
- data/lib/dnn/version.rb +1 -1
- metadata +10 -9
data/lib/dnn/core/param.rb
CHANGED
data/lib/dnn/core/savers.rb
CHANGED
@@ -19,15 +19,6 @@ module DNN
|
|
19
19
|
def load_bin(bin)
|
20
20
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'load_bin'"
|
21
21
|
end
|
22
|
-
|
23
|
-
def set_all_params_data(params_data)
|
24
|
-
all_params = @model.has_param_layers.map { |layer|
|
25
|
-
layer.get_params.values
|
26
|
-
}.flatten
|
27
|
-
all_params.each do |param|
|
28
|
-
param.data = params_data[param.name]
|
29
|
-
end
|
30
|
-
end
|
31
22
|
end
|
32
23
|
|
33
24
|
class MarshalLoader < Loader
|
@@ -36,12 +27,13 @@ module DNN
|
|
36
27
|
unless @model.class.name == data[:class]
|
37
28
|
raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
|
38
29
|
end
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
30
|
+
if data[:model]
|
31
|
+
data[:model].instance_variables.each do |ivar|
|
32
|
+
obj = data[:model].instance_variable_get(ivar)
|
33
|
+
@model.instance_variable_set(ivar, obj)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
@model.set_all_params_data(data[:params])
|
45
37
|
end
|
46
38
|
end
|
47
39
|
|
@@ -50,20 +42,20 @@ module DNN
|
|
50
42
|
|
51
43
|
def load_bin(bin)
|
52
44
|
data = JSON.parse(bin, symbolize_names: true)
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
@model.predict1(Xumo::SFloat.zeros(*data[:input_shape]))
|
58
|
-
base64_to_params_data(data[:params])
|
45
|
+
unless @model.class.name == data[:class]
|
46
|
+
raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
|
47
|
+
end
|
48
|
+
set_all_params_base64_data(data[:params])
|
59
49
|
end
|
60
50
|
|
61
|
-
def
|
62
|
-
|
63
|
-
|
64
|
-
|
51
|
+
def set_all_params_base64_data(params_data)
|
52
|
+
@model.trainable_layers.each.with_index do |layer, i|
|
53
|
+
params_data[i].each do |(key, (shape, base64_data))|
|
54
|
+
bin = Base64.decode64(base64_data)
|
55
|
+
data = Xumo::SFloat.from_binary(bin).reshape(*shape)
|
56
|
+
layer.get_params[key].data = data
|
57
|
+
end
|
65
58
|
end
|
66
|
-
set_all_params_data(params_data)
|
67
59
|
end
|
68
60
|
end
|
69
61
|
|
@@ -92,28 +84,28 @@ module DNN
|
|
92
84
|
def dump_bin
|
93
85
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'dump_bin'"
|
94
86
|
end
|
95
|
-
|
96
|
-
def get_all_params_data
|
97
|
-
all_params = @model.has_param_layers.map { |layer|
|
98
|
-
layer.get_params.values
|
99
|
-
}.flatten
|
100
|
-
all_params.to_h { |param| [param.name, param.data] }
|
101
|
-
end
|
102
87
|
end
|
103
88
|
|
104
89
|
class MarshalSaver < Saver
|
105
|
-
def initialize(model,
|
90
|
+
def initialize(model, include_model: true)
|
106
91
|
super(model)
|
107
|
-
@
|
92
|
+
@include_model = include_model
|
108
93
|
end
|
109
94
|
|
110
95
|
private def dump_bin
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
96
|
+
params_data = @model.get_all_params_data
|
97
|
+
if @include_model
|
98
|
+
@model.clean_layers
|
99
|
+
data = {
|
100
|
+
version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape,
|
101
|
+
params: params_data, model: @model
|
102
|
+
}
|
103
|
+
else
|
104
|
+
data = { version: VERSION, class: @model.class.name, params: params_data }
|
105
|
+
end
|
106
|
+
bin = Zlib::Deflate.deflate(Marshal.dump(data))
|
107
|
+
@model.set_all_params_data(params_data) if @include_model
|
108
|
+
bin
|
117
109
|
end
|
118
110
|
end
|
119
111
|
|
@@ -121,17 +113,16 @@ module DNN
|
|
121
113
|
private
|
122
114
|
|
123
115
|
def dump_bin
|
124
|
-
data = {
|
125
|
-
version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape, params: params_data_to_base64,
|
126
|
-
optimizer: @model.optimizer.to_hash, loss_func: @model.loss_func.to_hash
|
127
|
-
}
|
116
|
+
data = { version: VERSION, class: @model.class.name, params: get_all_params_base64_data }
|
128
117
|
JSON.dump(data)
|
129
118
|
end
|
130
119
|
|
131
|
-
def
|
132
|
-
|
133
|
-
|
134
|
-
|
120
|
+
def get_all_params_base64_data
|
121
|
+
@model.trainable_layers.map do |layer|
|
122
|
+
layer.get_params.to_h do |key, param|
|
123
|
+
base64_data = Base64.encode64(param.data.to_binary)
|
124
|
+
[key, [param.data.shape, base64_data]]
|
125
|
+
end
|
135
126
|
end
|
136
127
|
end
|
137
128
|
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
require "zlib"
|
2
|
+
require "archive/tar/minitar"
|
3
|
+
require_relative "downloader"
|
4
|
+
|
5
|
+
URL_STL10 = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
|
6
|
+
DIR_STL10 = "stl10_binary"
|
7
|
+
|
8
|
+
module DNN
|
9
|
+
module STL10
|
10
|
+
class DNN_STL10_LoadError < DNN_Error; end
|
11
|
+
|
12
|
+
def self.downloads
|
13
|
+
return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_STL10)
|
14
|
+
Downloader.download(URL_STL10)
|
15
|
+
stl10_binary_file_name = DOWNLOADS_PATH + "/downloads/" + URL_STL10.match(%r`.+/(.+)`)[1]
|
16
|
+
begin
|
17
|
+
Zlib::GzipReader.open(stl10_binary_file_name) do |gz|
|
18
|
+
Archive::Tar::Minitar.unpack(gz, DOWNLOADS_PATH + "/downloads")
|
19
|
+
end
|
20
|
+
ensure
|
21
|
+
File.unlink(stl10_binary_file_name)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def self.load_train
|
26
|
+
downloads
|
27
|
+
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_X.bin"
|
28
|
+
raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
|
29
|
+
y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_y.bin"
|
30
|
+
raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
|
31
|
+
x_bin = File.binread(x_fname)
|
32
|
+
y_bin = File.binread(y_fname)
|
33
|
+
x_train = Numo::UInt8.from_binary(x_bin).reshape(5000, 3, 96, 96).transpose(0, 3, 2, 1).clone
|
34
|
+
y_train = Numo::UInt8.from_binary(y_bin)
|
35
|
+
[x_train, y_train]
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.load_test
|
39
|
+
downloads
|
40
|
+
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_X.bin"
|
41
|
+
raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
|
42
|
+
y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_y.bin"
|
43
|
+
raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
|
44
|
+
x_bin = File.binread(x_fname)
|
45
|
+
y_bin = File.binread(y_fname)
|
46
|
+
x_test = Numo::UInt8.from_binary(x_bin).reshape(8000, 3, 96, 96).transpose(0, 3, 2, 1).clone
|
47
|
+
y_test = Numo::UInt8.from_binary(y_bin)
|
48
|
+
[x_test, y_test]
|
49
|
+
end
|
50
|
+
|
51
|
+
def self.load_unlabeled(range = 0...100000)
|
52
|
+
unless 0 <= range.begin && range.end <= 100000
|
53
|
+
raise DNN_Error, "Range must between 0 and 100000. (But the end is excluded)"
|
54
|
+
end
|
55
|
+
downloads
|
56
|
+
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/unlabeled_X.bin"
|
57
|
+
raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
|
58
|
+
num_datas = range.end - range.begin
|
59
|
+
length = num_datas * 3 * 96 * 96
|
60
|
+
ofs = range.begin * 3 * 96 * 96
|
61
|
+
x_bin = File.binread(x_fname, length, ofs)
|
62
|
+
Numo::UInt8.from_binary(x_bin).reshape(num_datas, 3, 96, 96).transpose(0, 3, 2, 1).clone
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.15.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-11-
|
11
|
+
date: 2019-11-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -118,24 +118,24 @@ files:
|
|
118
118
|
- ext/rb_stb_image/extconf.rb
|
119
119
|
- ext/rb_stb_image/rb_stb_image.c
|
120
120
|
- lib/dnn.rb
|
121
|
-
- lib/dnn/core/activations.rb
|
122
121
|
- lib/dnn/core/callbacks.rb
|
123
|
-
- lib/dnn/core/cnn_layers.rb
|
124
|
-
- lib/dnn/core/embedding.rb
|
125
122
|
- lib/dnn/core/error.rb
|
126
123
|
- lib/dnn/core/global.rb
|
127
124
|
- lib/dnn/core/initializers.rb
|
128
125
|
- lib/dnn/core/iterator.rb
|
129
|
-
- lib/dnn/core/layers.rb
|
126
|
+
- lib/dnn/core/layers/activations.rb
|
127
|
+
- lib/dnn/core/layers/basic_layers.rb
|
128
|
+
- lib/dnn/core/layers/cnn_layers.rb
|
129
|
+
- lib/dnn/core/layers/embedding.rb
|
130
|
+
- lib/dnn/core/layers/merge_layers.rb
|
131
|
+
- lib/dnn/core/layers/normalizations.rb
|
132
|
+
- lib/dnn/core/layers/rnn_layers.rb
|
130
133
|
- lib/dnn/core/link.rb
|
131
134
|
- lib/dnn/core/losses.rb
|
132
|
-
- lib/dnn/core/merge_layers.rb
|
133
135
|
- lib/dnn/core/models.rb
|
134
|
-
- lib/dnn/core/normalizations.rb
|
135
136
|
- lib/dnn/core/optimizers.rb
|
136
137
|
- lib/dnn/core/param.rb
|
137
138
|
- lib/dnn/core/regularizers.rb
|
138
|
-
- lib/dnn/core/rnn_layers.rb
|
139
139
|
- lib/dnn/core/savers.rb
|
140
140
|
- lib/dnn/core/tensor.rb
|
141
141
|
- lib/dnn/core/utils.rb
|
@@ -145,6 +145,7 @@ files:
|
|
145
145
|
- lib/dnn/datasets/fashion-mnist.rb
|
146
146
|
- lib/dnn/datasets/iris.rb
|
147
147
|
- lib/dnn/datasets/mnist.rb
|
148
|
+
- lib/dnn/datasets/stl-10.rb
|
148
149
|
- lib/dnn/image.rb
|
149
150
|
- lib/dnn/version.rb
|
150
151
|
- ruby-dnn.gemspec
|