ruby-dnn 0.13.0 → 0.13.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5a0f021b01f03d45c51e52a1147b05703f8da8895380560d4add4fd3f2fc3d90
4
- data.tar.gz: 2c5b3b04d23b0285a0acd1fbde3ca4f774948645d016c471264484f45230fa6f
3
+ metadata.gz: 58ef35a4277a86304c39350b743d308145da921fdc2db4308fa6ed208be47d93
4
+ data.tar.gz: b8238e52c849e222277284c20723c1a1ab72e70fed409ad1f2962a9ba93190b8
5
5
  SHA512:
6
- metadata.gz: 33758acb29f83b8523accd213293c84d4b887591926de256e38bd41895f86358c80ee64f2876cd4fb679a9a2c71a64929780149a324dd80020f5b6d397c9a608
7
- data.tar.gz: 6a8e95ca8c57a968a846562ad4fa921e382cef83c0839599651ab93ad2d7f6d6cb0c36ada945fb125c0b13e10aa1333030ce019b138146b00e19ed9882237ae7
6
+ metadata.gz: fd335af25c5d11745960364dc6096f18fbf78387358f5a1b24959328b001e87c3454132960133a74ba9cb81ac90f98e8d2663a9a9492b78d7a8558124471dc44
7
+ data.tar.gz: 132da9ac8fccee2c0de894543e2800b8066b75e746e4faed609594096e6d894a3938e69c0e1155056aef384dc522d39c647b9dfce8b9a5ceb659716f6ba961f0
@@ -66,9 +66,7 @@ module DNN
66
66
  end
67
67
 
68
68
  def backward(dy)
69
- dx = Xumo::SFloat.ones(@x.shape)
70
- dx[@x <= 0] = 0
71
- dy * dx
69
+ dy * Xumo::SFloat.cast(@x > 0)
72
70
  end
73
71
  end
74
72
 
@@ -353,15 +353,14 @@ module DNN
353
353
  x = zero_padding(x, @pad_size) if @padding
354
354
  @x_shape = x.shape
355
355
  col = im2col(x, *@out_size, *@pool_size, @strides)
356
- col = col.reshape(x.shape[0] * @out_size.reduce(:*), @pool_size.reduce(:*), x.shape[3]).transpose(0, 2, 1)
357
- .reshape(x.shape[0] * @out_size.reduce(:*) * x.shape[3], @pool_size.reduce(:*))
356
+ col = col.reshape(x.shape[0] * @out_size.reduce(:*), @pool_size.reduce(:*), x.shape[3])
358
357
  @max_index = col.max_index(1)
359
358
  col.max(1).reshape(x.shape[0], *@out_size, x.shape[3])
360
359
  end
361
360
 
362
361
  def backward(dy)
363
362
  dmax = Xumo::SFloat.zeros(dy.size * @pool_size.reduce(:*))
364
- dmax[@max_index] = dy.flatten
363
+ dmax[@max_index.flatten] = dy.flatten
365
364
  dcol = dmax.reshape(dy.shape[0..2].reduce(:*), @pool_size.reduce(:*) * dy.shape[3])
366
365
  dx = col2im(dcol, @x_shape, *@out_size, *@pool_size, @strides)
367
366
  @padding ? zero_padding_bwd(dx, @pad_size) : dx
@@ -374,8 +373,7 @@ module DNN
374
373
  x = zero_padding(x, @pad_size) if @padding
375
374
  @x_shape = x.shape
376
375
  col = im2col(x, *@out_size, *@pool_size, @strides)
377
- col = col.reshape(x.shape[0] * @out_size.reduce(:*), @pool_size.reduce(:*), x.shape[3]).transpose(0, 2, 1)
378
- .reshape(x.shape[0] * @out_size.reduce(:*) * x.shape[3], @pool_size.reduce(:*))
376
+ col = col.reshape(x.shape[0] * @out_size.reduce(:*), @pool_size.reduce(:*), x.shape[3])
379
377
  col.mean(1).reshape(x.shape[0], *@out_size, x.shape[3])
