cumo 0.4.3 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +3 -0
  3. data/.rubocop.yml +15 -0
  4. data/.rubocop_todo.yml +1272 -0
  5. data/3rd_party/mkmf-cu/Gemfile +2 -0
  6. data/3rd_party/mkmf-cu/Rakefile +2 -1
  7. data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +2 -0
  8. data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +36 -7
  9. data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +51 -45
  10. data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +2 -0
  11. data/3rd_party/mkmf-cu/mkmf-cu.gemspec +3 -1
  12. data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +5 -3
  13. data/CHANGELOG.md +69 -0
  14. data/Gemfile +6 -1
  15. data/README.md +2 -10
  16. data/Rakefile +8 -11
  17. data/bench/broadcast_fp32.rb +28 -26
  18. data/bench/cumo_bench.rb +18 -16
  19. data/bench/numo_bench.rb +18 -16
  20. data/bench/reduction_fp32.rb +14 -12
  21. data/bin/console +1 -0
  22. data/cumo.gemspec +5 -8
  23. data/ext/cumo/cuda/cudnn.c +2 -2
  24. data/ext/cumo/cumo.c +7 -3
  25. data/ext/cumo/depend.erb +15 -13
  26. data/ext/cumo/extconf.rb +32 -46
  27. data/ext/cumo/include/cumo/cuda/cudnn.h +3 -1
  28. data/ext/cumo/include/cumo/intern.h +1 -0
  29. data/ext/cumo/include/cumo/narray.h +13 -1
  30. data/ext/cumo/include/cumo/template.h +2 -4
  31. data/ext/cumo/include/cumo/types/complex_macro.h +1 -1
  32. data/ext/cumo/include/cumo/types/float_macro.h +2 -2
  33. data/ext/cumo/include/cumo/types/xint_macro.h +3 -2
  34. data/ext/cumo/include/cumo.h +2 -2
  35. data/ext/cumo/narray/array.c +3 -3
  36. data/ext/cumo/narray/data.c +23 -2
  37. data/ext/cumo/narray/gen/cogen.rb +8 -7
  38. data/ext/cumo/narray/gen/cogen_kernel.rb +8 -7
  39. data/ext/cumo/narray/gen/def/bit.rb +3 -1
  40. data/ext/cumo/narray/gen/def/dcomplex.rb +2 -0
  41. data/ext/cumo/narray/gen/def/dfloat.rb +2 -0
  42. data/ext/cumo/narray/gen/def/int16.rb +2 -0
  43. data/ext/cumo/narray/gen/def/int32.rb +2 -0
  44. data/ext/cumo/narray/gen/def/int64.rb +2 -0
  45. data/ext/cumo/narray/gen/def/int8.rb +2 -0
  46. data/ext/cumo/narray/gen/def/robject.rb +2 -0
  47. data/ext/cumo/narray/gen/def/scomplex.rb +2 -0
  48. data/ext/cumo/narray/gen/def/sfloat.rb +2 -0
  49. data/ext/cumo/narray/gen/def/uint16.rb +2 -0
  50. data/ext/cumo/narray/gen/def/uint32.rb +2 -0
  51. data/ext/cumo/narray/gen/def/uint64.rb +2 -0
  52. data/ext/cumo/narray/gen/def/uint8.rb +2 -0
  53. data/ext/cumo/narray/gen/erbln.rb +9 -7
  54. data/ext/cumo/narray/gen/erbpp2.rb +26 -24
  55. data/ext/cumo/narray/gen/narray_def.rb +13 -11
  56. data/ext/cumo/narray/gen/spec.rb +58 -55
  57. data/ext/cumo/narray/gen/tmpl/alloc_func.c +1 -1
  58. data/ext/cumo/narray/gen/tmpl/at.c +34 -0
  59. data/ext/cumo/narray/gen/tmpl/batch_norm.c +1 -1
  60. data/ext/cumo/narray/gen/tmpl/batch_norm_backward.c +2 -2
  61. data/ext/cumo/narray/gen/tmpl/conv.c +1 -1
  62. data/ext/cumo/narray/gen/tmpl/conv_grad_w.c +3 -1
  63. data/ext/cumo/narray/gen/tmpl/conv_transpose.c +1 -1
  64. data/ext/cumo/narray/gen/tmpl/fixed_batch_norm.c +1 -1
  65. data/ext/cumo/narray/gen/tmpl/init_class.c +1 -0
  66. data/ext/cumo/narray/gen/tmpl/pooling_backward.c +1 -1
  67. data/ext/cumo/narray/gen/tmpl/pooling_forward.c +1 -1
  68. data/ext/cumo/narray/gen/tmpl/qsort.c +1 -5
  69. data/ext/cumo/narray/gen/tmpl/sort.c +1 -1
  70. data/ext/cumo/narray/gen/tmpl_bit/binary.c +42 -14
  71. data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +5 -0
  72. data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +5 -0
  73. data/ext/cumo/narray/gen/tmpl_bit/mask.c +27 -7
  74. data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +21 -7
  75. data/ext/cumo/narray/gen/tmpl_bit/unary.c +21 -7
  76. data/ext/cumo/narray/index.c +243 -39
  77. data/ext/cumo/narray/index_kernel.cu +84 -0
  78. data/ext/cumo/narray/narray.c +38 -1
  79. data/ext/cumo/narray/ndloop.c +1 -1
  80. data/ext/cumo/narray/struct.c +1 -1
  81. data/lib/cumo/cuda/compile_error.rb +1 -1
  82. data/lib/cumo/cuda/compiler.rb +23 -22
  83. data/lib/cumo/cuda/cudnn.rb +1 -1
  84. data/lib/cumo/cuda/device.rb +1 -1
  85. data/lib/cumo/cuda/link_state.rb +2 -2
  86. data/lib/cumo/cuda/module.rb +1 -2
  87. data/lib/cumo/cuda/nvrtc_program.rb +3 -2
  88. data/lib/cumo/cuda.rb +2 -0
  89. data/lib/cumo/linalg.rb +2 -0
  90. data/lib/cumo/narray/extra.rb +137 -185
  91. data/lib/cumo/narray.rb +2 -0
  92. data/lib/cumo.rb +3 -1
  93. data/test/bit_test.rb +157 -0
  94. data/test/cuda/compiler_test.rb +69 -0
  95. data/test/cuda/device_test.rb +30 -0
  96. data/test/cuda/memory_pool_test.rb +45 -0
  97. data/test/cuda/nvrtc_test.rb +51 -0
  98. data/test/cuda/runtime_test.rb +28 -0
  99. data/test/cudnn_test.rb +498 -0
  100. data/test/cumo_test.rb +27 -0
  101. data/test/narray_test.rb +745 -0
  102. data/test/ractor_test.rb +52 -0
  103. data/test/test_helper.rb +31 -0
  104. metadata +31 -54
  105. data/.travis.yml +0 -5
  106. data/numo-narray-version +0 -1
