cumo 0.4.3 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +3 -0
  3. data/.rubocop.yml +15 -0
  4. data/.rubocop_todo.yml +1272 -0
  5. data/3rd_party/mkmf-cu/Gemfile +2 -0
  6. data/3rd_party/mkmf-cu/Rakefile +2 -1
  7. data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +2 -0
  8. data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +36 -7
  9. data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +51 -45
  10. data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +2 -0
  11. data/3rd_party/mkmf-cu/mkmf-cu.gemspec +3 -1
  12. data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +5 -3
  13. data/CHANGELOG.md +69 -0
  14. data/Gemfile +6 -1
  15. data/README.md +2 -10
  16. data/Rakefile +8 -11
  17. data/bench/broadcast_fp32.rb +28 -26
  18. data/bench/cumo_bench.rb +18 -16
  19. data/bench/numo_bench.rb +18 -16
  20. data/bench/reduction_fp32.rb +14 -12
  21. data/bin/console +1 -0
  22. data/cumo.gemspec +5 -8
  23. data/ext/cumo/cuda/cudnn.c +2 -2
  24. data/ext/cumo/cumo.c +7 -3
  25. data/ext/cumo/depend.erb +15 -13
  26. data/ext/cumo/extconf.rb +32 -46
  27. data/ext/cumo/include/cumo/cuda/cudnn.h +3 -1
  28. data/ext/cumo/include/cumo/intern.h +1 -0
  29. data/ext/cumo/include/cumo/narray.h +13 -1
  30. data/ext/cumo/include/cumo/template.h +2 -4
  31. data/ext/cumo/include/cumo/types/complex_macro.h +1 -1
  32. data/ext/cumo/include/cumo/types/float_macro.h +2 -2
  33. data/ext/cumo/include/cumo/types/xint_macro.h +3 -2
  34. data/ext/cumo/include/cumo.h +2 -2
  35. data/ext/cumo/narray/array.c +3 -3
  36. data/ext/cumo/narray/data.c +23 -2
  37. data/ext/cumo/narray/gen/cogen.rb +8 -7
  38. data/ext/cumo/narray/gen/cogen_kernel.rb +8 -7
  39. data/ext/cumo/narray/gen/def/bit.rb +3 -1
  40. data/ext/cumo/narray/gen/def/dcomplex.rb +2 -0
  41. data/ext/cumo/narray/gen/def/dfloat.rb +2 -0
  42. data/ext/cumo/narray/gen/def/int16.rb +2 -0
  43. data/ext/cumo/narray/gen/def/int32.rb +2 -0
  44. data/ext/cumo/narray/gen/def/int64.rb +2 -0
  45. data/ext/cumo/narray/gen/def/int8.rb +2 -0
  46. data/ext/cumo/narray/gen/def/robject.rb +2 -0
  47. data/ext/cumo/narray/gen/def/scomplex.rb +2 -0
  48. data/ext/cumo/narray/gen/def/sfloat.rb +2 -0
  49. data/ext/cumo/narray/gen/def/uint16.rb +2 -0
  50. data/ext/cumo/narray/gen/def/uint32.rb +2 -0
  51. data/ext/cumo/narray/gen/def/uint64.rb +2 -0
  52. data/ext/cumo/narray/gen/def/uint8.rb +2 -0
  53. data/ext/cumo/narray/gen/erbln.rb +9 -7
  54. data/ext/cumo/narray/gen/erbpp2.rb +26 -24
  55. data/ext/cumo/narray/gen/narray_def.rb +13 -11
  56. data/ext/cumo/narray/gen/spec.rb +58 -55
  57. data/ext/cumo/narray/gen/tmpl/alloc_func.c +1 -1
  58. data/ext/cumo/narray/gen/tmpl/at.c +34 -0
  59. data/ext/cumo/narray/gen/tmpl/batch_norm.c +1 -1
  60. data/ext/cumo/narray/gen/tmpl/batch_norm_backward.c +2 -2
  61. data/ext/cumo/narray/gen/tmpl/conv.c +1 -1
  62. data/ext/cumo/narray/gen/tmpl/conv_grad_w.c +3 -1
  63. data/ext/cumo/narray/gen/tmpl/conv_transpose.c +1 -1
  64. data/ext/cumo/narray/gen/tmpl/fixed_batch_norm.c +1 -1
  65. data/ext/cumo/narray/gen/tmpl/init_class.c +1 -0
  66. data/ext/cumo/narray/gen/tmpl/pooling_backward.c +1 -1
  67. data/ext/cumo/narray/gen/tmpl/pooling_forward.c +1 -1
  68. data/ext/cumo/narray/gen/tmpl/qsort.c +1 -5
  69. data/ext/cumo/narray/gen/tmpl/sort.c +1 -1
  70. data/ext/cumo/narray/gen/tmpl_bit/binary.c +42 -14
  71. data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +5 -0
  72. data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +5 -0
  73. data/ext/cumo/narray/gen/tmpl_bit/mask.c +27 -7
  74. data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +21 -7
  75. data/ext/cumo/narray/gen/tmpl_bit/unary.c +21 -7
  76. data/ext/cumo/narray/index.c +243 -39
  77. data/ext/cumo/narray/index_kernel.cu +84 -0
  78. data/ext/cumo/narray/narray.c +38 -1
  79. data/ext/cumo/narray/ndloop.c +1 -1
  80. data/ext/cumo/narray/struct.c +1 -1
  81. data/lib/cumo/cuda/compile_error.rb +1 -1
  82. data/lib/cumo/cuda/compiler.rb +23 -22
  83. data/lib/cumo/cuda/cudnn.rb +1 -1
  84. data/lib/cumo/cuda/device.rb +1 -1
  85. data/lib/cumo/cuda/link_state.rb +2 -2
  86. data/lib/cumo/cuda/module.rb +1 -2
  87. data/lib/cumo/cuda/nvrtc_program.rb +3 -2
  88. data/lib/cumo/cuda.rb +2 -0
  89. data/lib/cumo/linalg.rb +2 -0
  90. data/lib/cumo/narray/extra.rb +137 -185
  91. data/lib/cumo/narray.rb +2 -0
  92. data/lib/cumo.rb +3 -1
  93. data/test/bit_test.rb +157 -0
  94. data/test/cuda/compiler_test.rb +69 -0
  95. data/test/cuda/device_test.rb +30 -0
  96. data/test/cuda/memory_pool_test.rb +45 -0
  97. data/test/cuda/nvrtc_test.rb +51 -0
  98. data/test/cuda/runtime_test.rb +28 -0
  99. data/test/cudnn_test.rb +498 -0
  100. data/test/cumo_test.rb +27 -0
  101. data/test/narray_test.rb +745 -0
  102. data/test/ractor_test.rb +52 -0
  103. data/test/test_helper.rb +31 -0
  104. metadata +31 -54
  105. data/.travis.yml +0 -5
  106. data/numo-narray-version +0 -1
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  def_id "cast"
2
4
  def_id "eq"
