ruby-dnn 0.15.3 → 0.16.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Rakefile +1 -9
- data/examples/api-examples/early_stopping_example.rb +1 -1
- data/examples/api-examples/initializer_example.rb +1 -1
- data/examples/api-examples/regularizer_example.rb +1 -1
- data/examples/api-examples/save_example.rb +1 -1
- data/examples/dcgan/dcgan.rb +3 -3
- data/examples/iris_example.rb +41 -17
- data/examples/mnist_define_by_run.rb +1 -1
- data/examples/pix2pix/dcgan.rb +157 -0
- data/examples/pix2pix/imgen.rb +27 -0
- data/examples/pix2pix/train.rb +52 -0
- data/lib/dnn.rb +2 -0
- data/lib/dnn/core/layers/activations.rb +37 -19
- data/lib/dnn/core/layers/basic_layers.rb +110 -25
- data/lib/dnn/core/layers/cnn_layers.rb +19 -21
- data/lib/dnn/core/layers/embedding.rb +3 -3
- data/lib/dnn/core/layers/math_layers.rb +169 -0
- data/lib/dnn/core/layers/merge_layers.rb +29 -24
- data/lib/dnn/core/layers/normalizations.rb +4 -2
- data/lib/dnn/core/layers/rnn_layers.rb +44 -36
- data/lib/dnn/core/link.rb +7 -2
- data/lib/dnn/core/losses.rb +54 -30
- data/lib/dnn/core/models.rb +47 -47
- data/lib/dnn/core/monkey_patch.rb +75 -0
- data/lib/dnn/core/optimizers.rb +10 -6
- data/lib/dnn/core/param.rb +17 -0
- data/lib/dnn/core/regularizers.rb +35 -33
- data/lib/dnn/core/tensor.rb +40 -0
- data/lib/dnn/core/utils.rb +1 -1
- data/lib/dnn/datasets/cifar10.rb +10 -9
- data/lib/dnn/datasets/cifar100.rb +10 -9
- data/lib/dnn/datasets/downloader.rb +1 -5
- data/lib/dnn/datasets/fashion-mnist.rb +4 -12
- data/lib/dnn/datasets/iris.rb +9 -9
- data/lib/dnn/datasets/mnist.rb +4 -12
- data/lib/dnn/datasets/stl-10.rb +6 -8
- data/lib/dnn/version.rb +1 -1
- data/ruby-dnn.gemspec +1 -1
- metadata +7 -5
- data/ext/cifar_loader/cifar_loader.c +0 -77
- data/ext/cifar_loader/extconf.rb +0 -3
@@ -1,6 +1,28 @@
|
|
1
1
|
module DNN
|
2
2
|
module Layers
|
3
3
|
|
4
|
+
module LayerNode
|
5
|
+
def forward(input_tensor)
|
6
|
+
x = input_tensor.data
|
7
|
+
prev_link = (input_tensor.is_a?(Tensor) ? input_tensor.link : input_tensor)
|
8
|
+
y = forward_node(x)
|
9
|
+
link = Link.new(prev_link, self)
|
10
|
+
Tensor.new(y, link)
|
11
|
+
end
|
12
|
+
|
13
|
+
def backward(dy)
|
14
|
+
backward_node(dy)
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward_node(x)
|
18
|
+
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward_node'"
|
19
|
+
end
|
20
|
+
|
21
|
+
def backward_node(dy)
|
22
|
+
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'backward_node'"
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
4
26
|
# Super class of all layer classes.
|
5
27
|
class Layer
|
6
28
|
attr_reader :input_shape
|
@@ -26,12 +48,9 @@ module DNN
|
|
26
48
|
# @param [Tensor] input_tensor Input tensor.
|
27
49
|
# @return [Tensor] Output tensor.
|
28
50
|
def call(input_tensor)
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
y = forward(x)
|
33
|
-
link = Link.new(prev_link, self)
|
34
|
-
Tensor.new(y, link)
|
51
|
+
input_tensor = Tensor.new(input_tensor) if !input_tensor.is_a?(Tensor) && !input_tensor.is_a?(Param)
|
52
|
+
build(input_tensor.data.shape[1..-1]) unless built?
|
53
|
+
forward(input_tensor)
|
35
54
|
end
|
36
55
|
|
37
56
|
# Build the layer.
|
@@ -47,17 +66,12 @@ module DNN
|
|
47
66
|
end
|
48
67
|
|
49
68
|
# Forward propagation.
|
50
|
-
# @param [
|
51
|
-
|
69
|
+
# @param [Tensor] input_tensor Input tensor.
|
70
|
+
# @return [Tensor] Output tensor.
|
71
|
+
def forward(input_tensor)
|
52
72
|
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
|
53
73
|
end
|
54
74
|
|
55
|
-
# Backward propagation.
|
56
|
-
# @param [Numo::SFloat] dy Differential value of output data.
|
57
|
-
def backward(dy)
|
58
|
-
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'backward'"
|
59
|
-
end
|
60
|
-
|
61
75
|
# Please reimplement this method as needed.
|
62
76
|
# The default implementation return input_shape.
|
63
77
|
# @return [Array] Return the shape of the output data.
|
@@ -76,6 +90,7 @@ module DNN
|
|
76
90
|
initialize
|
77
91
|
end
|
78
92
|
|
93
|
+
# Clean the layer state.
|
79
94
|
def clean
|
80
95
|
input_shape = @input_shape
|
81
96
|
hash = to_hash
|
@@ -120,6 +135,8 @@ module DNN
|
|
120
135
|
end
|
121
136
|
|
122
137
|
class InputLayer < Layer
|
138
|
+
include LayerNode
|
139
|
+
|
123
140
|
def self.call(input)
|
124
141
|
shape = input.is_a?(Tensor) ? input.data.shape : input.shape
|
125
142
|
new(shape[1..-1]).(input)
|
@@ -140,21 +157,21 @@ module DNN
|
|
140
157
|
x = input
|
141
158
|
prev_link = nil
|
142
159
|
end
|
143
|
-
Tensor.new(
|
160
|
+
Tensor.new(forward_node(x), Link.new(prev_link, self))
|
144
161
|
end
|
145
162
|
|
146
163
|
def build(input_shape)
|
147
164
|
@built = true
|
148
165
|
end
|
149
166
|
|
150
|
-
def
|
167
|
+
def forward_node(x)
|
151
168
|
unless x.shape[1..-1] == @input_shape
|
152
169
|
raise DNN_ShapeError, "The shape of x does not match the input shape. input shape is #{@input_shape}, but x shape is #{x.shape[1..-1]}."
|
153
170
|
end
|
154
171
|
x
|
155
172
|
end
|
156
173
|
|
157
|
-
def
|
174
|
+
def backward_node(dy)
|
158
175
|
dy
|
159
176
|
end
|
160
177
|
|
@@ -248,6 +265,8 @@ module DNN
|
|
248
265
|
end
|
249
266
|
|
250
267
|
class Dense < Connection
|
268
|
+
include LayerNode
|
269
|
+
|
251
270
|
attr_reader :num_nodes
|
252
271
|
|
253
272
|
# @param [Integer] num_nodes Number of nodes.
|
@@ -273,14 +292,14 @@ module DNN
|
|
273
292
|
init_weight_and_bias
|
274
293
|
end
|
275
294
|
|
276
|
-
def
|
295
|
+
def forward_node(x)
|
277
296
|
@x = x
|
278
297
|
y = x.dot(@weight.data)
|
279
298
|
y += @bias.data if @bias
|
280
299
|
y
|
281
300
|
end
|
282
301
|
|
283
|
-
def
|
302
|
+
def backward_node(dy)
|
284
303
|
if @trainable
|
285
304
|
@weight.grad += @x.transpose.dot(dy)
|
286
305
|
@bias.grad += dy.sum(0) if @bias
|
@@ -307,11 +326,13 @@ module DNN
|
|
307
326
|
end
|
308
327
|
|
309
328
|
class Flatten < Layer
|
310
|
-
|
329
|
+
include LayerNode
|
330
|
+
|
331
|
+
def forward_node(x)
|
311
332
|
x.reshape(x.shape[0], *output_shape)
|
312
333
|
end
|
313
334
|
|
314
|
-
def
|
335
|
+
def backward_node(dy)
|
315
336
|
dy.reshape(dy.shape[0], *@input_shape)
|
316
337
|
end
|
317
338
|
|
@@ -321,6 +342,8 @@ module DNN
|
|
321
342
|
end
|
322
343
|
|
323
344
|
class Reshape < Layer
|
345
|
+
include LayerNode
|
346
|
+
|
324
347
|
attr_reader :output_shape
|
325
348
|
|
326
349
|
def initialize(output_shape)
|
@@ -328,11 +351,11 @@ module DNN
|
|
328
351
|
@output_shape = output_shape
|
329
352
|
end
|
330
353
|
|
331
|
-
def
|
354
|
+
def forward_node(x)
|
332
355
|
x.reshape(x.shape[0], *@output_shape)
|
333
356
|
end
|
334
357
|
|
335
|
-
def
|
358
|
+
def backward_node(dy)
|
336
359
|
dy.reshape(dy.shape[0], *@input_shape)
|
337
360
|
end
|
338
361
|
|
@@ -345,7 +368,69 @@ module DNN
|
|
345
368
|
end
|
346
369
|
end
|
347
370
|
|
371
|
+
class Lasso < Layer
|
372
|
+
include LayerNode
|
373
|
+
|
374
|
+
attr_accessor :l1_lambda
|
375
|
+
|
376
|
+
# @param [Float] l1_lambda L1 regularizer coefficient.
|
377
|
+
def initialize(l1_lambda = 0.01)
|
378
|
+
super()
|
379
|
+
@l1_lambda = l1_lambda
|
380
|
+
end
|
381
|
+
|
382
|
+
def forward_node(x)
|
383
|
+
@x = x
|
384
|
+
@l1_lambda * x.abs.sum
|
385
|
+
end
|
386
|
+
|
387
|
+
def backward_node(dy)
|
388
|
+
dx = Xumo::SFloat.ones(*@x.shape)
|
389
|
+
dx[@x < 0] = -1
|
390
|
+
@l1_lambda * dx
|
391
|
+
end
|
392
|
+
|
393
|
+
def to_hash
|
394
|
+
super(l1_lambda: @l1_lambda)
|
395
|
+
end
|
396
|
+
|
397
|
+
def load_hash(hash)
|
398
|
+
initialize(hash[:l1_lambda])
|
399
|
+
end
|
400
|
+
end
|
401
|
+
|
402
|
+
class Ridge < Layer
|
403
|
+
include LayerNode
|
404
|
+
|
405
|
+
attr_accessor :l2_lambda
|
406
|
+
|
407
|
+
# @param [Float] l2_lambda L2 regularizer coefficient.
|
408
|
+
def initialize(l2_lambda = 0.01)
|
409
|
+
super()
|
410
|
+
@l2_lambda = l2_lambda
|
411
|
+
end
|
412
|
+
|
413
|
+
def forward_node(x)
|
414
|
+
@x = x
|
415
|
+
0.5 * @l2_lambda * (x**2).sum
|
416
|
+
end
|
417
|
+
|
418
|
+
def backward_node(dy)
|
419
|
+
@l2_lambda * @x
|
420
|
+
end
|
421
|
+
|
422
|
+
def to_hash
|
423
|
+
super(l2_lambda: @l2_lambda)
|
424
|
+
end
|
425
|
+
|
426
|
+
def load_hash(hash)
|
427
|
+
initialize(hash[:l2_lambda])
|
428
|
+
end
|
429
|
+
end
|
430
|
+
|
348
431
|
class Dropout < Layer
|
432
|
+
include LayerNode
|
433
|
+
|
349
434
|
attr_accessor :dropout_ratio
|
350
435
|
attr_reader :use_scale
|
351
436
|
|
@@ -361,7 +446,7 @@ module DNN
|
|
361
446
|
@rnd = Random.new(@seed)
|
362
447
|
end
|
363
448
|
|
364
|
-
def
|
449
|
+
def forward_node(x)
|
365
450
|
if DNN.learning_phase
|
366
451
|
Xumo::SFloat.srand(@rnd.rand(1 << 31))
|
367
452
|
@mask = Xumo::SFloat.new(*x.shape).rand < @dropout_ratio
|
@@ -372,7 +457,7 @@ module DNN
|
|
372
457
|
x
|
373
458
|
end
|
374
459
|
|
375
|
-
def
|
460
|
+
def backward_node(dy)
|
376
461
|
dy[@mask] = 0
|
377
462
|
dy
|
378
463
|
end
|
@@ -84,6 +84,7 @@ module DNN
|
|
84
84
|
end
|
85
85
|
|
86
86
|
class Conv2D < Connection
|
87
|
+
include LayerNode
|
87
88
|
include Conv2DUtils
|
88
89
|
|
89
90
|
attr_reader :num_filters
|
@@ -130,7 +131,7 @@ module DNN
|
|
130
131
|
@out_size = calc_conv2d_out_size(prev_h, prev_w, *@filter_size, *@pad_size, @strides)
|
131
132
|
end
|
132
133
|
|
133
|
-
def
|
134
|
+
def forward_node(x)
|
134
135
|
x = zero_padding(x, @pad_size) if @padding
|
135
136
|
@x_shape = x.shape
|
136
137
|
@col = im2col(x, *@out_size, *@filter_size, @strides)
|
@@ -139,7 +140,7 @@ module DNN
|
|
139
140
|
y.reshape(x.shape[0], *@out_size, y.shape[3])
|
140
141
|
end
|
141
142
|
|
142
|
-
def
|
143
|
+
def backward_node(dy)
|
143
144
|
dy = dy.reshape(dy.shape[0..2].reduce(:*), dy.shape[3])
|
144
145
|
if @trainable
|
145
146
|
@weight.grad += @col.transpose.dot(dy)
|
@@ -186,6 +187,7 @@ module DNN
|
|
186
187
|
end
|
187
188
|
|
188
189
|
class Conv2DTranspose < Connection
|
190
|
+
include LayerNode
|
189
191
|
include Conv2DUtils
|
190
192
|
|
191
193
|
attr_reader :num_filters
|
@@ -232,7 +234,7 @@ module DNN
|
|
232
234
|
@out_size = calc_conv2d_transpose_out_size(prev_h, prev_w, *@filter_size, *@pad_size, @strides)
|
233
235
|
end
|
234
236
|
|
235
|
-
def
|
237
|
+
def forward_node(x)
|
236
238
|
bsize = x.shape[0]
|
237
239
|
x = x.reshape(x.shape[0..2].reduce(:*), x.shape[3])
|
238
240
|
@x = x
|
@@ -243,7 +245,7 @@ module DNN
|
|
243
245
|
@padding ? zero_padding_bwd(y, @pad_size) : y
|
244
246
|
end
|
245
247
|
|
246
|
-
def
|
248
|
+
def backward_node(dy)
|
247
249
|
dy = zero_padding(dy, @pad_size) if @padding
|
248
250
|
col = im2col(dy, *input_shape[0..1], *@filter_size, @strides)
|
249
251
|
if @trainable
|
@@ -291,6 +293,7 @@ module DNN
|
|
291
293
|
|
292
294
|
# Super class of all pooling2D class.
|
293
295
|
class Pool2D < Layer
|
296
|
+
include LayerNode
|
294
297
|
include Conv2DUtils
|
295
298
|
|
296
299
|
attr_reader :pool_size
|
@@ -345,7 +348,9 @@ module DNN
|
|
345
348
|
end
|
346
349
|
|
347
350
|
class MaxPool2D < Pool2D
|
348
|
-
|
351
|
+
include LayerNode
|
352
|
+
|
353
|
+
def forward_node(x)
|
349
354
|
x = zero_padding(x, @pad_size) if @padding
|
350
355
|
@x_shape = x.shape
|
351
356
|
col = im2col(x, *@out_size, *@pool_size, @strides)
|
@@ -354,7 +359,7 @@ module DNN
|
|
354
359
|
col.max(1).reshape(x.shape[0], *@out_size, x.shape[3])
|
355
360
|
end
|
356
361
|
|
357
|
-
def
|
362
|
+
def backward_node(dy)
|
358
363
|
dmax = Xumo::SFloat.zeros(dy.size * @pool_size.reduce(:*))
|
359
364
|
dmax[@max_index.flatten] = dy.flatten
|
360
365
|
dcol = dmax.reshape(dy.shape[0..2].reduce(:*), @pool_size.reduce(:*) * dy.shape[3])
|
@@ -364,7 +369,9 @@ module DNN
|
|
364
369
|
end
|
365
370
|
|
366
371
|
class AvgPool2D < Pool2D
|
367
|
-
|
372
|
+
include LayerNode
|
373
|
+
|
374
|
+
def forward_node(x)
|
368
375
|
x = zero_padding(x, @pad_size) if @padding
|
369
376
|
@x_shape = x.shape
|
370
377
|
col = im2col(x, *@out_size, *@pool_size, @strides)
|
@@ -372,7 +379,7 @@ module DNN
|
|
372
379
|
col.mean(1).reshape(x.shape[0], *@out_size, x.shape[3])
|
373
380
|
end
|
374
381
|
|
375
|
-
def
|
382
|
+
def backward_node(dy)
|
376
383
|
row_length = @pool_size.reduce(:*)
|
377
384
|
dy /= row_length
|
378
385
|
davg = Xumo::SFloat.zeros(dy.size, row_length)
|
@@ -391,24 +398,15 @@ module DNN
|
|
391
398
|
raise DNN_ShapeError, "Input shape is #{input_shape}. But input shape must be 3 dimensional."
|
392
399
|
end
|
393
400
|
super
|
394
|
-
@avg_pool2d = AvgPool2D.new(input_shape[0..1])
|
395
|
-
@avg_pool2d.build(input_shape)
|
396
|
-
@flatten = Flatten.new
|
397
|
-
@flatten.build([1, 1, input_shape[2]])
|
398
401
|
end
|
399
402
|
|
400
403
|
def forward(x)
|
401
|
-
|
402
|
-
@flatten.forward(y)
|
403
|
-
end
|
404
|
-
|
405
|
-
def backward(dy)
|
406
|
-
dy = @flatten.backward(dy)
|
407
|
-
@avg_pool2d.backward(dy)
|
404
|
+
Flatten.(AvgPool2D.(x, input_shape[0..1]))
|
408
405
|
end
|
409
406
|
end
|
410
407
|
|
411
408
|
class UnPool2D < Layer
|
409
|
+
include LayerNode
|
412
410
|
include Conv2DUtils
|
413
411
|
|
414
412
|
attr_reader :unpool_size
|
@@ -432,7 +430,7 @@ module DNN
|
|
432
430
|
@num_channel = input_shape[2]
|
433
431
|
end
|
434
432
|
|
435
|
-
def
|
433
|
+
def forward_node(x)
|
436
434
|
@x_shape = x.shape
|
437
435
|
unpool_h, unpool_w = @unpool_size
|
438
436
|
x2 = Xumo::SFloat.zeros(x.shape[0], x.shape[1], unpool_h, x.shape[2], unpool_w, @num_channel)
|
@@ -444,7 +442,7 @@ module DNN
|
|
444
442
|
x2.reshape(x.shape[0], *@out_size, x.shape[3])
|
445
443
|
end
|
446
444
|
|
447
|
-
def
|
445
|
+
def backward_node(dy)
|
448
446
|
in_size = input_shape[0..1]
|
449
447
|
col = im2col(dy, *in_size, *@unpool_size, @unpool_size)
|
450
448
|
col = col.reshape(dy.shape[0] * in_size.reduce(:*), @unpool_size.reduce(:*), dy.shape[3])
|
@@ -24,7 +24,7 @@ module DNN
|
|
24
24
|
|
25
25
|
def call(input_tensor)
|
26
26
|
build(@input_shape) unless built?
|
27
|
-
Tensor.new(
|
27
|
+
Tensor.new(forward_node(input_tensor.data), Link.new(nil, self))
|
28
28
|
end
|
29
29
|
|
30
30
|
def build(input_shape)
|
@@ -34,7 +34,7 @@ module DNN
|
|
34
34
|
@weight_regularizer.param = @weight if @weight_regularizer
|
35
35
|
end
|
36
36
|
|
37
|
-
def
|
37
|
+
def forward_node(x)
|
38
38
|
@x = x
|
39
39
|
y = Xumo::SFloat.zeros(*x.shape)
|
40
40
|
x.shape[0].times do |i|
|
@@ -43,7 +43,7 @@ module DNN
|
|
43
43
|
y
|
44
44
|
end
|
45
45
|
|
46
|
-
def
|
46
|
+
def backward_node(dy)
|
47
47
|
@weight.grad += Xumo::SFloat.zeros(*@weight.data.shape)
|
48
48
|
@x.shape[0].times do |i|
|
49
49
|
@x.shape[1].times do |j|
|
@@ -0,0 +1,169 @@
|
|
1
|
+
module DNN
|
2
|
+
module Layers
|
3
|
+
|
4
|
+
class Add < MergeLayer
|
5
|
+
def forward_node(x1, x2)
|
6
|
+
x1 + x2
|
7
|
+
end
|
8
|
+
|
9
|
+
def backward_node(dy)
|
10
|
+
[dy, dy]
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
class Sub < MergeLayer
|
15
|
+
def forward_node(x1, x2)
|
16
|
+
x1 - x2
|
17
|
+
end
|
18
|
+
|
19
|
+
def backward_node(dy)
|
20
|
+
[dy, -dy]
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
class Mul < MergeLayer
|
25
|
+
def forward_node(x1, x2)
|
26
|
+
@x1, @x2 = x1, x2
|
27
|
+
x1 * x2
|
28
|
+
end
|
29
|
+
|
30
|
+
def backward_node(dy)
|
31
|
+
[dy * @x2, dy * @x1]
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
class Div < MergeLayer
|
36
|
+
def forward_node(x1, x2)
|
37
|
+
@x1, @x2 = x1, x2
|
38
|
+
x1 / x2
|
39
|
+
end
|
40
|
+
|
41
|
+
def backward_node(dy)
|
42
|
+
dx1 = dy / @x2
|
43
|
+
dx2 = dy * -(@x1 / @x2**2)
|
44
|
+
[dx1, dx2]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
class Dot < MergeLayer
|
49
|
+
def forward_node(x1, x2)
|
50
|
+
@x1, @x2 = x1, x2
|
51
|
+
x1.dot(x2)
|
52
|
+
end
|
53
|
+
|
54
|
+
def backward_node(dy)
|
55
|
+
[dy.dot(@x2.transpose), @x1.transpose.dot(dy)]
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
class Exp < Layer
|
60
|
+
include LayerNode
|
61
|
+
|
62
|
+
def forward_node(x)
|
63
|
+
@x = x
|
64
|
+
Xumo::NMath.exp(x)
|
65
|
+
end
|
66
|
+
|
67
|
+
def backward_node(dy)
|
68
|
+
dy * Xumo::NMath.exp(@x)
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
class Log < Layer
|
73
|
+
include LayerNode
|
74
|
+
|
75
|
+
def forward_node(x)
|
76
|
+
@x = x
|
77
|
+
Xumo::NMath.log(x)
|
78
|
+
end
|
79
|
+
|
80
|
+
def backward_node(dy)
|
81
|
+
dy / @x
|
82
|
+
end
|
83
|
+
end
|
84
|
+
|
85
|
+
class Pow < Layer
|
86
|
+
include LayerNode
|
87
|
+
|
88
|
+
def initialize(index)
|
89
|
+
super()
|
90
|
+
@index = index
|
91
|
+
end
|
92
|
+
|
93
|
+
def forward_node(x)
|
94
|
+
@x = x
|
95
|
+
x**@index
|
96
|
+
end
|
97
|
+
|
98
|
+
def backward_node(dy)
|
99
|
+
@index * @x**(@index - 1)
|
100
|
+
end
|
101
|
+
end
|
102
|
+
|
103
|
+
class Sqrt < Layer
|
104
|
+
include LayerNode
|
105
|
+
|
106
|
+
def forward_node(x)
|
107
|
+
@x = x
|
108
|
+
Xumo::NMath.sqrt(x)
|
109
|
+
end
|
110
|
+
|
111
|
+
def backward_node(dy)
|
112
|
+
dy * (1.0 / 2 * Xumo::NMath.sqrt(@x))
|
113
|
+
end
|
114
|
+
end
|
115
|
+
|
116
|
+
class Sum < Layer
|
117
|
+
include LayerNode
|
118
|
+
|
119
|
+
def initialize(axis: 0)
|
120
|
+
super()
|
121
|
+
@axis = axis
|
122
|
+
end
|
123
|
+
|
124
|
+
def forward_node(x)
|
125
|
+
if @axis
|
126
|
+
@dim = x.shape[@axis]
|
127
|
+
x.sum(axis: @axis, keepdims: true)
|
128
|
+
else
|
129
|
+
x.sum
|
130
|
+
end
|
131
|
+
end
|
132
|
+
|
133
|
+
def backward_node(dy)
|
134
|
+
dx = dy.clone
|
135
|
+
if @axis
|
136
|
+
(@dim - 1).times do
|
137
|
+
dx = dx.concatenate(dy, axis: @axis)
|
138
|
+
end
|
139
|
+
end
|
140
|
+
dx
|
141
|
+
end
|
142
|
+
end
|
143
|
+
|
144
|
+
class Mean < Layer
|
145
|
+
include LayerNode
|
146
|
+
|
147
|
+
def initialize(axis: 0)
|
148
|
+
super()
|
149
|
+
@axis = axis
|
150
|
+
end
|
151
|
+
|
152
|
+
def forward_node(x)
|
153
|
+
@dim = @axis ? x.shape[@axis] : x.size
|
154
|
+
x.mean(axis: @axis, keepdims: true)
|
155
|
+
end
|
156
|
+
|
157
|
+
def backward_node(dy)
|
158
|
+
dx = dy
|
159
|
+
if @axis
|
160
|
+
(@dim - 1).times do
|
161
|
+
dx = dx.concatenate(dy, axis: @axis)
|
162
|
+
end
|
163
|
+
end
|
164
|
+
dx / @dim
|
165
|
+
end
|
166
|
+
end
|
167
|
+
|
168
|
+
end
|
169
|
+
end
|