ruby-dnn 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/API-Reference.ja.md +17 -17
  3. data/LIB-API-Reference.ja.md +9 -9
  4. data/bin/console +1 -1
  5. data/examples/cifar10_example.rb +12 -2
  6. data/lib/dnn/core/layers.rb +46 -48
  7. data/lib/dnn/ext/dataset_loader/dataset_loader.c +4 -4
  8. data/lib/dnn/ext/rb_stb_image/extconf.rb +3 -0
  9. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/compat.h +0 -0
  10. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/extconf.h +0 -0
  11. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/intern.h +0 -0
  12. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/narray.h +0 -0
  13. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/ndloop.h +0 -0
  14. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/template.h +0 -0
  15. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/bit.h +0 -0
  16. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/complex.h +0 -0
  17. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/complex_macro.h +0 -0
  18. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/dcomplex.h +0 -0
  19. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/dfloat.h +0 -0
  20. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/float_def.h +0 -0
  21. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/float_macro.h +0 -0
  22. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/int16.h +0 -0
  23. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/int32.h +0 -0
  24. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/int64.h +0 -0
  25. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/int8.h +0 -0
  26. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/int_macro.h +0 -0
  27. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/real_accum.h +0 -0
  28. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/robj_macro.h +0 -0
  29. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/robject.h +0 -0
  30. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/scomplex.h +0 -0
  31. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/sfloat.h +0 -0
  32. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/uint16.h +0 -0
  33. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/uint32.h +0 -0
  34. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/uint64.h +0 -0
  35. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/uint8.h +0 -0
  36. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/uint_macro.h +0 -0
  37. data/lib/dnn/ext/{image_io → rb_stb_image}/numo/types/xint_macro.h +0 -0
  38. data/lib/dnn/ext/rb_stb_image/rb_stb_image.c +99 -0
  39. data/lib/dnn/ext/{image_io → rb_stb_image}/stb_image.h +0 -0
  40. data/lib/dnn/ext/{image_io → rb_stb_image}/stb_image_write.h +0 -0
  41. data/lib/dnn/lib/cifar10.rb +6 -2
  42. data/lib/dnn/lib/image_io.rb +15 -11
  43. data/lib/dnn/version.rb +1 -1
  44. data/ruby-dnn.gemspec +1 -1
  45. metadata +36 -36
  46. data/lib/dnn/ext/image_io/extconf.rb +0 -3
  47. data/lib/dnn/ext/image_io/image_io_ext.c +0 -89
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: abf5548b3a0a11d3935f476aa172bab5551598a8c411250ef4816ccca1362d7c
4
- data.tar.gz: 1f96a83f5f1ed1e72f10afd1942523f51a80eb32805ee16d10ecc954e7ace669
3
+ metadata.gz: 9f0d01c2a4696fdf876c6f247a7ea7e41bd57f4d42d2e0065a3703c4faf5a7cc
4
+ data.tar.gz: 29bcebf1ae90f3b6f72dd391fbba9118162dc49f490b801d1e634f736cc580a4
5
5
  SHA512:
6
- metadata.gz: 5cb4bfa7d9bea04f287c14bc1e916fa8af45ecf7f3866632d8bd5cf7c3b6fe2a43a64660aade21666b6937ebee2151aa7bd3dbe1f3aa7ea092678d9a9d2c1294
7
- data.tar.gz: 2b6887c4b479861956033e3794d9c4415aaf66c572d165c9910addf73f90b3db2d48c51d73178c7f03591ae6bd35080193c8c11479f436332a1eb5e5f59d3cf9
6
+ metadata.gz: b7b01228b299fbf42ec74f0d4678e6ffd62775a606ecadba3253ce76fd62166699851450d359f13022b01023219b8f710ef75c75af6c1096591df180a49edd62
7
+ data.tar.gz: 2eb11ae75136aa6be186d81c1fe8c04d585167ce31feccdd26406d58cafd6f333080c871723ad533e9f3ad0fb677c08180ee9925e99040f093f90a08a7f24631
data/API-Reference.ja.md CHANGED
@@ -2,7 +2,7 @@
2
2
  ruby-dnnのAPIリファレンスです。このリファレンスでは、APIを利用するうえで必要となるクラスとメソッドしか記載していません。
3
3
  そのため、プログラムの詳細が必要な場合は、ソースコードを参照してください。
4
4
 
5
- 最終更新バージョン:0.3.2
5
+ 最終更新バージョン:0.4.0
6
6
 
7
7
  # module DNN
8
8
  ruby-dnnの名前空間をなすモジュールです。
@@ -287,13 +287,13 @@ Integer
287
287
 
288
288
  ## attr_reader :filter_size
289
289
  Array
290
- フィルターの横と縦の長さ。
291
- [Integer width, Integer height]の形式で取得します。
290
+ フィルターの縦と横の長さ。
291
+ [Integer height, Integer width]の形式で取得します。
292
292
 
293
293
  ## attr_reader :strides
294
294
  Array
295
295
  畳み込みを行う際のストライドの単位。
296
- [Integer width, Integer height]の形式で取得します。
296
+ [Integer height, Integer width]の形式で取得します。
297
297
 
298
298
  ## attr_reader :weight_decay
299
299
  Float
@@ -307,8 +307,8 @@ Float
307
307
  * Integer num_filters
308
308
  出力するフィルターの枚数。
309
309
  * Integer | Array filter_size