3
5
  def_id "ne"
@@ -14,8 +16,8 @@ if is_float
14
16
  def_id "copysign"
15
17
  end
16
18
  if is_int
17
- def_id "<<","left_shift"
18
- def_id ">>","right_shift"
19
+ def_id "<<", "left_shift"
20
+ def_id ">>", "right_shift"
19
21
  end
20
22
  if is_comparable && !is_object
21
23
  def_id "gt"
@@ -42,13 +44,13 @@ if is_object
42
44
  def_id "nan?"
43
45
  def_id "infinite?"
44
46
  def_id "finite?"
45
- def_id "==","eq"
46
- def_id "!=","ne"
47
- def_id ">" ,"gt"
48
- def_id ">=","ge"
49
- def_id "<" ,"lt"
50
- def_id "<=","le"
51
- def_id "<=>","ufo"
47
+ def_id "==", "eq"
48
+ def_id "!=", "ne"
49
+ def_id ">" , "gt"
50
+ def_id ">=", "ge"
51
+ def_id "<" , "lt"
52
+ def_id "<=", "le"
53
+ def_id "<=>", "ufo"
52
54
  end
53
55
  if (is_float || is_complex) && !is_object
54
56
  def_id "gemm"
@@ -119,18 +121,18 @@ def_method "store" do
119
121
  store_numeric
120
122
  store_from "Bit"
121
123
  if is_complex
122
- store_from "DComplex","cumo_dcomplex","m_from_dcomplex"
123
- store_from "SComplex","cumo_scomplex","m_from_scomplex"
124
+ store_from "DComplex", "cumo_dcomplex", "m_from_dcomplex"
125
+ store_from "SComplex", "cumo_scomplex", "m_from_scomplex"
124
126
  end
