ruby-dnn 0.8.8 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/API-Reference.ja.md +83 -46
- data/examples/cifar10_example.rb +5 -5
- data/examples/mnist_conv2d_example.rb +5 -5
- data/examples/mnist_example.rb +5 -5
- data/examples/mnist_lstm_example.rb +5 -5
- data/examples/xor_example.rb +4 -3
- data/lib/dnn.rb +3 -3
- data/lib/dnn/core/activations.rb +1 -112
- data/lib/dnn/core/cnn_layers.rb +14 -14
- data/lib/dnn/core/dataset.rb +18 -0
- data/lib/dnn/core/initializers.rb +28 -8
- data/lib/dnn/core/layers.rb +62 -90
- data/lib/dnn/core/losses.rb +120 -0
- data/lib/dnn/core/model.rb +124 -66
- data/lib/dnn/core/rnn_layers.rb +17 -13
- data/lib/dnn/core/{util.rb → utils.rb} +10 -6
- data/lib/dnn/version.rb +1 -1
- metadata +5 -3
data/lib/dnn/core/rnn_layers.rb
CHANGED
@@ -60,6 +60,10 @@ module DNN
|
|
60
60
|
dxs
|
61
61
|
end
|
62
62
|
|
63
|
+
def output_shape
|
64
|
+
@return_sequences ? [@time_length, @num_nodes] : [@num_nodes]
|
65
|
+
end
|
66
|
+
|
63
67
|
def to_hash(merge_hash = nil)
|
64
68
|
hash = {
|
65
69
|
num_nodes: @num_nodes,
|
@@ -94,7 +98,7 @@ module DNN
|
|
94
98
|
end
|
95
99
|
end
|
96
100
|
|
97
|
-
def
|
101
|
+
def d_lasso
|
98
102
|
if @l1_lambda > 0
|
99
103
|
dlasso = Xumo::SFloat.ones(*@weight.data.shape)
|
100
104
|
dlasso[@weight.data < 0] = -1
|
@@ -105,7 +109,7 @@ module DNN
|
|
105
109
|
end
|
106
110
|
end
|
107
111
|
|
108
|
-
def
|
112
|
+
def d_ridge
|
109
113
|
if @l2_lambda > 0
|
110
114
|
@weight.grad += l2_lambda * @weight.data
|
111
115
|
@weight2.grad += l2_lambda * @weight2.data
|
@@ -115,7 +119,7 @@ module DNN
|
|
115
119
|
private
|
116
120
|
|
117
121
|
def init_params
|
118
|
-
@time_length =
|
122
|
+
@time_length = @input_shape[0]
|
119
123
|
end
|
120
124
|
end
|
121
125
|
|
@@ -154,9 +158,9 @@ module DNN
|
|
154
158
|
simple_rnn = self.new(hash[:num_nodes],
|
155
159
|
stateful: hash[:stateful],
|
156
160
|
return_sequences: hash[:return_sequences],
|
157
|
-
activation:
|
158
|
-
weight_initializer:
|
159
|
-
bias_initializer:
|
161
|
+
activation: Utils.load_hash(hash[:activation]),
|
162
|
+
weight_initializer: Utils.load_hash(hash[:weight_initializer]),
|
163
|
+
bias_initializer: Utils.load_hash(hash[:bias_initializer]),
|
160
164
|
l1_lambda: hash[:l1_lambda],
|
161
165
|
l2_lambda: hash[:l2_lambda])
|
162
166
|
simple_rnn
|
@@ -188,7 +192,7 @@ module DNN
|
|
188
192
|
|
189
193
|
def init_params
|
190
194
|
super()
|
191
|
-
num_prev_nodes =
|
195
|
+
num_prev_nodes = @input_shape[1]
|
192
196
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes)
|
193
197
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes)
|
194
198
|
@bias.data = Xumo::SFloat.new(@num_nodes)
|
@@ -259,8 +263,8 @@ module DNN
|
|
259
263
|
lstm = self.new(hash[:num_nodes],
|
260
264
|
stateful: hash[:stateful],
|
261
265
|
return_sequences: hash[:return_sequences],
|
262
|
-
weight_initializer:
|
263
|
-
bias_initializer:
|
266
|
+
weight_initializer: Utils.load_hash(hash[:weight_initializer]),
|
267
|
+
bias_initializer: Utils.load_hash(hash[:bias_initializer]),
|
264
268
|
l1_lambda: hash[:l1_lambda],
|
265
269
|
l2_lambda: hash[:l2_lambda])
|
266
270
|
lstm
|
@@ -327,7 +331,7 @@ module DNN
|
|
327
331
|
|
328
332
|
def init_params
|
329
333
|
super()
|
330
|
-
num_prev_nodes =
|
334
|
+
num_prev_nodes = @input_shape[1]
|
331
335
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 4)
|
332
336
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 4)
|
333
337
|
@bias.data = Xumo::SFloat.new(@num_nodes * 4)
|
@@ -402,8 +406,8 @@ module DNN
|
|
402
406
|
gru = self.new(hash[:num_nodes],
|
403
407
|
stateful: hash[:stateful],
|
404
408
|
return_sequences: hash[:return_sequences],
|
405
|
-
weight_initializer:
|
406
|
-
bias_initializer:
|
409
|
+
weight_initializer: Utils.load_hash(hash[:weight_initializer]),
|
410
|
+
bias_initializer: Utils.load_hash(hash[:bias_initializer]),
|
407
411
|
l1_lambda: hash[:l1_lambda],
|
408
412
|
l2_lambda: hash[:l2_lambda])
|
409
413
|
gru
|
@@ -423,7 +427,7 @@ module DNN
|
|
423
427
|
|
424
428
|
def init_params
|
425
429
|
super()
|
426
|
-
num_prev_nodes =
|
430
|
+
num_prev_nodes = @input_shape[1]
|
427
431
|
@weight.data = Xumo::SFloat.new(num_prev_nodes, @num_nodes * 3)
|
428
432
|
@weight2.data = Xumo::SFloat.new(@num_nodes, @num_nodes * 3)
|
429
433
|
@bias.data = Xumo::SFloat.new(@num_nodes * 3)
|
@@ -1,11 +1,7 @@
|
|
1
1
|
module DNN
|
2
2
|
# This module provides utility functions.
|
3
|
-
module
|
4
|
-
|
5
|
-
def self.get_minibatch(x, y, batch_size)
|
6
|
-
indexes = (0...x.shape[0]).to_a.sample(batch_size)
|
7
|
-
[x[indexes, false], y[indexes, false]]
|
8
|
-
end
|
3
|
+
module Utils
|
4
|
+
NMath = Xumo::NMath
|
9
5
|
|
10
6
|
# Categorize labels into "num_classes" classes.
|
11
7
|
def self.to_categorical(y, num_classes, narray_type = nil)
|
@@ -25,5 +21,13 @@ module DNN
|
|
25
21
|
end
|
26
22
|
dnn_class.new
|
27
23
|
end
|
24
|
+
|
25
|
+
def self.sigmoid(x)
|
26
|
+
1 / (1 + NMath.exp(-x))
|
27
|
+
end
|
28
|
+
|
29
|
+
def self.softmax(x)
|
30
|
+
NMath.exp(x) / NMath.exp(x).sum(1).reshape(x.shape[0], 1)
|
31
|
+
end
|
28
32
|
end
|
29
33
|
end
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-05-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -112,14 +112,16 @@ files:
|
|
112
112
|
- lib/dnn.rb
|
113
113
|
- lib/dnn/core/activations.rb
|
114
114
|
- lib/dnn/core/cnn_layers.rb
|
115
|
+
- lib/dnn/core/dataset.rb
|
115
116
|
- lib/dnn/core/error.rb
|
116
117
|
- lib/dnn/core/initializers.rb
|
117
118
|
- lib/dnn/core/layers.rb
|
119
|
+
- lib/dnn/core/losses.rb
|
118
120
|
- lib/dnn/core/model.rb
|
119
121
|
- lib/dnn/core/optimizers.rb
|
120
122
|
- lib/dnn/core/param.rb
|
121
123
|
- lib/dnn/core/rnn_layers.rb
|
122
|
-
- lib/dnn/core/
|
124
|
+
- lib/dnn/core/utils.rb
|
123
125
|
- lib/dnn/lib/cifar10.rb
|
124
126
|
- lib/dnn/lib/downloader.rb
|
125
127
|
- lib/dnn/lib/image.rb
|