@@ -0,0 +1,745 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test_helper"
4
+
5
+ class NArrayTest < Test::Unit::TestCase
6
+ types = [
7
+ Cumo::DFloat,
8
+ Cumo::SFloat,
9
+ Cumo::DComplex,
10
+ Cumo::SComplex,
11
+ Cumo::Int64,
12
+ Cumo::Int32,
13
+ Cumo::Int16,
14
+ Cumo::Int8,
15
+ Cumo::UInt64,
16
+ Cumo::UInt32,
17
+ Cumo::UInt16,
18
+ Cumo::UInt8,
19
+ ]
20
+ float_types = [
21
+ Cumo::DFloat,
22
+ Cumo::DComplex,
23
+ ]
24
+
25
+ if ENV['DTYPE']
26
+ types.select! { |type| type.to_s.downcase.include?(ENV['DTYPE'].downcase) }
27
+ float_types.select! { |type| type.to_s.downcase.include?(ENV['DTYPE'].downcase) }
28
+ end
29
+
30
+ def setup
31
+ Cumo::NArray.srand(0)
32
+ end
33
+
34
+ types.each do |dtype|
35
+ test dtype do
36
+ assert { dtype < Cumo::NArray }
37
+ end
38
+
39
+ test "#{dtype}[]" do
40
+ a = dtype[]
41
+
42
+ assert_raise(Cumo::NArray::ShapeError) { a[true] }
43
+ assert_raise(Cumo::NArray::ShapeError) { a[1..-1] }
44
+
45
+ assert { a.size == 0 }
46
+ assert { a.ndim == 1 }
47
+ assert { a.shape == [0] }
48
+ assert { !a.inplace? }
49
+ assert { a.row_major? }
50
+ assert { !a.column_major? }
51
+ assert { a.host_order? }
52
+ assert { !a.byte_swapped? }
53
+ assert { a == [] }
54
+ assert { a.to_a == [] }
55
+ assert { a.to_a.is_a?(Array) }
56
+ assert { a.dup == a }
57
+ assert { a.clone == a }
58
+ assert { a.dup.object_id != a.object_id }
59
+ assert { a.clone.object_id != a.object_id }
60
+ end
61
+
62
+ types.each do |other_dtype|
63
+ next if dtype == other_dtype
64
+
65
+ test "#{dtype}[] == #{other_dtype}[]" do
66
+ assert { dtype[] == other_dtype[] }
67
+ end
68
+ end
69
+
70
+ test "#{dtype},free" do
71
+ a = dtype[1, 2, 3, 5, 7, 11]
72
+ assert { a.free }
73
+ assert { !a.free } # return false if already freed
74
+ end
75
+
76
+ procs = [
77
+ [proc { |tp, a| tp[*a] }, ""],
78
+ [proc { |tp, a| tp[*a][true] }, "[true]"],
79
+ [proc { |tp, a| tp[*a][0..-1] }, "[0..-1]"]
80
+ ]
81
+ procs.each do |init, ref|
82
+
83
+ test "#{dtype},[1,2,3,5,7,11]#{ref}" do
84
+ src = [1, 2, 3, 5, 7, 11]
85
+ a = init.call(dtype, src)
86
+
87
+ assert { a.is_a?(dtype) }
88
+ assert { a.size == 6 }
89
+ assert { a.ndim == 1 }
90
+ assert { a.shape == [6] }
91
+ assert { !a.inplace? }
92
+ assert { a.row_major? }
93
+ assert { !a.column_major? }
94
+ assert { a.host_order? }
95
+ assert { !a.byte_swapped? }
96
+ assert { a == [1, 2, 3, 5, 7, 11] }
97
+ assert { a.to_a == [1, 2, 3, 5, 7, 11] }
98
+ assert { a.to_a.is_a?(Array) }
99
+ assert { a.dup == a }
100
+ assert { a.clone == a }
101
+ assert { a.dup.object_id != a.object_id }
102
+ assert { a.clone.object_id != a.object_id }
103
+
104
+ assert { a.eq([1, 1, 3, 3, 7, 7]) == [1, 0, 1, 0, 1, 0] }
105
+ assert { a[3..4] == [5, 7] }
106
+ assert { a[5] == 11 }
107
+ assert { a[5].size == 1 }
108
+ assert { a[-1] == 11 }
109
+
110
+ assert { a.at([3, 4]) == [5, 7] }
111
+ assert { a.view.at([3, 4]) == [5, 7] }
112
+ assert { a[2..-1].at([1, 2]) == [5, 7] }
113
+ assert { a.at(Cumo::Int32.cast([3, 4])) == [5, 7] }
114
+ assert { a.view.at(Cumo::Int32.cast([3, 4])) == [5, 7] }
115
+ assert { a.at(3..4) == [5, 7] }
116
+ assert { a.view.at(3..4) == [5, 7] }
117
+ assert { a.at([5]) == [11] }
118
+ assert { a.view.at([5]) == [11] }
119
+ assert { a.at([-1]) == [11] }
120
+ assert { a.view.at([-1]) == [11] }
121
+
122
+ assert { a[(0..-1).each] == [1, 2, 3, 5, 7, 11] }
123
+ assert { a[(0...-1).each] == [1, 2, 3, 5, 7] }
124
+
125
+ if Enumerator.const_defined?(:ArithmeticSequence)
126
+ assert { a[0.step(-1)] == [1, 2, 3, 5, 7, 11] }
127
+ assert { a[0.step(4)] == [1, 2, 3, 5, 7] }
128
+ assert { a[-5.step(-1)] == [2, 3, 5, 7, 11] }
129
+ assert { a[0.step(-1, 2)] == [1, 3, 7] }
130
+ assert { a[0.step(4, 2)] == [1, 3, 7] }
131
+ assert { a[-5.step(-1, 2)] == [2, 5, 11] }
132
+
133
+ assert { a[0.step] == [1, 2, 3, 5, 7, 11] }
134
+ assert { a[-5.step] == [2, 3, 5, 7, 11] }
135
+ assert { eval('a[(0..).step(2)]') == [1, 3, 7] }
136
+ assert { eval('a[(0...).step(2)]') == [1, 3, 7] }
137
+ assert { eval('a[(-5..).step(2)]') == [2, 5, 11] }
138
+ assert { eval('a[(-5...).step(2)]') == [2, 5, 11] }
139
+ assert { eval('a[(0..) % 2]') == [1, 3, 7] }
140
+ assert { eval('a[(0...) % 2]') == [1, 3, 7] }
141
+ assert { eval('a[(-5..) % 2]') == [2, 5, 11] }
142
+ assert { eval('a[(-5...) % 2]') == [2, 5, 11] }
143
+ end
144
+
145
+ assert { a[(0..-1).step(2)] == [1, 3, 7] }
146
+ assert { a[(0...-1).step(2)] == [1, 3, 7] }
147
+ assert { a[(0..4).step(2)] == [1, 3, 7] }
148
+ assert { a[(0...4).step(2)] == [1, 3] }
149
+ assert { a[(-5..-1).step(2)] == [2, 5, 11] }
150
+ assert { a[(-5...-1).step(2)] == [2, 5] }
151
+ assert { a[(0..-1) % 2] == [1, 3, 7] }
152
+ assert { a[(0...-1) % 2] == [1, 3, 7] }
153
+ assert { a[(0..4) % 2] == [1, 3, 7] }
154
+ assert { a[(0...4) % 2] == [1, 3] }
155
+ assert { a[(-5..-1) % 2] == [2, 5, 11] }
156
+ assert { a[(-5...-1) % 2] == [2, 5] }
157
+ assert { a[[4, 3, 0, 1, 5, 2]] == [7, 5, 1, 2, 11, 3] }
158
+ assert { a.reverse == [11, 7, 5, 3, 2, 1] }
159
+ assert { a.sum == 29 }
160
+ if float_types.include?(dtype)
161
+ assert { a.mean == 29.0 / 6 }
162
+ assert { a.var == 13.766666666666666 }
163
+ assert { a.stddev == 3.7103458958251676 }
164
+ assert { a.rms == 5.901977069875258 }
165
+ end
166
+ assert { a.dup.fill(12) == [12] * 6 }
167
+ assert { (a + 1) == [2, 3, 4, 6, 8, 12] }
168
+ assert { (a - 1) == [0, 1, 2, 4, 6, 10] }
169
+ assert { (a * 3) == [3, 6, 9, 15, 21, 33] }
170
+ assert { (a / 0.5) == [2, 4, 6, 10, 14, 22] }
171
+ assert { (-a) == [-1, -2, -3, -5, -7, -11] }
172
+ assert { (a**2) == [1, 4, 9, 25, 49, 121] }
173
+ assert { a.swap_byte.swap_byte == [1, 2, 3, 5, 7, 11] }
174
+
175
+ assert { a.contiguous? }
176
+ assert { a.transpose.contiguous? }
177
+
178
+ if dtype == Cumo::DComplex || dtype == Cumo::SComplex
179
+ assert { a.real == src }
180
+ assert { a.imag == [0] * 6 }
181
+ assert { a.conj == src }
182
+ assert { a.angle == [0] * 6 }
183
+ else
184
+ assert { a.min == 1 }
185
+ assert { a.max == 11 }
186
+ assert { a.min_index == 0 }
187
+ assert { a.max_index == 5 }
188
+ assert { (a >= 3) == [0, 0, 1, 1, 1, 1] }
189
+ assert { (a > 3) == [0, 0, 0, 1, 1, 1] }
190
+ assert { (a <= 3) == [1, 1, 1, 0, 0, 0] }
191
+ assert { (a < 3) == [1, 1, 0, 0, 0, 0] }
192
+ assert { (a.eq 3) == [0, 0, 1, 0, 0, 0] }
193
+ assert { a.sort == src }
194
+ assert { a.sort_index == (0..5).to_a }
195
+ assert { a.median == 4 }
196
+ assert { dtype.maximum(a, 12 - a) == [11, 10, 9, 7, 7, 11] }
197
+ assert { dtype.minimum(a, 12 - a) == [1, 2, 3, 5, 5, 1] }
198
+ assert { dtype.maximum(a, 5) == [5, 5, 5, 5, 7, 11] }
199
+ assert { dtype.minimum(a, 5) == [1, 2, 3, 5, 5, 5] }
200
+ end
201
+ end
202
+ end
203
+
204
+ test "#{dtype},[1..4]" do
205
+ assert { dtype[1..4] == [1, 2, 3, 4] }
206
+ end
207
+
208
+ test "#{dtype},[-4..-1]" do
209
+ assert { dtype[-4..-1] == [-4, -3, -2, -1] }
210
+ end
211
+
212
+ if Enumerator.const_defined?(:ArithmeticSequence)
213
+ test "#{dtype},[1.step(4)]" do
214
+ assert { dtype[1.step(4)] == [1, 2, 3, 4] }
215
+ end
216
+
217
+ test "#{dtype},[-4.step(-1)]" do
218
+ assert { dtype[-4.step(-1)] == [-4, -3, -2, -1] }
219
+ end
220
+
221
+ test "#{dtype},[1.step(4, 2)]" do
222
+ assert { dtype[1.step(4, 2)] == [1, 3] }
223
+ end
224
+
225
+ test "#{dtype},[-4.step(-1, 2)]" do
226
+ assert { dtype[-4.step(-1, 2)] == [-4, -2] }
227
+ end
228
+
229
+ test "#{dtype},[(-4..-1).step(2)]" do
230
+ assert { dtype[(-4..-1).step(2)] == [-4, -2] }
231
+ end
232
+ end
233
+
234
+ test "#{dtype},[(1..4) % 2]" do
235
+ assert { dtype[(1..4) % 2] == [1, 3] }
236
+ end
237
+
238
+ test "#{dtype},[(-4..-1) % 2]" do
239
+ assert { dtype[(-4..-1) % 2] == [-4, -2] }
240
+ end
241
+
242
+ #test "#{dtype}.seq(5)" do
243
+ # assert { dtype.seq(5) == [0,1,2,3,4] }
244
+ #end
245
+
246
+ procs2 = [
247
+ [proc { |tp, src| tp[*src] }, ""],
248
+ [proc { |tp, src| tp[*src][true, true] }, "[true,true]"],
249
+ [proc { |tp, src| tp[*src][0..-1, 0..-1] }, "[0..-1,0..-1]"]
250
+ ]
251
+
252
+ procs2.each do |init, ref|
253
+
254
+ test "#{dtype},[[1,2,3],[5,7,11]]#{ref}" do
255
+ src = [[1, 2, 3], [5, 7, 11]]
256
+ a = init.call(dtype, src)
257
+
258
+ assert { a.is_a?(dtype) }
259
+ assert { a.size == 6 }
260
+ assert { a.ndim == 2 }
261
+ assert { a.shape == [2, 3] }
262
+ assert { !a.inplace? }
263
+ assert { a.row_major? }
264
+ assert { !a.column_major? }
265
+ assert { a.host_order? }
266
+ assert { !a.byte_swapped? }
267
+ assert { a == src }
268
+ assert { a.to_a == src }
269
+ assert { a.to_a.is_a?(Array) }
270
+
271
+ assert { a.eq([[1, 1, 3], [3, 7, 7]]) == [[1, 0, 1], [0, 1, 0]] }
272
+ assert { a[5] == 11 }
273
+ assert { a[-1] == 11 }
274
+ assert { a[1, 0] == src[1][0] }
275
+ assert { a[1, 1] == src[1][1] }
276
+ assert { a[1, 2] == src[1][2] }
277
+ assert { a[3..4] == [5, 7] }
278
+ assert { a[0, 1..2] == [2, 3] }
279
+
280
+ assert { a.at([0, 1], [1, 2]) == [2, 11] }
281
+ assert { a.view.at([0, 1], [1, 2]) == [2, 11] }
282
+ assert { a.at([0, 1], (0..2) % 2) == [1, 11] }
283
+ assert { a.view.at([0, 1], (0..2) % 2) == [1, 11] }
284
+ assert { a.at((0..1) % 1, [0, 2]) == [1, 11] }
285
+ assert { a.view.at((0..1) % 1, [0, 2]) == [1, 11] }
286
+ assert { a.at(Cumo::Int32.cast([0, 1]), Cumo::Int32.cast([1, 2])) == [2, 11] }
287
+ assert { a.view.at(Cumo::Int32.cast([0, 1]), Cumo::Int32.cast([1, 2])) == [2, 11] }
288
+ assert { a[[0, 1], [0, 2]].at([0, 1], [0, 1]) == [1, 11] }
289
+ assert { a[[0, 1], (0..2) % 2].at([0, 1], [0, 1]) == [1, 11] }
290
+ assert { a[(0..1) % 1, [0, 2]].at([0, 1], [0, 1]) == [1, 11] }
291
+ assert { a[(0..1) % 1, (0..2) % 2].at([0, 1], [0, 1]) == [1, 11] }
292
+
293
+ assert { a[0, :*] == src[0] }
294
+ assert { a[1, :*] == src[1] }
295
+ assert { a[:*, 1] == [src[0][1], src[1][1]] }
296
+ assert { a[true, [2, 0, 1]] == [[3, 1, 2], [11, 5, 7]] }
297
+ assert { a.reshape(3, 2) == [[1, 2], [3, 5], [7, 11]] }
298
+ assert { a.reshape(3, nil) == [[1, 2], [3, 5], [7, 11]] }
299
+ assert { a.reshape(nil, 2) == [[1, 2], [3, 5], [7, 11]] }
300
+ assert { a.transpose == [[1, 5], [2, 7], [3, 11]] }
301
+ assert { a.transpose(1, 0) == [[1, 5], [2, 7], [3, 11]] }
302
+ assert { a.triu == [[1, 2, 3], [0, 7, 11]] }
303
+ assert { a.tril == [[1, 0, 0], [5, 7, 0]] }
304
+ assert { a.reverse == [[11, 7, 5], [3, 2, 1]] }
305
+ assert { a.reverse(0, 1) == [[11, 7, 5], [3, 2, 1]] }
306
+ assert { a.reverse(1, 0) == [[11, 7, 5], [3, 2, 1]] }
307
+ assert { a.reverse(0) == [[5, 7, 11], [1, 2, 3]] }
308
+ assert { a.reverse(1) == [[3, 2, 1], [11, 7, 5]] }
309
+
310
+ assert { a.sum == 29 }
311
+ assert { a.sum(0) == [6, 9, 14] }
312
+ assert { a.sum(1) == [6, 23] }
313
+ assert { a.prod == 2310 }
314
+ assert { a.prod(0) == [5, 14, 33] }
315
+ assert { a.prod(1) == [6, 385] }
316
+ if float_types.include?(dtype)
317
+ assert { a.mean == 29.0 / 6 }
318
+ assert { a.mean(0) == [3, 4.5, 7] }
319
+ assert { a.mean(1) == [2, 23.0 / 3] }
320
+ end
321
+
322
+ assert { a.contiguous? }
323
+ assert { a.reshape(3, 2).contiguous? }
324
+ assert { a[true, 1..2].contiguous? == false }
325
+ assert { a.transpose.contiguous? == false }
326
+ assert { a.fortran_contiguous? == false }
327
+ assert { a.transpose.fortran_contiguous? }
328
+ assert { a.transpose.transpose.fortran_contiguous? == false }
329
+ assert { a.reshape(3, 2).fortran_contiguous? == false }
330
+ assert { a.reshape(3, 2).transpose.fortran_contiguous? }
331
+ assert { a[true, 1..2].fortran_contiguous? == false }
332
+ assert { a[true, 1..2].transpose.fortran_contiguous? == false }
333
+
334
+ if dtype == Cumo::DComplex || dtype == Cumo::SComplex
335
+ assert { a.real == src }
336
+ assert { a.imag == [[0] * 3] * 2 }
337
+ assert { a.conj == src }
338
+ assert { a.angle == [[0] * 3] * 2 }
339
+ else
340
+ assert { a.min == 1 }
341
+ assert { a.max == 11 }
342
+ assert { a.min_index == 0 }
343
+ assert { a.min_index(axis: 1) == [0, 3] }
344
+ assert { a.min_index(axis: 0) == [0, 1, 2] }
345
+ assert { a.max_index(axis: 1) == [2, 5] }
346
+ assert { a.max_index(axis: 0) == [3, 4, 5] }
347
+ assert { (a >= 3) == [[0, 0, 1], [1, 1, 1]] }
348
+ assert { (a > 3) == [[0, 0, 0], [1, 1, 1]] }
349
+ assert { (a <= 3) == [[1, 1, 1], [0, 0, 0]] }
350
+ assert { (a < 3) == [[1, 1, 0], [0, 0, 0]] }
351
+ assert { (a.eq 3) == [[0, 0, 1], [0, 0, 0]] }
352
+ assert { a.sort == src }
353
+ assert { a.sort_index == [[0, 1, 2], [3, 4, 5]] }
354
+ end
355
+ assert { a.dup.fill(12) == [[12] * 3] * 2 }
356
+ assert { (a + 1) == [[2, 3, 4], [6, 8, 12]] }
357
+ assert { (a + [1, 2, 3]) == [[2, 4, 6], [6, 9, 14]] }
358
+ assert { (a - 1) == [[0, 1, 2], [4, 6, 10]] }
359
+ assert { (a - [1, 2, 3]) == [[0, 0, 0], [4, 5, 8]] }
360
+ assert { (a * 3) == [[3, 6, 9], [15, 21, 33]] }
361
+ assert { (a * [1, 2, 3]) == [[1, 4, 9], [5, 14, 33]] }
362
+ assert { (a / 0.5) == [[2, 4, 6], [10, 14, 22]] }
363
+ assert { (-a) == [[-1, -2, -3], [-5, -7, -11]] }
364
+ assert { (a**2) == [[1, 4, 9], [25, 49, 121]] }
365
+ assert { (dtype[[1, 0], [0, 1]].dot dtype[[4, 1], [2, 2]]) == [[4, 1], [2, 2]] }
366
+ assert { a.swap_byte.swap_byte == src }
367
+ end
368
+
369
+ test "#{dtype},[[1,2,3],[5,7,11]]#{ref},aset[]=" do
370
+ src = [[1, 2, 3], [5, 7, 11]]
371
+
372
+ a = init.call(dtype, src)
373
+ a[5] = 13
374
+ assert { a[5] == 13 }
375
+
376
+ a = init.call(dtype, src)
377
+ a[-1] = 13
378
+ assert { a[-1] == 13 }
379
+
380
+ a = init.call(dtype, src)
381
+ a[1, 0] = 13
382
+ assert { a[1, 0] == 13 }
383
+
384
+ a = init.call(dtype, src)
385
+ a[1, 1] = 13
386
+ assert { a[1, 1] == 13 }
387
+
388
+ a = init.call(dtype, src)
389
+ a[1, 2] = 13
390
+ assert { a[1, 2] == 13 }
391
+
392
+ a = init.call(dtype, src)
393
+ a[3..4] = [13, 13]
394
+ assert { a[3..4] == [13, 13] }
395
+
396
+ a = init.call(dtype, src)
397
+ a[0, 1..2] = [13, 13]
398
+ assert { a[0, 1..2] == [13, 13] }
399
+
400
+ a = init.call(dtype, src)
401
+ a[0, :*] = [13, 13, 13]
402
+ assert { a[0, :*] == [13, 13, 13] }
403
+
404
+ a = init.call(dtype, src)
405
+ a[1, :*] = [13, 13, 13]
406
+ assert { a[1, :*] == [13, 13, 13] }
407
+
408
+ a = init.call(dtype, src)
409
+ a[:*, 1] = [13, 13]
410
+ assert { a[:*, 1] == [13, 13] }
411
+
412
+ a = init.call(dtype, src)
413
+ a[5] = dtype.cast(13)
414
+ assert { a[5] == 13 }
415
+ assert { a[5] == dtype.cast(13) }
416
+
417
+ a = init.call(dtype, src)
418
+ a[1, 1] = dtype.cast(13)
419
+ assert { a[1, 1] == 13 }
420
+ assert { a[1, 1] == dtype.cast(13) }
421
+
422
+ a = init.call(dtype, src)
423
+ a[3..4] = dtype.cast([13, 13])
424
+ assert { a[3..4] == [13, 13] }
425
+ assert { a[3..4] == dtype.cast([13, 13]) }
426
+
427
+ a = init.call(dtype, src)
428
+ a[:*, 1] = dtype.cast([13, 13])
429
+ assert { a[:*, 1] == [13, 13] }
430
+ assert { a[:*, 1] == dtype.cast([13, 13]) }
431
+
432
+ a = init.call(dtype, src)
433
+ v = a[0, false]
434
+ v[0] = 13
435
+ assert { v == [13, 2, 3] }
436
+ assert { a == [[13, 2, 3], [5, 7, 11]] }
437
+
438
+ a = init.call(dtype, src)
439
+ v = a[1, false]
440
+ v[0] = 13
441
+ assert { v == [13, 7, 11] }
442
+ assert { a == [[1, 2, 3], [13, 7, 11]] }
443
+
444
+ a = init.call(dtype, src)
445
+ a[[1, 2, 3]] = 13
446
+ assert { a[[1, 2, 3]] == [13, 13, 13] }
447
+ assert { a == [[1, 13, 13], [13, 7, 11]] }
448
+
449
+ a = init.call(dtype, src)
450
+ a[1, [0, 2]] = [13, 13]
451
+ assert { a[1, [0, 2]] == [13, 13] }
452
+ assert { a == [[1, 2, 3], [13, 7, 13]] }
453
+
454
+ a = init.call(dtype, src)
455
+ a[1, true] = 13
456
+ assert { a[1, true] == [13, 13, 13] }
457
+ assert { a == [[1, 2, 3], [13, 13, 13]] }
458
+ end
459
+
460
+ end
461
+
462
+ test "#{dtype},[[[1,2],[3,4]],[[5,6],[7,8]]]" do
463
+ arr = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
464
+ a = dtype[*arr]
465
+
466
+ assert { a[0, 1, 1] == 4 }
467
+ assert { a[:rest] == a }
468
+ assert { a[0, :rest] == [[1, 2], [3, 4]] }
469
+ assert { a[0, false] == [[1, 2], [3, 4]] }
470
+ assert { a[0, 1, :rest] == [3, 4] }
471
+ assert { a[0, 1, false] == [3, 4] }
472
+ assert { a[:rest, 0] == [[1, 3], [5, 7]] }
473
+ assert { a[:rest, 0, 1] == [2, 6] }
474
+ assert { a[1, :rest, 0] == [5, 7] }
475
+ assert { a[1, 1, :rest, 0] == 7 }
476
+ assert_raise(IndexError) { a[1, 1, 1, 1, :rest] }
477
+ assert_raise(IndexError) { a[1, 1, 1, :rest, 1] }
478
+ assert_raise(IndexError) { a[:rest, 1, :rest, 0] }
479
+
480
+ assert { a.transpose == [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] }
481
+ assert { a.transpose(2, 1, 0) == [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] }
482
+ assert { a.transpose(0, 2, 1) == [[[1, 3], [2, 4]], [[5, 7], [6, 8]]] }
483
+
484
+ assert { a.contiguous? }
485
+ assert { a.transpose.contiguous? == false }
486
+ assert { a.fortran_contiguous? == false }
487
+ assert { a.transpose.fortran_contiguous? }
488
+ assert { a.transpose.transpose.fortran_contiguous? == false }
489
+ assert { a.transpose(0, 2, 1).fortran_contiguous? == false }
490
+ assert { a.reshape(2, 4).fortran_contiguous? == false }
491
+ assert { a.reshape(2, 4).transpose.fortran_contiguous? }
492
+
493
+ assert { a.at([0, 1], [1, 0], [0, 1]) == [3, 6] }
494
+ assert { a.view.at([0, 1], [1, 0], [0, 1]) == [3, 6] }
495
+
496
+ assert { a.transpose == [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] }
497
+ assert { a.transpose(2, 1, 0) == [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] }
498
+ assert { a.transpose(0, 2, 1) == [[[1, 3], [2, 4]], [[5, 7], [6, 8]]] }
499
+
500
+ assert { a.reverse == [[[8, 7], [6, 5]], [[4, 3], [2, 1]]] }
501
+ assert { a.reverse(0, 1, 2) == [[[8, 7], [6, 5]], [[4, 3], [2, 1]]] }
502
+ assert { a.reverse(-3, -2, -1) == [[[8, 7], [6, 5]], [[4, 3], [2, 1]]] }
503
+ assert { a.reverse(0..2) == [[[8, 7], [6, 5]], [[4, 3], [2, 1]]] }
504
+ assert { a.reverse(-3..-1) == [[[8, 7], [6, 5]], [[4, 3], [2, 1]]] }
505
+ assert { a.reverse(0...3) == [[[8, 7], [6, 5]], [[4, 3], [2, 1]]] }
506
+ assert { a.reverse(0) == [[[5, 6], [7, 8]], [[1, 2], [3, 4]]] }
507
+ assert { a.reverse(1) == [[[3, 4], [1, 2]], [[7, 8], [5, 6]]] }
508
+ assert { a.reverse(2) == [[[2, 1], [4, 3]], [[6, 5], [8, 7]]] }
509
+ assert { a.reverse(0, 1) == [[[7, 8], [5, 6]], [[3, 4], [1, 2]]] }
510
+ assert { a.reverse(0..1) == [[[7, 8], [5, 6]], [[3, 4], [1, 2]]] }
511
+ assert { a.reverse(0...2) == [[[7, 8], [5, 6]], [[3, 4], [1, 2]]] }
512
+ assert { a.reverse(0, 2) == [[[6, 5], [8, 7]], [[2, 1], [4, 3]]] }
513
+ assert { a.reverse((0..2) % 2) == [[[6, 5], [8, 7]], [[2, 1], [4, 3]]] }
514
+ assert { a.reverse((0..2).step(2)) == [[[6, 5], [8, 7]], [[2, 1], [4, 3]]] }
515
+
516
+ enum = arr.flatten.to_enum
517
+ a.each do |e|
518
+ assert { e == enum.next }
519
+ end
520
+ a.each_with_index do |e, *i|
521
+ assert { e == a[*i] }
522
+ end
523
+ end
524
+
525
+ sub_test_case "#{dtype}, #mulsum" do
526
+ test "vector.mulsum(vector)" do
527
+ a = dtype[1..3]
528
+ b = dtype[2..4]
529
+ assert { a.mulsum(b) == (1 * 2 + 2 * 3 + 3 * 4) }
530
+ end
531
+
532
+ if [Cumo::DComplex, Cumo::SComplex, Cumo::DFloat, Cumo::SFloat].include?(dtype)
533
+ test "vector.mulsum(vector, nan: true)" do
534
+ a = dtype[1..3]
535
+ a[0] = 0.0 / 0 / 0
536
+ b = dtype[2..4]
537
+ assert { a.mulsum(b, nan: true) == (0 + 2 * 3 + 3 * 4) }
538
+ end
539
+ end
540
+ end
541
+
542
+ sub_test_case "#{dtype}, #dot" do
543
+ test "scalar.dot(scalar)" do
544
+ a = dtype[1].sum
545
+ b = dtype[3].sum
546
+ assert { a.dot(b) == 1 * 3 }
547
+ end
548
+ test "vector.dot(vector) of 1-elem" do
549
+ a = dtype[1]
550
+ b = dtype[3]
551
+ assert { a.dot(b) == 1 * 3 }
552
+ end
553
+ test "vector.dot(vector)" do
554
+ a = dtype[1..3]
555
+ b = dtype[2..4]
556
+ assert { a.dot(b) == (1 * 2 + 2 * 3 + 3 * 4) }
557
+ end
558
+ test "matrix.dot(vector)" do
559
+ a = dtype[1..6].reshape(3, 2)
560
+ b = dtype[1..2]
561
+ assert { a.dot(b) == [5, 11, 17] }
562
+ end
563
+ test "vector.dot(matrix)" do
564
+ a = dtype[1..2]
565
+ b = dtype[1..6].reshape(2, 3)
566
+ assert { a.dot(b) == [9, 12, 15] }
567
+ end
568
+ test "matrix.dot(matrix)" do
569
+ a = dtype[1..6].reshape(3, 2)
570
+ b = dtype[1..6].reshape(2, 3)
571
+ assert { a.dot(b) == [[9, 12, 15], [19, 26, 33], [29, 40, 51]] }
572
+ assert { b.dot(a) == [[22, 28], [49, 64]] }
573
+ end
574
+ test "matrix.dot(matrix.transpose)" do
575
+ a = dtype[1..6].reshape(3, 2)
576
+ b = dtype[1..6].reshape(3, 2).transpose
577
+ assert { a.dot(b) == [[5, 11, 17], [11, 25, 39], [17, 39, 61]] }
578
+ assert { b.dot(a) == [[35, 44], [44, 56]] }
579
+ end
580
+ test "matrix.dot(matrix) of contiguous view" do
581
+ a = dtype.new(4, 3).seq(0)[1..2, 0..2] # 2x3
582
+ b = dtype.new(3, 2).seq(0)
583
+ assert { a.dot(b) == [[28, 40], [46, 67]] }
584
+ assert { b.dot(a) == [[6, 7, 8], [24, 29, 34], [42, 51, 60]] }
585
+ end
586
+ test "matrix.dot(matrix) of non-contiguous view" do
587
+ a = dtype.new(4, 4).seq(0)[1..2, 0..2] # 2x3
588
+ b = dtype.new(3, 2).seq(0)
589
+ assert { a.dot(b) == [[34, 49], [58, 85]] }
590
+ assert { b.dot(a) == [[8, 9, 10], [32, 37, 42], [56, 65, 74]] }
591
+ end
592
+ test "matrix.dot(matrix) >= 3 dimensions" do
593
+ a = dtype[1..6 * 2].reshape(2, 3, 2)
594
+ b = dtype[1..6 * 2].reshape(2, 2, 3)
595
+ assert { a.dot(b) ==
596
+ [[[9, 12, 15],
597
+ [19, 26, 33],
598
+ [29, 40, 51]],
599
+ [[129, 144, 159],
600
+ [163, 182, 201],
601
+ [197, 220, 243]]] }
602
+ assert { b.dot(a) ==
603
+ [[[22, 28],
604
+ [49, 64]],
605
+ [[220, 244],
606
+ [301, 334]]] }
607
+ end
608
+ test "matrix.dot(matrix) >= 4 dimensions" do
609
+ a = dtype[1..6 * 2].reshape(1, 2, 3, 2)
610
+ b = dtype[1..6 * 2].reshape(1, 2, 2, 3)
611
+ assert { a.dot(b) ==
612
+ [[[[9, 12, 15],
613
+ [19, 26, 33],
614
+ [29, 40, 51]],
615
+ [[129, 144, 159],
616
+ [163, 182, 201],
617
+ [197, 220, 243]]]] }
618
+ assert { b.dot(a) ==
619
+ [[[[22, 28],
620
+ [49, 64]],
621
+ [[220, 244],
622
+ [301, 334]]]] }
623
+ end
624
+ test "matrix.dot(matrix.transpose) >= 3 dimensions" do
625
+ a = dtype[1..6 * 2].reshape(2, 3, 2)
626
+ b = dtype[1..6 * 2].reshape(3, 2, 2).transpose
627
+ assert { a.dot(b) ==
628
+ [[[7, 19, 31],
629
+ [15, 43, 71],
630
+ [23, 67, 111]],
631
+ [[46, 106, 166],
632
+ [58, 134, 210],
633
+ [70, 162, 254]]] }
634
+ assert { b.dot(a) ==
635
+ [[[61, 76],
636
+ [79, 100]],
637
+ [[178, 196],
638
+ [232, 256]]] }
639
+ end
640
+ test "matrix.dot(matrix) with incorrect shape" do
641
+ a = dtype[1..6].reshape(3, 2)
642
+ b = dtype[1..9].reshape(3, 3)
643
+ assert_raise(Cumo::NArray::ShapeError) { a.dot(b) }
644
+ end
645
+ end
646
+
647
+ if [Cumo::DComplex, Cumo::SComplex, Cumo::DFloat, Cumo::SFloat].include?(dtype)
648
+ sub_test_case "#{dtype}, #gemm" do
649
+ test "matrix.gemm(matrix) with alpha" do
650
+ a = dtype[1..6].reshape(2, 3)
651
+ b = dtype[1..6].reshape(2, 3)
652
+ alpha = [Cumo::DComplex, Cumo::SComplex].include?(dtype) ? Complex(3) : 3
653
+ assert { a.gemm(b.transpose) * alpha == a.gemm(b.transpose, alpha: alpha) }
654
+ end
655
+ end
656
+ end
657
+
658
+ test "#{dtype},eye" do
659
+ assert { dtype.new(3, 3).eye(1) == [[1, 0, 0], [0, 1, 0], [0, 0, 1]] }
660
+ assert { dtype.new(3, 3).eye(2) == [[2, 0, 0], [0, 2, 0], [0, 0, 2]] }
661
+ assert { dtype.new(3, 3).eye(1, 1) == [[0, 1, 0], [0, 0, 1], [0, 0, 0]] }
662
+ assert { dtype.new(3, 3).eye(1, -1) == [[0, 0, 0], [1, 0, 0], [0, 1, 0]] }
663
+ assert { dtype.new(2, 2, 2).eye(1) == [[[1, 0], [0, 1]], [[1, 0], [0, 1]]] }
664
+ assert { dtype.new(3, 1).eye(1) == [[1], [0], [0]] }
665
+ assert { dtype.new(1, 3).eye(1) == [[1, 0, 0]] }
666
+ assert { dtype.eye(3) == [[1, 0, 0], [0, 1, 0], [0, 0, 1]] }
667
+ assert { dtype.eye(3, 1) == [[1], [0], [0]] }
668
+ assert { dtype.eye(1, 3) == [[1, 0, 0]] }
669
+ end
670
+
671
+ test "#{dtype},element-wise" do
672
+ x = dtype[[1, 2, 3], [5, 7, 11]]
673
+ assert { x + x == [[2, 4, 6], [10, 14, 22]] }
674
+ assert { x + 1 == [[2, 3, 4], [6, 8, 12]] }
675
+ assert { x + dtype[1] == [[2, 3, 4], [6, 8, 12]] }
676
+ assert { x + dtype[[1], [2]] == [[2, 3, 4], [7, 9, 13]] }
677
+ assert { x + dtype[1, 2, 3] == [[2, 4, 6], [6, 9, 14]] }
678
+ assert { x + dtype[[1, 2], [3, 4], [5, 6]].transpose == [[2, 5, 8], [7, 11, 17]] }
679
+ assert { x[0, 1..2] + x[1, 0..1] == [7, 10] }
680
+ unless [Cumo::DComplex, Cumo::SComplex].include?(dtype)
681
+ y = x[x > 6] # [7,11]
682
+ assert { y + y == [14, 22] }
683
+ assert { y + 1 == [8, 12] }
684
+ assert { y + dtype[1] == [8, 12] }
685
+ assert { y + dtype[[1, 1], [2, 2]] == [[8, 12], [9, 13]] }
686
+ assert { y.reshape(2, 1) + dtype[[1, 1], [2, 2]] == [[8, 8], [13, 13]] }
687
+ end
688
+ end
689
+
690
+ test "#{dtype},reduction" do
691
+ assert { dtype.ones(2, 2, 3, 2).sum(axis: [0, 2, 3]) == [12, 12] }
692
+ assert { dtype.ones(5, 3, 4, 2, 1).sum(axis: [0, 3, 4]) == [[10, 10, 10, 10], [10, 10, 10, 10], [10, 10, 10, 10]] }
693
+ assert { dtype[[1, 2, 3], [4, 5, 6]].sum(axis: 1) == [6, 15] }
694
+ assert { dtype[[1, 2, 3], [4, 5, 6]].sum(axis: 1, keepdims: true) == [[6], [15]] }
695
+
696
+ unless [Cumo::DComplex, Cumo::SComplex].include?(dtype)
697
+ assert_nothing_raised { dtype.ones(2, 3, 9, 4, 2).max_index(2) }
698
+
699
+ a = dtype[[[6, 8, 5],
700
+ [2, 5, 6],
701
+ [4, 5, 5]],
702
+ [[7, 4, 3],
703
+ [9, 1, 0],
704
+ [4, 1, 6]]]
705
+ assert { a.max_index(2) == [[1, 5, 8], [9, 12, 17]] }
706
+ assert { a.max(2) == [[8, 6, 5], [7, 9, 6]] }
707
+
708
+ unless [Cumo::UInt64, Cumo::UInt32, Cumo::UInt16, Cumo::UInt8].include?(dtype)
709
+ a = dtype[[[-6, -8, -5],
710
+ [-2, 5, 6],
711
+ [4, -5, 5]],
712
+ [[-7, -4, -3],
713
+ [9, 1, 0],
714
+ [4, -1, -6]]]
715
+ assert { a.max_index(2) == [[2, 5, 8], [11, 12, 15]] }
716
+ assert { a.max(2) == [[-5, 6, 5], [-3, 9, 4]] }
717
+ end
718
+
719
+ if [Cumo::DFloat, Cumo::SFloat].include?(dtype)
720
+ assert { dtype[[-Float::INFINITY, 0, 1, Float::INFINITY]].max_index(0) == [0, 1, 2, 3] }
721
+ end
722
+ end
723
+ end
724
+
725
+ test "#{dtype},advanced indexing" do
726
+ a = dtype[[1, 2, 3], [4, 5, 6]]
727
+ assert { a[[0, 1], [0, 1]].dup == [[1, 2], [4, 5]] }
728
+ assert { a[[0, 1], [0, 1]].sum == 12 }
729
+ assert { a[[0, 1], [0, 1]].diagonal == [1, 5] }
730
+ diag = a.dup[[0, 1], [0, 1]].diagonal
731
+ diag.inplace - 1
732
+ assert { diag == [0, 4] }
733
+
734
+ assert { a.at([0, 1], [0, 1]).dup == [1, 5] }
735
+ at = a.dup
736
+ at.at([0, 1], [0, 1]).inplace - 1
737
+ assert { at == [[0, 2, 3], [4, 4, 6]] }
738
+ end
739
+ end
740
+
741
+ test "Cumo::DFloat.cast(Cumo::RObject[1, nil, 3])" do
742
+ assert_equal(Cumo::DFloat[1, Float::NAN, 3].format_to_a,
743
+ Cumo::DFloat.cast(Cumo::RObject[1, nil, 3]).format_to_a)
744
+ end
745
+ end