ruby-dnn 0.1.7 → 0.1.8

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d86f8c32dcd08b4a71ae4ce3892e0a4f0677d04ac32c0090b72c8f7fe008b591
4
- data.tar.gz: 4003068dfacd2739426e6c8d19411a47b9401aaf2c7c09f214962de3cdf57986
3
+ metadata.gz: a9bd42ca5cbea3b4a3abd984157df19cf87ca2a6ba80a3a76cf155e020582508
4
+ data.tar.gz: 2be2b617347618bf18d140e1ed4b5ede5d3a0f6540d6bc18217ae9b86c22aa2d
5
5
  SHA512:
6
- metadata.gz: 1699c885afb1494981784e435dd93c5f06c7f2041412eefacf0c2f24e700ab1a134a155dfc3babe6791b15c4936f15083463d32f3bd02f181c0f457273f4d4bd
7
- data.tar.gz: '09b397dba80f67bf9dc5a66a557d3e358b11c397776cad73b7cd4d2a1fbe601d93b1df0d5dfa9a3675736e4e3ab7386ad59bf71825689714439c5480c85feb3a'
6
+ metadata.gz: 9077ebdf09fbc33d5b163c9ef2bd9ec4bb6681b87e5f638db25d6fbc46603ba50a8a6cb81123a42e07984636d2055fae3ed54712f11df815d856854a5845973d
7
+ data.tar.gz: 5b3c56027727a7dee1f1cbae31c6ddd887d5ea7cf015c1d111942e6c518579ffc607c8c38c13702e77f86ad9b4428d27296287fd3c2b13e02a77992b5360900b
@@ -188,6 +188,12 @@ module DNN
188
188
  img2[true, true, pad...(ih + pad), pad...(iw + pad)] = img
189
189
  img2
190
190
  end
191
+
192
+ def back_padding(img, pad)
193
+ i_end = img.shape[2] - pad
194
+ j_end = img.shape[3] - pad
195
+ img[true, true, pad...i_end, pad...j_end]
196
+ end
191
197
  end
192
198
 
193
199
 
@@ -246,7 +252,8 @@ module DNN
246
252
  end
247
253
  @grads[:bias] = dout.sum(0)
248
254
  dcol = dout.dot(@params[:weight].transpose)
249
- col2im(dcol, @x_shape, @out_height, @out_width, @filter_height, @filter_width, @strides)
255
+ dx = col2im(dcol, @x_shape, @out_height, @out_width, @filter_height, @filter_width, @strides)
256
+ @padding ? back_padding(dx, @padding) : dx
250
257
  end
251
258
 
252
259
  def shape
@@ -293,11 +300,12 @@ module DNN
293
300
  super
294
301
  prev_height, prev_width = prev_layer.shape[1], prev_layer.shape[2]
295
302
  @num_channel = prev_layer.shape[0]
296
- @out_height = (prev_height - @pool_height) / @strides[0] + 1
297
- @out_width = (prev_width - @pool_width) / @strides[1] + 1
303
+ @out_height = (prev_height + @padding * 2 - @pool_height) / @strides[0] + 1
304
+ @out_width = (prev_width + @padding * 2 - @pool_width) / @strides[1] + 1
298
305
  end
299
306
 
300
307
  def forward(x)
308
+ x = padding(x, 2) if @padding > 0
301
309
  @x_shape = x.shape
302
310
  col = im2col(x, @out_height, @out_width, @pool_height, @pool_width, @strides)
303
311
  col = col.reshape(x.shape[0] * @out_height * @out_width * x.shape[1], @pool_height * @pool_width)
@@ -311,7 +319,8 @@ module DNN
311
319
  dmax = SFloat.zeros(dout.size * pool_size)
312
320
  dmax[@max_index] = dout.flatten
313
321
  dcol = dmax.reshape(dout.shape[0..2].reduce(:*), dout.shape[3] * pool_size)
314
- col2im(dcol, @x_shape, @out_height, @out_width, @pool_height, @pool_width, @strides)
322
+ dx = col2im(dcol, @x_shape, @out_height, @out_width, @pool_height, @pool_width, @strides)
323
+ @padding ? back_padding(dx, @padding) : dx
315
324
  end
316
325
 
317
326
  def shape
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.1.7"
2
+ VERSION = "0.1.8"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.7
4
+ version: 0.1.8
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-07-09 00:00:00.000000000 Z
11
+ date: 2018-07-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray