ruby-dnn 0.5.12 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 96c9ff5d01c5d85fc4aff092d5985db64e8848d8e0ce2da952afb28c49794d7c
4
- data.tar.gz: 7860cf4b52fa8742fe01b42039a91b183efe6b0d50a41c14533e844d3c687c95
3
+ metadata.gz: '08a74d28175b0e881a1cb7b93d9d973ac601e60498f1422890214d6888f7671a'
4
+ data.tar.gz: ebde049f70c4e43b90daae6c3eea3bafbf2bd04d9a8556726da46e6cb75ed313
5
5
  SHA512:
6
- metadata.gz: d5df63c654235145dcecc470b1510c2ad920c3fae83559888e79a30932cc0aea53cda0d11f6c4b732cece9a58dded024e5a9c8baf1f6df4433bb580dbd8f3139
7
- data.tar.gz: a73018f21c4432062ff7cac5fc4317667d28ee78607e9cf80573c8dc8d884cf98314cfd8f7eb284e5d26d623129541c2e3b9b2b89bf477f543748b4c245767c0
6
+ metadata.gz: 46a59448402d63fc8e594259d8bb084e029595acd58a73c7f9558bfe949029f12a406cf9b996b41822a093a3c130e9b7a1249b3b9db61a4ae11611df5ff91264
7
+ data.tar.gz: 31fd9ff92eb43999f30162650677ac470bec417845ca2cc613cf282b4840e4b5f3d698724d05392ac631c19b2a705ff0034c303684a88998996be8214231779f
@@ -1,8 +1,7 @@
1
1
  module DNN
2
2
  module Activations
3
- Layer = Layers::Layer
4
3
 
5
- class Sigmoid < Layer
4
+ class Sigmoid < Layers::Layer
6
5
  def forward(x)
7
6
  @out = 1 / (1 + Xumo::NMath.exp(-x))
8
7
  end
@@ -13,7 +12,7 @@ module DNN
13
12
  end
14
13
 
15
14
 
16
- class Tanh < Layer
15
+ class Tanh < Layers::Layer
17
16
  def forward(x)
18
17
  @out = Xumo::NMath.tanh(x)
19
18
  end
@@ -24,7 +23,7 @@ module DNN
24
23
  end
25
24
 
26
25
 
27
- class ReLU < Layer
26
+ class ReLU < Layers::Layer
28
27
  def forward(x)
29
28
  @x = x.clone
30
29
  x[x < 0] = 0
@@ -39,7 +38,7 @@ module DNN
39
38
  end
40
39
 
41
40
 
42
- class LeakyReLU < Layer
41
+ class LeakyReLU < Layers::Layer
43
42
  attr_reader :alpha
44
43
 
45
44
  def initialize(alpha = 0.3)
@@ -64,7 +63,7 @@ module DNN
64
63
  end
65
64
 
66
65
  def to_hash
67
- {name: self.class.name, alpha: alpha}
66
+ {class: self.class.name, alpha: alpha}
68
67
  end
69
68
  end
70
69
 
@@ -7,7 +7,7 @@ module DNN
7
7
  end
8
8
 
9
9
  def to_hash(merge_hash = nil)
10
- hash = {name: self.class.name}
10
+ hash = {class: self.class.name}
11
11
  hash.merge!(merge_hash) if merge_hash
12
12
  hash
13
13
  end
@@ -31,7 +31,7 @@ module DNN
31
31
 
32
32
  # Layer to a hash.
33
33
  def to_hash(merge_hash = nil)
34
- hash = {name: self.class.name}
34
+ hash = {class: self.class.name}
35
35
  hash.merge!(merge_hash) if merge_hash
36
36
  hash
37
37
  end
@@ -22,7 +22,6 @@ module DNN
22
22
  def initialize
23
23
  @layers = []
24
24
  @optimizer = nil
25
- @batch_size = nil
26
25
  @training = false
27
26
  @compiled = false
28
27
  end
@@ -103,13 +102,12 @@ module DNN
103
102
  unless compiled?
104
103
  raise DNN_Error.new("The model is not compiled.")
105
104
  end
106
- @batch_size = batch_size
107
105
  num_train_data = x.shape[0]
108
106
  (1..epochs).each do |epoch|
109
107
  puts "【 epoch #{epoch}/#{epochs} 】" if verbose
110
- (num_train_data.to_f / @batch_size).ceil.times do |index|
111
- x_batch, y_batch = Util.get_minibatch(x, y, @batch_size)
112
- loss = train_on_batch(x_batch, y_batch, @batch_size, &batch_proc)
108
+ (num_train_data.to_f / batch_size).ceil.times do |index|
109
+ x_batch, y_batch = Util.get_minibatch(x, y, batch_size)
110
+ loss = train_on_batch(x_batch, y_batch, &batch_proc)
113
111
  if loss.nan?
114
112
  puts "\nloss is nan" if verbose
115
113
  return
@@ -130,7 +128,7 @@ module DNN
130
128
  print log if verbose
131
129
  end
132
130
  if verbose && test