310
- フィルターの横と縦の長さ。
311
- Arrayで指定する場合、[Integer width, Integer height]の形式で指定します。
310
+ フィルターの縦と横の長さ。
311
+ Arrayで指定する場合、[Integer height, Integer width]の形式で指定します。
312
312
  * Initializer weight_initializer: nil
313
313
  重みの初期化に使用するイニシャライザーを設定します
314
314
  nilを指定すると、RandomNormalイニシャライザーが使用されます。
@@ -316,7 +316,7 @@ nilを指定すると、RandomNormalイニシャライザーが使用されま
316
316
  バイアスの初期化に使用するイニシャライザーを設定します。
317
317
  * Array<Integer> strides: 1
318
318
  畳み込みを行う際のストライドの単位を指定します。
319
- Arrayで指定する場合、[Integer width, Integer height]の形式で指定します。
319
+ Arrayで指定する場合、[Integer height, Integer width]の形式で指定します。
320
320
  * bool padding: true
321
321
  イメージに対してゼロパディングを行うか否かを設定します。trueを設定すると、出力されるイメージのサイズが入力されたイメージと同じになるように
322
322
  ゼロパディングを行います。
@@ -331,13 +331,13 @@ maxプーリングを行うレイヤーです。
331
331
 
332
332
  ## attr_reader :pool_size
333
333
  Array
334
- プーリングを行う横と縦の長さ。
335
- [Integer width, Integer height]の形式で取得します。
334
+ プーリングを行う縦と横の長さ。
335
+ [Integer height, Integer width]の形式で取得します。
336
336
 
337
337
  ## attr_reader :strides
338
338
  Array
339
339
  畳み込みを行う際のストライドの単位。
340
- [Integer width, Integer height]の形式で取得します。
340
+ [Integer height, Integer width]の形式で取得します。
341
341
 
342
342
  ## 【Instance methods】
343
343
 
@@ -345,11 +345,11 @@ Array
345
345
  コンストラクタ。
346
346
  ### arguments
347
347
  * Integer | Array pool_size
348
- プーリングを行う横と縦の長さ。
349
- Arrayで指定する場合、[Integer width, Integer height]の形式で指定します。
348
+ プーリングを行う縦と横の長さ。
349
+ Arrayで指定する場合、[Integer height, Integer width]の形式で指定します。
350
350
  * Array<Integer> strides: nil
351
351
  畳み込みを行う際のストライドの単位を指定します。
352
- Arrayで指定する場合、[Integer width, Integer height]の形式で指定します。
352
+ Arrayで指定する場合、[Integer height, Integer width]の形式で指定します。
353
353
  なお、nilが設定された場合は、pool_sizeがstridesの値となります。
354
354
  * bool padding: true
355
355
  イメージに対してゼロパディングを行うか否かを設定します。trueを設定すると、出力されるイメージのサイズが入力されたイメージと同じになるように
@@ -363,8 +363,8 @@ Arrayで指定する場合、[Integer width, Integer height]の形式で指定
363
363
 
364
364
  ## attr_reader :unpool_size
365
365
  Array
366
- 逆プーリングを行う横と縦の長さ。
367
- [Integer width, Integer height]の形式で取得します。
366
+ 逆プーリングを行う縦と横の長さ。
367
+ [Integer height, Integer width]の形式で取得します。
368
368
 
369
369
  ## 【Instance methods】
370
370
 
@@ -372,8 +372,8 @@ Array
372
372
  コンストラクタ。
373
373
  ### arguments
374
374
  * Integer unpool_size
375
- 逆プーリングを行う横と縦の長さ。
376
- Arrayで指定する場合、[Integer width, Integer height]の形式で指定します。
375
+ 逆プーリングを行う縦と横の長さ。
376
+ Arrayで指定する場合、[Integer height, Integer width]の形式で指定します。
377
377
 
378
378
 
379
379
  # class Flatten
@@ -1,6 +1,6 @@
1
1
  # LIB-APIリファレンス
2
2
  ruby-dnnの付属ライブラリのリファレンスです。
3
- 最終更新バージョン:0.3.0
3
+ 最終更新バージョン:0.4.0
4
4
 
5
5
 
6
6
  # dnn/lib/mnist
@@ -66,7 +66,7 @@ Array
66
66
  Array
67
67
  [イメージデータ, ラベルデータ]の形式で取得します。
68
68
  * イメージデータ
69
- UInt8の[10000, 3, 32, 32]の形式
69
+ UInt8の[10000, 32, 32, 3]の形式
70
70
  * テストデータ
71
71
  UInt8の[10000]の形式
72
72
 
@@ -78,20 +78,20 @@ Array
78
78
  # module ImageIO
79
79
 
80
80
  ## def self.read(file_name)
81
- 画像をXumo::UInt8形式で読み込みます。
81
+ 画像をUInt8形式で読み込みます。
82
82
  ### arguments
83
83
  * String file_name
84
84
  読み込む画像のファイル名。
85
85
  ### return
86
- Xumo::UInt8
87
- [width, height, rgb]のXumo::UInt8配列。
86
+ UInt8
87
+ [height, width, rgb]のUInt8配列。
88
88
 
89
- ## def self.write(file_name, nary, quality: 100)
90
- Xumo::UInt8形式の画像を書き込みます。
89
+ ## def self.write(file_name, img, quality: 100)
90
+ UInt8形式の画像を書き込みます。
91
91
  ### arguments
92
92
  * String file_name
93
93
  書き込む画像のファイル名。
94
- * Xumo::UInt8
95
- [width, height, rgb]のXumo::UInt8配列。
94
+ * UInt8 img
95
+ [height, width, rgb]のUInt8配列。
96
96
  * Integer quality: 100