125
- store_from "DFloat","double", "m_from_real"
126
- store_from "SFloat","float", "m_from_real"
127
+ store_from "DFloat", "double", "m_from_real"
128
+ store_from "SFloat", "float", "m_from_real"
127
129
  store_from "Int64", "int64_t", "m_from_int64"
128
130
  store_from "Int32", "int32_t", "m_from_int32"
129
131
  store_from "Int16", "int16_t", "m_from_sint"
130
132
  store_from "Int8", "int8_t", "m_from_sint"
131
- store_from "UInt64","u_int64_t","m_from_uint64"
132
- store_from "UInt32","u_int32_t","m_from_uint32"
133
- store_from "UInt16","u_int16_t","m_from_sint"
133
+ store_from "UInt64", "u_int64_t", "m_from_uint64"
134
+ store_from "UInt32", "u_int32_t", "m_from_uint32"
135
+ store_from "UInt16", "u_int16_t", "m_from_sint"
134
136
  store_from "UInt8", "u_int8_t", "m_from_sint"
135
137
  store_from "RObject", "VALUE", "m_num_to_data"
136
138
  store_array
@@ -144,6 +146,7 @@ def_singleton_method "cast"
144
146
  def_method "aref", op:"[]"
145
147
  def_method "aref_cpu"
146
148
  def_method "aset", op:"[]="
149
+ def_method "at"
147
150
 
148
151
  def_method "coerce_cast"
149
152
  def_method "to_a"
@@ -167,15 +170,15 @@ if is_bit
167
170
  binary "xor", "^"
168
171
  binary "eq"
169
172
  bit_count "count_true"
170
- def_alias "count_1","count_true"
171
- def_alias "count","count_true"
173
+ def_alias "count_1", "count_true"
174
+ def_alias "count", "count_true"
172
175
  bit_count "count_false"
173
- def_alias "count_0","count_false"
176
+ def_alias "count_0", "count_false"
174
177
  bit_count_cpu "count_true_cpu"
175
- def_alias "count_1_cpu","count_true_cpu"
176
- def_alias "count_cpu","count_true_cpu"
178
+ def_alias "count_1_cpu", "count_true_cpu"
179
+ def_alias "count_cpu", "count_true_cpu"
177
180
  bit_count_cpu "count_false_cpu"
178
- def_alias "count_0_cpu","count_false_cpu"
181
+ def_alias "count_0_cpu", "count_false_cpu"
179
182
  bit_reduce "all?", 1
180
183
  bit_reduce "any?", 0
181
184
  def_method "none?", "none_p"
@@ -215,17 +218,17 @@ if is_complex
215
218
  unary2 "real", "rtype", "cRT"
216
219
  unary2 "imag", "rtype", "cRT"
217
220
  unary2 "arg", "rtype", "cRT"
218
- def_alias "angle","arg"
221
+ def_alias "angle", "arg"
219
222
  set2 "set_imag", "rtype", "cRT"
220
223
  set2 "set_real", "rtype", "cRT"
221
- def_alias "imag=","set_imag"
222
- def_alias "real=","set_real"
224
+ def_alias "imag=", "set_imag"
225
+ def_alias "real=", "set_real"
223
226
  else
224
227
  def_alias "conj", "view"
225
228
  def_alias "im", "view"
226
229
  end
227
230
 
228
- def_alias "conjugate","conj"
231
+ def_alias "conjugate", "conj"
229
232
 
230
233
  # base_cond
231
234
 
@@ -278,9 +281,9 @@ if is_comparable
278
281
  cond_binary "lt"
279
282
  cond_binary "le"
280
283
  def_alias ">", "gt"
281
- def_alias ">=","ge"
284
+ def_alias ">=", "ge"
282
285
  def_alias "<", "lt"
283
- def_alias "<=","le"
286
+ def_alias "<=", "le"
284
287
  def_method "clip"
285
288
  end
286
289
 
@@ -296,32 +299,32 @@ end
296
299
 
297
300
  if is_int
298
301
  if is_unsigned
299
- accum "sum","u_int64_t","cumo_cUInt64"
300
- accum "prod","u_int64_t","cumo_cUInt64"
302
+ accum "sum", "u_int64_t", "cumo_cUInt64"
303
+ accum "prod", "u_int64_t", "cumo_cUInt64"
301
304
  else
302
- accum "sum","int64_t","cumo_cInt64"
303
- accum "prod","int64_t","cumo_cInt64"
305
+ accum "sum", "int64_t", "cumo_cInt64"
306
+ accum "prod", "int64_t", "cumo_cInt64"
304
307
  end
