cumo 0.4.3 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +3 -0
  3. data/.rubocop.yml +15 -0
  4. data/.rubocop_todo.yml +1272 -0
  5. data/3rd_party/mkmf-cu/Gemfile +2 -0
  6. data/3rd_party/mkmf-cu/Rakefile +2 -1
  7. data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +2 -0
  8. data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +36 -7
  9. data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +51 -45
  10. data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +2 -0
  11. data/3rd_party/mkmf-cu/mkmf-cu.gemspec +3 -1
  12. data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +5 -3
  13. data/CHANGELOG.md +69 -0
  14. data/Gemfile +6 -1
  15. data/README.md +2 -10
  16. data/Rakefile +8 -11
  17. data/bench/broadcast_fp32.rb +28 -26
  18. data/bench/cumo_bench.rb +18 -16
  19. data/bench/numo_bench.rb +18 -16
  20. data/bench/reduction_fp32.rb +14 -12
  21. data/bin/console +1 -0
  22. data/cumo.gemspec +5 -8
  23. data/ext/cumo/cuda/cudnn.c +2 -2
  24. data/ext/cumo/cumo.c +7 -3
  25. data/ext/cumo/depend.erb +15 -13
  26. data/ext/cumo/extconf.rb +32 -46
  27. data/ext/cumo/include/cumo/cuda/cudnn.h +3 -1
  28. data/ext/cumo/include/cumo/intern.h +1 -0
  29. data/ext/cumo/include/cumo/narray.h +13 -1
  30. data/ext/cumo/include/cumo/template.h +2 -4
  31. data/ext/cumo/include/cumo/types/complex_macro.h +1 -1
  32. data/ext/cumo/include/cumo/types/float_macro.h +2 -2
  33. data/ext/cumo/include/cumo/types/xint_macro.h +3 -2
  34. data/ext/cumo/include/cumo.h +2 -2
  35. data/ext/cumo/narray/array.c +3 -3
  36. data/ext/cumo/narray/data.c +23 -2
  37. data/ext/cumo/narray/gen/cogen.rb +8 -7
  38. data/ext/cumo/narray/gen/cogen_kernel.rb +8 -7
  39. data/ext/cumo/narray/gen/def/bit.rb +3 -1
  40. data/ext/cumo/narray/gen/def/dcomplex.rb +2 -0
  41. data/ext/cumo/narray/gen/def/dfloat.rb +2 -0
  42. data/ext/cumo/narray/gen/def/int16.rb +2 -0
  43. data/ext/cumo/narray/gen/def/int32.rb +2 -0
  44. data/ext/cumo/narray/gen/def/int64.rb +2 -0
  45. data/ext/cumo/narray/gen/def/int8.rb +2 -0
  46. data/ext/cumo/narray/gen/def/robject.rb +2 -0
  47. data/ext/cumo/narray/gen/def/scomplex.rb +2 -0
  48. data/ext/cumo/narray/gen/def/sfloat.rb +2 -0
  49. data/ext/cumo/narray/gen/def/uint16.rb +2 -0
  50. data/ext/cumo/narray/gen/def/uint32.rb +2 -0
  51. data/ext/cumo/narray/gen/def/uint64.rb +2 -0
  52. data/ext/cumo/narray/gen/def/uint8.rb +2 -0
  53. data/ext/cumo/narray/gen/erbln.rb +9 -7
  54. data/ext/cumo/narray/gen/erbpp2.rb +26 -24
  55. data/ext/cumo/narray/gen/narray_def.rb +13 -11
  56. data/ext/cumo/narray/gen/spec.rb +58 -55
  57. data/ext/cumo/narray/gen/tmpl/alloc_func.c +1 -1
  58. data/ext/cumo/narray/gen/tmpl/at.c +34 -0
  59. data/ext/cumo/narray/gen/tmpl/batch_norm.c +1 -1
  60. data/ext/cumo/narray/gen/tmpl/batch_norm_backward.c +2 -2
  61. data/ext/cumo/narray/gen/tmpl/conv.c +1 -1
  62. data/ext/cumo/narray/gen/tmpl/conv_grad_w.c +3 -1
  63. data/ext/cumo/narray/gen/tmpl/conv_transpose.c +1 -1
  64. data/ext/cumo/narray/gen/tmpl/fixed_batch_norm.c +1 -1
  65. data/ext/cumo/narray/gen/tmpl/init_class.c +1 -0
  66. data/ext/cumo/narray/gen/tmpl/pooling_backward.c +1 -1
  67. data/ext/cumo/narray/gen/tmpl/pooling_forward.c +1 -1
  68. data/ext/cumo/narray/gen/tmpl/qsort.c +1 -5
  69. data/ext/cumo/narray/gen/tmpl/sort.c +1 -1
  70. data/ext/cumo/narray/gen/tmpl_bit/binary.c +42 -14
  71. data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +5 -0
  72. data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +5 -0
  73. data/ext/cumo/narray/gen/tmpl_bit/mask.c +27 -7
  74. data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +21 -7
  75. data/ext/cumo/narray/gen/tmpl_bit/unary.c +21 -7
  76. data/ext/cumo/narray/index.c +243 -39
  77. data/ext/cumo/narray/index_kernel.cu +84 -0
  78. data/ext/cumo/narray/narray.c +38 -1
  79. data/ext/cumo/narray/ndloop.c +1 -1
  80. data/ext/cumo/narray/struct.c +1 -1
  81. data/lib/cumo/cuda/compile_error.rb +1 -1
  82. data/lib/cumo/cuda/compiler.rb +23 -22
  83. data/lib/cumo/cuda/cudnn.rb +1 -1
  84. data/lib/cumo/cuda/device.rb +1 -1
  85. data/lib/cumo/cuda/link_state.rb +2 -2
  86. data/lib/cumo/cuda/module.rb +1 -2
  87. data/lib/cumo/cuda/nvrtc_program.rb +3 -2
  88. data/lib/cumo/cuda.rb +2 -0
  89. data/lib/cumo/linalg.rb +2 -0
  90. data/lib/cumo/narray/extra.rb +137 -185
  91. data/lib/cumo/narray.rb +2 -0
  92. data/lib/cumo.rb +3 -1
  93. data/test/bit_test.rb +157 -0
  94. data/test/cuda/compiler_test.rb +69 -0
  95. data/test/cuda/device_test.rb +30 -0
  96. data/test/cuda/memory_pool_test.rb +45 -0
  97. data/test/cuda/nvrtc_test.rb +51 -0
  98. data/test/cuda/runtime_test.rb +28 -0
  99. data/test/cudnn_test.rb +498 -0
  100. data/test/cumo_test.rb +27 -0
  101. data/test/narray_test.rb +745 -0
  102. data/test/ractor_test.rb +52 -0
  103. data/test/test_helper.rb +31 -0
  104. metadata +31 -54
  105. data/.travis.yml +0 -5
  106. data/numo-narray-version +0 -1