97
97
  画像をJPEGで書き込む場合のクオリティ。
data/bin/console CHANGED
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env ruby
2
2
 
3
3
  require "bundler/setup"
4
- require "ruby/dnn"
4
+ require "dnn"
5
5
 
6
6
  # You can add fixtures and/or initialization code here to make experimenting
7
7
  # with your gem easier. You can also use a different console, if you like.
@@ -12,8 +12,8 @@ CIFAR10 = DNN::CIFAR10
12
12
  x_train, y_train = CIFAR10.load_train
13
13
  x_test, y_test = CIFAR10.load_test
14
14
 
15
- x_train = SFloat.cast(x_train).transpose(0, 2, 3, 1)
16
- x_test = SFloat.cast(x_test).transpose(0, 2, 3, 1)
15
+ x_train = SFloat.cast(x_train)
16
+ x_test = SFloat.cast(x_test)
17
17
 
18
18
  x_train /= 255
19
19
  x_test /= 255
@@ -43,6 +43,16 @@ model << Conv2D.new(32, 5, padding: true)
43
43
  model << BatchNormalization.new
44
44
  model << ReLU.new
45
45
 
46
+ model << MaxPool2D.new(2)
47
+
48
+ model << Conv2D.new(64, 5, padding: true)
49
+ model << BatchNormalization.new
50
+ model << ReLU.new
51
+
52
+ model << Conv2D.new(64, 5, padding: true)
53
+ model << BatchNormalization.new
54
+ model << ReLU.new
55
+
46
56
  model << Flatten.new
47
57
 
48
58
  model << Dense.new(512)
@@ -161,69 +161,69 @@ module DNN
161
161
  end
162
162
 
163
163
 
164
- #private module
165
- module Convert
164
+ #This module is used for convolution.
165
+ module Conv2DModule
166
166
  private
167
167
 
168
- def im2col(img, out_w, out_h, fil_w, fil_h, strides)
168
+ def im2col(img, out_h, out_w, fil_h, fil_w, strides)
169
169
  bsize = img.shape[0]
170
170
  ch = img.shape[3]
171
- col = SFloat.zeros(bsize, ch, fil_w, fil_h, out_w, out_h)
171
+ col = SFloat.zeros(bsize, ch, fil_h, fil_w, out_h, out_w)
172
172
  img = img.transpose(0, 3, 1, 2)
173
173
  (0...fil_h).each do |i|
174
- i_range = (i...(i + strides[1] * out_h)).step(strides[1]).to_a
174
+ i_range = (i...(i + strides[0] * out_h)).step(strides[0]).to_a
175
175
  (0...fil_w).each do |j|
176
- j_range = (j...(j + strides[0] * out_w)).step(strides[0]).to_a
177
- col[true, true, j, i, true, true] = img[true, true, j_range, i_range]
176
+ j_range = (j...(j + strides[1] * out_w)).step(strides[1]).to_a
177
+ col[true, true, i, j, true, true] = img[true, true, i_range, j_range]
178
178
  end
179
179
  end
180
- col.transpose(0, 4, 5, 2, 3, 1).reshape(bsize * out_w * out_h, fil_w * fil_h * ch)
180
+ col.transpose(0, 4, 5, 2, 3, 1).reshape(bsize * out_h * out_w, fil_h * fil_w * ch)
181
181
  end
182
182
 
183
- def col2im(col, img_shape, out_w, out_h, fil_w, fil_h, strides)
184
- bsize, img_w, img_h, ch = img_shape
185
- col = col.reshape(bsize, out_w, out_h, fil_w, fil_h, ch).transpose(0, 5, 3, 4, 1, 2)
186
- img = SFloat.zeros(bsize, ch, img_w, img_h)
183
+ def col2im(col, img_shape, out_h, out_w, fil_h, fil_w, strides)
184
+ bsize, img_h, img_w, ch = img_shape
185
+ col = col.reshape(bsize, out_h, out_w, fil_h, fil_w, ch).transpose(0, 5, 3, 4, 1, 2)
186
+ img = SFloat.zeros(bsize, ch, img_h, img_w)
187
187
  (0...fil_h).each do |i|
188
- i_range = (i...(i + strides[1] * out_h)).step(strides[1]).to_a
188
+ i_range = (i...(i + strides[0] * out_h)).step(strides[0]).to_a
189
189
  (0...fil_w).each do |j|
190
- j_range = (j...(j + strides[0] * out_w)).step(strides[0]).to_a
191
- img[true, true, j_range, i_range] += col[true, true, j, i, true, true]
190
+ j_range = (j...(j + strides[1] * out_w)).step(strides[1]).to_a
191
+ img[true, true, i_range, j_range] += col[true, true, i, j, true, true]
192
192
  end
193
193
  end
194
194
  img.transpose(0, 2, 3, 1)
195
195
  end
196
196
 
197
197
  def padding(img, pad)
198
- bsize, img_w, img_h, ch = img.shape
199
- img2 = SFloat.zeros(bsize, img_w + pad[0], img_h + pad[1], ch)
200
- i_begin = pad[1] / 2
198
+ bsize, img_h, img_w, ch = img.shape
199
+ img2 = SFloat.zeros(bsize, img_h + pad[0], img_w + pad[1], ch)
200
+ i_begin = pad[0] / 2
201
201
  i_end = i_begin + img_h
202
- j_begin = pad[0] / 2
202
+ j_begin = pad[1] / 2
203
203
  j_end = j_begin + img_w