305
308
  else
306
- accum "sum","dtype","cT"
307
- accum "prod","dtype","cT"
309
+ accum "sum", "dtype", "cT"
310
+ accum "prod", "dtype", "cT"
308
311
  end
309
312
 
310
313
  if is_double_precision
311
- accum "kahan_sum","dtype","cT"
314
+ accum "kahan_sum", "dtype", "cT"
312
315
  end
313
316
 
314
317
  if is_float
315
- accum "mean","dtype","cT"
316
- accum "stddev","rtype","cRT"
317
- accum "var","rtype","cRT"
318
- accum "rms","rtype","cRT"
318
+ accum "mean", "dtype", "cT"
319
+ accum "stddev", "rtype", "cRT"
320
+ accum "var", "rtype", "cRT"
321
+ accum "rms", "rtype", "cRT"
319
322
  end
320
323
 
321
324
  if is_comparable
322
- accum "min","dtype","cT"
323
- accum "max","dtype","cT"
324
- accum "ptp","dtype","cT"
325
+ accum "min", "dtype", "cT"
326
+ accum "max", "dtype", "cT"
327
+ accum "ptp", "dtype", "cT"
325
328
  accum_index "max_index"
326
329
  accum_index "min_index"
327
330
  def_method "minmax"
@@ -333,8 +336,8 @@ if is_int && !is_object
333
336
  def_method "bincount"
334
337
  end
335
338
 
336
- cum "cumsum","add"
337
- cum "cumprod","mul"
339
+ cum "cumsum", "add"
340
+ cum "cumprod", "mul"
338
341
 
339
342
  # dot
340
343
  accum_binary "mulsum"
@@ -377,17 +380,17 @@ def_method "poly"
377
380
 
378
381
  if is_comparable && !is_object
379
382
  if is_float
380
- qsort type_name,"dtype","*(dtype*)","_prnan"
381
- qsort type_name,"dtype","*(dtype*)","_ignan"
383
+ qsort type_name, "dtype", "*(dtype*)", "_prnan"
384
+ qsort type_name, "dtype", "*(dtype*)", "_ignan"
382
385
  else
383
- qsort type_name,"dtype","*(dtype*)"
386
+ qsort type_name, "dtype", "*(dtype*)"
384
387
  end
385
388
  def_method "sort"
386
389
  if is_float
387
- qsort type_name+"_index","dtype*","**(dtype**)","_prnan"
388
- qsort type_name+"_index","dtype*","**(dtype**)","_ignan"
390
+ qsort type_name + "_index", "dtype*", "**(dtype**)", "_prnan"
391
+ qsort type_name + "_index", "dtype*", "**(dtype**)", "_ignan"
389
392
  else
390
- qsort type_name+"_index","dtype*","**(dtype**)"
393
+ qsort type_name + "_index", "dtype*", "**(dtype**)"
391
394
  end
392
395
  def_method "sort_index"
393
396
  def_method "median"
@@ -407,7 +410,7 @@ def_module do
407
410
  set ns_var: "cT"
408
411
  set class_name: cn
409
412
  set name: "#{nm}_math"
410
- set full_module_name: fn+"::NMath"
413
+ set full_module_name: fn + "::NMath"
411
414
  set module_name: "Math"
412
415
  set module_var: "mTM"
413
416
 
@@ -433,14 +436,14 @@ def_module do
433
436
  math "atanh"
434
437
  math "sinc"
435
438
  if !is_c
436
- math "atan2",2
437
- math "hypot",2
439
+ math "atan2", 2
440
+ math "hypot", 2
438
441
  math "erf"
439
442
  math "erfc"
440
443
  math "log1p"
441
444
  math "expm1"
442
- math "ldexp",2
443
- math "frexp",1,"frexp"
445
+ math "ldexp", 2
446
+ math "frexp", 1, "frexp"
444
447
  end
445
448
  end
446
449
  end
@@ -85,7 +85,7 @@ static const rb_data_type_t <%=type_name%>_data_type = {
85
85
  {0, <%=type_name%>_free, <%=type_name%>_memsize,},
86
86
  &cumo_na_data_type,
87
87
  &<%=type_name%>_info,
88
- 0, // flags
88
+ RUBY_TYPED_FROZEN_SHAREABLE, // flags
89
89
  };
90
90
 
91
91
  <% end %>