133
- acc = accurate(test[0], test[1], batch_size,&batch_proc)
131
+ acc = accurate(test[0], test[1], batch_size, &batch_proc)
134
132
  print " accurate: #{acc}"
135
133
  end
136
134
  puts "" if verbose
@@ -138,8 +136,7 @@ module DNN
138
136
  end
139
137
  end
140
138
 
141
- def train_on_batch(x, y, batch_size, &batch_proc)
142
- @batch_size = batch_size
139
+ def train_on_batch(x, y, &batch_proc)
143
140
  x, y = batch_proc.call(x, y) if batch_proc
144
141
  forward(x, true)
145
142
  loss = @layers[-1].loss(y)
@@ -148,27 +145,21 @@ module DNN
148
145
  loss
149
146
  end
150
147
 
151
- def accurate(x, y, batch_size = nil, &batch_proc)
152
- unless batch_size
153
- if @batch_size
154
- batch_size = @batch_size >= x.shape[0] ? @batch_size : x.shape[0]
155
- else
156
- batch_size = 1
157
- end
158
- end
148
+ def accurate(x, y, batch_size = 1, &batch_proc)
149
+ batch_size = batch_size >= x.shape[0] ? batch_size : x.shape[0]
159
150
  correct = 0
160
- (x.shape[0].to_f / @batch_size).ceil.times do |i|
161
- x_batch = Xumo::SFloat.zeros(@batch_size, *x.shape[1..-1])
162
- y_batch = Xumo::SFloat.zeros(@batch_size, *y.shape[1..-1])
163
- @batch_size.times do |j|
164
- k = i * @batch_size + j
151
+ (x.shape[0].to_f / batch_size).ceil.times do |i|
152
+ x_batch = Xumo::SFloat.zeros(batch_size, *x.shape[1..-1])
153
+ y_batch = Xumo::SFloat.zeros(batch_size, *y.shape[1..-1])
154
+ batch_size.times do |j|
155
+ k = i * batch_size + j
165
156
  break if k >= x.shape[0]
166
157
  x_batch[j, false] = x[k, false]
167
158
  y_batch[j, false] = y[k, false]
168
159
  end
169
160
  x_batch, y_batch = batch_proc.call(x_batch, y_batch) if batch_proc
170
161
  out = forward(x_batch, false)
171
- @batch_size.times do |j|
162
+ batch_size.times do |j|
172
163
  if @layers[-1].shape == [1]
173
164
  correct += 1 if out[j, 0].round == y_batch[j, 0].round
174
165
  else
@@ -13,7 +13,7 @@ module DNN
13
13
  def update(layer) end
14
14
 
15
15
  def to_hash(merge_hash = nil)
16
- hash = {name: self.class.name, learning_rate: @learning_rate}
16
+ hash = {class: self.class.name, learning_rate: @learning_rate}
17
17
  hash.merge!(merge_hash) if merge_hash
18
18
  hash
19
19
  end
@@ -6,6 +6,7 @@ module DNN
6
6
  include Initializers
7
7
  include Activations
8
8
 
9
+ attr_accessor :h
9
10
  attr_reader :num_nodes
10
11
  attr_reader :stateful
11
12
  attr_reader :weight_decay
@@ -29,7 +30,7 @@ module DNN
29
30
 
30
31
  def to_hash(merge_hash = nil)
31
32
  hash = {
32
- name: self.class.name,
33
+ class: self.class.name,
33
34
  num_nodes: @num_nodes,
34
35
  stateful: @stateful,
35
36
  return_sequences: @return_sequences,
@@ -48,7 +49,7 @@ module DNN
48
49
 
49
50
  def ridge
50
51
  if @weight_decay > 0
51
- 0.5 * (@weight_decay * (@params[:weight]**2).sum + @weight_decay * (@params[:weight]**2).sum)
52
+ 0.5 * (@weight_decay * ((@params[:weight]**2).sum + (@params[:weight2]**2).sum))
52
53
  else
53
54
  0
54
55
  end
@@ -219,6 +220,8 @@ module DNN
219
220
 
220
221
 
221
222
  class LSTM < RNN
223
+ attr_accessor :cell
224
+
222
225
  def self.load_hash(hash)
223
226
  self.new(hash[:num_nodes],
224
227
  stateful: hash[:stateful],
@@ -235,6 +238,7 @@ module DNN
235
238
  bias_initializer: nil,
236
239
  weight_decay: 0)
237
240
  super
241
+ @cell = nil
238
242
  end
239
243
 
240
244
  def forward(xs)
data/lib/dnn/core/util.rb CHANGED
@@ -24,7 +24,7 @@ module DNN
24
24
 
25
25
  # Convert hash to an object.
26
26
  def self.load_hash(hash)
27
- dnn_class = DNN.const_get(hash[:name])
27
+ dnn_class = DNN.const_get(hash[:class])
28
28
  if dnn_class.respond_to?(:load_hash)
29
29
  return dnn_class.load_hash(hash)
30
30
  end
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.5.12"
2
+ VERSION = "0.6.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.12
4
+ version: 0.6.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-08-15 00:00:00.000000000 Z
11
+ date: 2018-08-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray