ruby-dnn 1.2.3 → 1.3.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -14,15 +14,13 @@ module DNN
14
14
  load_bin(File.binread(file_name))
15
15
  end
16
16
 
17
- private
18
-
19
17
  def load_bin(bin)
20
18
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'load_bin'"
21
19
  end
22
20
  end
23
21
 
24
22
  class MarshalLoader < Loader
25
- private def load_bin(bin)
23
+ def load_bin(bin)
26
24
  data = Marshal.load(Zlib::Inflate.inflate(bin))
27
25
  unless @model.class.name == data[:class]
28
26
  raise DNNError, "Class name is mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
@@ -38,8 +36,6 @@ module DNN
38
36
  end
39
37
 
40
38
  class JSONLoader < Loader
41
- private
42
-
43
39
  def load_bin(bin)
44
40
  data = JSON.parse(bin, symbolize_names: true)
45
41
  unless @model.class.name == data[:class]
@@ -48,7 +44,7 @@ module DNN
48
44
  set_all_params_base64_data(data[:params])
49
45
  end
50
46
 
51
- def set_all_params_base64_data(params_data)
47
+ private def set_all_params_base64_data(params_data)
52
48
  @model.trainable_layers.each.with_index do |layer, i|
53
49
  params_data[i].each do |(key, (shape, base64_data))|
54
50
  bin = Base64.decode64(base64_data)
@@ -79,8 +75,6 @@ module DNN
79
75
  end
80
76
  end
81
77
 
82
- private
83
-
84
78
  def dump_bin
85
79
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'dump_bin'"
86
80
  end
@@ -92,7 +86,7 @@ module DNN
92
86
  @include_model = include_model
93
87
  end
94
88
 
95
- private def dump_bin
89
+ def dump_bin
96
90
  params_data = @model.get_all_params_data
97
91
  if @include_model
98
92
  @model.clean_layers
@@ -110,14 +104,12 @@ module DNN
110
104
  end
111
105
 
112
106
  class JSONSaver < Saver
113
- private
114
-
115
107
  def dump_bin
116
108
  data = { version: VERSION, class: @model.class.name, params: get_all_params_base64_data }
117
109
  JSON.dump(data)
118
110
  end
119
111
 
120
- def get_all_params_base64_data
112
+ private def get_all_params_base64_data
121
113
  @model.trainable_layers.map do |layer|
122
114
  layer.get_params.to_h do |key, param|
123
115
  base64_data = Base64.encode64(param.data.to_binary)
@@ -39,6 +39,20 @@ module DNN
39
39
  Losses::SoftmaxCrossEntropy.softmax(x)
40
40
  end
41
41
 
42
+ # Check training or evaluate input data type.
43
+ def self.check_input_data_type(data_name, data, expected_type)
44
+ if !data.is_a?(expected_type) && !data.is_a?(Array)
45
+ raise TypeError, "#{data_name}:#{data.class.name} is not an instance of #{expected_type.name} class or Array class."
46
+ end
47
+ if data.is_a?(Array)
48
+ data.each.with_index do |v, i|
49
+ unless v.is_a?(expected_type)
50
+ raise TypeError, "#{data_name}[#{i}]:#{v.class.name} is not an instance of #{expected_type.name} class."
51
+ end
52
+ end
53
+ end
54
+ end
55
+
42
56
  # Perform numerical differentiation.
43
57
  def self.numerical_grad(x, func)
44
58
  (func.(x + 1e-7) - func.(x)) / 1e-7
@@ -41,7 +41,11 @@ module DNN
41
41
  end
42
42
  end
43
43
  if shuffle
44
- orig_seed = Random::DEFAULT.seed
44
+ if RUBY_VERSION.split(".")[0].to_i >= 3
45
+ orig_seed = Random.seed
46
+ else
47
+ orig_seed = Random::DEFAULT.seed
48
+ end
45
49
  srand(shuffle_seed)