@@ -0,0 +1,34 @@
1
+ /*
2
+ Multi-dimensional array indexing.
3
+ Same as [] for one-dimensional NArray.
4
+ Similar to numpy's tuple indexing, i.e., `a[[1,2,..],[3,4,..]]`
5
+ @overload at(*indices)
6
+ @param [Numeric,Range,etc] *indices Multi-dimensional Index Arrays.
7
+ @return [Cumo::NArray::<%=class_name%>] one-dimensional NArray view.
8
+
9
+ @example
10
+ x = Cumo::DFloat.new(3,3,3).seq
11
+ => Cumo::DFloat#shape=[3,3,3]
12
+ [[[0, 1, 2],
13
+ [3, 4, 5],
14
+ [6, 7, 8]],
15
+ [[9, 10, 11],
16
+ [12, 13, 14],
17
+ [15, 16, 17]],
18
+ [[18, 19, 20],
19
+ [21, 22, 23],
20
+ [24, 25, 26]]]
21
+
22
+ x.at([0,1,2],[0,1,2],[-1,-2,-3])
23
+ => Cumo::DFloat(view)#shape=[3]
24
+ [2, 13, 24]
25
+ */
26
+ static VALUE
27
+ <%=c_func(-1)%>(int argc, VALUE *argv, VALUE self)
28
+ {
29
+ int result_nd;
30
+ size_t pos;
31
+
32
+ result_nd = cumo_na_get_result_dimension(self, argc, argv, sizeof(dtype), &pos);
33
+ return cumo_na_at_main(argc, argv, self, 0, result_nd, pos);
34
+ }
@@ -193,7 +193,7 @@ BATCH_NORM_ERROR:
193
193
  }
194
194
 
195
195
  #else // CUDNN_FOUND
196
- VALUE cumo_cuda_eCUDNNError;
196
+ #include "cumo/cuda/cudnn.h"
197
197
 
198
198
  static VALUE
199
199
  <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
@@ -178,11 +178,11 @@ BATCH_NORM_BACKWARD_ERROR:
178
178
  }
179
179
 
180
180
  #else // CUDNN_FOUND
181
- VALUE cumo_cuda_eCudnnError;
181
+ #include "cumo/cuda/cudnn.h"
182
182
 
183
183
  static VALUE
184
184
  <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
185
185
  {
186
- rb_raise(cumo_cuda_eCudnnError, "cuDNN is not available");
186
+ rb_raise(cumo_cuda_eCUDNNError, "cuDNN is not available");
187
187
  }
188
188
  #endif // CUDNN_FOUND
@@ -206,7 +206,7 @@ CONV_ERROR:
206
206
  }
207
207
 
208
208
  #else // CUDNN_FOUND
209
- VALUE cumo_cuda_eCUDNNError;
209
+ #include "cumo/cuda/cudnn.h"
210
210
 
211
211
  static VALUE
212
212
  <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
@@ -95,6 +95,7 @@ static VALUE
95
95
  CUMO_CUDA_CUDNN_CHECK_DIM_EQ(sizet_w_shape[0], ngy->shape[1]);
96
96
  CUMO_CUDA_CUDNN_CHECK_DIM_EQ(sizet_w_shape[1], nx->shape[1]);
97
97
 
98
+ #if !defined(NDEBUG)
98
99
  {
99
100
  // shape check of gy
100
101
  size_t *y_shape = ngy->shape;
@@ -105,6 +106,7 @@ static VALUE
105
106
  x_shape[i + 2], sizet_w_shape[i + 2], int_stride[i], int_pad[i]));
106
107
  }
107
108
  }
109
+ #endif
108
110
 
109
111
  x_cont = cumo_na_as_contiguous_array(x);
110
112
  gy_cont = cumo_na_as_contiguous_array(gy);
@@ -173,7 +175,7 @@ CONV_GRAD_W_ERROR:
173
175
  }
174
176
 
175
177
  #else // CUDNN_FOUND
176
- VALUE cumo_cuda_eCUDNNError;
178
+ #include "cumo/cuda/cudnn.h"
177
179
 
178
180
  static VALUE
179
181
  <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
@@ -234,7 +234,7 @@ CONV_TRANSPOSE_ERROR:
234
234
  }
235
235
 
236
236
  #else // CUDNN_FOUND
237
- VALUE cumo_cuda_eCUDNNError;
237
+ #include "cumo/cuda/cudnn.h"
238
238
 
239
239
  static VALUE
240
240
  <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
