numo-narray 0.9.1.4 → 0.9.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (54) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +10 -4
  3. data/ext/numo/narray/array.c +17 -7
  4. data/ext/numo/narray/data.c +39 -36
  5. data/ext/numo/narray/extconf.rb +1 -0
  6. data/ext/numo/narray/gen/narray_def.rb +4 -0
  7. data/ext/numo/narray/gen/spec.rb +5 -1
  8. data/ext/numo/narray/gen/tmpl/accum.c +2 -2
  9. data/ext/numo/narray/gen/tmpl/accum_arg.c +88 -0
  10. data/ext/numo/narray/gen/tmpl/accum_binary.c +1 -1
  11. data/ext/numo/narray/gen/tmpl/accum_index.c +25 -14
  12. data/ext/numo/narray/gen/tmpl/aref.c +5 -35
  13. data/ext/numo/narray/gen/tmpl/aset.c +7 -37
  14. data/ext/numo/narray/gen/tmpl/bincount.c +7 -7
  15. data/ext/numo/narray/gen/tmpl/clip.c +11 -15
  16. data/ext/numo/narray/gen/tmpl/cum.c +1 -1
  17. data/ext/numo/narray/gen/tmpl/each.c +4 -2
  18. data/ext/numo/narray/gen/tmpl/each_with_index.c +5 -2
  19. data/ext/numo/narray/gen/tmpl/lib.c +2 -2
  20. data/ext/numo/narray/gen/tmpl/logseq.c +6 -5
  21. data/ext/numo/narray/gen/tmpl/map_with_index.c +5 -6
  22. data/ext/numo/narray/gen/tmpl/median.c +2 -2
  23. data/ext/numo/narray/gen/tmpl/minmax.c +1 -1
  24. data/ext/numo/narray/gen/tmpl/poly.c +4 -4
  25. data/ext/numo/narray/gen/tmpl/qsort.c +1 -1
  26. data/ext/numo/narray/gen/tmpl/rand.c +8 -6
  27. data/ext/numo/narray/gen/tmpl/rand_norm.c +18 -16
  28. data/ext/numo/narray/gen/tmpl/seq.c +5 -4
  29. data/ext/numo/narray/gen/tmpl/sort.c +3 -3
  30. data/ext/numo/narray/gen/tmpl/sort_index.c +2 -2
  31. data/ext/numo/narray/gen/tmpl/store_array.c +14 -2
  32. data/ext/numo/narray/gen/tmpl/unary_s.c +55 -31
  33. data/ext/numo/narray/gen/tmpl_bit/aref.c +22 -30
  34. data/ext/numo/narray/gen/tmpl_bit/aset.c +20 -34
  35. data/ext/numo/narray/gen/tmpl_bit/binary.c +42 -14
  36. data/ext/numo/narray/gen/tmpl_bit/bit_count.c +5 -0
  37. data/ext/numo/narray/gen/tmpl_bit/bit_reduce.c +5 -0
  38. data/ext/numo/narray/gen/tmpl_bit/store_array.c +14 -2
  39. data/ext/numo/narray/gen/tmpl_bit/store_bit.c +21 -7
  40. data/ext/numo/narray/gen/tmpl_bit/unary.c +21 -7
  41. data/ext/numo/narray/index.c +369 -59
  42. data/ext/numo/narray/math.c +2 -2
  43. data/ext/numo/narray/narray.c +45 -27
  44. data/ext/numo/narray/ndloop.c +2 -2
  45. data/ext/numo/narray/numo/intern.h +3 -2
  46. data/ext/numo/narray/numo/narray.h +24 -5
  47. data/ext/numo/narray/numo/ndloop.h +2 -2
  48. data/ext/numo/narray/numo/template.h +4 -6
  49. data/ext/numo/narray/numo/types/complex.h +2 -2
  50. data/ext/numo/narray/step.c +58 -252
  51. data/ext/numo/narray/struct.c +2 -2
  52. data/lib/numo/narray/extra.rb +172 -212
  53. data/numo-narray.gemspec +9 -5
  54. metadata +18 -17
