ruby-dnn 0.8.8 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/API-Reference.ja.md +83 -46
- data/examples/cifar10_example.rb +5 -5
- data/examples/mnist_conv2d_example.rb +5 -5
- data/examples/mnist_example.rb +5 -5
- data/examples/mnist_lstm_example.rb +5 -5
- data/examples/xor_example.rb +4 -3
- data/lib/dnn.rb +3 -3
- data/lib/dnn/core/activations.rb +1 -112
- data/lib/dnn/core/cnn_layers.rb +14 -14
- data/lib/dnn/core/dataset.rb +18 -0
- data/lib/dnn/core/initializers.rb +28 -8
- data/lib/dnn/core/layers.rb +62 -90
- data/lib/dnn/core/losses.rb +120 -0
- data/lib/dnn/core/model.rb +124 -66
- data/lib/dnn/core/rnn_layers.rb +17 -13
- data/lib/dnn/core/{util.rb → utils.rb} +10 -6
- data/lib/dnn/version.rb +1 -1
- metadata +5 -3
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -60,6 +60,10 @@ module DNN
|
|
60
60
|
dxs
|
61
61
|
end
|
62
62
|
|
63
|
+
def output_shape
|
64
|
+
@return_sequences ? [@time_length, @num_nodes] : [@num_nodes]
|
65
|
+
end
|
66
|
+
|
63
67
|
def to_hash(merge_hash = nil)
|
64
68
|
hash = {
|
65
69
|
num_nodes: @num_nodes,
|
@@ -94,7 +98,7 @@ module DNN
|
|
94
98
|
end
|
95
99
|
end
|
96
100
|
|
97
|
-
def
|
101
|
+
def d_lasso
|
98
102
|
if @l1_lambda > 0
|
99
103
|
dlasso = Xumo::SFloat.ones(*@weight.data.shape)
|
100
104
|
dlasso[@weight.data < 0] = -1
|
@@ -105,7 +109,7 @@ module DNN
|
|
105
109
|
end
|
106
110
|
end
|
107
111
|
|
108
|
-
def
|
112
|
+
def d_ridge
|
109
113
|
if @l2_lambda > 0
|
110
114
|
@weight.grad += l2_lambda * @weight.data
|
111
115
|
@weight2.grad += l2_lambda * @weight2.data
|
@@ -115,7 +119,7 @@ module DNN
|
|
115
119
|
private
|
116
120
|
|
117
121
|
def init_params
|
118
|
-
@time_length =
|
122
|
+
@time_length = @input_shape[0]
|
119
123
|
end
|
120
124
|
end
|
121
125
|
|
@@ -154,9 +158,9 @@ module DNN
|
|
154
158
|
simple_rnn = self.new(hash[:num_nodes],
|
155
159
|
stateful: hash[:stateful],
|
156
160
|
return_sequences: hash[:return_sequences],
|
157
|
-
activation:
|
158
|
-
weight_initializer:
|
159
|
-
bias_initializer:
|
161
|
+
activation: Utils.load_hash(hash[:activation]),
|
162
|
+
weight_initializer: Utils.load_hash(hash[:weight_initializer]),
|
163
|
+
bias_initializer: Utils.load_hash(hash[:bias_initializer]),
|
160
164
|
l1_lambda: hash[:l1_lambda],
|
161
165
|
l2_lambda: hash[:l2_lambda])
|
162
166
|
simple_rnn
|
@@ -188,7 +192,7 @@ module DNN
|
|
188
192
|
|
189
193
|
def init_params
|
190
194
|
super()
|
191
|
-
num_prev_nodes =
|
195
|
+
num_prev_nodes = @input_shape[1]
|
192
196
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
|
193
197
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
|
194
198
|
@bias.data = Xumo::SFloat.new(@num_nodes)
|
@@ -259,8 +263,8 @@ module DNN
|
|
259
263
|
lstm = self.new(hash[:num_nodes],
|
260
264
|
stateful: hash[:stateful],
|
261
265
|
return_sequences: hash[:return_sequences],
|
262
|
-
weight_initializer:
|
263
|
-
bias_initializer:
|
266
|
+
weight_initializer: Utils.load_hash(hash[:weight_initializer]),
|
267
|
+
bias_initializer: Utils.load_hash(hash[:bias_initializer]),
|
264
268
|
l1_lambda: hash[:l1_lambda],
|
265
269
|
l2_lambda: hash[:l2_lambda])
|
266
270
|
lstm
|
@@ -327,7 +331,7 @@ module DNN
|
|
327
331
|
|
328
332
|
def init_params
|
329
333
|
super()
|
330
|
-
num_prev_nodes =
|
334
|
+
num_prev_nodes = @input_shape[1]
|
331
335
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
|
332
336
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
|
333
337
|
@bias.data = Xumo::SFloat.new(@num_nodes * 4)
|
@@ -402,8 +406,8 @@ module DNN
|
|
402
406
|
gru = self.new(hash[:num_nodes],
|
403
407
|
stateful: hash[:stateful],
|
404
408
|
return_sequences: hash[:return_sequences],
|
405
|
-
weight_initializer:
|
406
|
-
bias_initializer:
|
409
|
+
weight_initializer: Utils.load_hash(hash[:weight_initializer]),
|
410
|
+
bias_initializer: Utils.load_hash(hash[:bias_initializer]),
|
407
411
|
l1_lambda: hash[:l1_lambda],
|
408
412
|
l2_lambda: hash[:l2_lambda])
|
409
413
|
gru
|
@@ -423,7 +427,7 @@ module DNN
|
|
423
427
|
|
424
428
|
def init_params
|
425
429
|
super()
|
426
|
-
num_prev_nodes =
|
430
|
+
num_prev_nodes = @input_shape[1]
|
427
431
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
|
428
432
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
|
429
433
|
@bias.data = Xumo::SFloat.new(@num_nodes * 3)
|
@@ -1,11 +1,7 @@
|
|
1
1
|
module DNN
|
2
2
|
# This module provides utility functions.
|
3
|
-
module
|
4
|
-
|
5
|
-
def self.get_minibatch(x, y, batch_size)
|
6
|
-
indexes = (0...x.shape[0]).to_a.sample(batch_size)
|
7
|
-
[x[indexes, false], y[indexes, false]]
|
8
|
-
end
|
3
|
+
module Utils
|
4
|
+
NMath = Xumo::NMath
|
9
5
|
|
10
6
|
# Categorize labels into "num_classes" classes.
|
11
7
|
def self.to_categorical(y, num_classes, narray_type = nil)
|
@@ -25,5 +21,13 @@ module DNN
|
|
25
21
|
end
|
26
22
|
dnn_class.new
|
27
23
|
end
|
24
|
+
|
25
|
+
def self.sigmoid(x)
|
26
|
+
1 / (1 + NMath.exp(-x))
|
27
|
+
end
|
28
|
+
|
29
|
+
def self.softmax(x)
|
30
|
+
NMath.exp(x) / NMath.exp(x).sum(1).reshape(x.shape[0], 1)
|
31
|
+
end
|
28
32
|
end
|
29
33
|
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-05-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -112,14 +112,16 @@ files:
|
|
112
112
|
- lib/dnn.rb
|
113
113
|
- lib/dnn/core/activations.rb
|
114
114
|
- lib/dnn/core/cnn_layers.rb
|
115
|
+
- lib/dnn/core/dataset.rb
|
115
116
|
- lib/dnn/core/error.rb
|
116
117
|
- lib/dnn/core/initializers.rb
|
117
118
|
- lib/dnn/core/layers.rb
|
119
|
+
- lib/dnn/core/losses.rb
|
118
120
|
- lib/dnn/core/model.rb
|
119
121
|
- lib/dnn/core/optimizers.rb
|
120
122
|
- lib/dnn/core/param.rb
|
121
123
|
- lib/dnn/core/rnn_layers.rb
|
122
|
-
- lib/dnn/core/
|
124
|
+
- lib/dnn/core/utils.rb
|
123
125
|
- lib/dnn/lib/cifar10.rb
|
124
126
|
- lib/dnn/lib/downloader.rb
|
125
127
|
- lib/dnn/lib/image.rb
|