@@ -139,7 +139,7 @@ FIXED_BATCH_NORM_ERROR:
139
139
  }
140
140
 
141
141
  #else // CUDNN_FOUND
142
- VALUE cumo_cuda_eCUDNNError;
142
+ #include "cumo/cuda/cudnn.h"
143
143
 
144
144
  static VALUE
145
145
  <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
@@ -14,6 +14,7 @@
14
14
  rb_hash_aset(hCast, rb_cArray, cT);
15
15
  <% for x in upcast %>
16
16
  <%= x %><% end %>
17
+ rb_obj_freeze(hCast);
17
18
 
18
19
  <% @children.each do |m| %>
19
20
  <%= m.init_def %><% end %>
@@ -126,7 +126,7 @@ POOLING_BACKAWARD_ERROR:
126
126
  }
127
127
 
128
128
  #else // CUDNN_FOUND
129
- VALUE cumo_cuda_eCUDNNError;
129
+ #include "cumo/cuda/cudnn.h"
130
130
 
131
131
  static VALUE
132
132
  <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
@@ -126,7 +126,7 @@ POLLING_FORWARD_ERROR:
126
126
  }
127
127
 
128
128
  #else // CUDNN_FOUND
129
- VALUE cumo_cuda_eCUDNNError;
129
+ #include "cumo/cuda/cudnn.h"
130
130
 
131
131
  static VALUE
132
132
  <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
@@ -76,11 +76,7 @@
76
76
  (es) % sizeof(long) ? 2 : (es) == sizeof(long)? 0 : 1;
77
77
 
78
78
  static inline void
79
- swapfunc(a, b, n, swaptype)
80
- char *a,
81
- *b;
82
- size_t n;
83
- int swaptype;
79
+ swapfunc(char *a, char *b, size_t n, int swaptype)
84
80
  {
85
81
  if (swaptype <= 1)
86
82
  swapcode(long, a, b, n);
@@ -32,7 +32,7 @@ static VALUE
32
32
  {
33
33
  VALUE reduce;
34
34
  cumo_ndfunc_arg_in_t ain[2] = {{CUMO_OVERWRITE,0},{cumo_sym_reduce,0}};
35
- cumo_ndfunc_t ndf = {0, CUMO_STRIDE_LOOP|CUMO_NDF_FLAT_REDUCE, 2,0, ain,0};
35
+ cumo_ndfunc_t ndf = {0, CUMO_NDF_HAS_LOOP|CUMO_NDF_FLAT_REDUCE, 2,0, ain,0};
36
36
 
37
37
  if (!CUMO_TEST_INPLACE(self)) {
38
38
  self = cumo_na_copy(self);
@@ -25,10 +25,8 @@ static void
25
25
  CUMO_STORE_BIT_STEP(a3, p3, s3, idx3, x);
26
26
  }
27
27
  } else {
28
- o1 = p1 % CUMO_NB;
29
- o1 -= p3;
30
- o2 = p2 % CUMO_NB;
31
- o2 -= p3;
28
+ o1 = p1-p3;
29
+ o2 = p2-p3;
32
30
  l1 = CUMO_NB+o1;
33
31
  r1 = CUMO_NB-o1;
34
32
  l2 = CUMO_NB+o2;
@@ -58,23 +56,53 @@ static void
58
56
  }
59
57
  } else {
60
58
  for (; n>=CUMO_NB; n-=CUMO_NB) {
61
- x = *a1>>o1;
62
- if (o1<0) x |= *(a1-1)>>l1;
63
- if (o1>0) x |= *(a1+1)<<r1;
59
+ if (o1==0) {
60
+ x = *a1;
61
+ } else if (o1>0) {
62
+ x = *a1>>o1 | *(a1+1)<<r1;
63
+ } else {
64
+ x = *a1<<-o1 | *(a1-1)>>l1;
65
+ }
64
66
  a1++;
65
- y = *a2>>o2;
66
- if (o2<0) y |= *(a2-1)>>l2;
67
- if (o2>0) y |= *(a2+1)<<r2;
67
+ if (o2==0) {
68
+ y = *a2;
69
+ } else if (o2>0) {
70
+ y = *a2>>o2 | *(a2+1)<<r2;
71
+ } else {
72
+ y = *a2<<-o2 | *(a2-1)>>l2;
73
+ }
68
74
  a2++;
69
75
  x = m_<%=name%>(x,y);
70
76
  *(a3++) = x;
71
77
  }
72
78
  }