@@ -1,7 +1,7 @@
1
1
  /*
2
2
  index.c
3
- Numerical Array Extension for Ruby
4
- (C) Copyright 1999-2017 by Masahiro TANAKA
3
+ Ruby/Numo::NArray - Numerical Array class for Ruby
4
+ Copyright (C) 1999-2019 Masahiro TANAKA
5
5
  */
6
6
  //#define NARRAY_C
7
7
 
@@ -16,24 +16,6 @@
16
16
  #define cIndex numo_cInt32
17
17
  #endif
18
18
 
19
- // from ruby/enumerator.c
20
- struct enumerator {
21
- VALUE obj;
22
- ID meth;
23
- VALUE args;
24
- // use only above in this source
25
- VALUE fib;
26
- VALUE dst;
27
- VALUE lookahead;
28
- VALUE feedvalue;
29
- VALUE stop_exc;
30
- VALUE size;
31
- // incompatible below depending on ruby version
32
- //VALUE procs; // ruby 2.4
33
- //rb_enumerator_size_func *size_fn; // ruby 2.1-2.4
34
- //VALUE (*size_fn)(ANYARGS); // ruby 2.0
35
- };
36
-
37
19
  // note: the memory refed by this pointer is not freed and causes memroy leak.
38
20
  typedef struct {
39
21
  size_t n; // the number of elements of the dimesnion
@@ -80,6 +62,7 @@ static ID id_dup;
80
62
  static ID id_bracket;
81
63
  static ID id_shift_left;
82
64
  static ID id_mask;
65
+ static ID id_where;
83
66
 
84
67
 
85
68
  static void
@@ -142,25 +125,66 @@ na_parse_array(VALUE ary, int orig_dim, ssize_t size, na_index_arg_t *q)
142
125
  static void
143
126
  na_parse_narray_index(VALUE a, int orig_dim, ssize_t size, na_index_arg_t *q)
144
127
  {
145
- VALUE idx;
146
- narray_t *na;
147
- narray_data_t *nidx;
128
+ VALUE idx, cls;
129
+ narray_t *na, *nidx;
148
130
  size_t k, n;
149
- ssize_t *nidxp;
150
131
 
151
132
  GetNArray(a,na);
152
133
  if (NA_NDIM(na) != 1) {
153
134
  rb_raise(rb_eIndexError, "should be 1-d NArray");
154
135
  }
155
- n = NA_SIZE(na);
156
- idx = nary_new(cIndex,1,&n);
157
- na_store(idx,a);
158
-
159
- GetNArrayData(idx,nidx);
160
- nidxp = (ssize_t*)nidx->ptr;
161
- q->idx = ALLOC_N(size_t, n);
162
- for (k=0; k<n; k++) {
163
- q->idx[k] = na_range_check(nidxp[k], size, orig_dim);
136
+ cls = rb_obj_class(a);
137
+ if (cls==numo_cBit) {
138
+ if (NA_SIZE(na) != (size_t)size) {
139
+ rb_raise(rb_eIndexError, "Bit-NArray size mismatch");
140
+ }
141
+ idx = rb_funcall(a,id_where,0);
142
+ GetNArray(idx,nidx);
143
+ n = NA_SIZE(nidx);
144
+ q->idx = ALLOC_N(size_t, n);
145
+ if (na->type!=NARRAY_DATA_T) {
146
+ rb_bug("NArray#where returned wrong type of NArray");
147
+ }
148
+ if (rb_obj_class(idx)==numo_cInt32) {
149
+ int32_t *p = (int32_t*)NA_DATA_PTR(nidx);
150
+ for (k=0; k<n; k++) {
151
+ q->idx[k] = (size_t)p[k];
152
+ }
153
+ } else
154
+ if (rb_obj_class(idx)==numo_cInt64) {
155
+ int64_t *p = (int64_t*)NA_DATA_PTR(nidx);
156
+ for (k=0; k<n; k++) {
157
+ q->idx[k] = (size_t)p[k];
158
+ }
159
+ } else {
160
+ rb_bug("NArray#where should return Int32 or Int64");
161
+ }
162
+ RB_GC_GUARD(idx);
163
+ } else {
164
+ n = NA_SIZE(na);
165
+ q->idx = ALLOC_N(size_t, n);
166
+ if (cls==numo_cInt32 && na->type==NARRAY_DATA_T) {
167
+ int32_t *p = (int32_t*)NA_DATA_PTR(na);
168
+ for (k=0; k<n; k++) {
169
+ q->idx[k] = na_range_check(p[k], size, orig_dim);
170
+ }
171
+ } else
172
+ if (cls==numo_cInt64 && na->type==NARRAY_DATA_T) {
173
+ int64_t *p = (int64_t*)NA_DATA_PTR(na);
174
+ for (k=0; k<n; k++) {
175
+ q->idx[k] = na_range_check(p[k], size, orig_dim);
176
+ }
177
+ } else {
178
+ ssize_t *p;
179
+ idx = nary_new(cIndex,1,&n);
180
+ na_store(idx,a);
181
+ GetNArray(idx,nidx);
182
+ p = (ssize_t*)NA_DATA_PTR(nidx);
183
+ for (k=0; k<n; k++) {
184
+ q->idx[k] = na_range_check(p[k], size, orig_dim);
185
+ }
186
+ RB_GC_GUARD(idx);
187
+ }
164
188
  }
165
189
  q->n = n;
166
190
  q->beg = 0;
@@ -173,10 +197,46 @@ static void
173
197
  na_parse_range(VALUE range, ssize_t step, int orig_dim, ssize_t size, na_index_arg_t *q)
174
198
  {
175
199
  int n;
176
- VALUE excl_end;
177
200
  ssize_t beg, end, beg_orig, end_orig;
178
201
  const char *dot = "..", *edot = "...";
179
202
 
203
+ #ifdef HAVE_RB_ARITHMETIC_SEQUENCE_EXTRACT
204
+ rb_arithmetic_sequence_components_t x;
205
+ rb_arithmetic_sequence_extract(range, &x);
206
+ step = NUM2SSIZET(x.step);
207
+
208
+ beg = beg_orig = NUM2SSIZET(x.begin);
209
+ if (beg < 0) {
210
+ beg += size;
211
+ }
212
+ if (T_NIL == TYPE(x.end)) { // endless range
213
+ end = size - 1;
214
+ if (RTEST(x.exclude_end)) {
215
+ dot = edot;
216
+ }
217
+ if (beg < 0 || beg >= size) {
218
+ rb_raise(rb_eRangeError,
219
+ "%"SZF"d%s is out of range for size=%"SZF"d",
220
+ beg_orig, dot, size);
221
+ }
222
+ } else {
223
+ end = end_orig = NUM2SSIZET(x.end);
224
+ if (end < 0) {
225
+ end += size;
226
+ }
227
+ if (RTEST(x.exclude_end)) {
228
+ end--;
229
+ dot = edot;
230
+ }
231
+ if (beg < 0 || beg >= size || end < 0 || end >= size) {
232
+ rb_raise(rb_eRangeError,
233
+ "%"SZF"d%s%"SZF"d is out of range for size=%"SZF"d",
234
+ beg_orig, dot, end_orig, size);
235
+ }
236
+ }
237
+ #else
238
+ VALUE excl_end;
239
+
180
240
  beg = beg_orig = NUM2SSIZET(rb_funcall(range,id_beg,0));
181
241
  if (beg < 0) {
182
242
  beg += size;
@@ -195,17 +255,18 @@ na_parse_range(VALUE range, ssize_t step, int orig_dim, ssize_t size, na_index_a
195
255
  "%"SZF"d%s%"SZF"d is out of range for size=%"SZF"d",
196
256
  beg_orig, dot, end_orig, size);
197
257
  }
258
+ #endif
198
259
  n = (end-beg)/step+1;
199
260
  if (n<0) n=0;
200
261
  na_index_set_step(q,orig_dim,n,beg,step);
201
262
 
202
263
  }
203
264
 
204
- static void
205
- na_parse_enumerator(VALUE enum_obj, int orig_dim, ssize_t size, na_index_arg_t *q)
265
+ void
266
+ na_parse_enumerator_step(VALUE enum_obj, VALUE *pstep )
206
267
  {
207
268
  int len;
208
- ssize_t step;
269
+ VALUE step;
209
270
  struct enumerator *e;
210
271
 
211
272
  if (!RB_TYPE_P(enum_obj, T_DATA)) {
@@ -213,26 +274,40 @@ na_parse_enumerator(VALUE enum_obj, int orig_dim, ssize_t size, na_index_arg_t *
213
274
  }
214
275
  e = (struct enumerator *)DATA_PTR(enum_obj);
215
276
 
216
- if (rb_obj_is_kind_of(e->obj, rb_cRange)) {
217
- if (e->meth == id_each) {
218
- na_parse_range(e->obj, 1, orig_dim, size, q);
277
+ if (!rb_obj_is_kind_of(e->obj, rb_cRange)) {
278
+ rb_raise(rb_eTypeError,"not Range object");
279
+ }
280
+
281
+ if (e->meth == id_each) {
282
+ step = INT2NUM(1);
283
+ }
284
+ else if (e->meth == id_step) {
285
+ if (TYPE(e->args) != T_ARRAY) {
286
+ rb_raise(rb_eArgError,"no argument for step");
219
287
  }
220
- else if (e->meth == id_step) {
221
- if (TYPE(e->args) != T_ARRAY) {
222
- rb_raise(rb_eArgError,"no argument for step");
223
- }
224
- len = RARRAY_LEN(e->args);
225
- if (len != 1) {
226
- rb_raise(rb_eArgError,"invalid number of step argument (1 for %d)",len);
227
- }
228
- step = NUM2SSIZET(RARRAY_AREF(e->args,0));
229
- na_parse_range(e->obj, step, orig_dim, size, q);
230
- } else {
231
- rb_raise(rb_eTypeError,"unknown Range method: %s",rb_id2name(e->meth));
288
+ len = RARRAY_LEN(e->args);
289
+ if (len != 1) {
290
+ rb_raise(rb_eArgError,"invalid number of step argument (1 for %d)",len);
232
291
  }
292
+ step = RARRAY_AREF(e->args,0);
233
293
  } else {
234
- rb_raise(rb_eTypeError,"not Range object");
294
+ rb_raise(rb_eTypeError,"unknown Range method: %s",rb_id2name(e->meth));
295
+ }
296
+ if (pstep) *pstep = step;
297
+ }
298
+
299
+ static void
300
+ na_parse_enumerator(VALUE enum_obj, int orig_dim, ssize_t size, na_index_arg_t *q)
301
+ {
302
+ VALUE step;
303
+ struct enumerator *e;
304
+
305
+ if (!RB_TYPE_P(enum_obj, T_DATA)) {
306
+ rb_raise(rb_eTypeError,"wrong argument type (not T_DATA)");
235
307
  }
308
+ na_parse_enumerator_step(enum_obj, &step);
309
+ e = (struct enumerator *)DATA_PTR(enum_obj);
310
+ na_parse_range(e->obj, NUM2SSIZET(step), orig_dim, size, q); // e->obj : Range Object
236
311
  }
237
312
 
238
313
  // Analyze *a* which is *i*-th index object and store the information to q
@@ -289,14 +364,15 @@ na_index_parse_each(volatile VALUE a, ssize_t size, int i, na_index_arg_t *q)
289
364
  if (rb_obj_is_kind_of(a, rb_cRange)) {
290
365
  na_parse_range(a, 1, i, size, q);
291
366
  }
367
+ #ifdef HAVE_RB_ARITHMETIC_SEQUENCE_EXTRACT
368
+ else if (rb_obj_is_kind_of(a, rb_cArithSeq)) {
369
+ //na_parse_arith_seq(a, i, size, q);
370
+ na_parse_range(a, 1, i, size, q);
371
+ }
372
+ #endif
292
373
  else if (rb_obj_is_kind_of(a, rb_cEnumerator)) {
293
374
  na_parse_enumerator(a, i, size, q);
294
375
  }
295
- else if (rb_obj_is_kind_of(a, na_cStep)) {
296
- ssize_t beg, step, n;
297
- nary_step_array_index(a, size, (size_t*)(&n), &beg, &step);
298
- na_index_set_step(q,i,n,beg,step);
299
- }
300
376
  // NArray index
301
377
  else if (NA_IsNArray(a)) {
302
378
  na_parse_narray_index(a, i, size, q);
@@ -308,6 +384,102 @@ na_index_parse_each(volatile VALUE a, ssize_t size, int i, na_index_arg_t *q)
308
384
  }
309
385
 
310
386
 
387
+ static void
388
+ na_at_parse_each(volatile VALUE a, ssize_t size, int i, VALUE *idx, ssize_t stride)
389
+ {
390
+ na_index_arg_t q;
391
+ size_t n, k;
392
+ ssize_t *index;
393
+
394
+ // NArray index
395
+ if (NA_IsNArray(a)) {
396
+ VALUE a2;
397
+ narray_t *na, *na2;
398
+ ssize_t *p2;
399
+ GetNArray(a,na);
400
+ if (NA_NDIM(na) != 1) {
401
+ rb_raise(rb_eIndexError, "should be 1-d NArray");
402
+ }
403
+ n = NA_SIZE(na);
404
+ a2 = nary_new(cIndex,1,&n);
405
+ na_store(a2,a);
406
+ GetNArray(a2,na2);
407
+ p2 = (ssize_t*)NA_DATA_PTR(na2);
408
+ if (*idx == Qnil) {
409
+ *idx = a2;
410
+ for (k=0; k<n; k++) {
411
+ na_range_check(p2[k],size,i);
412
+ }
413
+ } else {
414
+ narray_t *nidx;
415
+ GetNArray(*idx,nidx);
416
+ index = (ssize_t*)NA_DATA_PTR(nidx);
417
+ if (NA_SIZE(nidx) != n) {
418
+ rb_raise(nary_eShapeError, "index array sizes mismatch");
419
+ }
420
+ for (k=0; k<n; k++) {
421
+ index[k] += na_range_check(p2[k],size,i) * stride;
422
+ }
423
+ }
424
+ RB_GC_GUARD(a2);
425
+ return;
426
+ }
427
+ else if (TYPE(a) == T_ARRAY) {
428
+ n = RARRAY_LEN(a);
429
+ if (*idx == Qnil) {
430
+ *idx = nary_new(cIndex,1,&n);
431
+ index = (ssize_t*)na_get_pointer_for_write(*idx); // allocate memory
432
+ for (k=0; k<n; k++) {
433
+ index[k] = na_range_check(NUM2SSIZET(RARRAY_AREF(a,k)),size,i);
434
+ }
435
+ } else {
436
+ narray_t *nidx;
437
+ GetNArray(*idx,nidx);
438
+ index = (ssize_t*)NA_DATA_PTR(nidx);
439
+ if (NA_SIZE(nidx) != n) {
440
+ rb_raise(nary_eShapeError, "index array sizes mismatch");
441
+ }
442
+ for (k=0; k<n; k++) {
443
+ index[k] += na_range_check(NUM2SSIZET(RARRAY_AREF(a,k)),size,i) * stride;
444
+ }
445
+ }
446
+ return;
447
+ }
448
+ else if (rb_obj_is_kind_of(a, rb_cRange)) {
449
+ na_parse_range(a, 1, i, size, &q);
450
+ }
451
+ #ifdef HAVE_RB_ARITHMETIC_SEQUENCE_EXTRACT
452
+ else if (rb_obj_is_kind_of(a, rb_cArithSeq)) {
453
+ na_parse_range(a, 1, i, size, &q);
454
+ }
455
+ #endif
456
+ else if (rb_obj_is_kind_of(a, rb_cEnumerator)) {
457
+ na_parse_enumerator(a, i, size, &q);
458
+ }
459
+ else {
460
+ rb_raise(rb_eIndexError, "not allowed type");
461
+ }
462
+
463
+ if (*idx == Qnil) {
464
+ *idx = nary_new(cIndex,1,&q.n);
465
+ index = (ssize_t*)na_get_pointer_for_write(*idx); // allocate memory
466
+ for (k=0; k<q.n; k++) {
467
+ index[k] = q.beg + q.step*k;
468
+ }
469
+ } else {
470
+ narray_t *nidx;
471
+ GetNArray(*idx,nidx);
472
+ index = (ssize_t*)NA_DATA_PTR(nidx);
473
+ if (NA_SIZE(nidx) != q.n) {
474
+ rb_raise(nary_eShapeError, "index array sizes mismatch");
475
+ }
476
+ for (k=0; k<q.n; k++) {
477
+ index[k] += (q.beg + q.step*k) * stride;
478
+ }
479
+ }
480
+ }
481
+
482
+
311
483
  static size_t
312
484
  na_index_parse_args(VALUE args, narray_t *na, na_index_arg_t *q, int ndim)
313
485
  {
@@ -574,7 +746,7 @@ VALUE na_aref_md_protected(VALUE data_value)
574
746
 
575
747
  na_alloc_shape((narray_t*)na2, ndim_new);
576
748
 
577
- na2->stridx = ALLOC_N(stridx_t,ndim_new);
749
+ na2->stridx = ZALLOC_N(stridx_t,ndim_new);
578
750
 
579
751
  elmsz = nary_element_stride(self);
580
752
 
@@ -842,11 +1014,148 @@ static VALUE na_slice(int argc, VALUE *argv, VALUE self)
842
1014
  return na_aref_main(argc, argv, self, 1, nd);
843
1015
  }
844
1016
 
1017
+ /*
1018
+ Multi-dimensional element reference.
1019
+ Returns an element at `dim0`, `dim1`, ... are Numeric indices for each dimension, or returns a NArray View as a sliced array if `dim0`, `dim1`, ... includes other than Numeric index, e.g., Range or Array or true.
1020
+ @overload [](dim0,...,dimL)
1021
+ @param [Numeric,Range,Array,Numo::Int32,Numo::Int64,Numo::Bit,TrueClass,FalseClass,Symbol] dim0,...,dimL multi-dimensional indices.
1022
+ @return [Numeric,Numo::NArray] an element or NArray view.
1023
+ @see #[]=
1024
+ @see #at
1025
+
1026
+ @example
1027
+ a = Numo::DFloat.new(4,5).seq
1028
+ # => Numo::DFloat#shape=[4,5]
1029
+ # [[0, 1, 2, 3, 4],
1030
+ # [5, 6, 7, 8, 9],
1031
+ # [10, 11, 12, 13, 14],
1032
+ # [15, 16, 17, 18, 19]]
1033
+
1034
+ a[1,1]
1035
+ # => 6.0
1036
+
1037
+ a[1..3,1]
1038
+ # => Numo::DFloat#shape=[3]
1039
+ # [6, 11, 16]
1040
+
1041
+ a[1,[1,3,4]]
1042
+ # => Numo::DFloat#shape=[3]
1043
+ # [6, 8, 9]
1044
+
1045
+ a[true,2].fill(99)
1046
+ a
1047
+ # => Numo::DFloat#shape=[4,5]
1048
+ # [[0, 1, 99, 3, 4],
1049
+ # [5, 6, 99, 8, 9],
1050
+ # [10, 11, 99, 13, 14],
1051
+ # [15, 16, 99, 18, 19]]
1052
+ */
1053
+ static VALUE na_aref(int argc, VALUE *argv, VALUE self)
1054
+ {
1055
+ // implemented in subclasses
1056
+ return rb_f_notimplement(argc,argv,self);
1057
+ }
1058
+
1059
+ /*
1060
+ Multi-dimensional element assignment.
1061
+ Replace element(s) at `dim0`, `dim1`, ... .
1062
+ Broadcasting mechanism is applied.
1063
+ @overload []=(dim0,...,dimL,val)
1064
+ @param [Numeric,Range,Array,Numo::Int32,Numo::Int64,Numo::Bit,TrueClass,FalseClass,Symbol] dim0,...,dimL multi-dimensional indices.
1065
+ @param [Numeric,Numo::NArray,Array] val Value(s) to be set to self.
1066
+ @return [Numeric,Numo::NArray,Array] returns `val` (last argument).
1067
+ @see #[]
1068
+ @example
1069
+ a = Numo::DFloat.new(3,4).seq
1070
+ # => Numo::DFloat#shape=[3,4]
1071
+ # [[0, 1, 2, 3],
1072
+ # [4, 5, 6, 7],
1073
+ # [8, 9, 10, 11]]
1074
+
1075
+ a[1,2]=99
1076
+ a
1077
+ # => Numo::DFloat#shape=[3,4]
1078
+ # [[0, 1, 2, 3],
1079
+ # [4, 5, 99, 7],
1080
+ # [8, 9, 10, 11]]
1081
+
1082
+ a[1,[0,2]] = [101,102]
1083
+ a
1084
+ # => Numo::DFloat#shape=[3,4]
1085
+ # [[0, 1, 2, 3],
1086
+ # [101, 5, 102, 7],
1087
+ # [8, 9, 10, 11]]
1088
+
1089
+ a[1,true]=99
1090
+ a
1091
+ # => Numo::DFloat#shape=[3,4]
1092
+ # [[0, 1, 2, 3],
1093
+ # [99, 99, 99, 99],
1094
+ # [8, 9, 10, 11]]
1095
+
1096
+ */
1097
+ static VALUE na_aset(int argc, VALUE *argv, VALUE self)
1098
+ {
1099
+ // implemented in subclasses
1100
+ return rb_f_notimplement(argc,argv,self);
1101
+ }
1102
+
1103
+ /*
1104
+ Multi-dimensional array indexing.
1105
+ Similar to numpy's tuple indexing, i.e., `a[[1,2,..],[3,4,..]]`
1106
+ Same as Numo::NArray#[] for one-dimensional NArray.
1107
+ @overload at(dim0,...,dimL)
1108
+ @param [Range,Array,Numo::Int32,Numo::Int64] dim0,...,dimL multi-dimensional index arrays.
1109
+ @return [Numo::NArray] one-dimensional NArray view.
1110
+ @see #[]
1111
+
1112
+ @example
1113
+ x = Numo::DFloat.new(3,3,3).seq
1114
+ # => Numo::DFloat#shape=[3,3,3]
1115
+ # [[[0, 1, 2],
1116
+ # [3, 4, 5],
1117
+ # [6, 7, 8]],
1118
+ # [[9, 10, 11],
1119
+ # [12, 13, 14],
1120
+ # [15, 16, 17]],
1121
+ # [[18, 19, 20],
1122
+ # [21, 22, 23],
1123
+ # [24, 25, 26]]]
1124
+
1125
+ x.at([0,1,2],[0,1,2],[-1,-2,-3])
1126
+ # => Numo::DFloat(view)#shape=[3]
1127
+ # [2, 13, 24]
1128
+ */
1129
+ static VALUE na_at(int argc, VALUE *argv, VALUE self)
1130
+ {
1131
+ int i;
1132
+ size_t n;
1133
+ ssize_t stride=1;
1134
+ narray_t *na;
1135
+ VALUE idx=Qnil;
1136
+
1137
+ na_index_arg_to_internal_order(argc, argv, self);
1138
+
1139
+ GetNArray(self,na);
1140
+ if (NA_NDIM(na) != argc) {
1141
+ rb_raise(rb_eArgError,"the number of argument must be same as dimension");
1142
+ }
1143
+ for (i=argc; i>0; ) {
1144
+ i--;
1145
+ n = NA_SHAPE(na)[i];
1146
+ na_at_parse_each(argv[i], n, i, &idx, stride);
1147
+ stride *= n;
1148
+ }
1149
+ return na_aref_main(1, &idx, self, 1, 1);
1150
+ }
845
1151
 
846
1152
  void
847
1153
  Init_nary_index()
848
1154
  {
849
1155
  rb_define_method(cNArray, "slice", na_slice, -1);
1156
+ rb_define_method(cNArray, "[]", na_aref, -1);
1157
+ rb_define_method(cNArray, "[]=", na_aset, -1);
1158
+ rb_define_method(cNArray, "at", na_at, -1);
850
1159
 
851
1160
  sym_ast = ID2SYM(rb_intern("*"));
852
1161
  sym_all = ID2SYM(rb_intern("all"));
@@ -867,4 +1176,5 @@ Init_nary_index()
867
1176
  id_bracket = rb_intern("[]");
868
1177
  id_shift_left = rb_intern("<<");
869
1178
  id_mask = rb_intern("mask");
1179
+ id_where = rb_intern("where");
870
1180
  }