ruby-dnn 1.2.3 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/examples/dcgan/dcgan.rb +1 -1
- data/examples/iris_example.rb +17 -41
- data/examples/iris_example_unused_model.rb +57 -0
- data/examples/vae.rb +1 -1
- data/lib/dnn/core/callbacks.rb +18 -8
- data/lib/dnn/core/iterator.rb +20 -4
- data/lib/dnn/core/layers/rnn_layers.rb +20 -24
- data/lib/dnn/core/models.rb +474 -149
- data/lib/dnn/core/savers.rb +4 -12
- data/lib/dnn/core/utils.rb +14 -0
- data/lib/dnn/datasets/iris.rb +5 -1
- data/lib/dnn/version.rb +1 -1
- data/lib/dnn.rb +32 -26
- metadata +3 -2
data/lib/dnn/core/savers.rb
CHANGED
@@ -14,15 +14,13 @@ module DNN
|
|
14
14
|
load_bin(File.binread(file_name))
|
15
15
|
end
|
16
16
|
|
17
|
-
private
|
18
|
-
|
19
17
|
def load_bin(bin)
|
20
18
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'load_bin'"
|
21
19
|
end
|
22
20
|
end
|
23
21
|
|
24
22
|
class MarshalLoader < Loader
|
25
|
-
|
23
|
+
def load_bin(bin)
|
26
24
|
data = Marshal.load(Zlib::Inflate.inflate(bin))
|
27
25
|
unless @model.class.name == data[:class]
|
28
26
|
raise DNNError, "Class name is mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
|
@@ -38,8 +36,6 @@ module DNN
|
|
38
36
|
end
|
39
37
|
|
40
38
|
class JSONLoader < Loader
|
41
|
-
private
|
42
|
-
|
43
39
|
def load_bin(bin)
|
44
40
|
data = JSON.parse(bin, symbolize_names: true)
|
45
41
|
unless @model.class.name == data[:class]
|
@@ -48,7 +44,7 @@ module DNN
|
|
48
44
|
set_all_params_base64_data(data[:params])
|
49
45
|
end
|
50
46
|
|
51
|
-
def set_all_params_base64_data(params_data)
|
47
|
+
private def set_all_params_base64_data(params_data)
|
52
48
|
@model.trainable_layers.each.with_index do |layer, i|
|
53
49
|
params_data[i].each do |(key, (shape, base64_data))|
|
54
50
|
bin = Base64.decode64(base64_data)
|
@@ -79,8 +75,6 @@ module DNN
|
|
79
75
|
end
|
80
76
|
end
|
81
77
|
|
82
|
-
private
|
83
|
-
|
84
78
|
def dump_bin
|
85
79
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'dump_bin'"
|
86
80
|
end
|
@@ -92,7 +86,7 @@ module DNN
|
|
92
86
|
@include_model = include_model
|
93
87
|
end
|
94
88
|
|
95
|
-
|
89
|
+
def dump_bin
|
96
90
|
params_data = @model.get_all_params_data
|
97
91
|
if @include_model
|
98
92
|
@model.clean_layers
|
@@ -110,14 +104,12 @@ module DNN
|
|
110
104
|
end
|
111
105
|
|
112
106
|
class JSONSaver < Saver
|
113
|
-
private
|
114
|
-
|
115
107
|
def dump_bin
|
116
108
|
data = { version: VERSION, class: @model.class.name, params: get_all_params_base64_data }
|
117
109
|
JSON.dump(data)
|
118
110
|
end
|
119
111
|
|
120
|
-
def get_all_params_base64_data
|
112
|
+
private def get_all_params_base64_data
|
121
113
|
@model.trainable_layers.map do |layer|
|
122
114
|
layer.get_params.to_h do |key, param|
|
123
115
|
base64_data = Base64.encode64(param.data.to_binary)
|
data/lib/dnn/core/utils.rb
CHANGED
@@ -39,6 +39,20 @@ module DNN
|
|
39
39
|
Losses::SoftmaxCrossEntropy.softmax(x)
|
40
40
|
end
|
41
41
|
|
42
|
+
# Check training or evaluate input data type.
|
43
|
+
def self.check_input_data_type(data_name, data, expected_type)
|
44
|
+
if !data.is_a?(expected_type) && !data.is_a?(Array)
|
45
|
+
raise TypeError, "#{data_name}:#{data.class.name} is not an instance of #{expected_type.name} class or Array class."
|
46
|
+
end
|
47
|
+
if data.is_a?(Array)
|
48
|
+
data.each.with_index do |v, i|
|
49
|
+
unless v.is_a?(expected_type)
|
50
|
+
raise TypeError, "#{data_name}[#{i}]:#{v.class.name} is not an instance of #{expected_type.name} class."
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
42
56
|
# Perform numerical differentiation.
|
43
57
|
def self.numerical_grad(x, func)
|
44
58
|
(func.(x + 1e-7) - func.(x)) / 1e-7
|
data/lib/dnn/datasets/iris.rb
CHANGED
@@ -41,7 +41,11 @@ module DNN
|
|
41
41
|
end
|
42
42
|
end
|
43
43
|
if shuffle
|
44
|
-
|
44
|
+
if RUBY_VERSION.split(".")[0].to_i >= 3
|
45
|
+
orig_seed = Random.seed
|
46
|
+
else
|
47
|
+
orig_seed = Random::DEFAULT.seed
|
48
|
+
end
|
45
49
|
srand(shuffle_seed)
|
46
50
|
indexs = (0...csv_array.length).to_a.shuffle
|
47
51
|
x[indexs, true] = x
|
data/lib/dnn/version.rb
CHANGED
data/lib/dnn.rb
CHANGED
@@ -1,4 +1,8 @@
|
|
1
|
-
|
1
|
+
if RUBY_PLATFORM == "wasm32-wasi"
|
2
|
+
require "narray.so"
|
3
|
+
else
|
4
|
+
require "numo/narray"
|
5
|
+
end
|
2
6
|
|
3
7
|
module DNN
|
4
8
|
if ENV["RUBY_DNN_USE_CUMO"] == "ENABLE"
|
@@ -27,28 +31,30 @@ module DNN
|
|
27
31
|
end
|
28
32
|
end
|
29
33
|
|
30
|
-
|
31
|
-
require_relative "dnn/
|
32
|
-
require_relative "dnn/core/
|
33
|
-
require_relative "dnn/core/
|
34
|
-
require_relative "dnn/core/
|
35
|
-
require_relative "dnn/core/
|
36
|
-
require_relative "dnn/core/
|
37
|
-
require_relative "dnn/core/
|
38
|
-
require_relative "dnn/core/
|
39
|
-
require_relative "dnn/core/
|
40
|
-
require_relative "dnn/core/layers/
|
41
|
-
require_relative "dnn/core/layers/
|
42
|
-
require_relative "dnn/core/layers/
|
43
|
-
require_relative "dnn/core/layers/
|
44
|
-
require_relative "dnn/core/layers/
|
45
|
-
require_relative "dnn/core/layers/
|
46
|
-
require_relative "dnn/core/layers/
|
47
|
-
require_relative "dnn/core/layers/
|
48
|
-
require_relative "dnn/core/
|
49
|
-
require_relative "dnn/core/
|
50
|
-
require_relative "dnn/core/
|
51
|
-
require_relative "dnn/core/
|
52
|
-
require_relative "dnn/core/
|
53
|
-
require_relative "dnn/core/
|
54
|
-
require_relative "dnn/core/
|
34
|
+
if RUBY_PLATFORM != "wasm32-wasi"
|
35
|
+
require_relative "dnn/version"
|
36
|
+
require_relative "dnn/core/monkey_patch"
|
37
|
+
require_relative "dnn/core/error"
|
38
|
+
require_relative "dnn/core/global"
|
39
|
+
require_relative "dnn/core/tensor"
|
40
|
+
require_relative "dnn/core/param"
|
41
|
+
require_relative "dnn/core/link"
|
42
|
+
require_relative "dnn/core/iterator"
|
43
|
+
require_relative "dnn/core/models"
|
44
|
+
require_relative "dnn/core/layers/basic_layers"
|
45
|
+
require_relative "dnn/core/layers/normalizations"
|
46
|
+
require_relative "dnn/core/layers/activations"
|
47
|
+
require_relative "dnn/core/layers/merge_layers"
|
48
|
+
require_relative "dnn/core/layers/split_layers"
|
49
|
+
require_relative "dnn/core/layers/cnn_layers"
|
50
|
+
require_relative "dnn/core/layers/embedding"
|
51
|
+
require_relative "dnn/core/layers/rnn_layers"
|
52
|
+
require_relative "dnn/core/layers/math_layers"
|
53
|
+
require_relative "dnn/core/optimizers"
|
54
|
+
require_relative "dnn/core/losses"
|
55
|
+
require_relative "dnn/core/initializers"
|
56
|
+
require_relative "dnn/core/regularizers"
|
57
|
+
require_relative "dnn/core/callbacks"
|
58
|
+
require_relative "dnn/core/savers"
|
59
|
+
require_relative "dnn/core/utils"
|
60
|
+
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.
|
4
|
+
version: 1.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-03-
|
11
|
+
date: 2023-03-19 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -139,6 +139,7 @@ files:
|
|
139
139
|
- examples/dcgan/imgen.rb
|
140
140
|
- examples/dcgan/train.rb
|
141
141
|
- examples/iris_example.rb
|
142
|
+
- examples/iris_example_unused_model.rb
|
142
143
|
- examples/judge-number/README.md
|
143
144
|
- examples/judge-number/capture.PNG
|
144
145
|
- examples/judge-number/convnet8.rb
|