73
79
  if (n>0) {
74
- x = *a1>>o1;
75
- if (o1<0) x |= *(a1-1)>>l1;
76
- y = *a2>>o2;
77
- if (o2<0) y |= *(a2-1)>>l2;
80
+ if (o1==0) {
81
+ x = *a1;
82
+ } else if (o1>0) {
83
+ x = *a1>>o1;
84
+ if ((int)n>r1) {
85
+ x |= *(a1+1)<<r1;
86
+ }
87
+ } else {
88
+ x = *(a1-1)>>l1;
89
+ if ((int)n>-o1) {
90
+ x |= *a1<<-o1;
91
+ }
92
+ }
93
+ if (o2==0) {
94
+ y = *a2;
95
+ } else if (o2>0) {
96
+ y = *a2>>o2;
97
+ if ((int)n>r2) {
98
+ y |= *(a2+1)<<r2;
99
+ }
100
+ } else {
101
+ y = *(a2-1)>>l2;
102
+ if ((int)n>-o2) {
103
+ y |= *a2<<-o2;
104
+ }
105
+ }
78
106
  x = m_<%=name%>(x,y);
79
107
  *a3 = (x & CUMO_SLB(n)) | (*a3 & CUMO_BALL<<n);
80
108
  }
@@ -53,10 +53,15 @@ static VALUE
53
53
  return <%=c_func(-1)%>_cpu(argc, argv, self);