204
- img2[true, j_begin...j_end, i_begin...i_end, true] = img
204
+ img2[true, i_begin...i_end, j_begin...j_end, true] = img
205
205
  img2
206
206
  end
207
207
 
208
208
  def back_padding(img, pad)
209
- i_begin = pad[1] / 2
210
- i_end = img.shape[2] - (pad[1] / 2.0).round
211
- j_begin = pad[0] / 2
212
- j_end = img.shape[1] - (pad[0] / 2.0).round
213
- img[true, j_begin...j_end, i_begin...i_end, true]
209
+ i_begin = pad[0] / 2
210
+ i_end = img.shape[1] - (pad[0] / 2.0).round
211
+ j_begin = pad[1] / 2
212
+ j_end = img.shape[2] - (pad[1] / 2.0).round
213
+ img[true, i_begin...i_end, j_begin...j_end, true]
214
214
  end
215
215
 
216
- def out_size(prev_w, prev_h, fil_w, fil_h, strides)
217
- out_w = (prev_w - fil_w) / strides[0] + 1
218
- out_h = (prev_h - fil_h) / strides[1] + 1
219
- [out_w, out_h]
216
+ def out_size(prev_h, prev_w, fil_h, fil_w, strides)
217
+ out_h = (prev_h - fil_h) / strides[0] + 1
218
+ out_w = (prev_w - fil_w) / strides[1] + 1
219
+ [out_h, out_w]
220
220
  end
221
221
  end
222
222
 
223
223
 
224
224
  class Conv2D < HasParamLayer
225
225
  include Initializers
226
- include Convert
226
+ include Conv2DModule
227
227
 
228
228
  attr_reader :num_filters
229
229
  attr_reader :filter_size
@@ -257,12 +257,12 @@ module DNN
257
257
 
258
258
  def build(model)
259
259
  super
260
- prev_w, prev_h = prev_layer.shape[0..1]
261
- @out_size = out_size(prev_w, prev_h, *@filter_size, @strides)
260
+ prev_h, prev_w = prev_layer.shape[0..1]
261
+ @out_size = out_size(prev_h, prev_w, *@filter_size, @strides)
262
262
  out_w, out_h = @out_size
263
263
  if @padding
264
- @pad = [prev_w - out_w, prev_h - out_h]
265
- @out_size = [prev_w, prev_h]
264
+ @pad = [prev_h - out_h, prev_w - out_w]
265
+ @out_size = [prev_h, prev_w]
266
266
  end
267
267
  end
268
268
 
@@ -317,7 +317,7 @@ module DNN
317
317
 
318
318
 
319
319
  class MaxPool2D < Layer
320
- include Convert
320
+ include Conv2DModule
321
321
 
322
322
  attr_reader :pool_size
323
323
  attr_reader :strides
@@ -341,11 +341,11 @@ module DNN
341
341
  super
342
342
  prev_w, prev_h = prev_layer.shape[0..1]
343
343
  @num_channel = prev_layer.shape[2]
344
- @out_size = out_size(prev_w, prev_h, *@pool_size, @strides)
344
+ @out_size = out_size(prev_h, prev_w, *@pool_size, @strides)
345
345
  out_w, out_h = @out_size
346
346
  if @padding
347
- @pad = [prev_w - out_w, prev_h - out_h]
348
- @out_size = [prev_w, prev_h]
347
+ @pad = [prev_h - out_h, prev_w - out_w]
348
+ @out_size = [prev_h, prev_w]
349
349
  end
350
350
  end
351
351
 
@@ -383,8 +383,6 @@ module DNN
383
383
 
384
384
 
385
385
  class UnPool2D < Layer
386
- include Convert
387
-
388
386
  attr_reader :unpool_size
389
387
 
390
388
  def initialize(unpool_size)
@@ -398,25 +396,25 @@ module DNN
398
396
 
399
397
  def build(model)
400
398
  super
401
- prev_w, prev_h = prev_layer.shape[0..1]
402
- unpool_w, unpool_h = @unpool_size
403
- out_w = prev_w * unpool_w
399
+ prev_h, prev_w = prev_layer.shape[0..1]
400
+ unpool_h, unpool_w = @unpool_size
404
401
  out_h = prev_h * unpool_h
405
- @out_size = [out_w, out_h]
402
+ out_w = prev_w * unpool_w
403
+ @out_size = [out_h, out_w]
406
404
  @num_channel = prev_layer.shape[2]
407
405
  end
408
406
 
409
407
  def forward(x)
410
408
  @x_shape = x.shape
411
- unpool_w, unpool_h = @unpool_size
412
- x2 = SFloat.zeros(x.shape[0], x.shape[1], unpool_w, x.shape[2], unpool_h, @num_channel)
409
+ unpool_h, unpool_w = @unpool_size
410
+ x2 = SFloat.zeros(x.shape[0], x.shape[1], unpool_h, x.shape[2], unpool_w, @num_channel)
413
411
  x2[true, true, 0, true, 0, true] = x
414
412
  x2.reshape(x.shape[0], *@out_size, x.shape[3])
415
413
  end
416
414
 
417
415
  def backward(dout)
418
- unpool_w, unpool_h = @unpool_size
419
- dout = dout.reshape(dout.shape[0], @x_shape[0], unpool_w, @x_shape[1], unpool_h, @num_channel)
416
+ unpool_h, unpool_w = @unpool_size
417
+ dout = dout.reshape(dout.shape[0], @x_shape[0], unpool_h, @x_shape[1], unpool_w, @num_channel)
420
418
  dout[true, true, 0, true, 0, true].clone