46
50
  indexs = (0...csv_array.length).to_a.shuffle
47
51
  x[indexs, true] = x
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "1.2.3"
2
+ VERSION = "1.3.0"
3
3
  end
data/lib/dnn.rb CHANGED
@@ -1,4 +1,8 @@
1
- require "numo/narray"
1
+ if RUBY_PLATFORM == "wasm32-wasi"
2
+ require "narray.so"
3
+ else
4
+ require "numo/narray"
5
+ end
2
6
 
3
7
  module DNN
4
8
  if ENV["RUBY_DNN_USE_CUMO"] == "ENABLE"
@@ -27,28 +31,30 @@ module DNN
27
31
  end
28
32
  end
29
33
 
30
- require_relative "dnn/version"
31
- require_relative "dnn/core/monkey_patch"
32
- require_relative "dnn/core/error"
33
- require_relative "dnn/core/global"
34
- require_relative "dnn/core/tensor"
35
- require_relative "dnn/core/param"
36
- require_relative "dnn/core/link"
37
- require_relative "dnn/core/iterator"
38
- require_relative "dnn/core/models"
39
- require_relative "dnn/core/layers/basic_layers"
40
- require_relative "dnn/core/layers/normalizations"
41
- require_relative "dnn/core/layers/activations"
42
- require_relative "dnn/core/layers/merge_layers"
43
- require_relative "dnn/core/layers/split_layers"
44
- require_relative "dnn/core/layers/cnn_layers"
45
- require_relative "dnn/core/layers/embedding"
46
- require_relative "dnn/core/layers/rnn_layers"
47
- require_relative "dnn/core/layers/math_layers"
48
- require_relative "dnn/core/optimizers"
49
- require_relative "dnn/core/losses"
50
- require_relative "dnn/core/initializers"
51
- require_relative "dnn/core/regularizers"
52
- require_relative "dnn/core/callbacks"
53
- require_relative "dnn/core/savers"
54
- require_relative "dnn/core/utils"
34
+ if RUBY_PLATFORM != "wasm32-wasi"
35
+ require_relative "dnn/version"
36
+ require_relative "dnn/core/monkey_patch"
37
+ require_relative "dnn/core/error"
38
+ require_relative "dnn/core/global"
39
+ require_relative "dnn/core/tensor"
40
+ require_relative "dnn/core/param"
41
+ require_relative "dnn/core/link"
42
+ require_relative "dnn/core/iterator"
43
+ require_relative "dnn/core/models"
44
+ require_relative "dnn/core/layers/basic_layers"
45
+ require_relative "dnn/core/layers/normalizations"
46
+ require_relative "dnn/core/layers/activations"
47
+ require_relative "dnn/core/layers/merge_layers"
48
+ require_relative "dnn/core/layers/split_layers"
49
+ require_relative "dnn/core/layers/cnn_layers"
50
+ require_relative "dnn/core/layers/embedding"
51
+ require_relative "dnn/core/layers/rnn_layers"
52
+ require_relative "dnn/core/layers/math_layers"
53
+ require_relative "dnn/core/optimizers"
54
+ require_relative "dnn/core/losses"
55
+ require_relative "dnn/core/initializers"
56
+ require_relative "dnn/core/regularizers"
57
+ require_relative "dnn/core/callbacks"
58
+ require_relative "dnn/core/savers"
59
+ require_relative "dnn/core/utils"
60
+ end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.3
4
+ version: 1.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-03-11 00:00:00.000000000 Z
11
+ date: 2023-03-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -139,6 +139,7 @@ files:
139
139
  - examples/dcgan/imgen.rb
140
140
  - examples/dcgan/train.rb
141
141
  - examples/iris_example.rb
142
+ - examples/iris_example_unused_model.rb
142
143
  - examples/judge-number/README.md
143
144
  - examples/judge-number/capture.PNG
144
145
  - examples/judge-number/convnet8.rb