54
54
  } else {
55
55
  VALUE v, reduce;
56
+ cumo_narray_t *na;
56
57
  cumo_ndfunc_arg_in_t ain[3] = {{cT,0},{cumo_sym_reduce,0},{cumo_sym_init,0}};
57
58
  cumo_ndfunc_arg_out_t aout[1] = {{cumo_cUInt64,0}};
58
59
  cumo_ndfunc_t ndf = { <%=c_iter%>, CUMO_FULL_LOOP_NIP, 3, 1, ain, aout };
59
60
 
61
+ CumoGetNArray(self,na);
62
+ if (CUMO_NA_SIZE(na)==0) {
63
+ return INT2FIX(0);
64
+ }
60
65
  reduce = cumo_na_reduce_dimension(argc, argv, 1, &self, &ndf, 0);
61
66
  v = cumo_na_ndloop(&ndf, 3, self, reduce, INT2FIX(0));
62
67
  return v;
@@ -111,10 +111,15 @@ static VALUE
111
111
  <%=c_func(-1)%>(int argc, VALUE *argv, VALUE self)
112
112
  {
113
113
  VALUE v, reduce;
114
+ cumo_narray_t *na;
114
115
  cumo_ndfunc_arg_in_t ain[3] = {{cT,0},{cumo_sym_reduce,0},{cumo_sym_init,0}};
115
116
  cumo_ndfunc_arg_out_t aout[1] = {{cumo_cBit,0}};
116
117
  cumo_ndfunc_t ndf = {<%=c_iter%>, CUMO_FULL_LOOP_NIP, 3,1, ain,aout};
117
118
 
119
+ CumoGetNArray(self,na);
120
+ if (CUMO_NA_SIZE(na)==0) {
121
+ return INT2FIX(0);
122
+ }
118
123
  reduce = cumo_na_reduce_dimension(argc, argv, 1, &self, &ndf, 0);
119
124
  v = cumo_na_ndloop(&ndf, 3, self, reduce, INT2FIX(<%=init_bit%>));
120
125
  if (argc > 0) {
@@ -78,6 +78,10 @@ static void
78
78
  #define cIndex cumo_cInt32
79
79
  #endif
80
80
 
81
+ static void shape_error(void) {
82
+ rb_raise(cumo_na_eShapeError,"mask and masked arrays must have the same shape");
83
+ }
84
+
81
85
  /*
82
86
  Return subarray of argument masked with self bit array.
83
87
  @overload <%=op_map%>(array)
@@ -87,17 +91,33 @@ static void
87
91
  static VALUE
88
92
  <%=c_func(1)%>(VALUE mask, VALUE val)
89
93
  {
90
- volatile VALUE idx_1, view;
94
+ int i;
95
+ VALUE idx_1, view;
91
96
  cumo_narray_data_t *nidx;
92
- cumo_narray_view_t *nv;
93
- cumo_narray_t *na;
94
- cumo_narray_view_t *na1;
97
+ cumo_narray_view_t *nv, *nv_val;
98
+ cumo_narray_t *na, *na_mask;
95
99
  cumo_stridx_t stridx0;
96
100
  size_t n_1;
97
101
  where_opt_t g;
98
102
  cumo_ndfunc_arg_in_t ain[2] = {{cT,0},{Qnil,0}};
99
103
  cumo_ndfunc_t ndf = {<%=c_iter%>, CUMO_FULL_LOOP, 2, 0, ain, 0};
100
104
 
105
+ // cast val to NArray
106
+ if (!rb_obj_is_kind_of(val, cumo_cNArray)) {
107
+ val = rb_funcall(cumo_cNArray, cumo_id_cast, 1, val);
108
+ }
109
+ // shapes of mask and val must be same
110
+ CumoGetNArray(val, na);
111
+ CumoGetNArray(mask, na_mask);
112
+ if (na_mask->ndim != na->ndim) {
113
+ shape_error();
114
+ }
115
+ for (i=0; i<na->ndim; i++) {
116
+ if (na_mask->shape[i] != na->shape[i]) {
117
+ shape_error();
118
+ }
119
+ }
120
+
101
121
  // TODO(sonots): bit_count_true synchronizes with CPU. Avoid.
102
122
  n_1 = NUM2SIZET(<%=find_tmpl("count_true_cpu").c_func%>(0, NULL, mask));
103
123
  idx_1 = cumo_na_new(cIndex, 1, &n_1);
@@ -114,19 +134,19 @@ static VALUE
114
134
  CumoGetNArrayData(idx_1,nidx);
115
135
  CUMO_SDX_SET_INDEX(stridx0,(size_t*)nidx->ptr);
116
136
  nidx->ptr = NULL;
137
+ RB_GC_GUARD(idx_1);
117
138
 
118
139
  nv->stridx = ALLOC_N(cumo_stridx_t,1);
119
140
  nv->stridx[0] = stridx0;
120
141
  nv->offset = 0;
121
142
 
122
- CumoGetNArray(val, na);
123
143
  switch(CUMO_NA_TYPE(na)) {
124
144
  case CUMO_NARRAY_DATA_T:
125
145
  nv->data = val;
126
146
  break;
127
147
  case CUMO_NARRAY_VIEW_T:
128
- CumoGetNArrayView(val, na1);
129
- nv->data = na1->data;
148
+ CumoGetNArrayView(val, nv_val);
149
+ nv->data = nv_val->data;
130
150
  break;
131
151
  default:
132
152
  rb_raise(rb_eRuntimeError,"invalid CUMO_NA_TYPE: %d",CUMO_NA_TYPE(na));
@@ -22,8 +22,7 @@ static void
22
22
  CUMO_STORE_BIT_STEP(a3, p3, s3, idx3, x);
23
23
  }
24
24
  } else {
25
- o1 = p1 % CUMO_NB;
26
- o1 -= p3;
25
+ o1 = p1-p3;
27
26
  l1 = CUMO_NB+o1;
28
27
  r1 = CUMO_NB-o1;
29
28
  if (p3>0 || n<CUMO_NB) {
@@ -44,16 +43,31 @@ static void
44
43
  }
45
44
  } else {
46
45
  for (; n>=CUMO_NB; n-=CUMO_NB) {
47
- x = *a1>>o1;
48
- if (o1<0) x |= *(a1-1)>>l1;
49
- if (o1>0) x |= *(a1+1)<<r1;
46
+ if (o1==0) {
47
+ x = *a1;
48
+ } else if (o1>0) {
49
+ x = *a1>>o1 | *(a1+1)<<r1;
50
+ } else {
51
+ x = *a1<<-o1 | *(a1-1)>>l1;
52
+ }
50
53
  a1++;
51
54
  *(a3++) = x;
52
55
  }
53
56
  }
54
57
  if (n>0) {
55
- x = *a1>>o1;
56
- if (o1<0) x |= *(a1-1)>>l1;
58
+ if (o1==0) {
59
+ x = *a1;
60
+ } else if (o1>0) {
61
+ x = *a1>>o1;
62
+ if ((int)n>r1) {
63
+ x |= *(a1+1)<<r1;
64
+ }
65
+ } else {
66
+ x = *(a1-1)>>l1;
67
+ if ((int)n>-o1) {
68
+ x |= *a1<<-o1;
69
+ }
70
+ }
57
71
  *a3 = (x & CUMO_SLB(n)) | (*a3 & CUMO_BALL<<n);
58
72
  }
59
73
  }