421
419
  end
422
420
 
@@ -17,7 +17,7 @@ VALUE mnist_load_images(VALUE self, VALUE rb_bin, VALUE rb_num_images, VALUE rb_
17
17
  VALUE rb_na;
18
18
  narray_data_t* na_data;
19
19
 
20
- sprintf(script, "Xumo::UInt8.zeros(%d, %d, %d)", num_images, cols, rows);
20
+ sprintf(script, "Numo::UInt8.zeros(%d, %d, %d)", num_images, cols, rows);
21
21
  rb_na = rb_eval_string(&script[0]);
22
22
  na_data = RNARRAY_DATA(rb_na);
23
23
 
@@ -35,7 +35,7 @@ VALUE mnist_load_labels(VALUE self, VALUE rb_bin, VALUE rb_num_labels) {
35
35
  VALUE rb_na;
36
36
  narray_data_t* na_data;
37
37
 
38
- sprintf(script, "Xumo::UInt8.zeros(%d)", num_labels);
38
+ sprintf(script, "Numo::UInt8.zeros(%d)", num_labels);
39
39
  rb_na = rb_eval_string(&script[0]);
40
40
  na_data = RNARRAY_DATA(rb_na);
41
41
 
@@ -57,13 +57,13 @@ VALUE cifar10_load(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
57
57
  int k = 0;
58
58
  int size = CIFAR10_WIDTH * CIFAR10_HEIGHT * CIFAR10_CHANNEL;
59
59
 
60
- sprintf(script, "Xumo::UInt8.zeros(%d, %d, %d, %d)", num_datas, CIFAR10_CHANNEL, CIFAR10_WIDTH, CIFAR10_HEIGHT);
60
+ sprintf(script, "Numo::UInt8.zeros(%d, %d, %d, %d)", num_datas, CIFAR10_CHANNEL, CIFAR10_WIDTH, CIFAR10_HEIGHT);
61
61
  rb_na_x = rb_eval_string(&script[0]);
62
62
  na_data_x = RNARRAY_DATA(rb_na_x);
63
63
  for(i = 0; i < 64; i++) {
64
64
  script[i] = 0;
65
65
  }
66
- sprintf(script, "Xumo::UInt8.zeros(%d)", num_datas);
66
+ sprintf(script, "Numo::UInt8.zeros(%d)", num_datas);
67
67
  rb_na_y = rb_eval_string(&script[0]);
68
68
  na_data_y = RNARRAY_DATA(rb_na_y);
69
69
 
@@ -0,0 +1,3 @@
1
+ require "mkmf"
2
+
3
+ create_makefile("rb_stb_image")
@@ -0,0 +1,99 @@
1
+ #include <ruby.h>
2
+ #include "numo/narray.h"
3
+
4
+ #define STB_IMAGE_IMPLEMENTATION
5
+ #define STB_IMAGE_WRITE_IMPLEMENTATION
6
+
7
+ #include "stb_image.h"
8
+ #include "stb_image_write.h"
9
+
10
+ //STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp);
11
+ VALUE rb_stbi_load(VALUE self, VALUE rb_filename, VALUE rb_req_comp) {
12
+ char* filename = StringValuePtr(rb_filename);
13
+ int x, y, n;
14
+ int req_comp = FIX2INT(rb_req_comp);
15
+ unsigned char* pixels;
16
+ narray_data_t* na_data;
17
+ char script[64];
18
+ int ch;
19
+ VALUE rb_x, rb_y, rb_n;
20
+ VALUE rb_pixels;
21
+
22
+ pixels = stbi_load(filename, &x, &y, &n, req_comp);
23
+ rb_x = INT2FIX(x);
24
+ rb_y = INT2FIX(y);
25
+ rb_n = INT2FIX(n);
26
+ ch = req_comp == 0 ? n : req_comp;
27
+ sprintf(script, "Numo::UInt8.zeros(%d, %d, %d)", y, x, ch);
28
+ rb_pixels = rb_eval_string(&script[0]);
29
+ na_data = RNARRAY_DATA(rb_pixels);
30
+ memcpy(na_data->ptr, pixels, na_data->base.size);
31
+ stbi_image_free(pixels);
32
+ return rb_ary_new3(4, rb_pixels, rb_x, rb_y, rb_n);
33
+ }
34
+
35
+ //STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
36
+ VALUE rb_stbi_write_png(VALUE self, VALUE rb_filename, VALUE rb_w, VALUE rb_h, VALUE rb_comp, VALUE rb_pixels, VALUE rb_stride_in_bytes) {
37
+ char* filename = StringValuePtr(rb_filename);
38
+ int w = FIX2INT(rb_w);
39
+ int h = FIX2INT(rb_h);
40
+ int comp = FIX2INT(rb_comp);
41
+ unsigned char* pixels;
42
+ int stride_in_bytes = FIX2INT(rb_stride_in_bytes);
43
+ narray_data_t* na_data;
44
+ int result;
45
+
46
+ na_data = RNARRAY_DATA(rb_pixels);
47
+ pixels = (unsigned char*)malloc(na_data->base.size);
48
+ memcpy(pixels, na_data->ptr, na_data->base.size);
49
+ result = stbi_write_png(filename, w, h, comp, pixels, stride_in_bytes);
50
+ stbi_image_free(pixels);
51
+ return INT2FIX(result);
52
+ }
53
+
54
+ //STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
55
+ VALUE rb_stbi_write_bmp(VALUE self, VALUE rb_filename, VALUE rb_w, VALUE rb_h, VALUE rb_comp, VALUE rb_pixels) {
56
+ char* filename = StringValuePtr(rb_filename);
57
+ int w = FIX2INT(rb_w);
58
+ int h = FIX2INT(rb_h);
59
+ int comp = FIX2INT(rb_comp);
60
+ unsigned char* pixels;
61
+ narray_data_t* na_data;
62
+ int result;
63
+
64
+ na_data = RNARRAY_DATA(rb_pixels);
65
+ pixels = (unsigned char*)malloc(na_data->base.size);
66
+ memcpy(pixels, na_data->ptr, na_data->base.size);
67
+ result = stbi_write_bmp(filename, w, h, comp, pixels);
68
+ stbi_image_free(pixels);
69
+ return INT2FIX(result);
70
+ }
71
+
72
+ //STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality);
73
+ VALUE rb_stbi_write_jpg(VALUE self, VALUE rb_filename, VALUE rb_w, VALUE rb_h, VALUE rb_comp, VALUE rb_pixels, VALUE rb_quality) {
74
+ char* filename = StringValuePtr(rb_filename);
75
+ int w = FIX2INT(rb_w);
76
+ int h = FIX2INT(rb_h);
77
+ int comp = FIX2INT(rb_comp);
78
+ unsigned char* pixels;
79
+ int quality = FIX2INT(rb_quality);
80
+ narray_data_t* na_data;
81
+ int result;
82
+
83
+ na_data = RNARRAY_DATA(rb_pixels);
84
+ pixels = (unsigned char*)malloc(na_data->base.size);
85
+ memcpy(pixels, na_data->ptr, na_data->base.size);
86
+ result = stbi_write_jpg(filename, w, h, comp, pixels, quality);
87
+ stbi_image_free(pixels);
88
+ return INT2FIX(result);
89
+ }
90
+
91
+ void Init_rb_stb_image() {
92
+ VALUE rb_dnn = rb_define_module("DNN");
93
+ VALUE rb_stb = rb_define_module_under(rb_dnn, "Stb");
94
+
95
+ rb_define_module_function(rb_stb, "stbi_load", rb_stbi_load, 2);
96
+ rb_define_module_function(rb_stb, "stbi_write_png", rb_stbi_write_png, 6);
97
+ rb_define_module_function(rb_stb, "stbi_write_bmp", rb_stbi_write_bmp, 5);
98
+ rb_define_module_function(rb_stb, "stbi_write_jpg", rb_stbi_write_jpg, 6);
99
+ }
File without changes
@@ -42,7 +42,9 @@ module DNN
42
42
  raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
43
43
  bin << File.binread(fname)
44
44
  end
45
- _cifar10_load(bin, 50000)
45
+ x_train, y_train = _cifar10_load(bin, 50000)
46
+ x_train = x_train.transpose(0, 2, 3, 1).clone
47
+ [x_train, y_train]
46
48
  end
47
49
 
48
50
  def self.load_test
@@ -50,7 +52,9 @@ module DNN
50
52
  fname = __dir__ + "/#{CIFAR10_DIR}/test_batch.bin"
51
53
  raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
52
54
  bin = File.binread(fname)
53
- _cifar10_load(bin, 10000)
55
+ x_test, y_test = _cifar10_load(bin, 10000)
56
+ x_test = x_test.transpose(0, 2, 3, 1).clone
57
+ [x_test, y_test]
54
58
  end
55
59
  end
56
60
  end
@@ -1,26 +1,30 @@
1
1
  require "numo/narray"
2
- require "dnn/ext/image_io/image_io_ext"
2
+ require "dnn/ext/rb_stb_image/rb_stb_image"
3
3
 
4
4
  module DNN
5
5
  module ImageIO
6
- private_class_method :_read
7
- private_class_method :_write_bmp
8
- private_class_method :_write_png
9
- private_class_method :_write_jpg
10
-
11
6
  def self.read(file_name)
12
7
  raise ImageIO::ReadError.new("#{file_name} is not found.") unless File.exist?(file_name)
13
- _read(file_name)
8
+ img, = Stb.stbi_load(file_name, 3)
9
+ img
14
10
  end
15
11
 
16
- def self.write(file_name, nary, quality: 100)
12
+ def self.write(file_name, img, quality: 100)
13
+ img = img.clone
14
+ if img.shape.length == 2
15
+ img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
16
+ elsif img.shape[2] == 1
17
+ img = img.shape(img.shape[0], img.shape[1])
18
+ img = Numo::UInt8[img, img, img].transpose(1, 2, 0).clone
19
+ end
17
20
  case file_name
18
21
  when /\.png$/
19
- _write_png(file_name, nary)
22
+ stride_in_bytes = img.shape[0] * img.shape[2]
23
+ Stb.stbi_write_png(file_name, *img.shape, img, stride_in_bytes)
20
24
  when /\.bmp$/
21
- _write_bmp(file_name, nary)
25
+ Stb.stbi_write_bmp(file_name, *img.shape, img)
22
26
  when /\.jpg$/
23
- _write_jpg(file_name, nary, quality)
27
+ Stb.stbi_write_jpg(file_name, *img.shape, img, quality)
24
28
  end
25
29
  rescue => ex
26
30
  raise ImageIO::WriteError.new(ex.message)
data/lib/dnn/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.3.2"
2
+ VERSION = "0.4.0"
3
3
  end
data/ruby-dnn.gemspec CHANGED
@@ -13,7 +13,7 @@ Gem::Specification.new do |spec|
13
13
  spec.description = %q{ruby-dnn is a ruby deep learning library.}
14
14
  spec.homepage = "https://github.com/unagiootoro/ruby-dnn.git"
15
15
  spec.license = "MIT"
16
- spec.extensions = ["lib/dnn/ext/dataset_loader/extconf.rb", "lib/dnn/ext/image_io/extconf.rb"]
16
+ spec.extensions = ["lib/dnn/ext/dataset_loader/extconf.rb", "lib/dnn/ext/rb_stb_image/extconf.rb"]
17
17
 
18
18
  spec.add_dependency "numo-narray"
19
19
  spec.add_dependency "archive-tar-minitar"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.2
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-07-20 00:00:00.000000000 Z
11
+ date: 2018-07-21 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -86,7 +86,7 @@ email:
86
86
  executables: []
87
87
  extensions:
88
88
  - lib/dnn/ext/dataset_loader/extconf.rb
89
- - lib/dnn/ext/image_io/extconf.rb
89
+ - lib/dnn/ext/rb_stb_image/extconf.rb
90
90
  extra_rdoc_files: []
91
91
  files:
92
92
  - ".gitignore"
@@ -143,39 +143,39 @@ files:
143
143
  - lib/dnn/ext/dataset_loader/numo/types/uint8.h
144
144
  - lib/dnn/ext/dataset_loader/numo/types/uint_macro.h
145
145
  - lib/dnn/ext/dataset_loader/numo/types/xint_macro.h
146
- - lib/dnn/ext/image_io/extconf.rb
147
- - lib/dnn/ext/image_io/image_io_ext.c
148
- - lib/dnn/ext/image_io/numo/compat.h
149
- - lib/dnn/ext/image_io/numo/extconf.h
150
- - lib/dnn/ext/image_io/numo/intern.h
151
- - lib/dnn/ext/image_io/numo/narray.h
152
- - lib/dnn/ext/image_io/numo/ndloop.h
153
- - lib/dnn/ext/image_io/numo/template.h
154
- - lib/dnn/ext/image_io/numo/types/bit.h
155
- - lib/dnn/ext/image_io/numo/types/complex.h
156
- - lib/dnn/ext/image_io/numo/types/complex_macro.h
157
- - lib/dnn/ext/image_io/numo/types/dcomplex.h
158
- - lib/dnn/ext/image_io/numo/types/dfloat.h
159
- - lib/dnn/ext/image_io/numo/types/float_def.h
160
- - lib/dnn/ext/image_io/numo/types/float_macro.h
161
- - lib/dnn/ext/image_io/numo/types/int16.h
162
- - lib/dnn/ext/image_io/numo/types/int32.h
163
- - lib/dnn/ext/image_io/numo/types/int64.h
164
- - lib/dnn/ext/image_io/numo/types/int8.h
165
- - lib/dnn/ext/image_io/numo/types/int_macro.h
166
- - lib/dnn/ext/image_io/numo/types/real_accum.h
167
- - lib/dnn/ext/image_io/numo/types/robj_macro.h
168
- - lib/dnn/ext/image_io/numo/types/robject.h
169
- - lib/dnn/ext/image_io/numo/types/scomplex.h
170
- - lib/dnn/ext/image_io/numo/types/sfloat.h
171
- - lib/dnn/ext/image_io/numo/types/uint16.h
172
- - lib/dnn/ext/image_io/numo/types/uint32.h
173
- - lib/dnn/ext/image_io/numo/types/uint64.h
174
- - lib/dnn/ext/image_io/numo/types/uint8.h
175
- - lib/dnn/ext/image_io/numo/types/uint_macro.h
176
- - lib/dnn/ext/image_io/numo/types/xint_macro.h
177
- - lib/dnn/ext/image_io/stb_image.h
178
- - lib/dnn/ext/image_io/stb_image_write.h
146
+ - lib/dnn/ext/rb_stb_image/extconf.rb
147
+ - lib/dnn/ext/rb_stb_image/numo/compat.h
148
+ - lib/dnn/ext/rb_stb_image/numo/extconf.h
149
+ - lib/dnn/ext/rb_stb_image/numo/intern.h
150
+ - lib/dnn/ext/rb_stb_image/numo/narray.h
151
+ - lib/dnn/ext/rb_stb_image/numo/ndloop.h
152
+ - lib/dnn/ext/rb_stb_image/numo/template.h
153
+ - lib/dnn/ext/rb_stb_image/numo/types/bit.h
154
+ - lib/dnn/ext/rb_stb_image/numo/types/complex.h
155
+ - lib/dnn/ext/rb_stb_image/numo/types/complex_macro.h
156
+ - lib/dnn/ext/rb_stb_image/numo/types/dcomplex.h
157
+ - lib/dnn/ext/rb_stb_image/numo/types/dfloat.h
158
+ - lib/dnn/ext/rb_stb_image/numo/types/float_def.h
159
+ - lib/dnn/ext/rb_stb_image/numo/types/float_macro.h
160
+ - lib/dnn/ext/rb_stb_image/numo/types/int16.h
161
+ - lib/dnn/ext/rb_stb_image/numo/types/int32.h
162
+ - lib/dnn/ext/rb_stb_image/numo/types/int64.h
163
+ - lib/dnn/ext/rb_stb_image/numo/types/int8.h
164
+ - lib/dnn/ext/rb_stb_image/numo/types/int_macro.h
165
+ - lib/dnn/ext/rb_stb_image/numo/types/real_accum.h
166
+ - lib/dnn/ext/rb_stb_image/numo/types/robj_macro.h
167
+ - lib/dnn/ext/rb_stb_image/numo/types/robject.h
168
+ - lib/dnn/ext/rb_stb_image/numo/types/scomplex.h
169
+ - lib/dnn/ext/rb_stb_image/numo/types/sfloat.h
170
+ - lib/dnn/ext/rb_stb_image/numo/types/uint16.h
171
+ - lib/dnn/ext/rb_stb_image/numo/types/uint32.h
172
+ - lib/dnn/ext/rb_stb_image/numo/types/uint64.h
173
+ - lib/dnn/ext/rb_stb_image/numo/types/uint8.h
174
+ - lib/dnn/ext/rb_stb_image/numo/types/uint_macro.h
175
+ - lib/dnn/ext/rb_stb_image/numo/types/xint_macro.h
176
+ - lib/dnn/ext/rb_stb_image/rb_stb_image.c
177
+ - lib/dnn/ext/rb_stb_image/stb_image.h
178
+ - lib/dnn/ext/rb_stb_image/stb_image_write.h
179
179
  - lib/dnn/lib/cifar10.rb
180
180
  - lib/dnn/lib/image_io.rb
181
181
  - lib/dnn/lib/mnist.rb
@@ -1,3 +0,0 @@
1
- require "mkmf"
2
-
3
- create_makefile("image_io_ext")
@@ -1,89 +0,0 @@
1
- #include <ruby.h>
2
- #include "numo/narray.h"
3
-
4
- #define STB_IMAGE_IMPLEMENTATION
5
- #define STB_IMAGE_WRITE_IMPLEMENTATION
6
-
7
- #include "stb_image.h"
8
- #include "stb_image_write.h"
9
-
10
- VALUE image_io_read(VALUE self, VALUE rb_file_name) {
11
- char* file_name = StringValuePtr(rb_file_name);
12
- int width;
13
- int height;
14
- int bpp;
15
- unsigned char* pixels;
16
- char script[64];
17
- VALUE rb_na;
18
- narray_data_t* na_data;
19
- pixels = stbi_load(file_name, &width, &height, &bpp, 3);
20
- sprintf(script, "Xumo::UInt8.zeros(%d, %d, 3)", width, height);
21
- rb_na = rb_eval_string((char*)script);
22
- na_data = RNARRAY_DATA(rb_na);
23
- memcpy(na_data->ptr, pixels, na_data->base.size);
24
- stbi_image_free(pixels);
25
- return rb_na;
26
- }
27
-
28
- VALUE image_io_write_png(VALUE self, VALUE rb_file_name, VALUE rb_na) {
29
- char* file_name = StringValuePtr(rb_file_name);
30
- int width;
31
- int height;
32
- int bpp = 3;
33
- unsigned char* pixels;
34
- narray_data_t* na_data;
35
- na_data = RNARRAY_DATA(rb_na);
36
- pixels = (unsigned char*)malloc(na_data->base.size);
37
- memcpy(pixels, na_data->ptr, na_data->base.size);
38
- width = na_data->base.shape[0];
39
- height = na_data->base.shape[1];
40
- stbi_write_png(file_name, width, height, bpp, pixels, width * bpp);
41
- stbi_image_free(pixels);
42
- return Qnil;
43
- }
44
-
45
- VALUE image_io_write_bmp(VALUE self, VALUE rb_file_name, VALUE rb_na) {
46
- char* file_name = StringValuePtr(rb_file_name);
47
- int width;
48
- int height;
49
- int bpp = 3;
50
- unsigned char* pixels;
51
- narray_data_t* na_data;
52
- na_data = RNARRAY_DATA(rb_na);
53
- pixels = (unsigned char*)malloc(na_data->base.size);
54
- memcpy(pixels, na_data->ptr, na_data->base.size);
55
- width = na_data->base.shape[0];
56
- height = na_data->base.shape[1];
57
- stbi_write_bmp(file_name, width, height, bpp, pixels);
58
- stbi_image_free(pixels);
59
- return Qnil;
60
- }
61
-
62
- VALUE image_io_write_jpg(VALUE self, VALUE rb_file_name, VALUE rb_na, VALUE rb_quality) {
63
- char* file_name = StringValuePtr(rb_file_name);
64
- int width;
65
- int height;
66
- int bpp = 3;
67
- int quality = FIX2INT(rb_quality);
68
- unsigned char* pixels;
69
- narray_data_t* na_data;
70
- na_data = RNARRAY_DATA(rb_na);
71
- pixels = (unsigned char*)malloc(na_data->base.size);
72
- memcpy(pixels, na_data->ptr, na_data->base.size);
73
- width = na_data->base.shape[0];
74
- height = na_data->base.shape[1];
75
- stbi_write_jpg(file_name, width, height, bpp, pixels, quality);
76
- stbi_image_free(pixels);
77
- return Qnil;
78
- }
79
-
80
- void Init_image_io_ext() {
81
- VALUE rb_dnn;
82
- VALUE rb_image_io;
83
- rb_dnn = rb_define_module("DNN");
84
- rb_image_io = rb_define_module_under(rb_dnn, "ImageIO");
85
- rb_define_singleton_method(rb_image_io, "_read", image_io_read, 1);
86
- rb_define_singleton_method(rb_image_io, "_write_png", image_io_write_bmp, 2);
87
- rb_define_singleton_method(rb_image_io, "_write_bmp", image_io_write_png, 2);
88
- rb_define_singleton_method(rb_image_io, "_write_jpg", image_io_write_jpg, 3);
89
- }