ruby-dnn 1.2.2 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +0 -0
- data/.travis.yml +0 -0
- data/CODE_OF_CONDUCT.md +0 -0
- data/Gemfile +0 -0
- data/LICENSE.txt +0 -0
- data/README.md +0 -0
- data/Rakefile +5 -0
- data/examples/api-examples/early_stopping_example.rb +0 -0
- data/examples/api-examples/initializer_example.rb +0 -0
- data/examples/api-examples/regularizer_example.rb +0 -0
- data/examples/api-examples/save_example.rb +0 -0
- data/examples/cifar100_example.rb +0 -0
- data/examples/cifar10_example.rb +0 -0
- data/examples/dcgan/dcgan.rb +1 -1
- data/examples/dcgan/imgen.rb +0 -0
- data/examples/dcgan/train.rb +0 -0
- data/examples/iris_example.rb +17 -41
- data/examples/iris_example_unused_model.rb +57 -0
- data/examples/judge-number/README.md +0 -0
- data/examples/judge-number/capture.PNG +0 -0
- data/examples/judge-number/convnet8.rb +0 -0
- data/examples/judge-number/make_weights.rb +0 -0
- data/examples/judge-number/mnist_predict.rb +0 -0
- data/examples/judge-number/mnist_train.rb +0 -0
- data/examples/judge-number/public/httpRequest.js +0 -0
- data/examples/judge-number/public/judgeNumber.js +0 -0
- data/examples/judge-number/server.rb +0 -0
- data/examples/judge-number/trained_mnist_params.marshal +0 -0
- data/examples/judge-number/views/index.erb +0 -0
- data/examples/mnist_conv2d_example.rb +0 -0
- data/examples/mnist_define_by_run.rb +0 -0
- data/examples/mnist_example.rb +0 -0
- data/examples/mnist_gpu.rb +0 -0
- data/examples/mnist_lstm_example.rb +0 -0
- data/examples/pix2pix/dcgan.rb +0 -0
- data/examples/pix2pix/imgen.rb +0 -0
- data/examples/pix2pix/train.rb +0 -0
- data/examples/vae.rb +1 -1
- data/examples/xor_example.rb +0 -0
- data/ext/rb_stb_image/extconf.rb +0 -0
- data/ext/rb_stb_image/rb_stb_image.c +0 -0
- data/img/cart-pole.gif +0 -0
- data/img/cycle-gan.PNG +0 -0
- data/img/facade-pix2pix.png +0 -0
- data/lib/dnn/core/callbacks.rb +18 -8
- data/lib/dnn/core/error.rb +0 -0
- data/lib/dnn/core/global.rb +0 -0
- data/lib/dnn/core/initializers.rb +0 -0
- data/lib/dnn/core/iterator.rb +20 -4
- data/lib/dnn/core/layers/activations.rb +0 -0
- data/lib/dnn/core/layers/basic_layers.rb +2 -2
- data/lib/dnn/core/layers/cnn_layers.rb +0 -0
- data/lib/dnn/core/layers/embedding.rb +0 -0
- data/lib/dnn/core/layers/math_layers.rb +0 -0
- data/lib/dnn/core/layers/merge_layers.rb +2 -2
- data/lib/dnn/core/layers/normalizations.rb +0 -0
- data/lib/dnn/core/layers/rnn_layers.rb +20 -24
- data/lib/dnn/core/layers/split_layers.rb +0 -0
- data/lib/dnn/core/link.rb +0 -0
- data/lib/dnn/core/losses.rb +2 -2
- data/lib/dnn/core/models.rb +474 -149
- data/lib/dnn/core/monkey_patch.rb +0 -0
- data/lib/dnn/core/optimizers.rb +0 -0
- data/lib/dnn/core/param.rb +0 -0
- data/lib/dnn/core/regularizers.rb +0 -0
- data/lib/dnn/core/savers.rb +4 -12
- data/lib/dnn/core/tensor.rb +0 -0
- data/lib/dnn/core/utils.rb +14 -0
- data/lib/dnn/datasets/cifar10.rb +0 -0
- data/lib/dnn/datasets/cifar100.rb +0 -0
- data/lib/dnn/datasets/downloader.rb +12 -3
- data/lib/dnn/datasets/fashion-mnist.rb +0 -0
- data/lib/dnn/datasets/iris.rb +5 -1
- data/lib/dnn/datasets/mnist.rb +0 -0
- data/lib/dnn/datasets/stl-10.rb +0 -0
- data/lib/dnn/image.rb +1 -1
- data/lib/dnn/keras-model-convertor.rb +0 -0
- data/lib/dnn/numo2numpy.rb +0 -0
- data/lib/dnn/version.rb +1 -1
- data/lib/dnn.rb +32 -26
- data/ruby-dnn.gemspec +1 -0
- data/third_party/stb_image.h +0 -0
- data/third_party/stb_image_resize.h +0 -0
- data/third_party/stb_image_write.h +0 -0
- metadata +21 -6
|
File without changes
|
data/lib/dnn/core/optimizers.rb
CHANGED
|
File without changes
|
data/lib/dnn/core/param.rb
CHANGED
|
File without changes
|
|
File without changes
|
data/lib/dnn/core/savers.rb
CHANGED
|
@@ -14,15 +14,13 @@ module DNN
|
|
|
14
14
|
load_bin(File.binread(file_name))
|
|
15
15
|
end
|
|
16
16
|
|
|
17
|
-
private
|
|
18
|
-
|
|
19
17
|
def load_bin(bin)
|
|
20
18
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'load_bin'"
|
|
21
19
|
end
|
|
22
20
|
end
|
|
23
21
|
|
|
24
22
|
class MarshalLoader < Loader
|
|
25
|
-
|
|
23
|
+
def load_bin(bin)
|
|
26
24
|
data = Marshal.load(Zlib::Inflate.inflate(bin))
|
|
27
25
|
unless @model.class.name == data[:class]
|
|
28
26
|
raise DNNError, "Class name is mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
|
|
@@ -38,8 +36,6 @@ module DNN
|
|
|
38
36
|
end
|
|
39
37
|
|
|
40
38
|
class JSONLoader < Loader
|
|
41
|
-
private
|
|
42
|
-
|
|
43
39
|
def load_bin(bin)
|
|
44
40
|
data = JSON.parse(bin, symbolize_names: true)
|
|
45
41
|
unless @model.class.name == data[:class]
|
|
@@ -48,7 +44,7 @@ module DNN
|
|
|
48
44
|
set_all_params_base64_data(data[:params])
|
|
49
45
|
end
|
|
50
46
|
|
|
51
|
-
def set_all_params_base64_data(params_data)
|
|
47
|
+
private def set_all_params_base64_data(params_data)
|
|
52
48
|
@model.trainable_layers.each.with_index do |layer, i|
|
|
53
49
|
params_data[i].each do |(key, (shape, base64_data))|
|
|
54
50
|
bin = Base64.decode64(base64_data)
|
|
@@ -79,8 +75,6 @@ module DNN
|
|
|
79
75
|
end
|
|
80
76
|
end
|
|
81
77
|
|
|
82
|
-
private
|
|
83
|
-
|
|
84
78
|
def dump_bin
|
|
85
79
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'dump_bin'"
|
|
86
80
|
end
|
|
@@ -92,7 +86,7 @@ module DNN
|
|
|
92
86
|
@include_model = include_model
|
|
93
87
|
end
|
|
94
88
|
|
|
95
|
-
|
|
89
|
+
def dump_bin
|
|
96
90
|
params_data = @model.get_all_params_data
|
|
97
91
|
if @include_model
|
|
98
92
|
@model.clean_layers
|
|
@@ -110,14 +104,12 @@ module DNN
|
|
|
110
104
|
end
|
|
111
105
|
|
|
112
106
|
class JSONSaver < Saver
|
|
113
|
-
private
|
|
114
|
-
|
|
115
107
|
def dump_bin
|
|
116
108
|
data = { version: VERSION, class: @model.class.name, params: get_all_params_base64_data }
|
|
117
109
|
JSON.dump(data)
|
|
118
110
|
end
|
|
119
111
|
|
|
120
|
-
def get_all_params_base64_data
|
|
112
|
+
private def get_all_params_base64_data
|
|
121
113
|
@model.trainable_layers.map do |layer|
|
|
122
114
|
layer.get_params.to_h do |key, param|
|
|
123
115
|
base64_data = Base64.encode64(param.data.to_binary)
|
data/lib/dnn/core/tensor.rb
CHANGED
|
File without changes
|
data/lib/dnn/core/utils.rb
CHANGED
|
@@ -39,6 +39,20 @@ module DNN
|
|
|
39
39
|
Losses::SoftmaxCrossEntropy.softmax(x)
|
|
40
40
|
end
|
|
41
41
|
|
|
42
|
+
# Check training or evaluate input data type.
|
|
43
|
+
def self.check_input_data_type(data_name, data, expected_type)
|
|
44
|
+
if !data.is_a?(expected_type) && !data.is_a?(Array)
|
|
45
|
+
raise TypeError, "#{data_name}:#{data.class.name} is not an instance of #{expected_type.name} class or Array class."
|
|
46
|
+
end
|
|
47
|
+
if data.is_a?(Array)
|
|
48
|
+
data.each.with_index do |v, i|
|
|
49
|
+
unless v.is_a?(expected_type)
|
|
50
|
+
raise TypeError, "#{data_name}[#{i}]:#{v.class.name} is not an instance of #{expected_type.name} class."
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
|
|
42
56
|
# Perform numerical differentiation.
|
|
43
57
|
def self.numerical_grad(x, func)
|
|
44
58
|
(func.(x + 1e-7) - func.(x)) / 1e-7
|
data/lib/dnn/datasets/cifar10.rb
CHANGED
|
File without changes
|
|
File without changes
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
require "net/
|
|
1
|
+
require "net/https"
|
|
2
2
|
|
|
3
3
|
module DNN
|
|
4
4
|
DOWNLOADS_PATH = ENV["RUBY_DNN_DOWNLOADS_PATH"] || __dir__
|
|
@@ -18,13 +18,22 @@ module DNN
|
|
|
18
18
|
|
|
19
19
|
def initialize(url)
|
|
20
20
|
@url = url
|
|
21
|
-
*, @fqdn, @path = *url.match(%r`https
|
|
21
|
+
*, @protocol, @fqdn, @path = *url.match(%r`(https?)://(.+?)(/.+)`)
|
|
22
22
|
end
|
|
23
23
|
|
|
24
24
|
def download(dir_path)
|
|
25
25
|
puts %`download "#{@url}"`
|
|
26
26
|
buf = ""
|
|
27
|
-
|
|
27
|
+
if @protocol == "http"
|
|
28
|
+
port = 80
|
|
29
|
+
elsif @protocol == "https"
|
|
30
|
+
port = 443
|
|
31
|
+
else
|
|
32
|
+
raise "Protocol(#{@protocol}) is not supported."
|
|
33
|
+
end
|
|
34
|
+
http = Net::HTTP.new(@fqdn, port)
|
|
35
|
+
http.use_ssl = true if @protocol == "https"
|
|
36
|
+
http.start do |http|
|
|
28
37
|
content_length = http.head(@path).content_length
|
|
29
38
|
http.get(@path) do |body_segment|
|
|
30
39
|
buf << body_segment
|
|
File without changes
|
data/lib/dnn/datasets/iris.rb
CHANGED
|
@@ -41,7 +41,11 @@ module DNN
|
|
|
41
41
|
end
|
|
42
42
|
end
|
|
43
43
|
if shuffle
|
|
44
|
-
|
|
44
|
+
if RUBY_VERSION.split(".")[0].to_i >= 3
|
|
45
|
+
orig_seed = Random.seed
|
|
46
|
+
else
|
|
47
|
+
orig_seed = Random::DEFAULT.seed
|
|
48
|
+
end
|
|
45
49
|
srand(shuffle_seed)
|
|
46
50
|
indexs = (0...csv_array.length).to_a.shuffle
|
|
47
51
|
x[indexs, true] = x
|
data/lib/dnn/datasets/mnist.rb
CHANGED
|
File without changes
|
data/lib/dnn/datasets/stl-10.rb
CHANGED
|
File without changes
|
data/lib/dnn/image.rb
CHANGED
|
File without changes
|
data/lib/dnn/numo2numpy.rb
CHANGED
|
File without changes
|
data/lib/dnn/version.rb
CHANGED
data/lib/dnn.rb
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
|
|
1
|
+
if RUBY_PLATFORM == "wasm32-wasi"
|
|
2
|
+
require "narray.so"
|
|
3
|
+
else
|
|
4
|
+
require "numo/narray"
|
|
5
|
+
end
|
|
2
6
|
|
|
3
7
|
module DNN
|
|
4
8
|
if ENV["RUBY_DNN_USE_CUMO"] == "ENABLE"
|
|
@@ -27,28 +31,30 @@ module DNN
|
|
|
27
31
|
end
|
|
28
32
|
end
|
|
29
33
|
|
|
30
|
-
|
|
31
|
-
require_relative "dnn/
|
|
32
|
-
require_relative "dnn/core/
|
|
33
|
-
require_relative "dnn/core/
|
|
34
|
-
require_relative "dnn/core/
|
|
35
|
-
require_relative "dnn/core/
|
|
36
|
-
require_relative "dnn/core/
|
|
37
|
-
require_relative "dnn/core/
|
|
38
|
-
require_relative "dnn/core/
|
|
39
|
-
require_relative "dnn/core/
|
|
40
|
-
require_relative "dnn/core/layers/
|
|
41
|
-
require_relative "dnn/core/layers/
|
|
42
|
-
require_relative "dnn/core/layers/
|
|
43
|
-
require_relative "dnn/core/layers/
|
|
44
|
-
require_relative "dnn/core/layers/
|
|
45
|
-
require_relative "dnn/core/layers/
|
|
46
|
-
require_relative "dnn/core/layers/
|
|
47
|
-
require_relative "dnn/core/layers/
|
|
48
|
-
require_relative "dnn/core/
|
|
49
|
-
require_relative "dnn/core/
|
|
50
|
-
require_relative "dnn/core/
|
|
51
|
-
require_relative "dnn/core/
|
|
52
|
-
require_relative "dnn/core/
|
|
53
|
-
require_relative "dnn/core/
|
|
54
|
-
require_relative "dnn/core/
|
|
34
|
+
if RUBY_PLATFORM != "wasm32-wasi"
|
|
35
|
+
require_relative "dnn/version"
|
|
36
|
+
require_relative "dnn/core/monkey_patch"
|
|
37
|
+
require_relative "dnn/core/error"
|
|
38
|
+
require_relative "dnn/core/global"
|
|
39
|
+
require_relative "dnn/core/tensor"
|
|
40
|
+
require_relative "dnn/core/param"
|
|
41
|
+
require_relative "dnn/core/link"
|
|
42
|
+
require_relative "dnn/core/iterator"
|
|
43
|
+
require_relative "dnn/core/models"
|
|
44
|
+
require_relative "dnn/core/layers/basic_layers"
|
|
45
|
+
require_relative "dnn/core/layers/normalizations"
|
|
46
|
+
require_relative "dnn/core/layers/activations"
|
|
47
|
+
require_relative "dnn/core/layers/merge_layers"
|
|
48
|
+
require_relative "dnn/core/layers/split_layers"
|
|
49
|
+
require_relative "dnn/core/layers/cnn_layers"
|
|
50
|
+
require_relative "dnn/core/layers/embedding"
|
|
51
|
+
require_relative "dnn/core/layers/rnn_layers"
|
|
52
|
+
require_relative "dnn/core/layers/math_layers"
|
|
53
|
+
require_relative "dnn/core/optimizers"
|
|
54
|
+
require_relative "dnn/core/losses"
|
|
55
|
+
require_relative "dnn/core/initializers"
|
|
56
|
+
require_relative "dnn/core/regularizers"
|
|
57
|
+
require_relative "dnn/core/callbacks"
|
|
58
|
+
require_relative "dnn/core/savers"
|
|
59
|
+
require_relative "dnn/core/utils"
|
|
60
|
+
end
|
data/ruby-dnn.gemspec
CHANGED
|
@@ -17,6 +17,7 @@ Gem::Specification.new do |spec|
|
|
|
17
17
|
|
|
18
18
|
spec.add_dependency "numo-narray"
|
|
19
19
|
spec.add_dependency "archive-tar-minitar"
|
|
20
|
+
spec.add_development_dependency "rake-compiler"
|
|
20
21
|
|
|
21
22
|
spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
|
|
22
23
|
`git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
|
data/third_party/stb_image.h
CHANGED
|
File without changes
|
|
File without changes
|
|
File without changes
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: ruby-dnn
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 1.
|
|
4
|
+
version: 1.3.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- unagiootoro
|
|
8
|
-
autorequire:
|
|
8
|
+
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date:
|
|
11
|
+
date: 2023-03-19 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: numo-narray
|
|
@@ -38,6 +38,20 @@ dependencies:
|
|
|
38
38
|
- - ">="
|
|
39
39
|
- !ruby/object:Gem::Version
|
|
40
40
|
version: '0'
|
|
41
|
+
- !ruby/object:Gem::Dependency
|
|
42
|
+
name: rake-compiler
|
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
|
44
|
+
requirements:
|
|
45
|
+
- - ">="
|
|
46
|
+
- !ruby/object:Gem::Version
|
|
47
|
+
version: '0'
|
|
48
|
+
type: :development
|
|
49
|
+
prerelease: false
|
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
51
|
+
requirements:
|
|
52
|
+
- - ">="
|
|
53
|
+
- !ruby/object:Gem::Version
|
|
54
|
+
version: '0'
|
|
41
55
|
- !ruby/object:Gem::Dependency
|
|
42
56
|
name: bundler
|
|
43
57
|
requirement: !ruby/object:Gem::Requirement
|
|
@@ -125,6 +139,7 @@ files:
|
|
|
125
139
|
- examples/dcgan/imgen.rb
|
|
126
140
|
- examples/dcgan/train.rb
|
|
127
141
|
- examples/iris_example.rb
|
|
142
|
+
- examples/iris_example_unused_model.rb
|
|
128
143
|
- examples/judge-number/README.md
|
|
129
144
|
- examples/judge-number/capture.PNG
|
|
130
145
|
- examples/judge-number/convnet8.rb
|
|
@@ -195,7 +210,7 @@ homepage: https://github.com/unagiootoro/ruby-dnn.git
|
|
|
195
210
|
licenses:
|
|
196
211
|
- MIT
|
|
197
212
|
metadata: {}
|
|
198
|
-
post_install_message:
|
|
213
|
+
post_install_message:
|
|
199
214
|
rdoc_options: []
|
|
200
215
|
require_paths:
|
|
201
216
|
- lib
|
|
@@ -210,8 +225,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
|
210
225
|
- !ruby/object:Gem::Version
|
|
211
226
|
version: '0'
|
|
212
227
|
requirements: []
|
|
213
|
-
rubygems_version: 3.
|
|
214
|
-
signing_key:
|
|
228
|
+
rubygems_version: 3.2.33
|
|
229
|
+
signing_key:
|
|
215
230
|
specification_version: 4
|
|
216
231
|
summary: ruby deep learning library.
|
|
217
232
|
test_files: []
|