@@ -0,0 +1,498 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test_helper"
4
+
5
+ class CUDNNTest < Test::Unit::TestCase
6
+ float_types = [
7
+ Cumo::SFloat,
8
+ Cumo::DFloat,
9
+ ]
10
+
11
+ if ENV['DTYPE']
12
+ float_types.select! { |type| type.to_s.downcase.include?(ENV['DTYPE'].downcase) }
13
+ end
14
+
15
+ float_types.each do |dtype|
16
+ sub_test_case "conv_2d" do
17
+ setup do
18
+ @batch_size = 2
19
+ @in_channels = 3
20
+ @out_channels = 2
21
+ @in_dims = [10, 7]
22
+ @kernel_size = [2, 3]
23
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
24
+ @w_shape = [@out_channels, @in_channels].concat(@kernel_size)
25
+ @b_shape = [@out_channels]
26
+ @x = dtype.ones(*@x_shape)
27
+ @w = dtype.ones(*@w_shape)
28
+ @b = dtype.ones(*@b_shape) * 2
29
+ end
30
+
31
+ test "x.conv(w) #{dtype}" do
32
+ y = @x.conv(@w)
33
+ assert { y.shape == [@batch_size, @out_channels, 9, 5] }
34
+ assert y.to_a.flatten.all? { |e| e.to_i == 18 }
35
+ end
36
+
37
+ test "x.conv(w, b) #{dtype}" do
38
+ y = @x.conv(@w, b: @b)
39
+ assert { y.shape == [@batch_size, @out_channels, 9, 5] }
40
+ assert y.to_a.flatten.all? { |e| e.to_i == 20 }
41
+ assert { @b.shape == @b_shape }
42
+ end
43
+
44
+ test "x.conv(w, b, stride=int, pad=int) #{dtype}" do
45
+ y = @x.conv(@w, b: @b, stride: 2, pad: 2)
46
+ assert { y.shape == [@batch_size, @out_channels, 7, 5] }
47
+ assert y.to_a.flatten.all? { |e| [20, 2, 8].include?(e.to_i) }
48
+ assert { @b.shape == @b_shape }
49
+ end
50
+
51
+ test "x.conv(w, b, stride=array, pad=array) #{dtype}" do
52
+ y = @x.conv(@w, b: @b, stride: [3, 2], pad: [2, 0])
53
+ assert { y.shape == [@batch_size, @out_channels, 5, 3] }
54
+ assert y.to_a.flatten.all? { |e| e.to_i == 20 || e.to_i == 2 }
55
+ assert { @b.shape == @b_shape }
56
+ end
57
+ end
58
+
59
+ sub_test_case "conv_nd" do
60
+ setup do
61
+ @batch_size = 2
62
+ @in_channels = 3
63
+ @out_channels = 2
64
+ @in_dims = [4, 3, 2]
65
+ @kernel_size = [2, 3, 1]
66
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
67
+ @w_shape = [@out_channels, @in_channels].concat(@kernel_size)
68
+ @b_shape = [@out_channels]
69
+ @x = dtype.ones(*@x_shape)
70
+ @w = dtype.ones(*@w_shape)
71
+ @b = dtype.ones(*@b_shape) * 2
72
+ end
73
+
74
+ test "x.conv(w) #{dtype}" do
75
+ y = @x.conv(@w)
76
+ assert { y.shape == [@batch_size, @out_channels, 3, 1, 2] }
77
+ assert y.to_a.flatten.all? { |e| e.to_i == 18 }
78
+ assert { @b.shape == @b_shape }
79
+ end
80
+
81
+ test "x.conv(w, b) #{dtype}" do
82
+ y = @x.conv(@w, b: @b)
83
+ assert { y.shape == [@batch_size, @out_channels, 3, 1, 2] }
84
+ assert y.to_a.flatten.all? { |e| e.to_i == 20 }
85
+ assert { @b.shape == @b_shape }
86
+ end
87
+
88
+ test "x.conv(w, b, stride, pad) #{dtype}" do
89
+ y = @x.conv(@w, b: @b, stride: [3, 2, 1], pad: [2, 1, 0])
90
+ assert { y.shape == [@batch_size, @out_channels, 3, 2, 2] }
91
+ assert y.to_a.flatten.all? { |e| e.to_i == 14 || e.to_i == 2 }
92
+ assert { @b.shape == @b_shape }
93
+ end
94
+ end
95
+
96
+ sub_test_case "conv_transpose_2d" do
97
+ setup do
98
+ @batch_size = 2
99
+ @in_channels = 3
100
+ @out_channels = 2
101
+ @in_dims = [5, 3]
102
+ @kernel_size = [2, 3]
103
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
104
+ @w_shape = [@in_channels, @out_channels].concat(@kernel_size)
105
+ @b_shape = [@out_channels]
106
+ @x = dtype.ones(*@x_shape)
107
+ @w = dtype.ones(*@w_shape)
108
+ @b = dtype.ones(*@b_shape) * 2
109
+ end
110
+
111
+ test "x.conv_transpose(w) #{dtype}" do
112
+ y = @x.conv_transpose(@w)
113
+ assert { y.shape == [@batch_size, @out_channels, 6, 5] }
114
+ end
115
+
116
+ test "x.conv_transpose(w, b) #{dtype}" do
117
+ y = @x.conv_transpose(@w, b: @b)
118
+ assert { y.shape == [@batch_size, @out_channels, 6, 5] }
119
+ y_no_bias = @x.conv_transpose(@w)
120
+ assert { y == y_no_bias + 2 }
121
+ assert { @b.shape == @b_shape }
122
+ end
123
+
124
+ test "x.conv_transpose(w, b, stride=int, pad=int) #{dtype}" do
125
+ y = @x.conv_transpose(@w, b: @b, stride: 2, pad: 2)
126
+ assert { y.shape == [@batch_size, @out_channels, 6, 3] }
127
+ assert y.to_a.flatten.all? { |e| e.to_i == 8 || e.to_i == 5 }
128
+ assert { @b.shape == @b_shape }
129
+ end
130
+
131
+ test "x.conv_transpose(w, b, stride=array, pad=array) #{dtype}" do
132
+ y = @x.conv_transpose(@w, b: @b, stride: [3, 2], pad: [2, 0])
133
+ assert { y.shape == [@batch_size, @out_channels, 10, 7] }
134
+ assert y.to_a.flatten.all? { |e| [8, 5, 2].include?(e.to_i) }
135
+ assert { @b.shape == @b_shape }
136
+ end
137
+ end
138
+
139
+ sub_test_case "conv_transpose_nd" do
140
+ setup do
141
+ @batch_size = 2
142
+ @in_channels = 3
143
+ @out_channels = 2
144
+ @in_dims = [4, 3, 2]
145
+ @kernel_size = [2, 3, 1]
146
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
147
+ @w_shape = [@in_channels, @out_channels].concat(@kernel_size)
148
+ @b_shape = [@out_channels]
149
+ @x = dtype.ones(*@x_shape)
150
+ @w = dtype.ones(*@w_shape)
151
+ @b = dtype.ones(*@b_shape) * 2
152
+ end
153
+
154
+ test "x.conv_transpose(w) #{dtype}" do
155
+ y = @x.conv_transpose(@w)
156
+ assert { y.shape == [@batch_size, @out_channels, 5, 5, 2] }
157
+ assert y.to_a.flatten.all? { |e| [3, 6, 9, 12, 18].include?(e.to_i) }
158
+ assert { @b.shape == @b_shape }
159
+ end
160
+
161
+ test "x.conv_transpose(w, b) #{dtype}" do
162
+ y = @x.conv_transpose(@w, b: @b)
163
+ assert { y.shape == [@batch_size, @out_channels, 5, 5, 2] }
164
+ y_no_bias = @x.conv_transpose(@w)
165
+ assert { y == y_no_bias + 2 }
166
+ assert { @b.shape == @b_shape }
167
+ end
168
+
169
+ test "x.conv_transpose(w, b, stride, pad) #{dtype}" do
170
+ y = @x.conv_transpose(@w, b: @b, stride: [3, 2, 1], pad: [2, 1, 0])
171
+ assert { y.shape == [@batch_size, @out_channels, 7, 5, 2] }
172
+ assert y.to_a.flatten.all? { |e| [2, 5, 8].include?(e.to_i) }
173
+ assert { @b.shape == @b_shape }
174
+ end
175
+ end
176
+
177
+ sub_test_case "conv_grad_w_2d" do
178
+ setup do
179
+ @batch_size = 2
180
+ @in_channels = 3
181
+ @out_channels = 2
182
+ @in_dims = [10, 7]
183
+ @kernel_size = [2, 3]
184
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
185
+ @w_shape = [@out_channels, @in_channels].concat(@kernel_size)
186
+ @y_shape = [@batch_size, @out_channels, 9, 5]
187
+ @x = dtype.ones(*@x_shape)
188
+ @w = dtype.ones(*@w_shape)
189
+ @dy = dtype.ones(*@y_shape)
190
+ end
191
+
192
+ test "x.conv_grad_w(w) #{dtype}" do
193
+ dw = @x.conv_grad_w(@dy, @w_shape)
194
+ assert { dw.shape == @w_shape }
195
+ # TODO: assert values
196
+ end
197
+ end
198
+
199
+ sub_test_case "conv_grad_w_nd" do
200
+ setup do
201
+ @batch_size = 2
202
+ @in_channels = 3
203
+ @out_channels = 2
204
+ @in_dims = [4, 3, 2]
205
+ @kernel_size = [2, 3, 1]
206
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
207
+ @w_shape = [@out_channels, @in_channels].concat(@kernel_size)
208
+ @y_shape = [@batch_size, @out_channels, 3, 1, 2]
209
+ @x = dtype.ones(*@x_shape)
210
+ @w = dtype.ones(*@w_shape)
211
+ @dy = dtype.ones(*@y_shape)
212
+ end
213
+
214
+ test "x.conv_grad_w(w) #{dtype}" do
215
+ dw = @x.conv_grad_w(@dy, @w_shape)
216
+ assert { dw.shape == @w_shape }
217
+ # TODO: assert values
218
+ end
219
+ end
220
+
221
+ sub_test_case "conv_grad_w_2d" do
222
+ setup do
223
+ @batch_size = 2
224
+ @in_channels = 3
225
+ @out_channels = 2
226
+ @in_dims = [10, 7]
227
+ @kernel_size = [2, 3]
228
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
229
+ @w_shape = [@out_channels, @in_channels].concat(@kernel_size)
230
+ @y_shape = [@batch_size, @out_channels, 9, 5]
231
+ @x = dtype.ones(*@x_shape)
232
+ @w = dtype.ones(*@w_shape)
233
+ @dy = dtype.ones(*@y_shape)
234
+ end
235
+
236
+ test "x.conv_grad_w(w) #{dtype}" do
237
+ dw = @x.conv_grad_w(@dy, @w_shape)
238
+ assert { dw.shape == @w_shape }
239
+ # TODO: assert values
240
+ end
241
+ end
242
+
243
+ sub_test_case "batch_norm" do
244
+ setup do
245
+ @batch_size = 2
246
+ @in_channels = 3
247
+ @in_dims = [5, 3]
248
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
249
+ @reduced_shape = [1].concat(@x_shape[1..-1])
250
+ @x = dtype.ones(*@x_shape) * 3
251
+ @gamma = dtype.ones(*@reduced_shape) * 2
252
+ @beta = dtype.ones(*@reduced_shape)
253
+ end
254
+
255
+ test "x.batch_norm(gamma, beta) #{dtype}" do
256
+ y = @x.batch_norm(@gamma, @beta)
257
+ assert { y.shape == @x_shape }
258
+ assert { y == dtype.ones(*@x_shape) }
259
+ end
260
+
261
+ test "x.batch_norm(gamma, beta, axis: [0]) #{dtype}" do
262
+ assert { @x.batch_norm(@gamma, @beta) == @x.batch_norm(@gamma, @beta, axis: [0]) }
263
+ end
264
+
265
+ test "x.batch_norm(gamma, beta, axis: [0, 2, 3]) #{dtype}" do
266
+ reduced_shape = [1, @x_shape[1], 1, 1]
267
+ gamma = dtype.ones(reduced_shape) * 2
268
+ beta = dtype.ones(reduced_shape)
269
+ y = @x.batch_norm(gamma, beta, axis: [0, 2, 3])
270
+ assert { y.shape == @x_shape }
271
+ end
272
+
273
+ test "x.batch_norm(gamma, beta, running_mean, running_var) #{dtype}" do
274
+ running_mean = dtype.ones(*@reduced_shape)
275
+ running_var = dtype.ones(*@reduced_shape)
276
+ y = @x.batch_norm(@gamma, @beta, running_mean: running_mean, running_var: running_var)
277
+ assert { y.shape == @x_shape }
278
+ assert { y == dtype.ones(*@x_shape) }
279
+ end
280
+
281
+ test "x.batch_norm(gamma, beta, mean, inv_std) #{dtype}" do
282
+ mean = dtype.new(*@reduced_shape)
283
+ inv_std = dtype.new(*@reduced_shape)
284
+ y = @x.batch_norm(@gamma, @beta, mean: mean, inv_std: inv_std)
285
+ assert { y.shape == @x_shape }
286
+ assert { mean.shape == @reduced_shape }
287
+ assert { inv_std.shape == @reduced_shape }
288
+ end
289
+ end
290
+
291
+ sub_test_case "batch_norm_backward" do
292
+ setup do
293
+ @batch_size = 2
294
+ @in_channels = 3
295
+ @in_dims = [5, 3]
296
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
297
+ @reduced_shape = [1].concat(@x_shape[1..-1])
298
+ @x = dtype.ones(*@x_shape) * 3
299
+ @gamma = dtype.ones(*@reduced_shape) * 2
300
+ @beta = dtype.ones(*@reduced_shape)
301
+ @gy = dtype.ones(*@x_shape)
302
+ end
303
+
304
+ test "x.batch_norm_backward(gamma, gy) #{dtype}" do
305
+ @x.batch_norm(@gamma, @beta)
306
+ gx, ggamma, gbeta = @x.batch_norm_backward(@gamma, @gy)
307
+ assert { gx.shape == @x_shape }
308
+ assert { ggamma.shape == @reduced_shape }
309
+ assert { gbeta.shape == @reduced_shape }
310
+ end
311
+
312
+ test "x.batch_norm_backward(gamma, gy, axis: [0,2,3]) #{dtype}" do
313
+ @reduced_shape = [1, @x_shape[1], 1, 1]
314
+ @gamma = dtype.ones(@reduced_shape) * 2
315
+ @beta = dtype.ones(@reduced_shape)
316
+ @x.batch_norm(@gamma, @beta, axis: [0, 2, 3])
317
+ gx, ggamma, gbeta = @x.batch_norm_backward(@gamma, @gy, axis: [0, 2, 3])
318
+ assert { gx.shape == @x_shape }
319
+ assert { ggamma.shape == @reduced_shape }
320
+ assert { gbeta.shape == @reduced_shape }
321
+ end
322
+
323
+ test "x.batch_norm_backward(gamma, gy, mean:, inv_std:) #{dtype}" do
324
+ mean = dtype.new(*@reduced_shape)
325
+ inv_std = dtype.new(*@reduced_shape)
326
+ @x.batch_norm(@gamma, @beta, mean: mean, inv_std: inv_std)
327
+ gx, ggamma, gbeta = @x.batch_norm_backward(@gamma, @gy, mean: mean, inv_std: inv_std)
328
+ assert { gx.shape == @x_shape }
329
+ assert { ggamma.shape == @reduced_shape }
330
+ assert { gbeta.shape == @reduced_shape }
331
+ end
332
+ end
333
+
334
+ sub_test_case "fixed_batch_norm" do
335
+ setup do
336
+ @batch_size = 2
337
+ @in_channels = 3
338
+ @in_dims = [5, 3]
339
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
340
+ @reduced_shape = [1].concat(@x_shape[1..-1])
341
+ @x = dtype.ones(*@x_shape) * 3
342
+ @gamma = dtype.ones(*@reduced_shape) * 2
343
+ @beta = dtype.ones(*@reduced_shape)
344
+ @mean = dtype.ones(*@reduced_shape)
345
+ @var = dtype.ones(*@reduced_shape)
346
+ end
347
+
348
+ test "x.fixed_batch_norm(gamma, beta, mean, var) #{dtype}" do
349
+ y = @x.fixed_batch_norm(@gamma, @beta, @mean, @var)
350
+ assert { y.shape == @x_shape }
351
+ # TODO: check output values
352
+ end
353
+
354
+ test "x.fixed_batch_norm(gamma, beta, mean, var, axis: [0]) #{dtype}" do
355
+ assert { @x.fixed_batch_norm(@gamma, @beta, @mean, @var) == @x.fixed_batch_norm(@gamma, @beta, @mean, @var, axis: [0]) }
356
+ end
357
+
358
+ test "x.fixed_batch_norm(gamma, beta, mean, var, axis: [0, 2, 3]) #{dtype}" do
359
+ reduced_shape = [1, @x_shape[1], 1, 1]
360
+ gamma = dtype.ones(reduced_shape) * 2
361
+ beta = dtype.ones(reduced_shape)
362
+ mean = dtype.ones(reduced_shape) * 2
363
+ var = dtype.ones(reduced_shape)
364
+ y = @x.fixed_batch_norm(gamma, beta, mean, var, axis: [0, 2, 3])
365
+ assert { y.shape == @x_shape }
366
+ # TODO: check output values
367
+ end
368
+ end
369
+
370
+ sub_test_case "max_pool" do
371
+ setup do
372
+ @batch_size = 2
373
+ @in_channels = 3
374
+ @in_dims = [5, 3]
375
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
376
+ @ksize = [3] * @in_dims.size
377
+ @x = dtype.ones(*@x_shape) * 3
378
+ end
379
+
380
+ test "x.max_pool(ksize) #{dtype}" do
381
+ y = @x.max_pool(@ksize)
382
+ assert { y.shape == [@batch_size, @in_channels, 1, 1] }
383
+ assert y.to_a.flatten.all? { |e| e.to_i == 3 }
384
+ end
385
+
386
+ test "x.max_pool(ksize, stride:, pad:) #{dtype}" do
387
+ stride = [2] * @in_dims.size
388
+ pad = [1] * @in_dims.size
389
+ y = @x.max_pool(@ksize, stride: stride, pad: pad)
390
+ assert { y.shape == [@batch_size, @in_channels, 3, 2] }
391
+ assert y.to_a.flatten.all? { |e| e.to_i == 3 }
392
+ end
393
+ end
394
+
395
+ sub_test_case "avg_pool(pad_value: nil)" do
396
+ setup do
397
+ @batch_size = 2
398
+ @in_channels = 3
399
+ @in_dims = [5, 3]
400
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
401
+ @ksize = [3] * @in_dims.size
402
+ @x = dtype.ones(*@x_shape) * 3
403
+ end
404
+
405
+ test "x.avg_pool(ksize) #{dtype}" do
406
+ y = @x.avg_pool(@ksize)
407
+ assert { y.shape == [@batch_size, @in_channels, 1, 1] }
408
+ assert y.to_a.flatten.all? { |e| e.to_i == 3 }
409
+ end
410
+
411
+ test "x.avg_pool(ksize, stride:, pad:) #{dtype}" do
412
+ stride = [2] * @in_dims.size
413
+ pad = [1] * @in_dims.size
414
+ y = @x.avg_pool(@ksize, stride: stride, pad: pad)
415
+ assert { y.shape == [@batch_size, @in_channels, 3, 2] }
416
+ # TODO: assert values
417
+ end
418
+ end
419
+
420
+ sub_test_case "avg_pool(pad_value: 0)" do
421
+ setup do
422
+ @batch_size = 2
423
+ @in_channels = 3
424
+ @in_dims = [5, 3]
425
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
426
+ @ksize = [3] * @in_dims.size
427
+ @x = dtype.ones(*@x_shape) * 3
428
+ end
429
+
430
+ test "x.avg_pool(ksize, stride:, pad:) #{dtype}" do
431
+ stride = [2] * @in_dims.size
432
+ pad = [1] * @in_dims.size
433
+ y_pad_0 = @x.avg_pool(@ksize, pad_value: 0, stride: stride, pad: pad)
434
+ y_pad_nil = @x.avg_pool(@ksize, pad_value: nil, stride: stride, pad: pad)
435
+ assert { y_pad_0.shape == y_pad_nil.shape }
436
+ assert { y_pad_0 != y_pad_nil }
437
+ end
438
+ end
439
+
440
+ sub_test_case "max_pool_backward" do
441
+ setup do
442
+ @batch_size = 2
443
+ @in_channels = 3
444
+ @in_dims = [5, 3]
445
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
446
+ @ksize = [3] * @in_dims.size
447
+ @x = dtype.ones(*@x_shape) * 3
448
+ end
449
+
450
+ test "x.max_pool_backward(ksize) #{dtype}" do
451
+ y = @x.max_pool(@ksize)
452
+ gy = dtype.ones(*y.shape)
453
+ gx = @x.max_pool_backward(y, gy, @ksize)
454
+ assert { gx.shape == @x.shape }
455
+ # TODO: assert values
456
+ end
457
+
458
+ test "x.max_pool_backward(ksize, stride:, pad:) #{dtype}" do
459
+ stride = [2] * @in_dims.size
460
+ pad = [1] * @in_dims.size
461
+ y = @x.max_pool(@ksize, stride: stride, pad: pad)
462
+ gy = dtype.ones(*y.shape)
463
+ gx = @x.max_pool_backward(y, gy, @ksize, stride: stride, pad: pad)
464
+ assert { gx.shape == @x.shape }
465
+ # TODO: assert values
466
+ end
467
+ end
468
+
469
+ sub_test_case "avg_pool_backward" do
470
+ setup do
471
+ @batch_size = 2
472
+ @in_channels = 3
473
+ @in_dims = [5, 3]
474
+ @x_shape = [@batch_size, @in_channels].concat(@in_dims)
475
+ @ksize = [3] * @in_dims.size
476
+ @x = dtype.ones(*@x_shape) * 3
477
+ end
478
+
479
+ test "x.avg_pool_backward(ksize) #{dtype}" do
480
+ y = @x.avg_pool(@ksize)
481
+ gy = dtype.ones(*y.shape)
482
+ gx = @x.avg_pool_backward(y, gy, @ksize)
483
+ assert { gx.shape == @x.shape }
484
+ # TODO: assert values
485
+ end
486
+
487
+ test "x.avg_pool_backward(ksize, stride:, pad:) #{dtype}" do
488
+ stride = [2] * @in_dims.size
489
+ pad = [1] * @in_dims.size
490
+ y = @x.avg_pool(@ksize, stride: stride, pad: pad)
491
+ gy = dtype.ones(*y.shape)
492
+ gx = @x.avg_pool_backward(y, gy, @ksize, stride: stride, pad: pad)
493
+ assert { gx.shape == @x.shape }
494
+ # TODO: assert values
495
+ end
496
+ end
497
+ end
498
+ end
data/test/cumo_test.rb ADDED
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test_helper"
4
+
5
+ class CumoTest < Test::Unit::TestCase
6
+ def setup
7
+ @orig_compatible_mode = Cumo.compatible_mode_enabled?
8
+ end
9
+
10
+ def teardown
11
+ @orig_compatible_mode ? Cumo.enable_compatible_mode : Cumo.disable_compatible_mode
12
+ end
13
+
14
+ def test_enable_compatible_mode
15
+ Cumo.enable_compatible_mode
16
+ assert { Cumo.compatible_mode_enabled? }
17
+ end
18
+
19
+ def test_disable_compatible_mode
20
+ Cumo.disable_compatible_mode
21
+ assert { !Cumo.compatible_mode_enabled? }
22
+ end
23
+
24
+ def test_version
25
+ assert_nothing_raised { Cumo::VERSION }
26
+ end
27
+ end