ruby-dnn 0.14.3 → 0.15.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +5 -3
- data/Rakefile +4 -2
- data/examples/api-examples/save_example.rb +7 -5
- data/examples/dcgan/imgen.rb +2 -7
- data/examples/dcgan/train.rb +0 -1
- data/lib/dnn.rb +10 -10
- data/lib/dnn/core/callbacks.rb +6 -2
- data/lib/dnn/core/iterator.rb +10 -2
- data/lib/dnn/core/{activations.rb → layers/activations.rb} +0 -0
- data/lib/dnn/core/{layers.rb → layers/basic_layers.rb} +31 -8
- data/lib/dnn/core/{cnn_layers.rb → layers/cnn_layers.rb} +0 -0
- data/lib/dnn/core/{embedding.rb → layers/embedding.rb} +5 -4
- data/lib/dnn/core/{merge_layers.rb → layers/merge_layers.rb} +1 -1
- data/lib/dnn/core/{normalizations.rb → layers/normalizations.rb} +9 -5
- data/lib/dnn/core/{rnn_layers.rb → layers/rnn_layers.rb} +25 -16
- data/lib/dnn/core/losses.rb +8 -0
- data/lib/dnn/core/models.rb +164 -68
- data/lib/dnn/core/optimizers.rb +49 -72
- data/lib/dnn/core/param.rb +0 -2
- data/lib/dnn/core/savers.rb +40 -49
- data/lib/dnn/datasets/stl-10.rb +65 -0
- data/lib/dnn/version.rb +1 -1
- metadata +10 -9
data/lib/dnn/core/param.rb
CHANGED
data/lib/dnn/core/savers.rb
CHANGED
@@ -19,15 +19,6 @@ module DNN
|
|
19
19
|
def load_bin(bin)
|
20
20
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'load_bin'"
|
21
21
|
end
|
22
|
-
|
23
|
-
def set_all_params_data(params_data)
|
24
|
-
all_params = @model.has_param_layers.map { |layer|
|
25
|
-
layer.get_params.values
|
26
|
-
}.flatten
|
27
|
-
all_params.each do |param|
|
28
|
-
param.data = params_data[param.name]
|
29
|
-
end
|
30
|
-
end
|
31
22
|
end
|
32
23
|
|
33
24
|
class MarshalLoader < Loader
|
@@ -36,12 +27,13 @@ module DNN
|
|
36
27
|
unless @model.class.name == data[:class]
|
37
28
|
raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
|
38
29
|
end
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
30
|
+
if data[:model]
|
31
|
+
data[:model].instance_variables.each do |ivar|
|
32
|
+
obj = data[:model].instance_variable_get(ivar)
|
33
|
+
@model.instance_variable_set(ivar, obj)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
@model.set_all_params_data(data[:params])
|
45
37
|
end
|
46
38
|
end
|
47
39
|
|
@@ -50,20 +42,20 @@ module DNN
|
|
50
42
|
|
51
43
|
def load_bin(bin)
|
52
44
|
data = JSON.parse(bin, symbolize_names: true)
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
@model.predict1(Xumo::SFloat.zeros(*data[:input_shape]))
|
58
|
-
base64_to_params_data(data[:params])
|
45
|
+
unless @model.class.name == data[:class]
|
46
|
+
raise DNN_Error, "Class name is not mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
|
47
|
+
end
|
48
|
+
set_all_params_base64_data(data[:params])
|
59
49
|
end
|
60
50
|
|
61
|
-
def
|
62
|
-
|
63
|
-
|
64
|
-
|
51
|
+
def set_all_params_base64_data(params_data)
|
52
|
+
@model.trainable_layers.each.with_index do |layer, i|
|
53
|
+
params_data[i].each do |(key, (shape, base64_data))|
|
54
|
+
bin = Base64.decode64(base64_data)
|
55
|
+
data = Xumo::SFloat.from_binary(bin).reshape(*shape)
|
56
|
+
layer.get_params[key].data = data
|
57
|
+
end
|
65
58
|
end
|
66
|
-
set_all_params_data(params_data)
|
67
59
|
end
|
68
60
|
end
|
69
61
|
|
@@ -92,28 +84,28 @@ module DNN
|
|
92
84
|
def dump_bin
|
93
85
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'dump_bin'"
|
94
86
|
end
|
95
|
-
|
96
|
-
def get_all_params_data
|
97
|
-
all_params = @model.has_param_layers.map { |layer|
|
98
|
-
layer.get_params.values
|
99
|
-
}.flatten
|
100
|
-
all_params.to_h { |param| [param.name, param.data] }
|
101
|
-
end
|
102
87
|
end
|
103
88
|
|
104
89
|
class MarshalSaver < Saver
|
105
|
-
def initialize(model,
|
90
|
+
def initialize(model, include_model: true)
|
106
91
|
super(model)
|
107
|
-
@
|
92
|
+
@include_model = include_model
|
108
93
|
end
|
109
94
|
|
110
95
|
private def dump_bin
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
96
|
+
params_data = @model.get_all_params_data
|
97
|
+
if @include_model
|
98
|
+
@model.clean_layers
|
99
|
+
data = {
|
100
|
+
version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape,
|
101
|
+
params: params_data, model: @model
|
102
|
+
}
|
103
|
+
else
|
104
|
+
data = { version: VERSION, class: @model.class.name, params: params_data }
|
105
|
+
end
|
106
|
+
bin = Zlib::Deflate.deflate(Marshal.dump(data))
|
107
|
+
@model.set_all_params_data(params_data) if @include_model
|
108
|
+
bin
|
117
109
|
end
|
118
110
|
end
|
119
111
|
|
@@ -121,17 +113,16 @@ module DNN
|
|
121
113
|
private
|
122
114
|
|
123
115
|
def dump_bin
|
124
|
-
data = {
|
125
|
-
version: VERSION, class: @model.class.name, input_shape: @model.layers.first.input_shape, params: params_data_to_base64,
|
126
|
-
optimizer: @model.optimizer.to_hash, loss_func: @model.loss_func.to_hash
|
127
|
-
}
|
116
|
+
data = { version: VERSION, class: @model.class.name, params: get_all_params_base64_data }
|
128
117
|
JSON.dump(data)
|
129
118
|
end
|
130
119
|
|
131
|
-
def
|
132
|
-
|
133
|
-
|
134
|
-
|
120
|
+
def get_all_params_base64_data
|
121
|
+
@model.trainable_layers.map do |layer|
|
122
|
+
layer.get_params.to_h do |key, param|
|
123
|
+
base64_data = Base64.encode64(param.data.to_binary)
|
124
|
+
[key, [param.data.shape, base64_data]]
|
125
|
+
end
|
135
126
|
end
|
136
127
|
end
|
137
128
|
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
require "zlib"
|
2
|
+
require "archive/tar/minitar"
|
3
|
+
require_relative "downloader"
|
4
|
+
|
5
|
+
URL_STL10 = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
|
6
|
+
DIR_STL10 = "stl10_binary"
|
7
|
+
|
8
|
+
module DNN
|
9
|
+
module STL10
|
10
|
+
class DNN_STL10_LoadError < DNN_Error; end
|
11
|
+
|
12
|
+
def self.downloads
|
13
|
+
return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_STL10)
|
14
|
+
Downloader.download(URL_STL10)
|
15
|
+
stl10_binary_file_name = DOWNLOADS_PATH + "/downloads/" + URL_STL10.match(%r`.+/(.+)`)[1]
|
16
|
+
begin
|
17
|
+
Zlib::GzipReader.open(stl10_binary_file_name) do |gz|
|
18
|
+
Archive::Tar::Minitar.unpack(gz, DOWNLOADS_PATH + "/downloads")
|
19
|
+
end
|
20
|
+
ensure
|
21
|
+
File.unlink(stl10_binary_file_name)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def self.load_train
|
26
|
+
downloads
|
27
|
+
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_X.bin"
|
28
|
+
raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
|
29
|
+
y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_y.bin"
|
30
|
+
raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
|
31
|
+
x_bin = File.binread(x_fname)
|
32
|
+
y_bin = File.binread(y_fname)
|
33
|
+
x_train = Numo::UInt8.from_binary(x_bin).reshape(5000, 3, 96, 96).transpose(0, 3, 2, 1).clone
|
34
|
+
y_train = Numo::UInt8.from_binary(y_bin)
|
35
|
+
[x_train, y_train]
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.load_test
|
39
|
+
downloads
|
40
|
+
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_X.bin"
|
41
|
+
raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
|
42
|
+
y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_y.bin"
|
43
|
+
raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
|
44
|
+
x_bin = File.binread(x_fname)
|
45
|
+
y_bin = File.binread(y_fname)
|
46
|
+
x_test = Numo::UInt8.from_binary(x_bin).reshape(8000, 3, 96, 96).transpose(0, 3, 2, 1).clone
|
47
|
+
y_test = Numo::UInt8.from_binary(y_bin)
|
48
|
+
[x_test, y_test]
|
49
|
+
end
|
50
|
+
|
51
|
+
def self.load_unlabeled(range = 0...100000)
|
52
|
+
unless 0 <= range.begin && range.end <= 100000
|
53
|
+
raise DNN_Error, "Range must between 0 and 100000. (But the end is excluded)"
|
54
|
+
end
|
55
|
+
downloads
|
56
|
+
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/unlabeled_X.bin"
|
57
|
+
raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
|
58
|
+
num_datas = range.end - range.begin
|
59
|
+
length = num_datas * 3 * 96 * 96
|
60
|
+
ofs = range.begin * 3 * 96 * 96
|
61
|
+
x_bin = File.binread(x_fname, length, ofs)
|
62
|
+
Numo::UInt8.from_binary(x_bin).reshape(num_datas, 3, 96, 96).transpose(0, 3, 2, 1).clone
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.15.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-11-
|
11
|
+
date: 2019-11-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -118,24 +118,24 @@ files:
|
|
118
118
|
- ext/rb_stb_image/extconf.rb
|
119
119
|
- ext/rb_stb_image/rb_stb_image.c
|
120
120
|
- lib/dnn.rb
|
121
|
-
- lib/dnn/core/activations.rb
|
122
121
|
- lib/dnn/core/callbacks.rb
|
123
|
-
- lib/dnn/core/cnn_layers.rb
|
124
|
-
- lib/dnn/core/embedding.rb
|
125
122
|
- lib/dnn/core/error.rb
|
126
123
|
- lib/dnn/core/global.rb
|
127
124
|
- lib/dnn/core/initializers.rb
|
128
125
|
- lib/dnn/core/iterator.rb
|
129
|
-
- lib/dnn/core/layers.rb
|
126
|
+
- lib/dnn/core/layers/activations.rb
|
127
|
+
- lib/dnn/core/layers/basic_layers.rb
|
128
|
+
- lib/dnn/core/layers/cnn_layers.rb
|
129
|
+
- lib/dnn/core/layers/embedding.rb
|
130
|
+
- lib/dnn/core/layers/merge_layers.rb
|
131
|
+
- lib/dnn/core/layers/normalizations.rb
|
132
|
+
- lib/dnn/core/layers/rnn_layers.rb
|
130
133
|
- lib/dnn/core/link.rb
|
131
134
|
- lib/dnn/core/losses.rb
|
132
|
-
- lib/dnn/core/merge_layers.rb
|
133
135
|
- lib/dnn/core/models.rb
|
134
|
-
- lib/dnn/core/normalizations.rb
|
135
136
|
- lib/dnn/core/optimizers.rb
|
136
137
|
- lib/dnn/core/param.rb
|
137
138
|
- lib/dnn/core/regularizers.rb
|
138
|
-
- lib/dnn/core/rnn_layers.rb
|
139
139
|
- lib/dnn/core/savers.rb
|
140
140
|
- lib/dnn/core/tensor.rb
|
141
141
|
- lib/dnn/core/utils.rb
|
@@ -145,6 +145,7 @@ files:
|
|
145
145
|
- lib/dnn/datasets/fashion-mnist.rb
|
146
146
|
- lib/dnn/datasets/iris.rb
|
147
147
|
- lib/dnn/datasets/mnist.rb
|
148
|
+
- lib/dnn/datasets/stl-10.rb
|
148
149
|
- lib/dnn/image.rb
|
149
150
|
- lib/dnn/version.rb
|
150
151
|
- ruby-dnn.gemspec
|