380
378
  end
381
379
 
@@ -436,8 +434,7 @@ module DNN
436
434
  def backward(dy)
437
435
  in_size = input_shape[0..1]
438
436
  col = im2col(dy, *in_size, *@unpool_size, @unpool_size)
439
- col = col.reshape(dy.shape[0] * in_size.reduce(:*), @unpool_size.reduce(:*), dy.shape[3]).transpose(0, 2, 1)
440
- .reshape(dy.shape[0] * in_size.reduce(:*) * dy.shape[3], @unpool_size.reduce(:*))
437
+ col = col.reshape(dy.shape[0] * in_size.reduce(:*), @unpool_size.reduce(:*), dy.shape[3])
441
438
  col.sum(1).reshape(dy.shape[0], *in_size, dy.shape[3])
442
439
  end
443
440
 
@@ -120,7 +120,7 @@ module DNN
120
120
  end
121
121
 
122
122
  def self.softmax(y)
123
- Xumo::NMath.exp(y) / Xumo::NMath.exp(y).sum(1).reshape(y.shape[0], 1)
123
+ Xumo::NMath.exp(y) / Xumo::NMath.exp(y).sum(1, keepdims: true)
124
124
  end
125
125
 
126
126
  # @param [Float] eps Value to avoid nan.
@@ -81,7 +81,7 @@ module DNN
81
81
  end
82
82
  dxs = Xumo::SFloat.zeros(@xs_shape)
83
83
  dh = 0
84
- (0...dh2s.shape[1]).to_a.reverse.each do |t|
84
+ (dh2s.shape[1] - 1).downto(0) do |t|
85
85
  dh2 = dh2s[true, t, false]
86
86
  dx, dh = @layers[t].backward(dh2 + dh)
87
87
  dxs[true, t, false] = dx
@@ -341,7 +341,7 @@ module DNN
341
341
  dxs = Xumo::SFloat.zeros(@xs_shape)
342
342
  dh = 0
343
343
  dc = 0
344
- (0...dh2s.shape[1]).to_a.reverse.each do |t|
344
+ (dh2s.shape[1] - 1).downto(0) do |t|
345
345
  dh2 = dh2s[true, t, false]
346
346
  dx, dh, dc = @layers[t].backward(dh2 + dh, dc)
347
347
  dxs[true, t, false] = dx
@@ -55,10 +55,10 @@ module DNN
55
55
  end
56
56
 
57
57
  def base64_to_params_data(base64_params_data)
58
- params_data = base64_params_data.map { |key, (shape, base64_data)|
58
+ params_data = base64_params_data.to_h do |key, (shape, base64_data)|
59
59
  bin = Base64.decode64(base64_data)
60
60
  [key, Xumo::SFloat.from_binary(bin).reshape(*shape)]
61
- }.to_h
61
+ end
62
62
  set_all_params_data(params_data)
63
63
  end
64
64
  end
@@ -94,7 +94,7 @@ module DNN
94
94
  all_params = @model.has_param_layers.uniq.map { |layer|
95
95
  layer.get_params.values
96
96
  }.flatten
97
- all_params.map { |param| [param.name, param.data] }.to_h
97
+ all_params.to_h { |param| [param.name, param.data] }
98
98
  end
99
99
  end
100
100
 
@@ -127,10 +127,10 @@ module DNN
127
127
  end
128
128
 
129
129
  def params_data_to_base64
130
- get_all_params_data.map { |key, data|
130
+ get_all_params_data.to_h do |key, data|
131
131
  base64_data = Base64.encode64(data.to_binary)
132
132
  [key, [data.shape, base64_data]]
133
- }.to_h
133
+ end
134
134
  end
135
135
  end
136
136
 
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.13.0"
2
+ VERSION = "0.13.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.13.0
4
+ version: 0.13.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-09-15 00:00:00.000000000 Z
11
+ date: 2019-09-21 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray